Conditional Image Generation with Classifier-Free Guidance¶
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable, List
from itertools import product
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
import seaborn as sns
from IPython.display import Video
import os
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset
from torch import Tensor
from abc import ABC, abstractmethod
from torch.nn.functional import relu
from torch.utils.data.dataloader import DataLoader
from tqdm.notebook import tqdm
import scipy.stats as st
import torch.nn.functional as F
from torch.optim.lr_scheduler import LRScheduler
from sampling import run_backwards
Introduction. In our previous tutorial, we showed how we can build a conditional image generation machine. By training a classifier on noisy data, we could successfully turn an unconditional diffusion model into a conditional diffusion. Great!
What is the problem with what we have? However, there are a few problems with that approach. One is obvious and the other one is a bit more intricate:
-
Training an extra classifier: as we learnt, we need to train an extra classifier on noisy data. This takes approximately as much time as training the diffusion model itself, so we essentially double our work. We can't use an off-the-shelf classifier (e.g. an ImageNet classifier) either as we require it to be trained on noisy data.
-
Adversarial generation: there is an interesting question to what extend classifier guidance actually generates a realistic image or only an image that the classifier considers to be most likely of a certain class. In a way, one might object that we essentially perform an adversarial attack on the image classifier (i.e. make it believe that the image is a certain class) rather than generating an actually realistic image.
Main idea. The main idea behind classifier-free guidance is that you train two diffusion models: 1) a conditional and 2) an unconditional generation model in one model at the same time. One can then balance the conditional and unconditional part to ensure that the images are highly realistic but also fulfill the desired condition.
We again consider the task of sampling from the distribution $p(x|y)$ with diffusion models. Again, we would convert the data distribution $p_{0}(x|y)=p(x|y)$ into a noised distribution $p_{1}(x|y)$ gradually over time via an SDE with $X_t\sim p_{t}(x|y)$ for all $0\leq t \leq 1$. Again, we want an approximation of the score $\nabla_{x_t} \log p(x_t|y)$ for a conditional variable $y$. But first, let's do a quick reminder about our current scenario.
1. Reminder - diffusion models with SDEs¶
We again assume that we are in the scenario where we run a forward and backward SDE as in our previous tutorials. In particular, there is a forward SDE: $$ \begin{align*} dX_t = f(X_t,t)dt + g(t)dW_t \end{align*} $$ with $X_0\sim p_{\text{data}}=p_0$ and $p_{1} \approx \mathcal{N}(0,\mathbb{V}[X_1])$ and the drift coefficients are affine, i.e. $f(x,t)=a(t)x+b(t)$. As we saw, it holds that we can compute closed analytical formulas for $m(x,t)=\mathbb{E}[X_t|X_0=x]$ and $\mathbb{V}[X_t|X_0]$ where $v(t) = \mathbb{V}[X_t|X_0]$ is independent of $X_t$ and $X_0$.
In our case, we use the variance-preserving SDE (see here for an explanation), where $$\begin{align*} f(x,t)&=-x\frac{\beta_{\text{min}}+t(\beta_{\text{max}}-\beta_{\text{min}})}{2}\\ g(t)&=\sqrt{\beta_{\text{min}}+t(\beta_{\text{max}}-\beta_{\text{min}})}\\ m(x,t)&=x\exp\left(-\frac{1}{4}t^2(\beta_{\text{max}}-\beta_{\text{min}})-\frac{1}{2}t\beta_{\text{min}}\right)\\ v(t)&=1-\exp\left(-\frac{1}{2}t^2(\beta_{\text{max}}-\beta_{\text{min}})-t\beta_{\text{min}}\right) \end{align*}$$
Let's define the SDE
from sde import VPSDE
sde = VPSDE(T_max=1,beta_min=0.01, beta_max=10.0)
and visualize the forward SDE with MNIST:
def load_mnist():
image_size = 28
transform = transforms.Compose([transforms.Resize(image_size),\
transforms.ToTensor(),\
transforms.Normalize([0.5],[0.5])]) #Normalize to -1,1
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
batch_size = 256
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
return image_size, trainloader, trainset
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.figure(figsize=[20, 20])
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def visualize_forward_sde(X_0):
n_grid_points = 8
X_0 = torch.stack([X_0]*n_grid_points)
time_vec = torch.linspace(0,1,n_grid_points)**2
X_t, noise, score = sde.run_forward(X_0,torch.linspace(0,1.0,n_grid_points)**2)
imshow(torchvision.utils.make_grid(X_t.unsqueeze(1)))
image_size, trainloader, trainset = load_mnist()
X_0 = trainset.__getitem__(20130)[0].unsqueeze(0).squeeze()
visualize_forward_sde(X_0)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
2. Classifier-free network architecture¶
Going from conditional generation to a classifier. Let's imagine that we had a conditional generative model $p(x_t|y)$ available. Then we could simply invert it to a classifier by: \begin{align*} p_t(y|x_t) = \frac{p_t(x_t|y)}{p_t(x_t)}p_t(y) \end{align*} And if we wanted to do classifier-guidance with this implicit classifier, we would compute the classifier gradient $c(x_t,t)$ by \begin{align*} c(x_t,t) = \nabla_{x_t} \log p_t(y|x_t) = \nabla_{x_t} \log p_t(x_t|y) - \nabla_{x_t} \log p_t(x_t) \end{align*} To sample from a conditional diffusion model, we would then run the reverse SDE with the combined score where we weight the classifier gradient with a weight $w\in\mathbb{R}$: \begin{align*} \nabla_{x_t}\log p_t(x_t|y) =& \nabla_{x_t} \log p_t(x_t|y) + w*c(x_t,t) \\ =& (1+w)\nabla_{x_t} \log p_t(x_t|y) - w*\nabla_{x_t} \log p_t(x_t) \end{align*} What is written above is very intuitive: we go into the direction of score of the conditional model. However we scale up that score with $s$, we need to substract the score of the unconditional model. In theory, there is no better or worse $s$. In fact, $w=0$ should give equally good results. However, in turns out that in practice (for some delicate reasons I will explain later), choosing $w$ is important to achieve high performance.
Training with classifier-free guidance. The idea of classifier-free guidance is to train two a conditional and an unconditional generative model in one model simultanously and use the above equation to get the conditional score. We simply introduce a special token $\varnothing$ indicating that we do not specify a class. Our variable $y$ then can take either be a class variable or $\varnothing$. More specificaly, our denoising neural network $\epsilon_{\theta}(x_t,t,y)$ is again a rescaled score network: $$\begin{align*} \epsilon_{\theta}(x_t,t,y) = -\frac{1}{\sqrt{v(t)}}s_{\theta}(x_t,t,y) \end{align*}$$ and takes three inputs:
- The noised image/data point $x_t$
- A time variable $t\in [0,1]$.
- A conditional variable $y$ which is either $\varnothing$ or a class variable.
Let's define the variable y.
classes_by_index = np.arange(0,10).tolist()
token_variables = classes_by_index + [10]
token_variables
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Let's define our neural architecture.
Essentially, we are going to use the same U-Net that we used for the unconditional generation case (see here). However, we are not only conditioning on time but also on our class label $y$. The way we do this is similar to the time embedding with the only difference that here we have a discrete set of label $y=0,\dots,9,10$ where $10$ stands for the $\varnothing$ token. We use a standard embedding matrix that embeds the labels into high-dimensional vectors that are parameters that can be learnt during training.
class ChannelShuffle(nn.Module):
def __init__(self,groups):
super().__init__()
self.groups=groups
def forward(self,x):
n,c,h,w=x.shape
x=x.view(n,self.groups,c//self.groups,h,w) # group
x=x.transpose(1,2).contiguous().view(n,-1,h,w) #shuffle
return x
class ConvBnSiLu(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0):
super().__init__()
self.module=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
nn.BatchNorm2d(out_channels),
nn.SiLU(inplace=True))
def forward(self,x):
return self.module(x)
class ResidualBottleneck(nn.Module):
'''
shufflenet_v2 basic unit(https://arxiv.org/pdf/1807.11164.pdf)
'''
def __init__(self,in_channels,out_channels):
super().__init__()
self.branch1=nn.Sequential(nn.Conv2d(in_channels//2,in_channels//2,3,1,1,groups=in_channels//2),
nn.BatchNorm2d(in_channels//2),
ConvBnSiLu(in_channels//2,out_channels//2,1,1,0))
self.branch2=nn.Sequential(ConvBnSiLu(in_channels//2,in_channels//2,1,1,0),
nn.Conv2d(in_channels//2,in_channels//2,3,1,1,groups=in_channels//2),
nn.BatchNorm2d(in_channels//2),
ConvBnSiLu(in_channels//2,out_channels//2,1,1,0))
self.channel_shuffle=ChannelShuffle(2)
def forward(self,x):
x1,x2=x.chunk(2,dim=1)
x=torch.cat([self.branch1(x1),self.branch2(x2)],dim=1)
x=self.channel_shuffle(x) #shuffle two branches
return x
class ResidualDownsample(nn.Module):
'''
shufflenet_v2 unit for spatial down sampling(https://arxiv.org/pdf/1807.11164.pdf)
'''
def __init__(self,in_channels,out_channels):
super().__init__()
self.branch1=nn.Sequential(nn.Conv2d(in_channels,in_channels,3,2,1,groups=in_channels),
nn.BatchNorm2d(in_channels),
ConvBnSiLu(in_channels,out_channels//2,1,1,0))
self.branch2=nn.Sequential(ConvBnSiLu(in_channels,out_channels//2,1,1,0),
nn.Conv2d(out_channels//2,out_channels//2,3,2,1,groups=out_channels//2),
nn.BatchNorm2d(out_channels//2),
ConvBnSiLu(out_channels//2,out_channels//2,1,1,0))
self.channel_shuffle=ChannelShuffle(2)
def forward(self,x):
x=torch.cat([self.branch1(x),self.branch2(x)],dim=1)
x=self.channel_shuffle(x) #shuffle two branches
return x
class TimeMLP(nn.Module):
'''
naive introduce timestep information to feature maps with mlp and add shortcut
'''
def __init__(self,embedding_dim,hidden_dim,out_dim):
super().__init__()
self.mlp=nn.Sequential(nn.Linear(embedding_dim,hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim,out_dim))
self.act=nn.SiLU()
def forward(self,x,t):
t_emb=self.mlp(t).unsqueeze(-1).unsqueeze(-1)
x=x+t_emb
return self.act(x)
class ConditionMLP(nn.Module):
'''
naive introduce conditional informationy to feature maps with mlp and add shortcut
'''
def __init__(self,embedding_dim,hidden_dim,out_dim):
super().__init__()
self.mlp=nn.Sequential(nn.Linear(embedding_dim,hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim,out_dim))
self.act=nn.SiLU()
def forward(self,x,y):
y_emb=self.mlp(y).unsqueeze(-1).unsqueeze(-1)
x=x+y_emb
return self.act(x)
class EncoderBlock(nn.Module):
def __init__(self,in_channels,out_channels,time_embedding_dim, cond_embedding_dim):
super().__init__()
self.conv0=nn.Sequential(*[ResidualBottleneck(in_channels,in_channels) for i in range(3)],
ResidualBottleneck(in_channels,out_channels//2))
self.time_mlp=TimeMLP(embedding_dim=time_embedding_dim, hidden_dim=out_channels, out_dim=out_channels//2)
self.y_mlp=ConditionMLP(embedding_dim=cond_embedding_dim, hidden_dim=out_channels, out_dim=out_channels//2)
self.conv1=ResidualDownsample(out_channels//2,out_channels)
def forward(self,x, t=None, y=None):
x_shortcut=self.conv0(x)
if y is not None:
x=self.y_mlp(x,y)
if t is not None:
x=self.time_mlp(x_shortcut,t)
x=self.conv1(x)
return [x,x_shortcut]
class DecoderBlock(nn.Module):
def __init__(self,in_channels,out_channels,time_embedding_dim, cond_embedding_dim):
super().__init__()
self.upsample=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=False)
self.conv0=nn.Sequential(*[ResidualBottleneck(in_channels,in_channels) for i in range(3)],
ResidualBottleneck(in_channels,in_channels//2))
self.time_mlp=TimeMLP(embedding_dim=time_embedding_dim,hidden_dim=in_channels,out_dim=in_channels//2)
self.y_mlp=ConditionMLP(embedding_dim=cond_embedding_dim,hidden_dim=in_channels,out_dim=in_channels//2)
self.conv1=ResidualBottleneck(in_channels//2,out_channels//2)
def forward(self,x,x_shortcut,t=None, y=None):
x=self.upsample(x)
x=torch.cat([x,x_shortcut],dim=1)
x=self.conv0(x)
if y is not None:
x=self.y_mlp(x,y)
if t is not None:
x=self.time_mlp(x,t)
x=self.conv1(x)
return x
class Unet(nn.Module):
'''
simple unet design without attention
'''
def __init__(self, token_variables, token_embedding_dim, timesteps, time_embedding_dim, in_channels=3, out_channels=2,base_dim=32, dim_mults=[2,4,8,16], temp: float = 20.0):
super().__init__()
self.token_variables = set(token_variables)
self.n_tokens = len(token_variables)
self.token_embedding_dim = token_embedding_dim
assert isinstance(dim_mults,(list,tuple))
assert base_dim%2==0
channels=self._cal_channels(base_dim,dim_mults)
self.init_conv=ConvBnSiLu(in_channels,base_dim,3,1,1)
self.time_embedding=nn.Linear(timesteps,time_embedding_dim)
self.token_embedding=torch.nn.Embedding(self.n_tokens, token_embedding_dim)
self.encoder_blocks=nn.ModuleList([EncoderBlock(c[0],c[1],time_embedding_dim=time_embedding_dim,cond_embedding_dim=token_embedding_dim) for c in channels])
self.decoder_blocks=nn.ModuleList([DecoderBlock(c[1],c[0],time_embedding_dim=time_embedding_dim,cond_embedding_dim=token_embedding_dim) for c in channels[::-1]])
self.mid_block=nn.Sequential(*[ResidualBottleneck(channels[-1][1],channels[-1][1]) for i in range(2)],
ResidualBottleneck(channels[-1][1],channels[-1][1]//2))
self.final_conv=nn.Conv2d(in_channels=channels[0][0]//2,out_channels=out_channels,kernel_size=1)
self.centers = nn.Parameter(torch.linspace(0,1,timesteps+1)[:-1]+0.5/timesteps,requires_grad=False)
self.temp = temp
def get_softmax(self, t):
softmax_mat = F.softmax(-self.temp*torch.abs(t[:,None]-self.centers[None,:]), dim=1)
return softmax_mat
def get_time_emb(self,t):
softmax_mat = self.get_softmax(t)
t=self.time_embedding(softmax_mat)
return t
def get_cond_emb(self,y):
return self.token_embedding(y)
def forward(self,x,t=None, y=None):
x=self.init_conv(x)
if t is not None:
t = self.get_time_emb(t)
if y is not None:
y = self.get_cond_emb(y)
encoder_shortcuts=[]
for encoder_block in self.encoder_blocks:
x,x_shortcut=encoder_block(x,t=t,y=y)
encoder_shortcuts.append(x_shortcut)
x=self.mid_block(x)
encoder_shortcuts.reverse()
for decoder_block,shortcut in zip(self.decoder_blocks,encoder_shortcuts):
x=decoder_block(x,shortcut,t=t, y=y)
x=self.final_conv(x)
return x
def _cal_channels(self,base_dim,dim_mults):
dims=[base_dim*x for x in dim_mults]
dims.insert(0,base_dim)
channels=[]
for i in range(len(dims)-1):
channels.append((dims[i],dims[i+1])) # in_channel, out_channel
return channels
model = Unet(token_variables=token_variables, base_dim=image_size, in_channels=1, out_channels=1, token_embedding_dim=256, time_embedding_dim=256, timesteps=100, dim_mults=[2, 4], temp=100.0)
model = torch.compile(model)
Let's see whether we did a good job and pass an example input through our network:
for idx, batch in enumerate(trainloader):
break
x = batch[0]
y = batch[1]
t = torch.rand(len(y))
print("x.shape: ", x.shape)
print("y.shape: ", y.shape)
print("t.shape: ", t.shape)
output = model(x,t=t,y=y)
print("output.shape: ", output.shape)
x.shape: torch.Size([256, 1, 28, 28]) y.shape: torch.Size([256]) t.shape: torch.Size([256]) output.shape: torch.Size([256, 1, 28, 28])
Great, we get the desired output shape (same as shape of x
)!
3. Classifier-Free Guidance Training¶
As said above, our goal is to approximate $$\begin{align*} s_{\theta}(x_t,t,y) \approx \begin{cases} \nabla_{x_t}\log p(x_t)\quad\text{if } y=\varnothing\\ \nabla_{x_t}\log p(x_t|y)\quad\text{ if } y\neq\varnothing\\ \end{cases} \end{align*}$$ Our goal is to have one denoising model $\epsilon_{\theta}(x_t,t,y)=-\frac{1}{\sqrt{v(t)}}s_{\theta}(x_t,t,y)$ approximating it. We can train it by standard denoising score matching, i.e. by predicting the noise that perturbed our data point. In essence, we train both the conditional and the unconditional case at the same time.
Therefore, it is natural to design our training algorithm works as follows:
- Sample $(x,y)\sim p_{\text{data}}(x,y)$ (i.e. randomly draw images and labels from our train data)
- With probability $p_{\text{uncond}}$, set $y=\varnothing$ (i.e. randomly discard labeling)
- Sample random noise $\epsilon\sim\mathcal{N}(0,\mathbf{I})$
- Run SDE forward, i.e. set: $x_t =m(x,t) + \sqrt{v(t)}\epsilon$
- Compute loss: $$\begin{align*} L(\theta) = \|\epsilon_{\theta}(x_t,t,y)-\epsilon\|^2 \end{align*} $$
- Take gradient step on $\theta$ minimizing $L(\theta)$.
Of course, we compute the above computation in mini batches of data.
from sde import ItoSDE
def train_clfree_guidance(model, sde: ItoSDE, dataloader: DataLoader, optimizer, device, n_epochs: int, print_every: int, scheduler: LRScheduler = None, p_uncond: float = 0.5, p_uncond_label: int = 10, train_score: bool = False):
model.train()
model = model.to(DEVICE)
running_loss_list = []
lr_list = []
loss_function = nn.MSELoss(reduction='mean')
for epoch in range(n_epochs):
print(f"Epoch: {epoch}")
running_loss = 0.0
for idx, (x_inp,target) in tqdm(enumerate(dataloader), total=len(dataloader)):
#Randomly change labels to unconditional with probability p_uncond:
random_uncond_mask = (torch.rand(size=(len(x),))<=p_uncond)
y[random_uncond_mask] = p_uncond_label
#Zero gradients:
optimizer.zero_grad()
#Run forward samples:
X_t,noise,score,time = sde.run_forward_random_times(x_inp)
#Send to device:
X_t = X_t.to(DEVICE)
noise = noise.to(DEVICE)
time = time.to(DEVICE)
target = target.to(DEVICE)
#Predict score:
model_pred = model(x=X_t, t=time, y=target)
#Compute loss:
if train_score:
loss = loss_function(score,model_pred)
else:
loss = loss_function(noise,model_pred)
#Optimize:
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()
# print statistics
running_loss += loss.detach().item()
if (idx+1) % print_every == 0:
avg_loss = running_loss/print_every
running_loss_list.append(avg_loss)
running_loss = 0.0
if scheduler is not None:
print(f"Loss: {avg_loss:.4f} | {scheduler.get_lr()}")
lr_list.append(scheduler.get_lr())
else:
print(f"Loss: {avg_loss:.4f}")
return model,running_loss_list
Let's run the training. I would recommend loading the pre-trained model :)
LEARNING_RATE = 1e-2 #2e-5
WEIGHT_DECAY = 0.0
N_EPOCHS = 100
RETRAIN = False
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_SCORE = False
if RETRAIN:
optimizer = torch.optim.AdamW(model.parameters(),lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY,maximize=False)
scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,LEARNING_RATE,total_steps=N_EPOCHS*len(trainloader),pct_start=0.25,anneal_strategy='cos')
model,running_loss_list = train_clfree_guidance(model, sde, trainloader, train_score=TRAIN_SCORE, optimizer=optimizer, scheduler=scheduler, device=DEVICE, n_epochs=N_EPOCHS, print_every=100)
torch.save(model.state_dict(),"20231218_mnist_diffusion_denoiser_full_training.ckpt")
else:
model_state_dict = torch.load("20231218_mnist_diffusion_denoiser_full_training.ckpt")
model.load_state_dict(model_state_dict)
4. Classifier-Free Guidance Sampling¶
Finally, if we want to sample from our diffusion model, we execute the following algorithm:
Given conditional variable $y$, a guidance weight $w$ and number of grid points $k$.
- Initialize: $x_{1}\sim\mathcal{N}(0,\mathbf{I}_{d})$.
- For $i=1,\dots,k$: sample $\epsilon_i\sim\mathcal{N}(0,\mathbf{1})$ and perform a time-reversal step of the SDE with the classifer-free guidance score:
$$ S = -\frac{1}{\sqrt{v(i/k)}}\left[(1+w)\epsilon_{\theta}(x_{i/k},i/k,y)-w \epsilon_{\theta}(x_{i/k},i/k,\varnothing)\right] $$ $$\begin{align*} x_{i/k} = x_{(i-1)/k} + \frac{g^2(i/k)S-f(x_{i/k},i/k)}{k} + g(i/k) \epsilon_i \end{align*}$$
Let's sample a few examples of each class for different guidance weights $s$:
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid
torch._dynamo.config.suppress_errors = True
weight_list = [0.0, 0.2, 1.0, 3.0, 5.0]
weight_list = [1.0, 3.0, 5.0]
for weight in weight_list:
single_target_shape = [8,1,image_size,image_size]
output_list = []
for target in classes_by_index:
x_start = torch.clip(torch.randn(size=single_target_shape),-1.0,1.0)
output,time_grid = run_backwards(model,sde,
x_start=x_start,
n_steps=1000,
device=DEVICE,
train_score=TRAIN_SCORE,
clip_min=-10.0,
clip_max=10.0,
clfree_guidance=True,
target=int(target),
unconditional_target=10,
clfree_guidance_weight=weight)
output_list.append(output)
output_agg = torch.stack([output.transpose(1,0) for output in output_list],dim=1)
n_time_steps = output_agg.shape[0]
n_labels = output_agg.shape[1]
images_per_label = output_agg.shape[2]
time_idx = -1
fig = plt.figure(figsize=(12., 12.))
#fig, axs = plt.subplots(n_labels, images_per_label)
grid_idx = 0
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(n_labels, images_per_label), # creates 2x2 grid of axes
axes_pad=0.01)
for label in range(n_labels):
for image_idx in range(images_per_label):
grid[grid_idx].imshow(output_agg[time_idx,label,image_idx].squeeze(),cmap='grey')
grid[grid_idx].set_xticks([])
grid[grid_idx].set_yticks([])
grid[grid_idx].set_ylabel(f"{label}",fontsize=16)
grid_idx += 1
plt.savefig(f"diffusion_loss_weight={weight:.2f}.png")
plt.show()
5. Visually inspecting output¶
Looking at the examples, we can make the following observation:
- For very high guidance weight, we very sharp images that look very realistic. However, they are not that diverse. In a way, they all look the same.
- For very low guidance weight, we see a large variety in the samples per class. However, they are partially of very low quality and they are quite blurry.