Building a Generative Adversarial Network using Keras Last Updated : 12 Jul, 2025 Summarize Comments Improve Suggest changes Share Like Article Like Report Generative Adversarial Networks (GANs)are deep learning models that involve two neural networks: generator and a discriminator. These networks work in a setup where they are trained together in an adversarial manner.The generator tries to generate fake data that is made from real data.While the discriminator attempts to distinguish between real and fake data. GANs have revolutionized fields like image generation, video creation and even text-to-image synthesis. In this article we will build a simple GAN using Keras. Below is the step by step implementation of GANs:1. Importing LibrariesHere we will be using numpy, matplotlib and keras. Python import numpy as np import matplotlib.pyplot as plt import keras from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers import LeakyReLU from keras.layers import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam,SGD 2. Loading and Preprocessing the DatasetHere we will loads the CIFAR-10 dataset and filters the images to only include a specific class (class 8).keras.datasets.cifar10.load_data(): Loads the CIFAR-10 dataset, which has 60,000 32x32 color images in 10 classes.X[y.flatten() == 8]: Filters out only the images of class 8. Python (X, y), (_, _) = keras.datasets.cifar10.load_data() X = X[y.flatten() == 8] 3. Defining Input Shape and Latent DimensionIt defines the shape of the input image and the size of the latent vector.image_shape: Defines the input image shape (32x32 with 3 color channels).latent_dimensions: Specifies the size of the latent vector i.e noise input for the generator. Python image_shape = (32, 32, 3) latent_dimensions = 100 4. Building the GeneratorIt defines the generator which takes random noise as input and outputs an image.Dense: A fully connected layer used to transform the latent vector into a higher-dimensional representation.Reshape: Changes the shape of the output from Dense to make it suitable for convolution.UpSampling2D: Upsamples the image to a higher resolution.Conv2D: Convolutional layers to process the image and generate features.Activation("tanh"): Activation function that ensures the pixel values of the generated image are in the range [-1, 1] using tanh. Python def build_generator(): model = Sequential() model.add(Dense(128 * 8 * 8, activation="relu", input_dim=latent_dimensions)) model.add(Reshape((8, 8, 128))) model.add(UpSampling2D()) model.add(Conv2D(128, kernel_size=3, padding="same")) model.add(BatchNormalization(momentum=0.78)) model.add(Activation("relu")) model.add(UpSampling2D()) model.add(Conv2D(64, kernel_size=3, padding="same")) model.add(BatchNormalization(momentum=0.78)) model.add(Activation("relu")) model.add(Conv2D(3, kernel_size=3, padding="same")) model.add(Activation("tanh")) noise = Input(shape=(latent_dimensions,)) image = model(noise) return Model(noise, image) 5. Building the DiscriminatorHere we will defines the discriminator which classifies images as real or fake.Conv2D: Convolutional layers used to extract features from images.LeakyReLU: An activation function that allows a small slope for negative values.Dropout: A regularization technique that helps prevent overfitting.Flatten: Flattens the image into a 1D vector for classification.Dense: Fully connected layer to classify the image as real or fake. Python def build_discriminator(): model = Sequential() model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=image_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(BatchNormalization(momentum=0.82)) model.add(LeakyReLU(alpha=0.25)) model.add(Dropout(0.25)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(BatchNormalization(momentum=0.82)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) model.add(BatchNormalization(momentum=0.8)) model.add(LeakyReLU(alpha=0.25)) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(1, activation='sigmoid')) image = Input(shape=image_shape) validity = model(image) return Model(image, validity) 6. Displaying Generated ImagesHere we will visualizes the images generated by the generator.plt.subplots: Creates a grid of subplots to display multiple images.0.5 * generated_images + 0.5: Rescales the generated images back to the range [0, 1]. Python def display_images(): r, c = 4,4 noise = np.random.normal(0, 1, (r * c,latent_dimensions)) generated_images = generator.predict(noise) generated_images = 0.5 * generated_images + 0.5 fig, axs = plt.subplots(r, c) count = 0 for i in range(r): for j in range(c): axs[i,j].imshow(generated_images[count, :,:,]) axs[i,j].axis('off') count += 1 plt.show() plt.close() 7. Building and Compiling the DiscriminatorWe will build and compile the discriminator and freezes its weights for the combined model training.Adam(0.0002, 0.5): Adam optimizer with specific learning rate and beta values.We will be using binary crossentropy for loss calculation.trainable = False: Freezes the weights of the discriminator so that only the generator gets trained during the combined model's training. Python discriminator = build_discriminator() discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002,0.5), metrics=['accuracy']) discriminator.trainable = False 8. Building the Combined ModelWe will create combined GAN model by connecting the generator and discriminator.combined_network: A model that takes noise as input, generates an image and then checks if the image is real or fake using the discriminator. Python generator = build_generator() z = Input(shape=(latent_dimensions,)) image = generator(z) valid = discriminator(image) combined_network = Model(z, valid) combined_network.compile(loss='binary_crossentropy', optimizer=Adam(0.0002,0.5)) 9. Training the GANWe will train the GAN by alternating between training the discriminator and generator.train_on_batch: Trains the models on a single batch of data.discriminator.train_on_batch: Trains the discriminator on real and fake images.combined_network.train_on_batch: Trains the generator to produce images that can fool the discriminator.We will be using batch size of 32.We will be using 12,500 epochs for training and will display outputs after every 2500 epochs to see difference. Python num_epochs = 12500 batch_size = 32 display_interval = 2500 losses = [] X = (X / 127.5) - 1. valid = np.ones((batch_size, 1)) valid += 0.05 * np.random.random(valid.shape) fake = np.zeros((batch_size, 1)) fake += 0.05 * np.random.random(fake.shape) for epoch in range(num_epochs): index = np.random.randint(0, X.shape[0], batch_size) images = X[index] noise = np.random.normal(0, 1, (batch_size, latent_dimensions)) generated_images = generator.predict(noise) discm_loss_real = discriminator.train_on_batch(images, valid) discm_loss_fake = discriminator.train_on_batch(generated_images, fake) discm_loss = 0.5 * np.add(discm_loss_real, discm_loss_fake) genr_loss = combined_network.train_on_batch(noise, valid) if epoch % display_interval == 0: display_images() Epoch 0:Epoch 2500:Epoch 5000:Epoch 7500:Epoch 10000:Epoch 12500:We can observe that with each 2500 epoch interval the quality of the generated images improves significantly. This incremental enhancement shows how the generator progressively learns to create more realistic images as the training advances.While this is a basic example, GANs can be extended with more complex architectures including convolutional layers for image generation. By combining a generator and discriminator in a competitive setup, GANs enable the creation of realistic synthetic images from random noise. You can explore advanced GAN variants such as CycleGAN, StyleGAN and Conditional GANs which are used for tasks like high-resolution image generation, style transfer and more. Comment More infoAdvertise with us Next Article Deep Learning Tutorial A AlindGupta Follow Improve Article Tags : Machine Learning python Practice Tags : Machine Learningpython Similar Reads Deep Learning Tutorial Deep Learning is a subset of Artificial Intelligence (AI) that helps machines to learn from large datasets using multi-layered neural networks. It automatically finds patterns and makes predictions and eliminates the need for manual feature extraction. Deep Learning tutorial covers the basics to adv 5 min read Introduction to Deep LearningIntroduction to Deep LearningDeep Learning is transforming the way machines understand, learn and interact with complex data. Deep learning mimics neural networks of the human brain, it enables computers to autonomously uncover patterns and make informed decisions from vast amounts of unstructured data. How Deep Learning Works? 7 min read Difference Between Artificial Intelligence vs Machine Learning vs Deep LearningArtificial Intelligence is basically the mechanism to incorporate human intelligence into machines through a set of rules(algorithm). AI is a combination of two words: "Artificial" meaning something made by humans or non-natural things and "Intelligence" meaning the ability to understand or think ac 14 min read Basic Neural NetworkDifference between ANN and BNNBoth natural intelligence and artificial intelligence works on networks of neurons. While Artificial Neural Networks (ANNs) draw ideas from Biological Neural Networks (BNNs) they still differ in structure, function and adaptability. In this article we will explore how these systems work and what set 3 min read Single Layer Perceptron in TensorFlowSingle Layer Perceptron is inspired by biological neurons and their ability to process information. To understand the SLP we first need to break down the workings of a single artificial neuron which is the fundamental building block of neural networks. An artificial neuron is a simplified computatio 4 min read Multi-Layer Perceptron Learning in TensorflowMulti-Layer Perceptron (MLP) consists of fully connected dense layers that transform input data from one dimension to another. It is called multi-layer because it contains an input layer, one or more hidden layers and an output layer. The purpose of an MLP is to model complex relationships between i 6 min read Deep Neural net with forward and back propagation from scratch - PythonThis article aims to implement a deep neural network from scratch. We will implement a deep neural network containing two input layers, a hidden layer with four units and one output layer. The implementation will go from scratch and the following steps will be implemented. Algorithm:1. Loading and v 6 min read Understanding Multi-Layer Feed Forward NetworksLet's understand how errors are calculated and weights are updated in backpropagation networks(BPNs). Consider the following network in the below figure. Backpropagation Network (BPN) The network in the above figure is a simple multi-layer feed-forward network or backpropagation network. It contains 7 min read List of Deep Learning LayersDeep learning (DL) is characterized by the use of neural networks with multiple layers to model and solve complex problems. Each layer in the neural network plays a unique role in the process of converting input data into meaningful and insightful outputs. The article explores the layers that are us 7 min read Activation FunctionsActivation FunctionsTo put it in simple terms, an artificial neuron calculates the 'weighted sum' of its inputs and adds a bias, as shown in the figure below by the net input. Mathematically, \text{Net Input} =\sum \text{(Weight} \times \text{Input)+Bias} Now the value of net input can be any anything from -inf to +inf 3 min read Types Of Activation Function in ANNThe biological neural network has been modeled in the form of Artificial Neural Networks with artificial neurons simulating the function of a biological neuron. The artificial neuron is depicted in the below picture:Structure of an Artificial NeuronEach neuron consists of three major components: A s 3 min read Activation Functions in PytorchIn this article, we will Understand PyTorch Activation Functions. What is an activation function and why to use them?Activation functions are the building blocks of Pytorch. Before coming to types of activation function, let us first understand the working of neurons in the human brain. In the Artif 5 min read Understanding Activation Functions in DepthIn artificial neural networks, the activation function of a neuron determines its output for a given input. This output serves as the input for subsequent neurons in the network, continuing the process until the network solves the original problem. Consider a binary classification problem, where the 6 min read Artificial Neural NetworkArtificial Neural Networks and its ApplicationsArtificial Neural Networks (ANNs) are computer systems designed to mimic how the human brain processes information. Just like the brain uses neurons to process data and make decisions, ANNs use artificial neurons to analyze data, identify patterns and make predictions. These networks consist of laye 8 min read Gradient Descent Optimization in TensorflowGradient descent is an optimization algorithm used to find the values of parameters (coefficients) of a function (f) that minimizes a cost function. In other words, gradient descent is an iterative algorithm that helps to find the optimal solution to a given problem.In this blog, we will discuss gra 15+ min read Choose Optimal Number of Epochs to Train a Neural Network in KerasOne of the critical issues while training a neural network on the sample data is Overfitting. When the number of epochs used to train a neural network model is more than necessary, the training model learns patterns that are specific to sample data to a great extent. This makes the model incapable t 6 min read ClassificationPython | Classify Handwritten Digits with TensorflowClassifying handwritten digits is the basic problem of the machine learning and can be solved in many ways here we will implement them by using TensorFlowUsing a Linear Classifier Algorithm with tf.contrib.learn linear classifier achieves the classification of handwritten digits by making a choice b 4 min read Train a Deep Learning Model With PytorchNeural Network is a type of machine learning model inspired by the structure and function of human brain. It consists of layers of interconnected nodes called neurons which process and transmit information. Neural networks are particularly well-suited for tasks such as image and speech recognition, 6 min read RegressionLinear Regression using PyTorchLinear Regression is a very commonly used statistical method that allows us to determine and study the relationship between two continuous variables. The various properties of linear regression and its Python implementation have been covered in this article previously. Now, we shall find out how to 4 min read Linear Regression Using TensorflowWe will briefly summarize Linear Regression before implementing it using TensorFlow. Since we will not get into the details of either Linear Regression or Tensorflow, please read the following articles for more details: Linear Regression (Python Implementation)Introduction to TensorFlowIntroduction 6 min read Hyperparameter tuningHyperparameter TuningHyperparameter tuning is the process of selecting the optimal values for a machine learning model's hyperparameters. These are typically set before the actual training process begins and control aspects of the learning process itself. They influence the model's performance its complexity and how fas 7 min read Introduction to Convolution Neural NetworkIntroduction to Convolution Neural NetworkConvolutional Neural Network (CNN) is an advanced version of artificial neural networks (ANNs), primarily designed to extract features from grid-like matrix datasets. This is particularly useful for visual datasets such as images or videos, where data patterns play a crucial role. CNNs are widely us 8 min read Digital Image Processing BasicsDigital Image Processing means processing digital image by means of a digital computer. We can also say that it is a use of computer algorithms, in order to get enhanced image either to extract some useful information. Digital image processing is the use of algorithms and mathematical models to proc 7 min read Difference between Image Processing and Computer VisionImage processing and Computer Vision both are very exciting field of Computer Science. Computer Vision: In Computer Vision, computers or machines are made to gain high-level understanding from the input digital images or videos with the purpose of automating tasks that the human visual system can do 2 min read CNN | Introduction to Pooling LayerPooling layer is used in CNNs to reduce the spatial dimensions (width and height) of the input feature maps while retaining the most important information. It involves sliding a two-dimensional filter over each channel of a feature map and summarizing the features within the region covered by the fi 5 min read CIFAR-10 Image Classification in TensorFlowPrerequisites:Image ClassificationConvolution Neural Networks including basic pooling, convolution layers with normalization in neural networks, and dropout.Data Augmentation.Neural Networks.Numpy arrays.In this article, we are going to discuss how to classify images using TensorFlow. Image Classifi 8 min read Implementation of a CNN based Image Classifier using PyTorchIntroduction: Introduced in the 1980s by Yann LeCun, Convolution Neural Networks(also called CNNs or ConvNets) have come a long way. From being employed for simple digit classification tasks, CNN-based architectures are being used very profoundly over much Deep Learning and Computer Vision-related t 9 min read Convolutional Neural Network (CNN) ArchitecturesConvolutional Neural Network(CNN) is a neural network architecture in Deep Learning, used to recognize the pattern from structured arrays. However, over many years, CNN architectures have evolved. Many variants of the fundamental CNN Architecture This been developed, leading to amazing advances in t 11 min read Object Detection vs Object Recognition vs Image SegmentationObject Recognition: Object recognition is the technique of identifying the object present in images and videos. It is one of the most important applications of machine learning and deep learning. The goal of this field is to teach machines to understand (recognize) the content of an image just like 5 min read YOLO v2 - Object DetectionIn terms of speed, YOLO is one of the best models in object recognition, able to recognize objects and process frames at the rate up to 150 FPS for small networks. However, In terms of accuracy mAP, YOLO was not the state of the art model but has fairly good Mean average Precision (mAP) of 63% when 7 min read Recurrent Neural NetworkNatural Language Processing (NLP) TutorialNatural Language Processing (NLP) is a branch of Artificial Intelligence (AI) that helps machines to understand and process human languages either in text or audio form. It is used across a variety of applications from speech recognition to language translation and text summarization.Natural Languag 5 min read NLTK - NLPNatural Language Processing (NLP) plays an important role in enabling machines to understand and generate human language. Natural Language Toolkit (NLTK) stands out as one of the most widely used libraries. It provides a combination linguistic resources, including text processing libraries and pre-t 5 min read Word Embeddings in NLPWord Embeddings are numeric representations of words in a lower-dimensional space, that capture semantic and syntactic information. They play a important role in Natural Language Processing (NLP) tasks. Here, we'll discuss some traditional and neural approaches used to implement Word Embeddings, suc 14 min read Introduction to Recurrent Neural NetworksRecurrent Neural Networks (RNNs) differ from regular neural networks in how they process information. While standard neural networks pass information in one direction i.e from input to output, RNNs feed information back into the network at each step.Lets understand RNN with a example:Imagine reading 10 min read Recurrent Neural Networks ExplanationToday, different Machine Learning techniques are used to handle different types of data. One of the most difficult types of data to handle and the forecast is sequential data. Sequential data is different from other types of data in the sense that while all the features of a typical dataset can be a 8 min read Sentiment Analysis with an Recurrent Neural Networks (RNN)Recurrent Neural Networks (RNNs) are used in sequence tasks such as sentiment analysis due to their ability to capture context from sequential data. In this article we will be apply RNNs to analyze the sentiment of customer reviews from Swiggy food delivery platform. The goal is to classify reviews 5 min read Short term MemoryIn the wider community of neurologists and those who are researching the brain, It is agreed that two temporarily distinct processes contribute to the acquisition and expression of brain functions. These variations can result in long-lasting alterations in neuron operations, for instance through act 5 min read What is LSTM - Long Short Term Memory?Long Short-Term Memory (LSTM) is an enhanced version of the Recurrent Neural Network (RNN) designed by Hochreiter and Schmidhuber. LSTMs can capture long-term dependencies in sequential data making them ideal for tasks like language translation, speech recognition and time series forecasting. Unlike 5 min read Long Short Term Memory Networks ExplanationPrerequisites: Recurrent Neural Networks To solve the problem of Vanishing and Exploding Gradients in a Deep Recurrent Neural Network, many variations were developed. One of the most famous of them is the Long Short Term Memory Network(LSTM). In concept, an LSTM recurrent unit tries to "remember" al 7 min read LSTM - Derivation of Back propagation through timeLong Short-Term Memory (LSTM) are a type of neural network designed to handle long-term dependencies by handling the vanishing gradient problem. One of the fundamental techniques used to train LSTMs is Backpropagation Through Time (BPTT) where we have sequential data. In this article we see how BPTT 4 min read Text Generation using Recurrent Long Short Term Memory NetworkLSTMs are a type of neural network that are well-suited for tasks involving sequential data such as text generation. They are particularly useful because they can remember long-term dependencies in the data which is crucial when dealing with text that often has context that spans over multiple words 4 min read Like