Matheus Schmitz
LinkedIn
Github Portfolio
Variational Auto-Encoders (VAEs) are a widely used class of generative models. They are simple to implement and, in contrast to other generative model classes like Generative Adversarial Networks, they optimize an explicit maximum likelihood objective to train the model. Finally, their architecture makes them well-suited for unsupervised representation learning, i.e. learning low-dimensional representations of high-dimenionsal inputs, like images, with only self-supervised objectives (data reconstruction in the case of VAEs).
(image source: https://mlexplained.com/2017/12/28/an-intuitive-explanation-of-variational-autoencoders-vaes-part-1)
By working on this problem you will learn and practice the following steps:
Note: For faster training of the models in this assignment you can use Colab with enabled GPU support. In Colab, navigate to "Runtime" --> "Change Runtime Type" and set the "Hardware Accelerator" to "GPU".
We will perform all experiments for this problem using the MNIST dataset, a standard dataset of handwritten digits. The main benefits of this dataset are that it is small and relatively easy to model. It therefore allows for quick experimentation and serves as initial test bed in many papers.
Another benefit is that it is so widely used that PyTorch even provides functionality to automatically download it.
Let's start by downloading the data and visualizing some samples.
import matplotlib.pyplot as plt
%matplotlib inline
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
!pip install --quiet adabelief-pytorch
import torch
import torchvision
from adabelief_pytorch import AdaBelief
# this will automatically download the MNIST training set
mnist_train = torchvision.datasets.MNIST(root='./data',
train=True,
download=True,
transform=torchvision.transforms.ToTensor())
print("\n Download complete! Downloaded {} training examples!".format(len(mnist_train)))
Download complete! Downloaded 60000 training examples!
import matplotlib.pyplot as plt
import numpy as np
# Let's display some of the training samples.
sample_images = []
mnist_it = iter(mnist_train) # create simple iterator, later we will use proper DataLoader
for _ in range(5):
sample = next(mnist_it) # samples a tuple (image, label)
sample_images.append(sample[0][0].data.cpu().numpy())
fig = plt.figure(figsize = (10, 50))
ax1 = plt.subplot(111)
ax1.imshow(np.concatenate(sample_images, axis=1), cmap='gray')
plt.show()
Before implementing the full VAE, we will first implement an auto-encoder architecture. Auto-encoders feature the same encoder-decoder architecture as VAEs and therefore also learn a low-dimensional representation of the input data without supervision. In contrast to VAEs they are fully deterministic models and do not employ variational inference for optimization.
The architecture is very simple: we will encode the input image into a low-dimensional representation using a convolutional network with strided convolutions that reduce the image resolution in every layer. This results in a low-dimensional representation of the input image. This representation will get decoded back into the dimensionality of the input image using a convolutional decoder network that mirrors the architecture of the encoder. It employs transposed convolutions to increase the resolution of its input in every layer. The whole model is trained by minimizing a reconstruction loss between the input and the decoded image.
Intuitively, the auto-encoder needs to compress the information contained in the input image into a much lower dimensional representation (e.g. 28x28=784px vs. 64 embedding dimensions for our MNIST model). This is possible since the information captured in the pixels is highly redundant. E.g. encoding an MNIST image requires <4 bits to encode which of the 10 possible digits is displayed and a few additional bits to capture information about shape and orientation. This is much less than the $255^{28\cdot 28}$ bits of information that could be theoretically captured in the input image.
Learning such a compressed representation can make downstream task learning easier. For example, learning to add two numbers based on the inferred digits is much easier than performing the task based on two piles of pixel values that depict the digits.
In the following, we will first define the architecture of encoder and decoder and then train the auto-encoder model.
import torch.nn as nn
# Let's define encoder and decoder networks
#####################################################################
# Encoder Architecture: #
# - Conv2d, hidden units: 32, output resolution: 14x14, kernel: 4 #
# - LeakyReLU #
# - Conv2d, hidden units: 64, output resolution: 7x7, kernel: 4 #
# - BatchNorm2d #
# - LeakyReLU #
# - Conv2d, hidden units: 128, output resolution: 3x3, kernel: 3 #
# - BatchNorm2d #
# - LeakyReLU #
# - Conv2d, hidden units: 256, output resolution: 1x1, kernel: 3 #
# - BatchNorm2d #
# - LeakyReLU #
# - Flatten #
# - Linear, output units: nz (= representation dimensionality) #
#####################################################################
class Encoder(nn.Module):
def __init__(self, nz):
super().__init__()
################################# TODO #########################################
# Create the network architecture using a nn.Sequential module wrapper. #
# All convolutional layers should also learn a bias. #
# HINT: use the given information to compute stride and padding #
# for each convolutional layer. Verify the shapes of intermediate layers #
# by running partial networks (with the next cell) and visualizing the #
# output shapes. #
################################################################################
self.net = nn.Sequential(
# add your network layers here
torch.nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 4, stride=2, padding=1, bias=True),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 4, stride=2, padding=1, bias=True),
torch.nn.BatchNorm2d(64),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride=2, padding=0, bias=True),
torch.nn.BatchNorm2d(128),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride=2, padding=0, bias=True),
torch.nn.BatchNorm2d(256),
torch.nn.LeakyReLU(),
torch.nn.Flatten(start_dim=1, end_dim=- 1),
torch.nn.Linear(in_features = 256, out_features = nz, bias=True)
)
################################ END TODO #######################################
def forward(self, x):
return self.net(x)
#####################################################################
# Decoder Architecture (mirrors encoder architecture): #
# - Linear, output units: 256 #
# - Reshape, output shape: (256, 1, 1) #
# - BatchNorm2d #
# - LeakyReLU #
# - ConvT2d, hidden units: 128, output resolution: 3x3, kernel: 3 #
# - BatchNorm2d #
# - LeakyReLU #
# - ConvT2d, hidden units: 64, output resolution: 7x7, kernel: 3 #
# - ... #
# - ... #
# - ConvT2d, output units: 1, output resolution: 28x28, kernel: 4 #
# - Sigmoid (to limit output in range [0...1]) #
#####################################################################
class Decoder(nn.Module):
def __init__(self, nz):
super().__init__()
################################# TODO #########################################
# Create the network architecture using a nn.Sequential module wrapper. #
# Again, all (transposed) convolutional layers should also learn a bias. #
# We need to separate the intial linear layer into a separate variable since #
# nn.Sequential does not support reshaping. Instead the "Reshape" is performed #
# in the forward() function below and does not need to be added to self.net #
# HINT: use the class nn.ConvTranspose2d for the transposed convolutions. #
# Verify the shapes of intermediate layers by running partial networks #
# (using the next cell) and visualizing the output shapes. #
################################################################################
self.map = torch.nn.Linear(in_features = nz, out_features = 256, bias=True) # for initial Linear layer
self.net = nn.Sequential(
# add your network layers here
torch.nn.BatchNorm2d(256),
torch.nn.LeakyReLU(),
torch.nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size = 3, stride=2, padding=0, bias=True, padding_mode='zeros'),
torch.nn.BatchNorm2d(128),
torch.nn.LeakyReLU(),
torch.nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = 3, stride=2, padding=0, bias=True, padding_mode='zeros'),
torch.nn.BatchNorm2d(64),
torch.nn.LeakyReLU(),
torch.nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = 4, stride=2, padding=1, output_padding=0, bias=True, padding_mode='zeros'),
torch.nn.BatchNorm2d(32),
torch.nn.LeakyReLU(),
torch.nn.ConvTranspose2d(in_channels = 32, out_channels = 1, kernel_size = 4, stride=2, padding=1, output_padding=0, bias=True, padding_mode='zeros'),
)
################################ END TODO #######################################
def forward(self, x):
return self.net(self.map(x).reshape(-1, 256, 1, 1))
# To test your encoder/decoder, let's encode/decode some sample images
# first, make a PyTorch DataLoader object to sample data batches
batch_size = 64
nworkers = 4 # number of wrokers used for efficient data loading
####################################### TODO #######################################
# Create a PyTorch DataLoader object for efficiently generating training batches. #
# Make sure that the data loader automatically shuffles the training dataset. #
# HINT: The DataLoader wraps the MNIST dataset class we created earlier. #
# Use the given batch_size and number of data loading workers when creating #
# the DataLoader. #
####################################################################################
mnist_data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
shuffle=True, num_workers=nworkers)
#################################### END TODO #######################################
# now we can run a forward pass for encoder and decoder and check the produced shapes
nz = 64 # dimensionality of the learned embedding
encoder = Encoder(nz)
decoder = Decoder(nz)
for sample_img, sample_label in mnist_data_loader:
enc = encoder(sample_img)
print("Shape of encoding vector (should be [batch_size, nz]): {}".format(enc.shape))
dec = decoder(enc)
print("Shape of decoded image (should be [batch_size, 1, 28, 28]): {}".format(dec.shape))
break
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:490: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked))
Shape of encoding vector (should be [batch_size, nz]): torch.Size([64, 64]) Shape of decoded image (should be [batch_size, 1, 28, 28]): torch.Size([64, 1, 28, 28])
Now that we defined encoder and decoder network our architecture is nearly complete. However, before we start training, we can wrap encoder and decoder into an auto-encoder class for easier handling.
class AutoEncoder(nn.Module):
def __init__(self, nz):
super().__init__()
self.encoder = Encoder(nz)
self.decoder = Decoder(nz)
def forward(self, x):
return self.decoder(self.encoder(x))
def reconstruct(self, x):
"""Only used later for visualization."""
return self.forward(x)
After implementing the network architecture, we can now set up the training loop and run training.
epochs = 10
learning_rate = 1e-3
# build AE model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # use GPU if available
ae_model = AutoEncoder(nz).to(device) # transfer model to GPU if available
ae_model = ae_model.train() # set model in train mode (eg batchnorm params get updated)
# build optimizer and loss function
####################################### TODO #######################################
# Create the optimizer and loss classes. For the loss you can use a loss layer #
# from the torch.nn package. #
# HINT: We will use the Adam optimizer (learning rate given above, otherwise #
# default parameters) and MSE loss for the criterion / loss. #
# NOTE: We could also use alternative loss functions like cross entropy, depending #
# on the assumptions we are making about the output distribution. Here we #
# will use MSE loss as it is the most common choice, assuming a Gaussian #
# output distribution. #
####################################################################################
opt = AdaBelief(ae_model.parameters(), lr=learning_rate, eps=1e-12, betas=(0.9,0.999), rectify=False, print_change_log=False) # create optimizer instance
criterion = torch.nn.MSELoss() # create loss layer instance
#################################### END TODO #######################################
train_it = 0
for ep in range(epochs):
print("Run Epoch {}".format(ep))
####################################### TODO #######################################
# Implement the main training loop for the auto-encoder model. #
# HINT: Your training loop should sample batches from the data loader, run the #
# forward pass of the AE, compute the loss, perform the backward pass and #
# perform one gradient step with the optimizer. #
# HINT: Don't forget to erase old gradients before performing the backward pass. #
####################################################################################
rec_loss = 0.0
for data in mnist_data_loader:
# access data
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
opt.zero_grad()
# forward
outputs = ae_model(inputs)
# loss
loss = criterion(outputs, inputs)
# backward
loss.backward()
# update the weights
opt.step()
# print statistics
rec_loss += loss.item()
#################################### END TODO #####################################
if train_it % 1 == 0:
print("It {}: Reconstruction Loss: {}".format(train_it, rec_loss))
train_it += 1
print("Done!")
Weight decoupling enabled in AdaBelief Run Epoch 0
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:490: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked))
It 0: Reconstruction Loss: 30.08385004196316 Run Epoch 1 It 1: Reconstruction Loss: 11.307898331433535 Run Epoch 2 It 2: Reconstruction Loss: 8.777552753686905 Run Epoch 3 It 3: Reconstruction Loss: 7.576860204339027 Run Epoch 4 It 4: Reconstruction Loss: 6.691348219290376 Run Epoch 5 It 5: Reconstruction Loss: 6.115002739708871 Run Epoch 6 It 6: Reconstruction Loss: 5.661501551512629 Run Epoch 7 It 7: Reconstruction Loss: 5.277036506216973 Run Epoch 8 It 8: Reconstruction Loss: 4.981823894195259 Run Epoch 9 It 9: Reconstruction Loss: 4.7335849474184215 Done!
Now that we trained the auto-encoder we can visualize some of the reconstructions on the test set to verify that it is converged and did not overfit. Before continuing, make sure that your auto-encoder is able to reconstruct these samples near-perfectly.
# visualize test data reconstructions
def vis_reconstruction(model):
# download MNIST test set + build Dataset object
mnist_test = torchvision.datasets.MNIST(root='./data',
train=False,
download=True,
transform=torchvision.transforms.ToTensor())
mnist_test_iter = iter(mnist_test)
model.eval() # set model in evalidation mode (eg freeze batchnorm params)
input_imgs, test_reconstructions = [], []
for _ in range(5):
input_img = np.asarray(next(mnist_test_iter)[0])
reconstruction = model.reconstruct(torch.tensor(input_img[None], device=device))
input_imgs.append(input_img[0])
test_reconstructions.append(reconstruction[0, 0].data.cpu().numpy())
fig = plt.figure(figsize = (20, 50))
ax1 = plt.subplot(111)
ax1.imshow(np.concatenate([np.concatenate(input_imgs, axis=1),
np.concatenate(test_reconstructions, axis=1)], axis=0), cmap='gray')
plt.show()
vis_reconstruction(ae_model)
To test whether the auto-encoder is useful as a generative model, we can use it like any other generative model: draw embedding samples from a prior distribution and decode them through the decoder network. We will choose a unit Gaussian prior to allow for easy comparison to the VAE later.
# we will sample N embeddings, then decode and visualize them
def vis_samples(model):
####################################### TODO #######################################
# Sample embeddings from a diagonal unit Gaussian distribution and decode them #
# using the model. #
# HINT: The sampled embeddings should have shape [batch_size, nz]. Diagonal unit #
# Gaussians have mean 0 and a covariance matrix with ones on the diagonal #
# and zeros everywhere else. #
# HINT: If you are unsure whether you sampled the correct distribution, you can #
# sample a large batch and compute the empirical mean and variance using the #
# .mean() and .var() functions. #
# HINT: You can directly use model.decoder() to decode the samples. #
####################################################################################
batch_size = next(iter(mnist_data_loader))[0].size()[0]
nz = model.decoder.map.in_features
sampled_embeddings = torch.randn((batch_size, nz)).to(device) # sample batch of embedding from prior
decoded_samples = model.decoder(sampled_embeddings) # decoder output images for sampled embeddings
#################################### END TODO ######################################
fig = plt.figure(figsize = (10, 10))
ax1 = plt.subplot(111)
ax1.imshow(torchvision.utils.make_grid(decoded_samples[:16], nrow=4, pad_value=1.)\
.data.cpu().numpy().transpose(1, 2, 0), cmap='gray')
plt.show()
vis_samples(ae_model)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:490: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Inline Question: Describe your observations, why do you think they occur? [2pt] \ (please limit your answer to <150 words) \ Answer: The compressed encoder representation is unconstrained and therefore unlikely to be nornally distributed, so we are sampling from areas of the embedding space that were un-unsed during encoding and thus the decoder didn't learn how to handle.
Variational auto-encoders use a very similar architecture to deterministic auto-encoders, but are inherently storchastic models, i.e. we perform a stochastic sampling operation during the forward pass, leading to different outputs every time we run the network for the same input. This sampling is required to optimize the VAE objective also known as the evidence lower bound (ELBO):
$$ p(x) > \underbrace{\mathbb{E}_{z\sim q(z\vert x)} p(x \vert z)}_{\text{reconstruction}} - \underbrace{D_{\text{KL}}\big(q(z \vert x), p(z)\big)}_{\text{prior divergence}} $$Here, $D_{\text{KL}}(q, p)$ denotes the Kullback-Leibler (KL) divergence between the posterior distribution $q(z \vert x)$, i.e. the output of our encoder, and $p(z)$, the prior over the embedding variable $z$, which we can choose freely.
For simplicity, we will again choose a unit Gaussian prior. The left term is the reconstruction term we already know from training the auto-encoder. When assuming a Gaussian output distribution for both encoder $q(z \vert x)$ and decoder $p(x \vert z)$ the objective reduces to:
$$ \mathcal{L}_{\text{VAE}} = \sum_{x\sim \mathcal{D}} (x - \hat{x})^2 - \beta \cdot D_{\text{KL}}\big(\mathcal{N}(\mu_q, \sigma_q), \mathcal{N}(0, I)\big) $$Here, $\hat{x}$ is the reconstruction output of the decoder. In comparison to the auto-encoder objetive, the VAE adds a regularizing term between the output of the encoder and a chosen prior distribution, effectively forcing the encoder output to not stray too far from the prior during training. As a result the decoder gets trained with samples that look pretty similar to samples from the prior, which will hopefully allow us to generate better images when using the VAE as a generative model and actually feeding it samples from the prior (as we have done for the AE before).
The coefficient $\beta$ is a scalar weighting factor that trades off between reconstruction and regularization objective. We will investigate the influence of this factor in out experiments below.
If you need a refresher on VAEs you can check out this tutorial paper: https://arxiv.org/abs/1606.05908
The sampling procedure inside the VAE's forward pass for obtaining a sample $z$ from the posterior distribution $q(z \vert x)$, when implemented naively, is non-differentiable. However, since $q(z\vert x)$ is parametrized with a Gaussian function, there is a simple trick to obtain a differentiable sampling operator, known as the reparametrization trick.
Instead of directly sampling $z \sim \mathcal{N}(\mu_q, \sigma_q)$ we can "separate" the network's predictions and the random sampling by computing the sample as:
$$ z = \mu_q + \sigma_q * \epsilon , \quad \epsilon \sim \mathcal{N}(0, I) $$Note that in this equation, the sample $z$ is computed as a deterministic function of the network's predictions $\mu_q$ and $\sigma_q$ and therefore allows to propagate gradients through the sampling procedure.
Note: While in the equations above the encoder network parametrizes the standard deviation $\sigma_q$ of the Gaussian posterior distribution, in practice we usually parametrize the logarithm of the standard deviation $\log \sigma_q$ for numerical stability. Before sampling $z$ we will then exponentiate the network's output to obtain $\sigma_q$.
Indented block
def kl_divergence(mu1, log_sigma1, mu2, log_sigma2):
"""Computes KL[p||q] between two Gaussians defined by [mu, log_sigma]."""
return (log_sigma2 - log_sigma1) + (torch.exp(log_sigma1) ** 2 + (mu1 - mu2) ** 2) \
/ (2 * torch.exp(log_sigma2) ** 2) - 0.5
class VAE(nn.Module):
def __init__(self, nz, beta=1.0):
super().__init__()
self.beta = beta # factor trading off between two loss components
####################################### TODO #######################################
# Instantiate Encoder and Decoder. #
# HINT: Remember that the encoder is now parametrizing a Gaussian distribution's #
# mean and log_sigma, so the dimensionality of the output embedding needs to #
# double. #
####################################################################################
self.encoder = Encoder(2 * nz) #2nz
self.decoder = Decoder(nz)
#################################### END TODO ######################################
def forward(self, x):
####################################### TODO #######################################
# Implement the forward pass of the VAE. #
# HINT: Your code should implement the following steps: #
# 1. encode input x, split encoding into mean and log_sigma of Gaussian #
# 2. sample z from inferred posterior distribution using #
# reparametrization trick #
# 3. decode the sampled z to obtain the reconstructed image #
####################################################################################
# encode input into posterior distribution q(z | x)
q = self.encoder(x) # output of encoder (concatenated mean and log_sigma)
# sample latent variable z with reparametrization
mean_ = q[:, :nz]
var_ = q[:, nz:]
z = torch.exp(var_) * torch.normal(0,1,var_.size()).to(device) + mean_ # batch of sampled embeddings
# compute reconstruction
reconstruction = self.decoder(z) # decoder reconstruction from embedding
#################################### END TODO ######################################
return {'q': q,
'rec': reconstruction}
def loss(self, x, outputs):
####################################### TODO #######################################
# Implement the loss computation of the VAE. #
# HINT: Your code should implement the following steps: #
# 1. compute the image reconstruction loss, similar to AE we use MSE loss #
# 2. compute the KL divergence loss between the inferred posterior #
# distribution and a unit Gaussian prior; you can use the provided #
# function above for computing the KL divergence between two Gaussians #
# parametrized by mean and log_sigma #
# HINT: Make sure to compute the KL divergence in the correct order since it is #
# not symmetric, ie. KL(p, q) != KL(q, p)! #
####################################################################################
# compute reconstruction loss
rec_loss = torch.nn.MSELoss()(x, outputs['rec'])
# compute KL divergence loss
q = outputs['q']
mean_ = q[:, :nz].to(device)
var_ = q[:, nz:].to(device)
# kl_loss = kl_divergence(mean_.mean(), var_.mean(), torch.tensor(0), torch.tensor(1)) # make sure that this is a scalar, not a vector / array
gaussian_mean = torch.zeros(mean_.size()).to(device)
gaussian_var = torch.zeros(var_.size()).to(device)
kl_loss = kl_divergence(mean_, var_, gaussian_mean, gaussian_var).sum()/x.size()[0] # make sure that this is a scalar, not a vector / array
#print(type(kl_loss))
#################################### END TODO ######################################
# return weihgted objective
return rec_loss + self.beta * kl_loss, \
{'rec_loss': rec_loss, 'kl_loss': kl_loss}
def reconstruct(self, x):
"""Use mean of posterior estimate for visualization reconstruction."""
####################################### TODO #######################################
# This function is used for visualizing reconstructions of our VAE model. To #
# obtain the maximum likelihood estimate we bypass the sampling procedure of the #
# inferred latent and instead directly use the mean of the inferred posterior. #
# HINT: encode the input image and then decode the mean of the posterior to obtain #
# the reconstruction. #
####################################################################################
encoded = self.encoder(x)
mean_ = encoded[:,:nz]
reconstruction = self.decoder(mean_)
#################################### END TODO ######################################
return reconstruction
Let's start training the VAE model! We will first verify our implementation by setting $\beta = 0$.
learning_rate = 1e-3
nz = 64
####################################### TODO #######################################
# Tune the beta parameter to obtain good VAE training results. However, for the #
# initial experiments leave beta = 0 in order to verify our implementation. #
####################################################################################
epochs = 30 # using 5 epochs is sufficient for the first two experiments
# for the experiment where you tune beta, 20 epochs are appropriate
beta = 0.001
#################################### END TODO ######################################
# build VAE model
vae_model = VAE(nz, beta).to(device) # transfer model to GPU if available
vae_model = vae_model.train() # set model in train mode (eg batchnorm params get updated)
# build optimizer and loss function
####################################### TODO #######################################
# Build the optimizer for the vae_model. We will again use the Adam optimizer with #
# the given learning rate and otherwise default parameters. #
####################################################################################
opt = AdaBelief(vae_model.parameters(), lr=learning_rate, eps=1e-12, betas=(0.9,0.999), rectify=False, print_change_log=False) # create optimizer instance
#################################### END TODO ######################################
train_it = 0
rec_loss, kl_loss = [], []
for ep in range(epochs):
print("Run Epoch {}".format(ep))
####################################### TODO #######################################
# Implement the main training loop for the VAE model. #
# HINT: Your training loop should sample batches from the data loader, run the #
# forward pass of the VAE, compute the loss, perform the backward pass and #
# perform one gradient step with the optimizer. #
# HINT: Don't forget to erase old gradients before performing the backward pass. #
# HINT: This time we will use the loss() function of our model for computing the #
# training loss. It outputs the total training loss and a dict containing #
# the breakdown of reconstruction and KL loss. #
####################################################################################
for data in mnist_data_loader:
# access data
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
opt.zero_grad()
# forward
outputs = vae_model(inputs)
# loss
tot_loss ,losses = vae_model.loss(inputs, outputs)
total_loss = tot_loss.item()
# backward
tot_loss.backward()
# update the weights
opt.step()
#################################### END TODO ####################################
rec_loss.append(losses['rec_loss']); kl_loss.append(losses['kl_loss'])
if train_it % 1 == 0:
print("It {}: Total Loss: {}, \t Rec Loss: {},\t KL Loss: {}"\
.format(train_it, total_loss, losses['rec_loss'], losses['kl_loss']))
train_it += 1
print("Done!")
Weight decoupling enabled in AdaBelief Run Epoch 0
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:490: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked))
It 0: Total Loss: 0.04319138824939728, Rec Loss: 0.0344826765358448, KL Loss: 8.708709716796875 Run Epoch 1 It 1: Total Loss: 0.037333082407712936, Rec Loss: 0.02720918320119381, KL Loss: 10.123899459838867 Run Epoch 2 It 2: Total Loss: 0.04145604372024536, Rec Loss: 0.030026249587535858, KL Loss: 11.429794311523438 Run Epoch 3 It 3: Total Loss: 0.038921646773815155, Rec Loss: 0.026759246364235878, KL Loss: 12.162398338317871 Run Epoch 4 It 4: Total Loss: 0.036812521517276764, Rec Loss: 0.024559326469898224, KL Loss: 12.253192901611328 Run Epoch 5 It 5: Total Loss: 0.03802490234375, Rec Loss: 0.025172457098960876, KL Loss: 12.852444648742676 Run Epoch 6 It 6: Total Loss: 0.038524843752384186, Rec Loss: 0.025768162682652473, KL Loss: 12.756681442260742 Run Epoch 7 It 7: Total Loss: 0.042113546282052994, Rec Loss: 0.028856098651885986, KL Loss: 13.2574462890625 Run Epoch 8 It 8: Total Loss: 0.036508262157440186, Rec Loss: 0.023641172796487808, KL Loss: 12.867088317871094 Run Epoch 9 It 9: Total Loss: 0.038543447852134705, Rec Loss: 0.02548873797059059, KL Loss: 13.054708480834961 Run Epoch 10 It 10: Total Loss: 0.03585407882928848, Rec Loss: 0.022505730390548706, KL Loss: 13.348345756530762 Run Epoch 11 It 11: Total Loss: 0.03559671342372894, Rec Loss: 0.022123362869024277, KL Loss: 13.473349571228027 Run Epoch 12 It 12: Total Loss: 0.03708306699991226, Rec Loss: 0.02403838001191616, KL Loss: 13.044684410095215 Run Epoch 13 It 13: Total Loss: 0.03540626913309097, Rec Loss: 0.022217964753508568, KL Loss: 13.188302040100098 Run Epoch 14 It 14: Total Loss: 0.034312862902879715, Rec Loss: 0.02085701748728752, KL Loss: 13.455843925476074 Run Epoch 15 It 15: Total Loss: 0.036488451063632965, Rec Loss: 0.022558197379112244, KL Loss: 13.930252075195312 Run Epoch 16 It 16: Total Loss: 0.031690191477537155, Rec Loss: 0.018409233540296555, KL Loss: 13.28095817565918 Run Epoch 17 It 17: Total Loss: 0.031309597194194794, Rec Loss: 0.01763611100614071, KL Loss: 13.673484802246094 Run Epoch 18 It 18: Total Loss: 0.03379136696457863, Rec Loss: 0.020440923050045967, KL Loss: 13.350441932678223 Run Epoch 19 It 19: Total Loss: 0.033252038061618805, Rec Loss: 0.0199708491563797, KL Loss: 13.281187057495117 Run Epoch 20 It 20: Total Loss: 0.029497720301151276, Rec Loss: 0.016504915431141853, KL Loss: 12.992805480957031 Run Epoch 21 It 21: Total Loss: 0.03512067720293999, Rec Loss: 0.021391401067376137, KL Loss: 13.729276657104492 Run Epoch 22 It 22: Total Loss: 0.034240007400512695, Rec Loss: 0.020848488435149193, KL Loss: 13.391519546508789 Run Epoch 23 It 23: Total Loss: 0.036975931376218796, Rec Loss: 0.023083442822098732, KL Loss: 13.892486572265625 Run Epoch 24 It 24: Total Loss: 0.03841464966535568, Rec Loss: 0.023534854874014854, KL Loss: 14.87979507446289 Run Epoch 25 It 25: Total Loss: 0.03403574600815773, Rec Loss: 0.020501144230365753, KL Loss: 13.534602165222168 Run Epoch 26 It 26: Total Loss: 0.033879783004522324, Rec Loss: 0.01995890960097313, KL Loss: 13.92087173461914 Run Epoch 27 It 27: Total Loss: 0.03364895284175873, Rec Loss: 0.01941053383052349, KL Loss: 14.238418579101562 Run Epoch 28 It 28: Total Loss: 0.03266599029302597, Rec Loss: 0.018732773140072823, KL Loss: 13.93321418762207 Run Epoch 29 It 29: Total Loss: 0.033399391919374466, Rec Loss: 0.019550027325749397, KL Loss: 13.849363327026367 Done!
rec_loss_ = [x.detach().cpu().numpy().tolist() for x in rec_loss]
kl_loss_ = [x.detach().cpu().numpy().tolist() for x in kl_loss]
# log the loss training curves
fig = plt.figure(figsize = (10, 5))
ax1 = plt.subplot(121)
ax1.plot(rec_loss_)
ax1.title.set_text("Reconstruction Loss")
ax2 = plt.subplot(122)
ax2.plot(kl_loss_)
ax2.title.set_text("KL Loss")
plt.show()
Let's look at some reconstructions and decoded embedding samples!
# visualize VAE reconstructions and samples from the generative model
vis_reconstruction(vae_model)
vis_samples(vae_model)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:490: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Inline Question: What can you observe when setting $\beta = 0$? Explain your observations! [3pt] \ (please limit your answer to <150 words) \ Answer: When β = 0, the model only considers reconstruction loss. As a result, it shows good reconstruction of image. However, when sampled mean and variance is given, it generates ambiguous images.
Let's repeat the same experiment for $\beta = 10$, a very high value for the coefficient. You can modify the $\beta$ value in the cell above and rerun it (it is okay to overwrite the outputs of the previous experiment, but make sure to copy the visualizations of training curves, reconstructions and samples for $\beta = 0$ into your solution PDF before deleting them).
Inline Question: What can you observe when setting $\beta = 10$? Explain your observations! [3pt] \ (please limit your answer to <200 words) \ Answer: When β = 10, the model gives low weight on reconstruction loss. Therefore, generated images show blurred unclear images for both reconstruction and sample decoding task. But from the samples, it shows very similar result for all the different samples.
Now we can start tuning the beta value to achieve a good result. First describe what a "good result" would look like (focus what you would expect for reconstructions and sample quality).
Inline Question: Characterize what properties you would expect for reconstructions (1pt) and samples (2pt) of a well-tuned VAE! [3pt] \ (please limit your answer to <200 words) \ Answer:
- In case of reconstruction, as the model aims to generate similar image as the input, the model have to show very similar images compared to the input images.
- Even with the samples, it should generate images which is clear and images like trained images. In addition, each output from the samples should be different to each other.
Now that you know what outcome we would like to obtain, try to tune $\beta$ to achieve this result.
(logarithmic search in steps of 10x will be helpful, good results can be achieved after ~20 epochs of training). It is again okay to overwrite the results of the previous $\beta=10$ experiment after copying them to the solution PDF.
Your final notebook should include the visualizations of your best-tuned VAE.
As mentioned in the introduction, AEs and VAEs cannot only be used to generate images, but also to learn low-dimensional representations of their inputs. In this final section we will investigate the representations we learned with both models by interpolating in embedding space between different images. We will encode two images into their low-dimensional embedding representations, then interpolate these embeddings and reconstruct the result.
START_LABEL = 6
END_LABEL = 9
nz=64
def get_image_with_label(target_label):
"""Returns a random image from the training set with the requested digit."""
for img_batch, label_batch in mnist_data_loader:
for img, label in zip(img_batch, label_batch):
if label == target_label:
return img.to(device)
def interpolate_and_visualize(model, tag, start_img, end_img):
"""Encodes images and performs interpolation. Displays decodings."""
model.eval() # put model in eval mode to avoid updating batchnorm
# encode both images into embeddings (use posterior mean for interpolation)
z_start = model.encoder(start_img[None])[..., :nz]
z_end = model.encoder(end_img[None])[..., :nz]
# compute interpolated latents
N_INTER_STEPS = 5
z_inter = [z_start + i/N_INTER_STEPS * (z_end - z_start) for i in range(N_INTER_STEPS)]
# decode interpolated embeddings (as a single batch)
img_inter = model.decoder(torch.cat(z_inter))
# reshape result and display interpolation
vis_imgs = torch.cat([start_img[None], img_inter, end_img[None]])
fig = plt.figure(figsize = (10, 10))
ax1 = plt.subplot(111)
ax1.imshow(torchvision.utils.make_grid(vis_imgs, nrow=N_INTER_STEPS+2, pad_value=1.)\
.data.cpu().numpy().transpose(1, 2, 0), cmap='gray')
plt.title(tag)
plt.show()
# sample two training images with given labels
start_img = get_image_with_label(START_LABEL)
end_img = get_image_with_label(END_LABEL)
# visualize interpolations for AE and VAE models
interpolate_and_visualize(ae_model, "Auto-Encoder", start_img, end_img)
interpolate_and_visualize(vae_model, "Variational Auto-Encoder", start_img, end_img)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:490: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Repeat the experiment for different start / end labels and different samples. Describe your observations.
Inline Question: Repeat the interpolation experiment with different start / end labels and multiple samples. Describe your observations! Focus on: \
- In case of AE, it generates ambiguous image in the middle of interpolations. However, in case of VAE, it shows clear, distinguishable image in the interpolations.
- How do AE and VAE embedding space interpolations differ? \
- As interpolated embeddings are different from embeddings generated from train data, in case of AE, it generates unclear images. However, as embeddings of VAE are means and variances and it samples from the means and variances, it generates clear images.
- How do you expect these differences to affect the usefulness of the learned representation for downstream learning? \
- VAE looks more useful since it can always generate clear result compared to Auto Encoder. As VAE can generate outputs which has similar characteristic to the training data, it can be used to generate new images which is similar to training data such as generating face images. In addition, in terms of embedding space, due to the KL divergence loss, embeddings of similar output in VAE gather more than embeddings of AE.
Matheus Schmitz
LinkedIn
Github Portfolio