Autoencoder Slides
Autoencoder Slides
Autoencoders
In [39]:
def plotn(n, x):
fig, ax = plt.subplots(1, n)
for i, z in enumerate(x[0:n]):
ax[i].imshow(z.reshape(28,28) if z.size==28*28 else z.reshape(14,14) if z.size==14*14 else z)
plt.show()
plotn(5,x_train)
In [40]:
def plotidx(indices, x):
fig, ax = plt.subplots(1, len(indices))
for i, z in enumerate(x[indices]):
ax[i].imshow(z.reshape(28,28) if z.size==28*28 else z.reshape(14,14) if z.size==14*14 else z)
plt.show()
plotidx([5,6,7,8,9],x_train)
Example: Autoencoders with MNIST
In [41]:
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.losses import binary_crossentropy, mse
In [43]:
input_img = Input(shape=(28,28,1))
input_rep = Input(shape=(4,4,8))
In [12]:
autoencoder.fit(x_train, x_train,
epochs=25,
batch_size=128,
shuffle=True,
validation_data=(x_test, x_test))
Epoch 1/25
469/469 [==============================] - 7s
16ms/step - loss: 0.1455 - val_loss: 0.1335
Epoch 2/25
469/469 [==============================] - 7s
16ms/step - loss: 0.1290 - val_loss: 0.1229
Epoch 3/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1205 - val_loss: 0.1163
Epoch 4/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1148 - val_loss: 0.1113
Epoch 5/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1107 - val_loss: 0.1083
Epoch 6/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1080 - val_loss: 0.1057
Epoch 7/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1061 - val_loss: 0.1042
Epoch 8/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1046 - val_loss: 0.1029
Epoch 9/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1035 - val_loss: 0.1019
Epoch 10/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1024 - val_loss: 0.1011
Epoch 11/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1014 - val_loss: 0.0999
Epoch 12/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1006 - val_loss: 0.0993
Epoch 13/25
469/469 [==============================] - 7s
15ms/step - loss: 0.0999 - val_loss: 0.0986
Epoch 14/25
469/469 [==============================] - 7s
15ms/step - loss: 0.0993 - val_loss: 0.0980
Epoch 15/25
469/469 [==============================] - 7s
15ms/step - loss: 0.0987 - val_loss: 0.0976
Epoch 16/25
469/469 [==============================] - 7s
15ms/step - loss: 0.0983 - val_loss: 0.0971
Epoch 17/25
469/469 [==============================] - 7s
15ms/step - loss: 0.0978 - val_loss: 0.0966
Epoch 18/25
469/469 [==============================] - 7s
16ms/step - loss: 0.0975 - val_loss: 0.0963
Epoch 19/25
469/469 [==============================] - 8s
16ms/step - loss: 0.0971 - val_loss: 0.0962
Epoch 20/25
469/469 [==============================] - 8s
17ms/step - loss: 0.0968 - val_loss: 0.0957
Epoch 21/25
469/469 [==============================] - 8s
17ms/step - loss: 0.0965 - val_loss: 0.0954
Epoch 22/25
469/469 [==============================] - 8s
17ms/step - loss: 0.0963 - val_loss: 0.0952
Epoch 23/25
469/469 [==============================] - 8s
17ms/step - loss: 0.0960 - val_loss: 0.0950
Epoch 24/25
469/469 [==============================] - 8s
17ms/step - loss: 0.0957 - val_loss: 0.0944
Epoch 25/25
469/469 [==============================] - 8s
17ms/step - loss: 0.0955 - val_loss: 0.0945
Out[12]:
<keras.src.callbacks.History at 0x1c8c195fa60
>
Example: Autoencoders with MNIST
In [13]:
y_test = autoencoder.predict(x_test[0:5])
plotn(5,x_test)
plotn(5,y_test)
In [27]:
res = decoder.predict(encoded_imgs, verbose=False)
plotn(len(indices), res)
Example: Autoencoder with MNIST
x_train_noise = noisify(x_train)
x_test_noise = noisify(x_test)
plotn(5,x_train_noise)
Example: Denoising for MNIST
Epoch 1/25
469/469 [==============================] - 7s
15ms/step - loss: 0.2232 - val_loss: 0.1877
Epoch 2/25
469/469 [==============================] - 7s
16ms/step - loss: 0.1801 - val_loss: 0.1718
Epoch 3/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1674 - val_loss: 0.1612
Epoch 4/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1599 - val_loss: 0.1560
Epoch 5/25
469/469 [==============================] - 7s
16ms/step - loss: 0.1554 - val_loss: 0.1523
Epoch 6/25
469/469 [==============================] - 7s
16ms/step - loss: 0.1521 - val_loss: 0.1495
Epoch 7/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1494 - val_loss: 0.1472
Epoch 8/25
469/469 [==============================] - 7s
16ms/step - loss: 0.1471 - val_loss: 0.1444
Epoch 9/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1452 - val_loss: 0.1431
Epoch 10/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1436 - val_loss: 0.1415
Epoch 11/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1425 - val_loss: 0.1406
Epoch 12/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1413 - val_loss: 0.1393
Epoch 13/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1403 - val_loss: 0.1387
Epoch 14/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1395 - val_loss: 0.1376
Epoch 15/25
469/469 [==============================] - 7s
15ms/step - loss: 0.1387 - val_loss: 0.1369
Epoch 16/25
469/469 [==============================] - 7s
16ms/step - loss: 0.1381 - val_loss: 0.1362
Epoch 17/25
469/469 [==============================] - 8s
17ms/step - loss: 0.1374 - val_loss: 0.1356
Epoch 18/25
469/469 [==============================] - 9s
19ms/step - loss: 0.1369 - val_loss: 0.1356
Epoch 19/25
469/469 [==============================] - 8s
17ms/step - loss: 0.1365 - val_loss: 0.1346
Epoch 20/25
469/469 [==============================] - 8s
17ms/step - loss: 0.1359 - val_loss: 0.1343
Epoch 21/25
469/469 [==============================] - 8s
17ms/step - loss: 0.1356 - val_loss: 0.1342
Epoch 22/25
469/469 [==============================] - 8s
17ms/step - loss: 0.1350 - val_loss: 0.1340
Epoch 23/25
469/469 [==============================] - 8s
17ms/step - loss: 0.1346 - val_loss: 0.1333
Epoch 24/25
469/469 [==============================] - 8s
17ms/step - loss: 0.1342 - val_loss: 0.1328
Epoch 25/25
469/469 [==============================] - 8s
18ms/step - loss: 0.1338 - val_loss: 0.1327
Out[40]:
<keras.src.callbacks.History at 0x1c8e1749a80
>
Example: Denoising for MNIST
In [43]:
x_train_lr = tf.keras.layers.AveragePooling2D()(x_train).numpy()
x_test_lr = tf.keras.layers.AveragePooling2D()(x_test).numpy()
plotn(5, x_train_lr)
Example: Super-resolution on MNIST
In [45]:
input_img = Input(shape=(14, 14, 1))
input_rep = Input(shape=(4,4,8))
Epoch 1/25
469/469 [==============================] - 6s
12ms/step - loss: 0.2228 - val_loss: 0.1549
Epoch 2/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1420 - val_loss: 0.1308
Epoch 3/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1268 - val_loss: 0.1212
Epoch 4/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1197 - val_loss: 0.1163
Epoch 5/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1152 - val_loss: 0.1123
Epoch 6/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1119 - val_loss: 0.1100
Epoch 7/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1094 - val_loss: 0.1068
Epoch 8/25
469/469 [==============================] - 6s
13ms/step - loss: 0.1074 - val_loss: 0.1050
Epoch 9/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1057 - val_loss: 0.1037
Epoch 10/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1043 - val_loss: 0.1025
Epoch 11/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1031 - val_loss: 0.1012
Epoch 12/25
469/469 [==============================] - 6s
12ms/step - loss: 0.1021 - val_loss: 0.1007
Epoch 13/25
469/469 [==============================] - 6s
14ms/step - loss: 0.1012 - val_loss: 0.0996
Epoch 14/25
469/469 [==============================] - 6s
13ms/step - loss: 0.1005 - val_loss: 0.0996
Epoch 15/25
469/469 [==============================] - 6s
13ms/step - loss: 0.0998 - val_loss: 0.0981
Epoch 16/25
469/469 [==============================] - 6s
12ms/step - loss: 0.0991 - val_loss: 0.0977
Epoch 17/25
469/469 [==============================] - 6s
12ms/step - loss: 0.0985 - val_loss: 0.0970
Epoch 18/25
469/469 [==============================] - 6s
12ms/step - loss: 0.0980 - val_loss: 0.0964
Epoch 19/25
469/469 [==============================] - 6s
12ms/step - loss: 0.0975 - val_loss: 0.0958
Epoch 20/25
469/469 [==============================] - 6s
12ms/step - loss: 0.0971 - val_loss: 0.0957
Epoch 21/25
469/469 [==============================] - 6s
12ms/step - loss: 0.0966 - val_loss: 0.0952
Epoch 22/25
469/469 [==============================] - 6s
13ms/step - loss: 0.0963 - val_loss: 0.0948
Epoch 23/25
469/469 [==============================] - 6s
13ms/step - loss: 0.0959 - val_loss: 0.0944
Epoch 24/25
469/469 [==============================] - 6s
13ms/step - loss: 0.0956 - val_loss: 0.0942
Epoch 25/25
469/469 [==============================] - 6s
14ms/step - loss: 0.0953 - val_loss: 0.0943
Out[46]:
<keras.src.callbacks.History at 0x1c8e3fa8fd0
>
Example: Super-resolution on MNIST
In [48]:
y_test_lr = autoencoder.predict(x_test_lr[0:5], verbose=False)
plotn(5, x_test_lr)
plotn(5, y_test_lr)
Exercise:
Try to train similarly a super-resolution network
on Fashion-MNIST.
Try to train super-resolution networks on CIFAR-
10 for 2x and 4x upscaling.
Variational Auto-Encoders (VAE)
Variational Auto-Encoders (VAE)
tf.compat.v1.disable_eager_execution()
inputs = Input(shape=(784,))
h = Dense(intermediate_dim, activation='relu')(inputs)
z_mean = Dense(latent_dim)(h)
z_log_sigma = Dense(latent_dim)(h)
In [46]:
@tf.function
def sampling(args):
z_mean, z_log_sigma = args
bs = tf.shape(z_mean)[0]
epsilon = tf.random.normal(shape=(bs, latent_dim))
return z_mean + tf.exp(z_log_sigma) * epsilon
z = Lambda(sampling)([z_mean, z_log_sigma])
Example: VAE on MNIST
In [47]:
encoder = Model(inputs, [z_mean, z_log_sigma, z])
latent_inputs = Input(shape=(latent_dim,))
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = Dense(784, activation='sigmoid')(x)
outputs = decoder(encoder(inputs)[2])
vae.compile(optimizer='rmsprop', loss=vae_loss)
Example: VAE on MNIST
In [49]:
x_train_flat = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test_flat = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
vae.fit(x_train_flat, x_train_flat,
shuffle=True,
epochs=25,
batch_size=batch_size,
validation_data=(x_test_flat, x_test_flat))
<keras.src.callbacks.History at 0x1a4fb03d690
>
Example: VAE on MNIST
In [50]:
y_test = vae.predict(x_test_flat[0:5])
plotn(5,x_test_flat)
plotn(5,y_test)
In [51]:
x_test_encoded = encoder.predict(x_test_flat)[0]
plt.figure(figsize=(6,6))
plt.scatter(x_test_encoded[:,0], x_test_encoded[:,1], c=y_testclass)
plt.colorbar()
plt.show()
In [52]:
def plotsample(n):
dx = np.linspace(-2,2,n)
dy = np.linspace(-2,2,n)
fig,ax = plt.subplots(n,n)
for i,xi in enumerate(dx):
for j,xj in enumerate(dy):
res = decoder.predict(np.array([xi,xj]).reshape(-1,2))[0]
ax[i,j].imshow(res.reshape(28,28))
ax[i,j].axis('off')
plt.show()
plotsample(10)
In [95]:
fig.tight_layout()
plt.show()
In [ ]: