Building diffusion models for images¶
In this tutorial, we are going to use the implementation of diffusion models that we built in the last tutorial and apply it on images. You will see that apart from a change of the network architecture, there is basically no change to our previous code. Therefore, this is going to be brief, but cool :) We are going to use the MNIST dataset of handwritten digits. Essentially, we are going to build an algorithm that generates novel handwritten digits as follows:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable, List
from itertools import product
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
import seaborn as sns
from IPython.display import Video
import os
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset
from torch import Tensor
from abc import ABC, abstractmethod
from torch.nn.functional import relu
from torch.utils.data.dataloader import DataLoader
from tqdm.notebook import tqdm
import scipy.stats as st
from sde import VPSDE
from train import train_diffusion_model
from sampling import run_backwards
1. Load Dataset¶
We are loading the MNIST dataset from pytorch (see here for details).
image_size = 28
classes_by_index = np.arange(0,10).astype('str')
transform = transforms.Compose([transforms.Resize(image_size),\
transforms.ToTensor(),\
transforms.Normalize([0.5],[0.5])]) #Normalize to -1,1
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
batch_size = 128
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
Let's plot a few examples
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.figure(figsize=[20, 20])
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
images = images[:8]
labels = labels[:8]
# show images
imshow(torchvision.utils.make_grid(images))
print(' '.join(f'{classes_by_index[labels[j]]:5s}' for j in range(8)))
7 7 7 4 4 8 7 1
2. Define SDE¶
Let's define the variance-preserving SDE:
sde = VPSDE(T_max=1,beta_min=0.01, beta_max=10.0)
Let's plot the forward-evolution of an example on MNIST:
n_grid_points = 16
time_vec = torch.linspace(0,1,n_grid_points)**2
X_0 = torch.stack([trainset.__getitem__(23420)[0].unsqueeze(0).squeeze()]*n_grid_points)
X_t, noise, score = sde.run_forward(X_0,torch.linspace(0,1.0,n_grid_points)**2)
imshow(torchvision.utils.make_grid(X_t.unsqueeze(1)))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
3. Neural network: Time-dependent U-Net¶
Next, we need to define our denoising network $\epsilon_{\theta}(x_t,t)$. As we operate on images, a natural choice is a U-Net:
A U-Net is a good choice as it maps an image $x$ (i.e. a tensor of shape $[c,h,w]$ where $c$ is the number of channels and $h,w$ the height and width in pixels) to a tensor of the same shape using only convolutions, i.e. respecting the spatial structure of the image. Such a tensor of shape $[c,h,w]$ is usually called a multi-channel feature map as it assigns a pixel location a feature of $c$ dimensions.
A U-Net is a fully convolutional neural network, i.e. there is no fully connected layer but only convolutions (in addition to non-linearities and pooling layers that have no learnable parameters). It has a special “residual” structure of a contractual and an expansive path giving it that U-shape above. The contractual path is simply a standard convolutional neural network with any padding, i.e. the width and height of the image tensor decrease (from 572x572 to 30x30) while the number of channels (the $c$) increases. The expansive path uses upsampling, i.e. creating a higher resolution tensor from a lower resolution one artifically. Every previous tensor is upsampled and a cropped version of the corresponding multi-channel feature map from the contractive path is added.
It is not straight-forward though as we have to adopt a U-Net slightly to account for the fact that we need a time-dependent output $\epsilon_{\theta}(x_t,t)$. To do so, we add a time embedding that embeds time into a $d$-dimensional vector to the overall U-Net: \begin{align*} E(t) = A\cdot\text{softmax}(\tau(t-t_1)^2,\dots, \tau(t-t_n)^2)+b \end{align*} where $A$ is a learnable matrix and $b$ a bias vector and $\tau$ a temperature parameter and $t_1,\dots,t_n$ are evenly spaced time points in $[0,T]$. For each block of the contractive and expanding path, we then add a fully connected neural network to map $E(t)$ to a $c$-dimensional vector where $c$ is the number of channels of that block. We then simply add the embedding to the feature map and apply a non-linearity.
TRAIN_SCORE = False #whether to train score or denoiser network
import torch
import torch.nn as nn
import torch.nn.functional as F
The below neural network implementation was borrowed from here and modified to account for a continuous-time variable.
class ChannelShuffle(nn.Module):
def __init__(self,groups):
super().__init__()
self.groups=groups
def forward(self,x):
n,c,h,w=x.shape
x=x.view(n,self.groups,c//self.groups,h,w) # group
x=x.transpose(1,2).contiguous().view(n,-1,h,w) #shuffle
return x
class ConvBnSiLu(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0):
super().__init__()
self.module=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
nn.BatchNorm2d(out_channels),
nn.SiLU(inplace=True))
def forward(self,x):
return self.module(x)
class ResidualBottleneck(nn.Module):
'''
shufflenet_v2 basic unit(https://arxiv.org/pdf/1807.11164.pdf)
'''
def __init__(self,in_channels,out_channels):
super().__init__()
self.branch1=nn.Sequential(nn.Conv2d(in_channels//2,in_channels//2,3,1,1,groups=in_channels//2),
nn.BatchNorm2d(in_channels//2),
ConvBnSiLu(in_channels//2,out_channels//2,1,1,0))
self.branch2=nn.Sequential(ConvBnSiLu(in_channels//2,in_channels//2,1,1,0),
nn.Conv2d(in_channels//2,in_channels//2,3,1,1,groups=in_channels//2),
nn.BatchNorm2d(in_channels//2),
ConvBnSiLu(in_channels//2,out_channels//2,1,1,0))
self.channel_shuffle=ChannelShuffle(2)
def forward(self,x):
x1,x2=x.chunk(2,dim=1)
x=torch.cat([self.branch1(x1),self.branch2(x2)],dim=1)
x=self.channel_shuffle(x) #shuffle two branches
return x
class ResidualDownsample(nn.Module):
'''
shufflenet_v2 unit for spatial down sampling(https://arxiv.org/pdf/1807.11164.pdf)
'''
def __init__(self,in_channels,out_channels):
super().__init__()
self.branch1=nn.Sequential(nn.Conv2d(in_channels,in_channels,3,2,1,groups=in_channels),
nn.BatchNorm2d(in_channels),
ConvBnSiLu(in_channels,out_channels//2,1,1,0))
self.branch2=nn.Sequential(ConvBnSiLu(in_channels,out_channels//2,1,1,0),
nn.Conv2d(out_channels//2,out_channels//2,3,2,1,groups=out_channels//2),
nn.BatchNorm2d(out_channels//2),
ConvBnSiLu(out_channels//2,out_channels//2,1,1,0))
self.channel_shuffle=ChannelShuffle(2)
def forward(self,x):
x=torch.cat([self.branch1(x),self.branch2(x)],dim=1)
x=self.channel_shuffle(x) #shuffle two branches
return x
class TimeMLP(nn.Module):
'''
naive introduce timestep information to feature maps with mlp and add shortcut
'''
def __init__(self,embedding_dim,hidden_dim,out_dim):
super().__init__()
self.mlp=nn.Sequential(nn.Linear(embedding_dim,hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim,out_dim))
self.act=nn.SiLU()
def forward(self,x,t):
t_emb=self.mlp(t).unsqueeze(-1).unsqueeze(-1)
x=x+t_emb
return self.act(x)
class EncoderBlock(nn.Module):
def __init__(self,in_channels,out_channels,time_embedding_dim):
super().__init__()
self.conv0=nn.Sequential(*[ResidualBottleneck(in_channels,in_channels) for i in range(3)],
ResidualBottleneck(in_channels,out_channels//2))
self.time_mlp=TimeMLP(embedding_dim=time_embedding_dim,hidden_dim=out_channels,out_dim=out_channels//2)
self.conv1=ResidualDownsample(out_channels//2,out_channels)
def forward(self,x,t=None):
x_shortcut=self.conv0(x)
if t is not None:
x=self.time_mlp(x_shortcut,t)
x=self.conv1(x)
return [x,x_shortcut]
class DecoderBlock(nn.Module):
def __init__(self,in_channels,out_channels,time_embedding_dim):
super().__init__()
self.upsample=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=False)
self.conv0=nn.Sequential(*[ResidualBottleneck(in_channels,in_channels) for i in range(3)],
ResidualBottleneck(in_channels,in_channels//2))
self.time_mlp=TimeMLP(embedding_dim=time_embedding_dim,hidden_dim=in_channels,out_dim=in_channels//2)
self.conv1=ResidualBottleneck(in_channels//2,out_channels//2)
def forward(self,x,x_shortcut,t=None):
x=self.upsample(x)
x=torch.cat([x,x_shortcut],dim=1)
x=self.conv0(x)
if t is not None:
x=self.time_mlp(x,t)
x=self.conv1(x)
return x
class Unet(nn.Module):
'''
simple unet design without attention
'''
def __init__(self,timesteps,time_embedding_dim,in_channels=3,out_channels=2,base_dim=32,dim_mults=[2,4,8,16], temp: float = 20.0):
super().__init__()
assert isinstance(dim_mults,(list,tuple))
assert base_dim%2==0
channels=self._cal_channels(base_dim,dim_mults)
self.init_conv=ConvBnSiLu(in_channels,base_dim,3,1,1)
#self.time_embedding=nn.Embedding(timesteps,time_embedding_dim)
self.time_embedding=nn.Linear(timesteps,time_embedding_dim)
self.encoder_blocks=nn.ModuleList([EncoderBlock(c[0],c[1],time_embedding_dim) for c in channels])
self.decoder_blocks=nn.ModuleList([DecoderBlock(c[1],c[0],time_embedding_dim) for c in channels[::-1]])
self.mid_block=nn.Sequential(*[ResidualBottleneck(channels[-1][1],channels[-1][1]) for i in range(2)],
ResidualBottleneck(channels[-1][1],channels[-1][1]//2))
self.final_conv=nn.Conv2d(in_channels=channels[0][0]//2,out_channels=out_channels,kernel_size=1)
self.centers = nn.Parameter(torch.linspace(0,1,timesteps+1)[:-1]+0.5/timesteps,requires_grad=False)
self.temp = temp
def get_softmax(self, t):
softmax_mat = F.softmax(-self.temp*torch.abs(t[:,None]-self.centers[None,:]), dim=1)
return softmax_mat
def get_time_emb(self,t):
softmax_mat = self.get_softmax(t)
t=self.time_embedding(softmax_mat)
return t
def forward(self,x,t=None):
x=self.init_conv(x)
if t is not None:
t = self.get_time_emb(t)
encoder_shortcuts=[]
for encoder_block in self.encoder_blocks:
x,x_shortcut=encoder_block(x,t)
encoder_shortcuts.append(x_shortcut)
x=self.mid_block(x)
encoder_shortcuts.reverse()
for decoder_block,shortcut in zip(self.decoder_blocks,encoder_shortcuts):
x=decoder_block(x,shortcut,t)
x=self.final_conv(x)
return x
def _cal_channels(self,base_dim,dim_mults):
dims=[base_dim*x for x in dim_mults]
dims.insert(0,base_dim)
channels=[]
for i in range(len(dims)-1):
channels.append((dims[i],dims[i+1])) # in_channel, out_channel
return channels
4. Training the U-Net¶
Next, we train the denoising network. We use the torch.compile
function to compile our model and accelerate training. At this point, you will definitely need a GPU to replicate this training. If you don't have one, I also provide a model that you can load directly.
N_TIMESTEPS = 100
n_channels = 1
model = Unet(base_dim=image_size, in_channels=n_channels, out_channels=n_channels, time_embedding_dim=256, timesteps=N_TIMESTEPS, dim_mults=[2, 4], temp=100.0)
model = torch.compile(model)
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-2 #2e-5
WEIGHT_DECAY = 0.0
N_EPOCHS = 100
RETRAIN = False
if RETRAIN:
optimizer = torch.optim.AdamW(model.parameters(),lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY,maximize=False)
scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,LEARNING_RATE,total_steps=N_EPOCHS*len(trainloader),pct_start=0.25,anneal_strategy='cos')
model,running_loss_list = train_diffusion_model(model, sde, trainloader, train_score=TRAIN_SCORE, optimizer=optimizer, scheduler=scheduler, device=device, n_epochs=N_EPOCHS, print_every=100)
torch.save(model.state_dict(),"20231120_mnist_diffusion_denoiser.ckpt")
else:
model_state_dict = torch.load("20231120_mnist_diffusion_denoiser.ckpt")
model.load_state_dict(model_state_dict)
5. Model deployment¶
Finally, let's run the SDE backwards and sample from our diffusion model.
#torch._dynamo.config.suppress_errors = True #You might need to add that line for it to work
model = model.to(device)
x_start = torch.clip(torch.randn(size=next(iter(trainloader))[0].shape)[:64],-1.0,1.0)
output,time_grid = run_backwards(model,sde,x_start=x_start,n_steps=1000,device=device,train_score=TRAIN_SCORE, clip_min=-10.0, clip_max=10.0)
0it [00:00, ?it/s]/afs/csail.mit.edu/u/p/phold/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:90: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. warnings.warn( 1000it [00:35, 28.29it/s]
Look below, it worked! Some very realistically looking images!
def imshow(img):
#img = (img + 1)/2
npimg = img.numpy()
plt.figure(figsize=[20, 20])
plt.imshow(np.transpose(npimg, (1, 2, 0)))
#plt.show()
imshow(torchvision.utils.make_grid(output[:,-1]))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
from PIL import Image
def make_gif():
n_images = 20
time_jumps = output.shape[1]//n_images
idx_list = [i * time_jumps for i in range(n_images)] + [output.shape[1]-1]
image_paths = []
for idx in idx_list:
imshow(torchvision.utils.make_grid(output[:,idx]))
filepath = f"mnist_gen_idx={idx}.png"
plt.savefig(filepath)
image_paths.append(filepath)
frames = [Image.open(image) for image in image_paths+[image_paths[-1]]*min(len(image_paths),15)]
frame_one = frames[0]
frame_one.save("MNIST_diffusion.gif", format="GIF", append_images=frames,
save_all=True, duration=100, loop=0)
for image_path in image_paths:
os.remove(image_path)
from IPython.display import clear_output
make_gif()
clear_output()
from IPython.display import HTML
HTML('<img src="/assets/animation/MNIST_diffusion.gif"">')