0% found this document useful (0 votes)
13 views4 pages

DL 8

The document provides a Python program that implements image augmentation using Generative Adversarial Networks (GANs) with the MNIST dataset. It includes the setup of a generator and discriminator network, training loop, and saving of generated images and model weights. The program also visualizes the generated images after training.

Uploaded by

MATHAN KUMAR M
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as DOCX, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
13 views4 pages

DL 8

The document provides a Python program that implements image augmentation using Generative Adversarial Networks (GANs) with the MNIST dataset. It includes the setup of a generator and discriminator network, training loop, and saving of generated images and model weights. The program also visualizes the generated images after training.

Uploaded by

MATHAN KUMAR M
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as DOCX, PDF, TXT or read online on Scribd
You are on page 1/ 4

Ex:8

Image augmentation using GANs

Program:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os

# Check for GPU


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
latent_dim = 100 # Size of the latent vector for the generator input
img_size = 28 # Image size (28x28 for MNIST)
batch_size = 64
num_epochs = 100
learning_rate = 0.0002

# Create directory to save augmented images


os.makedirs("gan_augmented_images", exist_ok=True)

# Image transformations (e.g., for MNIST images)


transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # Normalize images between -1 and 1
])

# Load dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Discriminator network
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_size * img_size, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)

def forward(self, x):


x = x.view(x.size(0), -1)
return self.model(x)

# Generator network
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, img_size * img_size),
nn.Tanh()
)

def forward(self, z):


img = self.model(z)
img = img.view(img.size(0), 1, img_size, img_size)
return img

# Initialize generator and discriminator


generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Loss function and optimizers


criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(train_loader):
real_imgs = imgs.to(device)
batch_size = real_imgs.size(0) # Get actual batch size for the last incomplete batch if any

# Update real and fake labels to match batch size


real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)

# Train Discriminator
real_outputs = discriminator(real_imgs)
d_loss_real = criterion(real_outputs, real_labels)

# Generate fake images


z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
fake_outputs = discriminator(fake_imgs.detach())
d_loss_fake = criterion(fake_outputs, fake_labels)

# Total discriminator loss


d_loss = d_loss_real + d_loss_fake

# Backpropagation for discriminator


optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()

# Train Generator
gen_labels = torch.ones(batch_size, 1).to(device) # Generator aims for these to be classified as real
fake_outputs = discriminator(fake_imgs)
g_loss = criterion(fake_outputs, gen_labels)

# Backpropagation for generator


optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()

# Print training progress


if (i + 1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(train_loader)}], D Loss: {d_loss.item()}, G
Loss: {g_loss.item()}")

# Save fake images for every epoch


fake_imgs = fake_imgs.reshape(fake_imgs.size(0), 1, img_size, img_size)
save_image(fake_imgs, f"gan_augmented_images/fake_images_epoch_{epoch+1}.png", normalize=True)

print("Training completed! Generated images saved in 'gan_augmented_images' directory.")

import matplotlib.pyplot as plt


import torchvision.transforms as transforms
from PIL import Image
import glob

# Load and display images generated in each epoch


images = sorted(glob.glob("gan_augmented_images/*.png")) # Get the saved images sorted by epoch
for img_path in images:
img = Image.open(img_path)
plt.imshow(img)
plt.title(f"Generated Image - {img_path}")
plt.axis("off")
plt.show()
# Save the trained models
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
# Load the saved models safely
generator.load_state_dict(torch.load("generator.pth", weights_only=True))
discriminator.load_state_dict(torch.load("discriminator.pth", weights_only=True))

Output:

<All keys matched successfully>

You might also like