# Load The Dataset # Normalize The Images To (-1, 1) For Better Performance of The GAN # Add Channel Dimension # Set Buffer and Batch Size
# Load The Dataset # Normalize The Images To (-1, 1) For Better Performance of The GAN # Add Channel Dimension # Set Buffer and Batch Size
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import numpy as np
In [ ]:
from tensorflow.keras import layers
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bi
as=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
return model
generator = make_generator_model()
In [ ]:
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28,
28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
discriminator = make_discriminator_model()
In [ ]:
# Loss functions for both generator and discriminator
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
In [ ]:
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_v
ariables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator
.trainable_variables))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig(f'image_at_epoch_{epoch:04d}.png')
plt.show()