PyTorch を使用したノイズ除去拡散モデルの実装

PyTorch を使用したノイズ除去拡散モデルの実装

ノイズ除去拡散確率モデル (DDPM) の仕組みを詳しく検討する前に、生成 AI の進歩、具体的には DDPM に関する基礎研究の一部を見てみましょう。

VA

VAE は、エンコーダー、確率的潜在空間、およびデコーダーを使用します。トレーニング中、エンコーダーは各画像の平均と分散を予測します。これらの値はガウス分布からサンプリングされ、デコーダーに渡されます。デコーダーでは、入力画像が出力画像と類似していることが予想されます。このプロセスでは、KL ダイバージェンスを使用して損失を計算します。 VAE の大きな利点は、多種多様な画像を生成できることです。デコーダーは、サンプリングフェーズ中にガウス分布から単純にサンプリングすることによって新しい画像を作成します。

ガン

変分オートエンコーダ (VAE) の登場からわずか 1 年後、画期的な生成モデル ファミリである敵対的生成ネットワーク (GAN) が登場しました。これは、敵対的トレーニング プロセスを伴う、ジェネレータとディスクリミネーターの 2 つのニューラル ネットワークの連携を特徴とする新しいクラスの生成モデルの始まりを示しています。ジェネレーターの目的は、ランダム ノイズから画像などの現実的なデータを生成することです。一方、ディスクリミネーターは、実際のデータと生成されたデータを区別しようとします。トレーニング フェーズ全体を通じて、ジェネレーターとディスクリミネーターは、競合学習プロセスを通じて継続的に機能を向上させます。ジェネレーターはますます説得力のあるデータを生成し、それによってディスクリミネーターを凌駕し、実際のサンプルと生成されたサンプルを区別する能力が向上します。この敵対的な相互作用は、ジェネレーターが高品質で現実的なデータを生成するときに最高潮に達します。サンプリング段階では、GAN トレーニング後に、ジェネレーターがランダムノイズを入力して新しいサンプルを生成します。このノイズを、一般的に実際の例を反映するデータに変換します。

なぜ別のモデル アーキテクチャが必要なのでしょうか?

どちらのモデルにも異なる問題があり、GAN はトレーニング セット内の画像に非常によく似たリアルな画像を生成するのに優れているのに対し、VAE はぼやけた画像を生成する傾向があるものの、多種多様な画像を作成することに優れています。しかし、既存のモデルでは、これら 2 つの機能をうまく組み合わせることができず、非常にリアルで多様性のある画像を作成できていません。この課題は研究者にとって取り組むべき大きなハードルとなります。

最初の GAN 論文から 6 年後、VAE 論文から 7 年後、画期的なモデルであるノイズ除去拡散確率モデル (DDPM) が登場しました。 DDPM は両方の長所を組み合わせ、多様でリアルな画像の作成に優れています。

この記事では、DDPM の複雑な部分を詳しく調べ、順方向と逆方向の両方のトレーニング プロセスをカバーし、サンプリングの実行方法を探ります。この調査を通して、PyTorch を使用して DDPM をゼロから構築し、完全なトレーニングを完了します。

すでにディープラーニングの基礎に精通しており、ディープラーニング コンピューター ビジョンの強固な基礎があることを前提としています。私たちはこれらの基本的な概念を紹介しません。私たちの目標は、人間が本物であると確信できる画像を生成することです。

DDPM

ノイズ除去拡散確率モデル (DDPM) は、生成モデルの分野における最先端の方法です。明示的な尤度関数に依存する従来のモデルとは異なり、DDPM は拡散プロセスを反復的にノイズ除去することによって動作します。これには、画像に徐々にノイズを追加し、そのノイズを除去しようとする作業が含まれます。基本理論は、ガウス分布などの単純な分布を一連の拡散ステップで変換することで、複雑で表現力豊かな画像データの分布が得られるという考えに基づいています。つまり、元の画像分布のサンプルをガウス分布に転送することで、プロセスを逆転させるモデルを作成できます。これにより、完全なガウス分布から始めて画像分布で終了することで、新しい画像を効果的に生成できます。

DDPM のトレーニングは、固定され学習不可能なノイズ画像を生成する順方向プロセスと、その後の逆プロセスという 2 つの基本的なステップで構成されます。逆プロセスの主な目的は、特殊な機械学習モデルを使用して画像のノイズを除去することです。

順方向拡散プロセス

前進プロセスは固定されており、学習不可能なステップですが、いくつかの事前定義された設定が必要です。設定に入る前に、まずそれがどのように機能するかを理解しましょう。

このプロセスの中心的な概念は、明確なイメージから始めることです。 「T」で表される特定のステップ サイズでは、ガウス分布に従って少量のノイズが徐々に導入されます。

画像からわかるように、ノイズはステップごとに増加しています。このノイズの数学的表現を詳しく見てみましょう。

ノイズはガウス分布からサンプリングされます。各ステップで少量のノイズを導入するために、マルコフ連鎖を使用します。現在のタイムスタンプで画像を生成するには、最後のタイムスタンプでの画像のみが必要です。ここで鍵となるのはマルコフ連鎖の概念であり、その後に続く数学的詳細にとって極めて重要です。

マルコフ連鎖は、特定の状態に移行する確率が、以前の一連のイベントではなく、現在の状態と経過した時間のみに依存する確率過程です。この機能により、ノイズ追加プロセスのモデリングが簡素化され、数学的な分析が容易になります。

ベータで示される分散パラメータは、各ステップで最小限のノイズのみを導入するために意図的に非常に小さい値に設定されています。

ステップ サイズ パラメータ「T」は、完全なノイズ イメージを生成するために必要なステップ サイズを決定します。この記事では、このパラメータは 1000 に設定されていますが、これは大きいと思われるかもしれません。データセット内の元の画像ごとに、ノイズの多い画像を 1000 枚作成することが本当に必要でしょうか? マルコフ連鎖の側面は、この問題の解決に役立ちます。次のステップを予測するには前のステップの画像のみが必要であり、各ステップで追加されるノイズは一定であるため、特定のタイムスタンプでノイズのある画像を生成することで計算を簡素化できます。再パラメータ化技術を使用すると、方程式をさらに簡素化できます。

式(3)で導入された新しいパラメータを式(2)に組み込み、式(2)を展開して結果を得る。

逆拡散プロセス

画像にノイズを導入したので、次のステップでは逆の操作を実行します。初期条件、つまり t = 0 でのノイズ除去されていない画像がわからない限り、画像を逆処理してノイズを除去することは数学的に不可能です。私たちの目標は、結果に関する情報が不足しているノイズから直接サンプリングして新しい画像を作成することです。したがって、結果を知らなくても段階的に画像のノイズを除去する方法を考案する必要があります。そこで、ディープラーニング モデルを使用してこの複雑な数学関数を近似するという解決策が生まれました。

少し数学的な背景があれば、モデルは式(5)に近似します。注目すべき点の 1 つは、モデルに分散を学習させることも可能ですが、オリジナルの DDPM 論文に固執し、分散を固定したままにすることです。

モデルのタスクは、現在のタイムスタンプと前のタイムスタンプの間に追加されたノイズの平均値を予測することです。そうすることで、ノイズを効果的に除去し、望ましい効果を得ることができます。しかし、モデルに「元の画像」から最後のタイムスタンプまでに追加されるノイズを予測させることが目標だとしたらどうでしょうか?

ノイズのない初期画像がわからない限り、逆のプロセスを数学的に実行するのは困難です。事後分散を定義することから始めましょう。

モデルのタスクは、初期画像からタイムスタンプ t で画像に追加されるノイズを予測することです。フォワードプロセスを使用すると、クリーンな画像から始めて、タイムスタンプ t のノイズの多い画像に進むという操作を実行できます。

トレーニングアルゴリズム

予測を行うために使用されるモデル アーキテクチャは U-Net であると想定します。トレーニングフェーズの目標は、データセット内の各画像に対して [0,T] の範囲のタイムスタンプをランダムに選択し、順方向拡散プロセスを計算することです。これにより、シャープで多少ノイズのある画像が生成されますが、このノイズは実際には役立ちます。次に、このモデルを使用して、逆プロセスに関する理解を活用し、画像に追加されるノイズを予測します。実際のノイズと予測されたノイズにより、教師あり機械学習の問題に入ったようです。

主な問題は、モデルをトレーニングするためにどの損失関数を使用すべきかということです。確率的潜在空間を扱っているため、Kullback-Leibler (KL) ダイバージェンスが適切な選択です。

KL ダイバージェンスは、2 つの確率分布 (この場合は、モデルによって予測された分布と期待される分布) の差を測定します。損失関数に KL ダイバージェンスを含めると、モデルが正確な予測を生成するように誘導されるだけでなく、潜在空間表現が目的の確率構造に準拠していることも保証されます。

KL ダイバージェンスは L2 損失関数として近似できるため、次の損失関数が得られます。

ついに、論文で提案されたトレーニング アルゴリズムが完成しました。

サンプリング

逆のプロセスについては説明しましたが、次にその使用方法を説明します。 T 時刻の完全にランダムな画像から開始し、逆のプロセスを T 回実行すると、最終的に時刻 0 に到達します。これは、この論文で概説した2番目のアルゴリズムを形成する。

パラメータ

beta、beta_tildes、alpha、alpha_hat など、さまざまなパラメーターがあります。これらのパラメータをどのように選択するかは不明です。しかし、この時点でわかっている唯一のパラメータは 1000 に設定されている T です。

リストされているすべてのパラメータの選択はベータによって異なります。ある意味では、ベータは各ステップで追加するノイズの量を決定します。したがって、アルゴリズムの成功を確実にするためには、ベータを慎重に選択することが重要です。他にもパラメータが多すぎるので、論文を参照してください。

元の論文の実験段階では、さまざまなサンプリング手法が検討されました。元の線形サンプリング方式の画像では、ノイズが不十分であったり、ノイズが多すぎたりしていました。この問題を解決するために、より一般的に使用される別の方法であるコサインサンプリングが採用されています。コサイン サンプリングにより、よりスムーズで一貫性のあるノイズの追加が可能になります。

モデル Pytorch 実装

ノイズ予測には U-Net アーキテクチャを利用します。U-Net を選択したのは、画像処理、空間マップと特徴マップのキャプチャ、入力と同じサイズの出力の提供に最適なアーキテクチャだからです。

タスクの複雑さと、各ステップで同じモデルを使用する必要があること(モデルは、完全にノイズの多い画像とわずかにノイズの多い画像の両方を同じ重みでノイズ除去できる必要がある)を考えると、モデルの調整が不可欠です。これには、より複雑なブロックを組み込み、正弦波埋め込みステップを介して使用されるタイムスタンプの認識を導入することが含まれます。これらの機能強化の目的は、ノイズ除去タスクにおいてモデルをエキスパートにすることです。完全なモデルの構築に進む前に、各ブロックを紹介します。

ConvNextブロック

モデルの複雑さが増すニーズを満たすために、畳み込みブロックが重要な役割を果たします。 u-net 論文の基本ブロックだけに頼るのではなく、ConvNext を組み合わせます。

入力は、画像を表す「x」と、サイズ「time_embedding_dim」の埋め込みのタイムスタンプ視覚化である「t」で構成されます。ブロックの複雑さと、入力および最後のレイヤーとの残余接続により、ブロックはプロセス全体における空間マッピングと特徴マッピングの学習において重要な役割を果たします。

 class ConvNextBlock(nn.Module): def __init__( self, in_channels, out_channels, mult=2, time_embedding_dim=None, norm=True, group=8, ): super().__init__() self.mlp = ( nn.Sequential(nn.GELU(), nn.Linear(time_embedding_dim, in_channels)) if time_embedding_dim else None ) self.in_conv = nn.Conv2d( in_channels, in_channels, 7, padding=3, groups=in_channels ) self.block = nn.Sequential( nn.GroupNorm(1, in_channels) if norm else nn.Identity(), nn.Conv2d(in_channels, out_channels * mult, 3, padding=1), nn.GELU(), nn.GroupNorm(1, out_channels * mult), nn.Conv2d(out_channels * mult, out_channels, 3, padding=1), ) self.residual_conv = ( nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() ) def forward(self, x, time_embedding=None): h = self.in_conv(x) if self.mlp is not None and time_embedding is not None: assert self.mlp is not None, "MLP is None" h = h + rearrange(self.mlp(time_embedding), "bc -> bc 1 1") h = self.block(h) return h + self.residual_conv(x)

正弦波タイムスタンプの埋め込み

モデル内の重要なブロックの 1 つは、正弦波タイムスタンプ埋め込みブロックです。これにより、特定のタイムスタンプのエンコードで、モデルによるデコードに必要な現在の時刻に関する情報を保持できるようになります。この情報は、すべての異なるタイムスタンプに使用されます。

これは非常に古典的な実装であり、さまざまな場所に適用されています。コードを直接貼り付けます

class SinusoidalPosEmb(nn.Module): def __init__(self, dim, theta=10000): super().__init__() self.dim = dim self.theta = theta def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(self.theta) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb

ダウンサンプルとアップサンプル

class DownSample(nn.Module): def __init__(self, dim, dim_out=None): super().__init__() self.net = nn.Sequential( Rearrange("bc (h p1) (w p2) -> b (c p1 p2) hw", p1=2, p2=2), nn.Conv2d(dim * 4, default(dim_out, dim), 1), ) def forward(self, x): return self.net(x) class Upsample(nn.Module): def __init__(self, dim, dim_out=None): super().__init__() self.net = nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(dim, dim_out or dim, kernel_size=3, padding=1), ) def forward(self, x): return self.net(x)

時間的多層パーセプトロン

このモジュールは、指定されたタイムスタンプ t に基づいて時間表現を作成するためにこれを使用します。この多層パーセプトロン (MLP) の出力は、変更されたすべての ConvNext ブロックへの入力「t」としても機能します。

ここで、「dim」はモデルのハイパーパラメータであり、最初のブロックに必要なチャネル数を表します。これは、後続のブロックのチャネル数の基本計算として機能します。

 sinu_pos_emb = SinusoidalPosEmb(dim, theta=10000) time_dim = dim * 4 time_mlp = nn.Sequential( sinu_pos_emb, nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim), )

注意

これは、unet で使用されるオプションのコンポーネントです。注意を払うことは、学習における残りのつながりを強化するのに役立ちます。これは、残差接続によって計算される注意メカニズムと、中潜在空間と低潜在空間で計算される特徴マップを通じて、Unet の左側から取得される重要な空間情報に重点を置いています。これは ACC-UNet の論文から派生したものです。

ゲートは次のブロックのアップサンプリングされた出力を表し、xresidual は注意が適用されるレベルでの残差接続を表します。

 class BlockAttention(nn.Module): def __init__(self, gate_in_channel, residual_in_channel, scale_factor): super().__init__() self.gate_conv = nn.Conv2d(gate_in_channel, gate_in_channel, kernel_size=1, stride=1) self.residual_conv = nn.Conv2d(residual_in_channel, gate_in_channel, kernel_size=1, stride=1) self.in_conv = nn.Conv2d(gate_in_channel, 1, kernel_size=1, stride=1) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: in_attention = self.relu(self.gate_conv(g) + self.residual_conv(x)) in_attention = self.in_conv(in_attention) in_attention = self.sigmoid(in_attention) return in_attention * x

最終統合

これまでに説明したすべてのブロック (注意ブロックを除く) を Unet に結合します。各ブロックには、1 つではなく 2 つの残余接続が含まれます。この変更は、潜在的な過剰適合の問題に対処するためのものです。

 class TwoResUNet(nn.Module): def __init__( self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, sinusoidal_pos_emb_theta=10000, convnext_block_groups=8, ): super().__init__() self.channels = channels input_channels = channels self.init_dim = default(init_dim, dim) self.init_conv = nn.Conv2d(input_channels, self.init_dim, 7, padding=3) dims = [self.init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) sinu_pos_emb = SinusoidalPosEmb(dim, theta=sinusoidal_pos_emb_theta) time_dim = dim * 4 self.time_mlp = nn.Sequential( sinu_pos_emb, nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim), ) self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.ModuleList( [ ConvNextBlock( in_channels=dim_in, out_channels=dim_in, time_embedding_dim=time_dim, group=convnext_block_groups, ), ConvNextBlock( in_channels=dim_in, out_channels=dim_in, time_embedding_dim=time_dim, group=convnext_block_groups, ), DownSample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1), ] ) ) mid_dim = dims[-1] self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_embedding_dim=time_dim) self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_embedding_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind == (len(in_out) - 1) is_first = ind == 0 self.ups.append( nn.ModuleList( [ ConvNextBlock( in_channels=dim_out + dim_in, out_channels=dim_out, time_embedding_dim=time_dim, group=convnext_block_groups, ), ConvNextBlock( in_channels=dim_out + dim_in, out_channels=dim_out, time_embedding_dim=time_dim, group=convnext_block_groups, ), Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1) ] ) ) default_out_dim = channels self.out_dim = default(out_dim, default_out_dim) self.final_res_block = ConvNextBlock(dim * 2, dim, time_embedding_dim=time_dim) self.final_conv = nn.Conv2d(dim, self.out_dim, 1) def forward(self, x, time): b, _, h, w = x.shape x = self.init_conv(x) r = x.clone() t = self.time_mlp(time) unet_stack = [] for down1, down2, downsample in self.downs: x = down1(x, t) unet_stack.append(x) x = down2(x, t) unet_stack.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_block2(x, t) for up1, up2, upsample in self.ups: x = torch.cat((x, unet_stack.pop()), dim=1) x = up1(x, t) x = torch.cat((x, unet_stack.pop()), dim=1) x = up2(x, t) x = upsample(x) x = torch.cat((x, r), dim=1) x = self.final_res_block(x, t) return self.final_conv(x)

拡散コードの実装

最後に、拡散がどのように達成されるかを紹介します。順方向、逆方向、およびサンプリング プロセスのすべての計算についてはすでに説明したので、ここではコードに焦点を当てます。

 class DiffusionModel(nn.Module): SCHEDULER_MAPPING = { "linear": linear_beta_schedule, "cosine": cosine_beta_schedule, "sigmoid": sigmoid_beta_schedule, } def __init__( self, model: nn.Module, image_size: int, *, beta_scheduler: str = "linear", timesteps: int = 1000, schedule_fn_kwargs: dict | None = None, auto_normalize: bool = True, ) -> None: super().__init__() self.model = model self.channels = self.model.channels self.image_size = image_size self.beta_scheduler_fn = self.SCHEDULER_MAPPING.get(beta_scheduler) if self.beta_scheduler_fn is None: raise ValueError(f"unknown beta schedule {beta_scheduler}") if schedule_fn_kwargs is None: schedule_fn_kwargs = {} betas = self.beta_scheduler_fn(timesteps, **schedule_fn_kwargs) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) register_buffer = lambda name, val: self.register_buffer( name, val.to(torch.float32) ) register_buffer("betas", betas) register_buffer("alphas_cumprod", alphas_cumprod) register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas)) register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) register_buffer( "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) ) register_buffer("posterior_variance", posterior_variance) timesteps, *_ = betas.shape self.num_timesteps = int(timesteps) self.sampling_timesteps = timesteps self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity @torch.inference_mode() def p_sample(self, x: torch.Tensor, timestamp: int) -> torch.Tensor: b, *_, device = *x.shape, x.device batched_timestamps = torch.full( (b,), timestamp, device=device, dtype=torch.long ) preds = self.model(x, batched_timestamps) betas_t = extract(self.betas, batched_timestamps, x.shape) sqrt_recip_alphas_t = extract( self.sqrt_recip_alphas, batched_timestamps, x.shape ) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, batched_timestamps, x.shape ) predicted_mean = sqrt_recip_alphas_t * ( x - betas_t * preds / sqrt_one_minus_alphas_cumprod_t ) if timestamp == 0: return predicted_mean else: posterior_variance = extract( self.posterior_variance, batched_timestamps, x.shape ) noise = torch.randn_like(x) return predicted_mean + torch.sqrt(posterior_variance) * noise @torch.inference_mode() def p_sample_loop( self, shape: tuple, return_all_timesteps: bool = False ) -> torch.Tensor: batch, device = shape[0], "mps" img = torch.randn(shape, device=device) # This cause me a RunTimeError on MPS device due to MPS back out of memory # No ideas how to resolve it at this point # imgs = [img] for t in tqdm(reversed(range(0, self.num_timesteps)), total=self.num_timesteps): img = self.p_sample(img, t) # imgs.append(img) ret = img # if not return_all_timesteps else torch.stack(imgs, dim=1) ret = self.unnormalize(ret) return ret def sample( self, batch_size: int = 16, return_all_timesteps: bool = False ) -> torch.Tensor: shape = (batch_size, self.channels, self.image_size, self.image_size) return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps) def q_sample( self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None ) -> torch.Tensor: if noise is None: noise = torch.randn_like(x_start) sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x_start.shape ) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise def p_loss( self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None, loss_type: str = "l2", ) -> torch.Tensor: if noise is None: noise = torch.randn_like(x_start) x_noised = self.q_sample(x_start, t, noise=noise) predicted_noise = self.model(x_noised, t) if loss_type == "l2": loss = F.mse_loss(noise, predicted_noise) elif loss_type == "l1": loss = F.l1_loss(noise, predicted_noise) else: raise ValueError(f"unknown loss type {loss_type}") return loss def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w, device, img_size = *x.shape, x.device, self.image_size assert h == w == img_size, f"image size must be {img_size}" timestamp = torch.randint(0, self.num_timesteps, (1,)).long().to(device) x = self.normalize(x) return self.p_loss(x, timestamp)

拡散プロセスはモデルのトレーニング部分です。すでにトレーニング済みのモデルを使用してサンプルを生成できるサンプリング インターフェイスが開きます。

トレーニングポイントのまとめ

トレーニング部分では、ステップごとに 16 バッチで 37,000 のトレーニング ステップを設定しました。 GPU メモリ割り当ての制限により、画像サイズは 128x128 に制限されます。モデルの重みの指数移動平均 (EMA) を使用して 1000 ステップごとにサンプルを生成し、サンプリングを平滑化し、モデルのバージョンを保存します。

トレーニングの最初の 1000 ステップでは、モデルはいくつかの特徴を捉え始めますが、まだいくつかの領域を捉えていません。ステップ 10,000 あたりで、モデルは有望な結果を生成し始め、進歩がより顕著になります。 30,000 ステップの終了時点で、結果の品質は大幅に向上しましたが、黒い画像はまだ残っています。これは、モデルに十分なサンプル タイプがなく、実際の画像のデータ分布がガウス分布に完全にマッピングされていないためです。

最終的なモデルの重みがわかれば、いくつかの画像を生成できます。 128 x 128 のサイズ制限により画像品質は制限されますが、モデルのパフォーマンスは十分に優れています。

注: この記事で使用されているデータセットは森林地形の衛星画像です。具体的な取得方法については、ソースコードの ETL セクションを参照してください。

要約する

拡散モデルに関する必要な知識を完全に紹介し、Pytorch を使用して完全に実装しました。この記事のコードは次のとおりです。

https://github.com/Camaltra/this-is-not-real-aerial-imagery/

<<: 

>>:  AIがあなたが何歳で死ぬかを予測?トランスフォーマーの「占い」がネイチャーのサブジャーナルに掲載され、事故死の予測に成功

ブログ    
ブログ    
ブログ    

推薦する

ディープインテリジェンスとは: 2021 年のディープインテリジェンスのトレンドは何ですか?

人工知能の開発は60年以上前に遡りますが、技術的な理由により、ディープラーニングの出現により再び人工...

ディープラーニングを使用してフロントエンドデザインモデルをコードに自動的に変換する方法は?

[[223504]]現在、フロントエンド開発の自動化に対する最大の障壁はコンピューティング能力です...

2023年に出現するサイバー脅威、AI、量子コンピューティング、データ汚染まで

ハッカーや詐欺師が新しいテクノロジーを入手したり、古い脆弱性を悪用する新しい方法を考え出したりするに...

MITはディープラーニングが計算限界に近づいていると警告。ネットユーザー:減速は良いことだ

MIT の調査によると、ディープラーニングは計算能力の限界に近づいているようです。 [[334431...

映画の好みを予測しますか?オートエンコーダを使用して協調フィルタリングを実装する方法

推奨システムは、協調フィルタリングを使用して、ユーザーの好み情報を収集し、特定のユーザーの興味を予測...

アリババの無人車が路上試験を開始、BATの3大巨頭が同じ舞台に集結

テンセントと百度の自動運転車はアリババを上回っており、自動運転分野でのBATの戦いがまもなく始まるか...

人工知能は、企業がエンドツーエンドのインテリジェントな自動化を実現することを促進します。

[[401604]]新型コロナウイルスによる混乱に対応するため、組織が急いでビジネスプロセスを適応...

「顔認識」時代の準備はできていますか?

[51CTO.comからのオリジナル記事] 近年、生体認証技術はますます成熟し、私たちの生活の中に...

...

...

GitHub が機械学習コードの脆弱性スキャンを無料で提供、JavaScript / TypeScript もサポート

現在、JavaScript および TypeScript リポジトリで開発およびテストが行​​われて...

ChatGPT-4、Bard、Claude-2、Copilot空間タスクの正確性の比較

大規模言語モデル (LLM) を含む生成 AI は、エンコード、空間計算、サンプル データ生成、時系...

DeepMind の新しい研究: ReST は大規模なモデルを人間の好みに合わせて調整し、オンライン RLHF よりも効果的です

過去数か月間、私たちは大規模言語モデル (LLM) が高品質のテキストを生成し、幅広い言語タスクを解...

...

AIと機械学習がIoTデータから重要な洞察を引き出す方法

過去数年間、モノのインターネットに関する議論の多くは、接続されたデバイス自体、つまりそれが何であるか...