# imports and dataset
import numpy as np
import scipy
import matplotlib.pyplot as plt
from sklearn.neural_network import BernoulliRBM
from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import mnist
(X_train, X_test), (X_test, Y_test) = mnist.load_data()
# reshape to reduce dims
X_train = X_train.reshape(-1, 784)/255
X_test = X_test.reshape(-1, 784)/255
# binarize values
X_train = np.where(X_train > 0.2, 1, 0)
X_test = np.where(X_test > 0.2, 1, 0)
X_train, X_val = train_test_split(
X_train, test_size=1/5)
plt.figure(figsize=(8,4))
for i in range(10):
plt.subplot(2,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(X_train[i].reshape(28,28), cmap='Greys')
plt.tight_layout()
rbm = BernoulliRBM(random_state=0, n_components=80,
verbose=True, batch_size=20, n_iter=60, learning_rate=0.01)
rbm.fit(X_train)
print("Training set Pseudo-Likelihood =", rbm.score_samples(X_train).mean())
print("Validation set Pseudo-Likelihood =", rbm.score_samples(X_val).mean())
# noisify random image from dataset
X_pick = X_test[23]
pick = np.random.choice(28 * 28, 50)
x_noisy = np.copy(X_pick)
x_noisy[pick] = ((X_pick[pick] + 1) % 2)
# Perform the denoising
k_iter = 12 #number of gibbs sampling iterations
alpha = 0.9 #decay factor
b = rbm.gibbs(x_noisy)
x_final = np.zeros(784) + np.copy(b)
for i in range(k_iter):
b = rbm.gibbs(b) # Perform one Gibbs sampling step
x_final += (alpha**(i+1))*b.astype(float)
# binarize values agian
x_final = np.where(x_final > 0.5*np.max(x_final), 1, 0)
fig, ax = plt.subplots(1, 3, figsize=(10, 3))
ax[0].imshow(X_pick.reshape(28, 28), cmap='Greys')
ax[0].set_title('Original')
ax[1].imshow(x_noisy.reshape(28, 28), cmap='Greys')
ax[1].set_title('Corrupted')
ax[2].imshow(x_final.reshape(28, 28), cmap='Greys')
for i in range(3):
ax[i].set_xticks([])
ax[i].set_yticks([])
ax[2].set_title('De-Noised')
plt.show()