def train_acgan(epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X, y), (_, _) = mnist.load_data()
# Configure inputs
X = X.astype(np.float32)
X = (X - 127.5) / 127.5
X = np.expand_dims(X, axis=3)
y = y.reshape(-1, 1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# Select a random batch of images
index = np.random.randint(0, X.shape[0], batch_size)
images = X[index]
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# The labels of the digits that the generator tries to create an
# image representation of
new_labels = np.random.randint(0, 10, (batch_size, 1))
# Generate a half batch of new images
gen_images = generator.predict([noise, new_labels])
image_labels = y[index]
# Training the discriminator
disc_loss_real = discriminator.train_on_batch(
images, [valid, image_labels])
disc_loss_fake = discriminator.train_on_batch(
gen_images, [fake, new_labels])
disc_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
# Training the generator
gen_loss = combined.train_on_batch(
[noise, new_labels], [valid, new_labels])
# Print the accuracies
print ("%d [acc.: %.2f%%, op_acc: %.2f%%]" % (
epoch, 100 * disc_loss[3], 100 * disc_loss[4]))
# display at every defined epoch interval
if epoch % sample_interval == 0:
display_images()