Variational Autoencoders: A Vanilla Implementation
Generative models and Variational Autoencoders have gained a lot of popularity over the past few years, as their use cases have been growing and it’s generally considered a hot research area. The main goal of these models is to generate high quality output data, (e.g. images, texts or sounds) that belong to the same distribution of the input data. Generative models have three main families: Variational autoencoders (VAE), Generative Adversarial Network (GANs) and Diffusion Models. In this article we will focus on the first and the main architecture (VAE), that lead to creating other architectures. In this article we explain the mystery of Variational Autoencoders and introduce a Vanilla Implementation as well. After reading this article you should know the following:
- What are Autoencoders?
- Variational Autoencoders (VAE)
- Difference between Variational autoencoders and Generative models (GANs)
- VAE Python Implementation (PyTorch)
What are Autoencoders?
Autoencoders have been first introduced in [1], that is Neural Network models which’s main goal is to reconstruct the input by encoding the input into lower dimensional (Z) then decoding it back into a reconstructed output with same dimensions of the input (Figure 2). The learning process is completely unsupervised. The main reason behind that is that learning the data representation in un unsupervised way means Autoencoders can be used in many applications such as: Clustering, Dimensionality Reduction and Self Supervised Learning. You can read more about the Types of Machine Learning Systems to have a better understanding.
We can formulate the Loss Function as a Reconstruction loss (L1 or L2) (Eq 1), and the goal is to learn the functions 𝐴 : Rn→ Rp and 𝐵 : Rp→ Rn, where 𝐴 is the encoder and 𝐵 is the decoder. and the loss can be defined as:
Where E is the expectation over the distribution of X and the Δ is the reconstructed loss function, that measure the similarity between the input and the reconstructed output [2].
There are many types of autoencoders:
- Regularized autoencoders
- Sparse autoencoders
- Denoising autoencoders
- Contractive autoencoders
- Variational autoencoders
All are described in some details in [2]. but in this article, we will focus on variational autoencoders.
Variational Autoencoders (VAE)
The main idea of Variational autoencoder is to find latent variable Z that is dependent on observed random variable X(input) –> (P(Z|X)). So the objective of Neural networks in this case are to find a relationship between Variable Z and X, this can be achieved through finding parameter of the distribution of the random variable Z (see Figure 2).
There are many differences between Variational Autoencoder and Standard autoencoder but the main difference is that in variational autoencoder you are trying to predict mean and standard deviation of the latent variable Z instead of predicting Z directly.
To get the variable Z you need to use Convolutional neural network U-net like architecture (Figure 4) which in the encoder step you are trying to predict mean and standard deviation that you will sample latent variable Z from.
This can be achieved by optimizing the parameters that predict the distribution of p(Z/X), and while this isn’t possible, we will try to approximate a distribution q(X|Z) that is close as possible to our objective posterior distribution. in this case Kl divergence is mainly used for calculating the divergence between the posterior distribution p(Z|X) and the approximated distribution q(Z|X).
We can observe in (Fig. 4) that we assume that the model’s output is 2 log variance not the variance directly. this is because we want to put a constrain that the variance always be positive, so when we use the output in the normal distribution, we put it inside exponential function to cancel the log function and also we multiply by 0.5 to cancel the duplication (see Fig. 4).
Our main goal is to approximate a posterior distribution q(Z|X) that is close as possible to the posterior distribution p(Z|X), so, Kullback-Leibler (KL) divergence [3] to measure the similarity between both distributions (see Eq. 2). And as the posterior the distribution p(Z|X) can’t be calculated we use the Bayes’ rule that lead us to the ELBOW function. ELBOW is composed of tow terms the reconstruction. term Loss which most of time is Mean Squared Error (MSE) between the reconstructed output and the input, And the the second term is the KL divergence that its main goal to make sure that the mean and standard deviation output of the encoder represents the standard normal distribution and we will see how this is implemented in the code Section (Last section). Maximizing the ELBOW will directly lead
Minimizing the KL divergence term which is our main goal to ensure that the approximated distribution q(Z|X) is the standard normal distribution.
For more details about the ELBOW derivation, We recommend taking a look at this article [4]
VAE vs GANs
While our main topic is variational autoencoder we just wanted to refer to some differences between VAE and GANs.
Generative adversarial composed of two main: Generator and Discriminator, the generator part has a goal to learn the inputs distributions to generate new samples in the same domain of the inputs. while the discriminator means could learn to distinguish, as in supervised learning a classification model could be called discriminative model. The main goal of the discriminative part is to learn until it can’t distinguish between the fake generated images that are generated from the generator and the real inputs (see Figure 5).
We can conclude that there are two main differences between VAE and GANs:
- VAE models don’t have discriminator.
- GANs get its latent variable Z from sampling directly from Gaussian distribution without leaning the distribution parameters (see Fig. 5).
VAE Python Implementation (PyTorch)
in this section we will implement Vanilla autoencoder and loss functions that you need to train a VAE. Model is divided into the following parts:
- Model
- Reparameterization trick
- KL divergence loss
- Reconstruction loss
- Sampling
Model
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from types import *
#convtranspose= ((256-1) *stride) + 3 - Padding + outputpadding
class VanillaVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims = None,
**kwargs) -> None:
super(VanillaVAE, self).__init__()
self.latent_dim = latent_dim
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*8*8, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*8*8, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
)
self.final_after = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
)
self.final_conv=nn.Sequential(nn.Conv2d(hidden_dims[-1], out_channels= 1,
kernel_size= 3, padding= 1),nn.Tanh())
def encode(self, input):
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z):
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
result=self.final_after(result)
result=self.final_conv(result)
return result
def reparameterize(self, mu, logvar):
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input, **kwargs):
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), z, mu, log_var]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}
def sample(self,
num_samples:int,
current_device: int, **kwargs) :
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x, **kwargs) :
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
In the previous block we implemented the model with all functions needed, so, let me explain how it could be used.
Note that we will give hints of how to use the model but we will not implement a train pipeline.
First we initialize object from the model class giving it the number of input channels = 3 and the dimensions of the latent variables =128.
model= VanillaVAE(3,128)
ex= torch.rand(3,3,256,256) #3 input image of size of (3,256,256)
x_hat, Z ,mean, log_var = model(ex) #x_har, latent space, mean , log variance
x_hat.size()
###torch.Size([3, 1, 256, 256])
Z.size()
###torch.Size([3, 128])
mean.size()
###torch.Size([3, 128])
log_var.size()
###torch.Size([3, 128])
We can see that the forward function in the model will return list of four elements: the reconstructed output, z latent space, mean and log variance.
Reparameterization trick
def reparameterize(self, mu, logvar):
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
We can see in the above code block that the model produce the Z latent space through the Reparameterization trick that uses the mean, standard deviation and noise from normal distribution (Figure 6).
KL divergence loss
Now let’s assume you want to calculate the KL divergence between the model’s output mean and standard deviation and stand normal distribution parameters:
def KL(mu, logvar):
KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
return KLD
loss = KL(mean, log_var)
You may be wondering why we calculated the divergence this way, in the next figure we will illustrate exactly why
First remember from the ELBOW, that your goal is to minimize the difference between the approximated distribution q(Z|X) and the posterior distribution p(Z|X). and as it’s hard to determine the domain distribution p(Z|X) we replace it with standard normal distribution and the (Eq. 3 and 4) we can see how we can derive the final formal that we used in the code.
In VAE tutorial, kl-divergence of two Normal Distributions is defined by:
Now we can see how we could simplify this formula to the final formula in the code
Mean squared error
We can see in the following the code block below that the final loss is composed of the reconstruction loss (Figure 2) between the input and the reconstructed output plus the KL divergence loss.
mse=nn.MSELoss()
loss2=mse(ex,x)
final_loss=loss1+loss2
Sampling
In the first code block you can see that the model has function called (sample):
def sample(self,
num_samples:int,
current_device: int, **kwargs) :
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
You have to use this function after finishing the training and save the weights, as it’s obvious in the function, it’s working by sampling random noise with the number of images you want to generate, then produces Z for these images using the model’s encoder then the decoder produces the generated output and return all four components, the generated samples, Z, Mean and Log Variance.
You can find the notebook for the code implementation on our Github.