from typing import *
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.special import softmax
from keras.datasets import mnist
(X_train_mnist, y_train_mnist), (X_test_mnist, y_test_mnist) = mnist.load_data()
# Normalization
X_train_mnist.shape
mnist_digits = np.concatenate([X_train_mnist, X_test_mnist], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
print("MIN:",mnist_digits.min(),"MAX:",mnist_digits.max())
MIN: 0.0 MAX: 1.0
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.flatten = Flatten()
self.conv1 = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")
self.conv2 = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")
self.enc_h1 = layers.Dense(32)
self.enc_h2 = layers.Dense(32)
self.hidden1 = Dense(7 * 7 * 64,activation='relu')
self.tconv1 = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")
self.tconv2 = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")
self.tconv3 = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
name="reconstruction_loss")
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
def encoder(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.flatten(x)
z_mean = self.enc_h1(x)
z_log_var = self.enc_h2(x)
z = Sampling()([z_mean, z_log_var])
return z_mean,z_log_var, z
def decoder(self, x):
x = self.hidden1(x)
x = layers.Reshape((7, 7, 64))(x)
x = self.tconv1(x)
x = self.tconv2(x)
x = self.tconv3(x)
return x
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
tf.keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
model = MyModel()
model.compile(optimizer='adam')
history_adam = model.fit(mnist_digits, epochs=10, batch_size=128)
Epoch 1/10
547/547 [==============================] - 176s 320ms/step - loss: 216.7794 - reconstruction_loss: 143.4430 - kl_loss: 16.9783
Epoch 2/10
547/547 [==============================] - 176s 322ms/step - loss: 110.8367 - reconstruction_loss: 83.0104 - kl_loss: 26.0823
Epoch 3/10
547/547 [==============================] - 175s 320ms/step - loss: 105.7279 - reconstruction_loss: 78.8369 - kl_loss: 26.4659
Epoch 4/10
547/547 [==============================] - 175s 320ms/step - loss: 103.7990 - reconstruction_loss: 77.1852 - kl_loss: 26.4833
Epoch 5/10
151/547 [=======>......................] - ETA: 2:05 - loss: 102.7431 - reconstruction_loss: 76.2704 - kl_loss: 26.4031
#https://keras.io/examples/generative/vae/
import matplotlib.pyplot as plt
def plot_latent_space(vae, n=30, figsize=15):
# display a n*n 2D manifold of digits
digit_size = 28
scale = 1.0
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates corresponding to the 2D plot
# of digit classes in the latent space
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = vae.decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size + start_range
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()
plot_latent_space(vae)
Q = np.array([0.8,0.5,0.21])
K = np.array([0.1,0.62,0.73])
V = np.array([0.6,0.2,0.9])
dk = 3
Attention_matrix = softmax(Q*K.T/np.sqrt(dk))*V
Attention_matrix