# Variational Autoencoder (VAE) in PyTorch
This notebook implements a VAE and trains it on canonical datasets like MNIST and FashionMNIST.

In [1]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

ModuleNotFoundError: No module named 'torch'

In [2]:
# VAE Model
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # mu
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # logvar
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

NameError: name 'nn' is not defined

In [None]:
# Loss Function
def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
# Data Loaders
def get_dataloader(dataset_name, batch_size=128):
    transform = transforms.ToTensor()
    if dataset_name == 'MNIST':
        dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    elif dataset_name == 'FashionMNIST':
        dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
    else:
        raise ValueError("Unknown dataset")
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Training Loop
def train(model, dataloader, optimizer, device):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f'Average loss: {train_loss / len(dataloader.dataset):.4f}')

In [None]:
# Visualization
def show_reconstructions(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for data, _ in dataloader:
            data = data.to(device)
            recon, _, _ = model(data)
            n = 8
            comparison = torch.cat([data[:n], recon.view(-1, 1, 28, 28)[:n]])
            grid = torchvision.utils.make_grid(comparison.cpu(), nrow=n)
            plt.figure(figsize=(12,4))
            plt.imshow(grid.permute(1, 2, 0))
            plt.axis('off')
            plt.show()
            break

In [None]:
# Run the full pipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = 'MNIST'  # or 'FashionMNIST'
dataloader = get_dataloader(dataset_name)
vae = VAE().to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
for epoch in range(1, 6):
    print(f'Epoch {epoch}')
    train(vae, dataloader, optimizer, device)
show_reconstructions(vae, dataloader, device)