Open In App

Building a Generative Adversarial Network using Keras

Last Updated : 12 Jul, 2025
Summarize
Comments
Improve
Suggest changes
Share
Like Article
Like
Report

Generative Adversarial Networks (GANs)are deep learning models that involve two neural networks: generator and a discriminator. These networks work in a setup where they are trained together in an adversarial manner.

  • The generator tries to generate fake data that is made from real data.
  • While the discriminator attempts to distinguish between real and fake data.

GANs have revolutionized fields like image generation, video creation and even text-to-image synthesis. In this article we will build a simple GAN using Keras.

Below is the step by step implementation of GANs:

1. Importing Libraries

Here we will be using numpy, matplotlib and keras.

Python
import numpy as np 
import matplotlib.pyplot as plt 
import keras 
from keras.layers import Input, Dense, Reshape, Flatten, Dropout 
from keras.layers import BatchNormalization, Activation, ZeroPadding2D 
from keras.layers import LeakyReLU 
from keras.layers import UpSampling2D, Conv2D 
from keras.models import Sequential, Model 
from keras.optimizers import Adam,SGD 

2. Loading and Preprocessing the Dataset

Here we will loads the CIFAR-10 dataset and filters the images to only include a specific class (class 8).

  • keras.datasets.cifar10.load_data(): Loads the CIFAR-10 dataset, which has 60,000 32x32 color images in 10 classes.
  • X[y.flatten() == 8]: Filters out only the images of class 8.
Python
(X, y), (_, _) = keras.datasets.cifar10.load_data() 

X = X[y.flatten() == 8]

3. Defining Input Shape and Latent Dimension

It defines the shape of the input image and the size of the latent vector.

  • image_shape: Defines the input image shape (32x32 with 3 color channels).
  • latent_dimensions: Specifies the size of the latent vector i.e noise input for the generator.
Python
image_shape = (32, 32, 3)
latent_dimensions = 100

4. Building the Generator

It defines the generator which takes random noise as input and outputs an image.

  • Dense: A fully connected layer used to transform the latent vector into a higher-dimensional representation.
  • Reshape: Changes the shape of the output from Dense to make it suitable for convolution.
  • UpSampling2D: Upsamples the image to a higher resolution.
  • Conv2D: Convolutional layers to process the image and generate features.
  • Activation("tanh"): Activation function that ensures the pixel values of the generated image are in the range [-1, 1] using tanh.
Python
def build_generator(): 
    model = Sequential() 

    model.add(Dense(128 * 8 * 8, activation="relu", input_dim=latent_dimensions)) 
    model.add(Reshape((8, 8, 128))) 

    model.add(UpSampling2D()) 
    model.add(Conv2D(128, kernel_size=3, padding="same")) 
    model.add(BatchNormalization(momentum=0.78)) 
    model.add(Activation("relu")) 

    model.add(UpSampling2D()) 
    model.add(Conv2D(64, kernel_size=3, padding="same")) 
    model.add(BatchNormalization(momentum=0.78)) 
    model.add(Activation("relu")) 

    model.add(Conv2D(3, kernel_size=3, padding="same")) 
    model.add(Activation("tanh")) 

    noise = Input(shape=(latent_dimensions,)) 
    image = model(noise) 

    return Model(noise, image)

5. Building the Discriminator

Here we will defines the discriminator which classifies images as real or fake.

  • Conv2D: Convolutional layers used to extract features from images.
  • LeakyReLU: An activation function that allows a small slope for negative values.
  • Dropout: A regularization technique that helps prevent overfitting.
  • Flatten: Flattens the image into a 1D vector for classification.
  • Dense: Fully connected layer to classify the image as real or fake.
Python
def build_discriminator(): 
    model = Sequential() 

    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=image_shape, padding="same")) 
    model.add(LeakyReLU(alpha=0.2)) 
    model.add(Dropout(0.25)) 
    
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) 
    model.add(ZeroPadding2D(padding=((0,1),(0,1)))) 
    model.add(BatchNormalization(momentum=0.82)) 
    model.add(LeakyReLU(alpha=0.25)) 
    model.add(Dropout(0.25)) 
    
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) 
    model.add(BatchNormalization(momentum=0.82)) 
    model.add(LeakyReLU(alpha=0.2)) 
    model.add(Dropout(0.25)) 
    
    model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) 
    model.add(BatchNormalization(momentum=0.8)) 
    model.add(LeakyReLU(alpha=0.25)) 
    model.add(Dropout(0.25)) 
    
    model.add(Flatten()) 
    model.add(Dense(1, activation='sigmoid')) 

    image = Input(shape=image_shape) 
    validity = model(image) 

    return Model(image, validity)

6. Displaying Generated Images

Here we will visualizes the images generated by the generator.

  • plt.subplots: Creates a grid of subplots to display multiple images.
  • 0.5 * generated_images + 0.5: Rescales the generated images back to the range [0, 1].
Python
def display_images(): 
    r, c = 4,4
    noise = np.random.normal(0, 1, (r * c,latent_dimensions)) 
    generated_images = generator.predict(noise) 

    generated_images = 0.5 * generated_images + 0.5

    fig, axs = plt.subplots(r, c) 
    count = 0
    for i in range(r): 
        for j in range(c): 
            axs[i,j].imshow(generated_images[count, :,:,]) 
            axs[i,j].axis('off') 
            count += 1
    plt.show() 
    plt.close() 

7. Building and Compiling the Discriminator

We will build and compile the discriminator and freezes its weights for the combined model training.

  • Adam(0.0002, 0.5): Adam optimizer with specific learning rate and beta values.
  • We will be using binary crossentropy for loss calculation.
  • trainable = False: Freezes the weights of the discriminator so that only the generator gets trained during the combined model's training.
Python
discriminator = build_discriminator() 
discriminator.compile(loss='binary_crossentropy', 
                    optimizer=Adam(0.0002,0.5), 
                    metrics=['accuracy']) 

discriminator.trainable = False

8. Building the Combined Model

We will create combined GAN model by connecting the generator and discriminator.

  • combined_network: A model that takes noise as input, generates an image and then checks if the image is real or fake using the discriminator.
Python
generator = build_generator() 

z = Input(shape=(latent_dimensions,)) 
image = generator(z) 

valid = discriminator(image) 

combined_network = Model(z, valid) 
combined_network.compile(loss='binary_crossentropy', 
                        optimizer=Adam(0.0002,0.5))

9. Training the GAN

We will train the GAN by alternating between training the discriminator and generator.

  • train_on_batch: Trains the models on a single batch of data.
  • discriminator.train_on_batch: Trains the discriminator on real and fake images.
  • combined_network.train_on_batch: Trains the generator to produce images that can fool the discriminator.
  • We will be using batch size of 32.
  • We will be using 12,500 epochs for training and will display outputs after every 2500 epochs to see difference.
Python
num_epochs = 12500
batch_size = 32
display_interval = 2500
losses = []

X = (X / 127.5) - 1.

valid = np.ones((batch_size, 1)) 
valid += 0.05 * np.random.random(valid.shape) 
fake = np.zeros((batch_size, 1)) 
fake += 0.05 * np.random.random(fake.shape) 

for epoch in range(num_epochs): 
    index = np.random.randint(0, X.shape[0], batch_size) 
    images = X[index] 

    noise = np.random.normal(0, 1, (batch_size, latent_dimensions)) 
    generated_images = generator.predict(noise) 

    discm_loss_real = discriminator.train_on_batch(images, valid) 
    discm_loss_fake = discriminator.train_on_batch(generated_images, fake) 
    discm_loss = 0.5 * np.add(discm_loss_real, discm_loss_fake) 
    
    genr_loss = combined_network.train_on_batch(noise, valid) 
    
    if epoch % display_interval == 0: 
        display_images()

Epoch 0:

Epoch 2500:

Epoch 5000:

Epoch 7500:

Epoch 10000:

Epoch 12500:

We can observe that with each 2500 epoch interval the quality of the generated images improves significantly. This incremental enhancement shows how the generator progressively learns to create more realistic images as the training advances.

While this is a basic example, GANs can be extended with more complex architectures including convolutional layers for image generation. By combining a generator and discriminator in a competitive setup, GANs enable the creation of realistic synthetic images from random noise. You can explore advanced GAN variants such as CycleGAN, StyleGAN and Conditional GANs which are used for tasks like high-resolution image generation, style transfer and more.


Article Tags :
Practice Tags :

Similar Reads