0% found this document useful (0 votes)
6 views3 pages

# Load The Dataset # Normalize The Images To (-1, 1) For Better Performance of The GAN # Add Channel Dimension # Set Buffer and Batch Size

Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chain

Uploaded by

rithikrajvaishy2
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
6 views3 pages

# Load The Dataset # Normalize The Images To (-1, 1) For Better Performance of The GAN # Add Channel Dimension # Set Buffer and Batch Size

Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chaining Chain

Uploaded by

rithikrajvaishy2
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 3

In [ ]:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
import numpy as np

# Load the dataset


(x_train, _), (_, _) = mnist.load_data()

# Normalize the images to [-1, 1] for better performance of the GAN


x_train = (x_train - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1) # Add channel dimension

# Set buffer and batch size


BUFFER_SIZE = 60000
BATCH_SIZE = 256

# Create batches and shuffle the dataset


train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(B
ATCH_SIZE)

In [ ]:
from tensorflow.keras import layers

def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Reshape((7, 7, 256)))
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bi
as=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bia


s=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias


=False, activation='tanh'))

return model

generator = make_generator_model()

In [ ]:
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28,
28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))

model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))


model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))

model.add(layers.Flatten())
model.add(layers.Dense(1))

return model

discriminator = make_discriminator_model()
In [ ]:
# Loss functions for both generator and discriminator
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):


real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
return real_loss + fake_loss

def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)

# Optimizers for both generator and discriminator


generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# Checkpoint to save models


import os
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)

In [ ]:
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Seed for generating random noise vectors


noise_dim = 100
num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate, noise_dim])

@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:


generated_images = generator(noise, training=True)

real_output = discriminator(images, training=True)


fake_output = discriminator(generated_images, training=True)

gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)

gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)


gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_va
riables)

generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_v
ariables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator
.trainable_variables))

def train(dataset, epochs):


for epoch in range(epochs):
start = time.time()

for image_batch in dataset:


train_step(image_batch)

# Produce images for the GIF


clear_output(wait=True)
generate_and_save_images(generator, epoch + 1, seed)

# Save the model every 15 epochs


if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix=checkpoint_prefix)

print(f'Time for epoch {epoch + 1} is {time.time() - start} sec')

# Generate after the final epoch


clear_output(wait=True)
generate_and_save_images(generator, epochs, seed)

def generate_and_save_images(model, epoch, test_input):


predictions = model(test_input, training=False)

fig = plt.figure(figsize=(4, 4))

for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')

plt.savefig(f'image_at_epoch_{epoch:04d}.png')
plt.show()

# Set the number of epochs and train the model


EPOCHS = 50
train(train_dataset, EPOCHS)

You might also like