top of page
Interesting Recent Posts :
Writer's pictureRohit chopra

GAN Unsupervised Deep Learning



Generative Adversarial Networks (GANs) are a type of unsupervised deep learning algorithm that are used to generate synthetic data samples. GANs consist of two neural networks: a generator and a discriminator. The generator network generates synthetic data samples, while the discriminator network tries to distinguish between real and synthetic samples.

The two networks are trained in an adversarial manner, where the generator tries to produce synthetic samples that are indistinguishable from real samples, and the discriminator tries to correctly classify real and synthetic samples. The training process continues until the generator produces synthetic samples that are difficult for the discriminator to distinguish from real samples.

The code example above trains a GAN on the MNIST dataset of handwritten digits using the PyTorch library. The generator network generates synthetic images of handwritten digits, and the discriminator network tries to distinguish between real and synthetic images. The two networks are trained in an alternating manner, with the generator trying to generate synthetic images that are indistinguishable from real images, and the discriminator trying to correctly classify real and synthetic images. The code also plots some synthetic images after each epoch to visualize the progress of the generator network.



import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Define the generator network
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = torch.tanh(self.fc3(x))
        return x

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = torch.sigmoid(self.fc3(x))
        return x

# Load the MNIST dataset
from torchvision.datasets import MNIST
mnist = MNIST(root="data/", download=True, transform=None)

# Define the loss function and optimizers
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

# Train the GAN
for epoch in range(100):
    for i, (images, _) in enumerate(mnist):
        images = images.view(images.size(0), -1)
        images = images.float() / 255.0
        
        # Train the discriminator
        d_optimizer.zero_grad()
        real_labels = torch.ones(images.size(0), 1)
        real_output = discriminator(images)
        d_real_loss = criterion(real_output, real_labels)
        d_real_loss.backward()
        
        z = torch.randn(images.size(0), 100)
        synthetic_images = generator(z)
        synthetic_labels = torch.zeros(images.size(0), 1)
        synthetic_output = discriminator
        d_synthetic_loss = criterion(synthetic_output, synthetic_labels)
        d_synthetic_loss.backward()
        d_optimizer.step()
        
        # Train the generator
        g_optimizer.zero_grad()
        z = torch.randn(images.size(0), 100)
        synthetic_images = generator(z)
        synthetic_labels = torch.ones(images.size(0), 1)
        synthetic_output = discriminator(synthetic_images)
        g_loss = criterion(synthetic_output, synthetic_labels)
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}], d_real_loss: {:.4f}, d_synthetic_loss: {:.4f}, g_loss: {:.4f}"
                  .format(epoch+1, 100, i+1, len(mnist), d_real_loss.item(), d_synthetic_loss.item(), g_loss.item()))
    
    # Plot some synthetic images
    with torch.no_grad():
        z = torch.randn(64, 100)
        synthetic_images = generator(z)
        synthetic_images = synthetic_images.view(synthetic_images.size(0), 28, 28)
        plt.imshow(np.transpose(torchvision.utils.make_grid(synthetic_images), (1, 2, 0)))
        plt.show()

This code trains a GAN to generate synthetic images of handwritten digits using the MNIST dataset. The generator network generates synthetic images, and the discriminator network tries to distinguish between real and synthetic images. The loss function used in this code is binary cross-entropy loss, and the optimizers used are Adam optimizers. The training process alternates between updating the weights of the discriminator and the generator, with the generator trying to generate synthetic images that are indistinguishable from real images, and the discriminator trying to correctly classify real and synthetic images. The code also plots some synthetic images at the end of each epoch to visually inspect the progress of the generator network.


4 views

Recent Posts

See All

Machine Learning : KNN model

Introduction to K-Nearest Neighbors (KNN) Algorithm in Machine Learning: K-Nearest Neighbors (KNN) is one of the simplest and most...

Comments


bottom of page