import numpy as np
import pandas as pd
import scipy.linalg as linalg
import rdkit.Chem as Chem
from rdkit.Chem import AllChem
import torch
import matplotlib.pyplot as plt
def generate_graph(smile):
def get_atom_hash(atomic_number):
""" A helper function to quickly encode atomic number into a one-hot vector """
atomic_number_list = [6, 7, 8, 9, 15, 16, 17, 35, 53]
if atomic_number in atomic_number_list:
return atomic_number_list.index(atomic_number)
else:
return len(atomic_number_list)
def encode_atom(atom):
""" Generates the vector of features for an atom in a molecule"""
atom_matrix = np.zeros(10)
atom_matrix[get_atom_hash(atom.GetAtomicNum())] = 1
return atom_matrix
m = AllChem.MolFromSmiles(smile)
m = AllChem.AddHs(m)
AllChem.ComputeGasteigerCharges(m)
atom_matrix = np.array(list(map(encode_atom, (atom for atom in m.GetAtoms()))))
adj_matrix = AllChem.GetAdjacencyMatrix(m) + np.identity(m.GetNumAtoms()) # Augmented Adjacency Matrix
return(atom_matrix, adj_matrix, m.GetNumAtoms())
esol_data = pd.read_csv('delaney-processed.csv', sep=',')
amigdalin = esol_data['smiles'][0]
print(amigdalin)
print(generate_graph(amigdalin)[2])
OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O
59
n_samples = 10
feature_matrices = []
adj_matrices = []
molecule_membership = []
for i in range(n_samples):
feature_matrix, adj_matrix, n_atoms = generate_graph(esol_data['smiles'][i])
feature_matrices.append(feature_matrix)
adj_matrices.append(adj_matrix)
molecule_membership = np.concatenate((molecule_membership, [i]*n_atoms))
big_feature_matrix = np.concatenate(feature_matrices, axis = 0)
print(np.shape(big_feature_matrix))
big_adj_matrix = linalg.block_diag(*adj_matrices)
print(np.shape(big_adj_matrix))
masks = torch.tensor(np.array([np.where(molecule_membership == i, 1, 0) for i in range(n_samples)])).double()
print(np.shape(masks))
(315, 10)
(315, 315)
torch.Size([10, 315])
degree_vector = np.sum(big_adj_matrix, axis=1)
degree_vector = np.sqrt(1/degree_vector)
D = torch.tensor(np.diag(degree_vector))
A = torch.tensor(big_adj_matrix)
norm_graph_laplacian = D @ A @ D
class SolNet(torch.nn.Module):
def __init__(self, laplacian, molecule_masks, input_size = 10, internal_size = 10, output_size = 1, num_conv_layers = 3, activation_function=torch.nn.functional.relu):
super().__init__()
self.W0 = torch.nn.Parameter(torch.randn((input_size, internal_size)))
self.W1 = torch.nn.Parameter(torch.randn((internal_size, internal_size)))
self.W2 = torch.nn.Parameter(torch.randn((internal_size, internal_size)))
self.W3 = torch.nn.Parameter(torch.randn((internal_size, output_size)))
self.laplacian = laplacian
self.molecule_masks = molecule_masks
self.activation_function = activation_function
def forward(self, x):
"""Computes the prediction of the neural network"""
x = self.activation_function(self.laplacian @ x @ self.W0)
x = self.activation_function(self.laplacian @ x @ self.W1)
x = self.activation_function(self.laplacian @ x @ self.W2)
x = self.molecule_masks @ x #aggregation of atom-level features
y = x @ self.W3
return y
## Create Targets
targets = torch.tensor(esol_data['measured log solubility in mols per litre'][:n_samples])
# Only train on 2/3 of points
labelled_indices = np.random.choice(np.arange(n_samples), int(2*n_samples/3), replace = False)
unlabelled_indices = np.setdiff1d(np.arange(n_samples), labelled_indices)
model = SolNet(norm_graph_laplacian, masks)
model.double()
num_epochs = 200
learning_rate = 1E-2
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
all_train_loss = []
all_test_loss = []
for epoch in range(num_epochs):
pred = model(big_feature_matrix).squeeze()
train_loss = torch.mean((pred[labelled_indices] - targets[labelled_indices])**2)
opt.zero_grad()
train_loss.backward()
opt.step()
all_train_loss.append(train_loss)
test_loss = torch.mean((pred[unlabelled_indices] - targets[unlabelled_indices])**2)
all_test_loss.append(test_loss)
plt.plot(all_train_loss, label="Train")
plt.plot(all_test_loss, label="Test")
plt.legend
print(all_train_loss[-1], all_test_loss[-1])
tensor(2.5780e-07, dtype=torch.float64, grad_fn=<MeanBackward0>) tensor(46.2258, dtype=torch.float64, grad_fn=<MeanBackward0>)
class MPNet(torch.nn.Module):
def __init__(self, n_features, activation_function=torch.nn.functional.relu):
super().__init__()
self.W0 = torch.nn.Parameter(torch.randn(n_features * 2, n_features))
self.W1 = torch.nn.Parameter(torch.randn(n_features * 2, n_features))
self.W2 = torch.nn.Parameter(torch.randn(n_features, 1))
self.n_features = n_features
self.activation_function = activation_function
def forward(self, feats, adj):
"""Computes the prediction of the neural network"""
messages = torch.stack([sum([self.activation_function(torch.hstack((feats[i], feats[j])) @ self.W0) for j in np.where(adj[i] == 1)[0]]) for i in range(len(feats))])
feats = self.activation_function(torch.hstack((feats,messages)) @ self.W1)
y = torch.sum(feats, axis=0) @ self.W2
return y
### Train your model here (You can use a very similar approach to the one used in previous practicals) ###
# Note that training here can be slow if your module is not sufficiently vectorized. It might
# be helpful to reduce n_samples to make sure your architecture is working properly before trying to apply
# to larger data.
model = MPNet(10)
model.double()
num_epochs = 100
learning_rate = 1E-2
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
all_train_loss = []
all_test_loss = []
for epoch in range(num_epochs):
print(epoch)
epoch_loss = 0
for i in labelled_indices:
pred = model(torch.tensor(generate_graph(esol_data['smiles'][i])[0]), torch.tensor(generate_graph(esol_data['smiles'][i])[1]))
loss = torch.mean((pred - targets[i])**2)
epoch_loss += loss
opt.zero_grad()
epoch_loss.backward()
opt.step()
all_train_loss.append(epoch_loss/len(labelled_indices))
print(epoch_loss/len(labelled_indices))
epoch_test_loss = 0
for j in unlabelled_indices:
with torch.no_grad():
pred = model(torch.tensor(generate_graph(esol_data['smiles'][j])[0]), torch.tensor(generate_graph(esol_data['smiles'][j])[1]))
loss = torch.mean((pred - targets[j])**2)
epoch_test_loss += loss
all_test_loss.append(epoch_test_loss/len(unlabelled_indices))
print(epoch_test_loss/len(unlabelled_indices))
plt.plot(all_train_loss, label="Train")
plt.plot(all_test_loss, label="Test")
plt.legend()
print(all_train_loss[-1], all_test_loss[-1])
0
tensor(18933.4322, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(23638.4051, dtype=torch.float64)
1
tensor(9920.8200, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(10711.3745, dtype=torch.float64)
2
tensor(3820.1634, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(3056.6433, dtype=torch.float64)
3
tensor(728.8314, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(197.6520, dtype=torch.float64)
4
tensor(204.8037, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(781.3509, dtype=torch.float64)
5
tensor(1412.3053, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(2653.4424, dtype=torch.float64)
6
tensor(3041.7109, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(3957.0769, dtype=torch.float64)
7
tensor(4036.3742, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(4025.2443, dtype=torch.float64)
8
tensor(4071.8748, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(3104.9462, dtype=torch.float64)
9
tensor(3357.0449, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(1792.5787, dtype=torch.float64)
10
tensor(2293.4148, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(670.0776, dtype=torch.float64)
11
tensor(1272.0062, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(109.3485, dtype=torch.float64)
12
tensor(515.2457, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(235.5336, dtype=torch.float64)
13
tensor(134.1525, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(928.7951, dtype=torch.float64)
14
tensor(112.0001, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(1946.2232, dtype=torch.float64)
15
tensor(333.4745, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(3015.9923, dtype=torch.float64)
16
tensor(660.6422, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(3877.2021, dtype=torch.float64)
17
tensor(959.6258, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(4352.1483, dtype=torch.float64)
18
tensor(1134.7690, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(4375.0634, dtype=torch.float64)
19
tensor(1145.5468, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(3986.9178, dtype=torch.float64)
20
tensor(1004.8428, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(3305.3100, dtype=torch.float64)
21
tensor(764.0931, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(2483.8169, dtype=torch.float64)
22
tensor(491.9625, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(1673.0082, dtype=torch.float64)
23
tensor(254.3924, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(989.6831, dtype=torch.float64)
24
tensor(98.4088, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(498.5373, dtype=torch.float64)
25
tensor(42.9029, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(208.2595, dtype=torch.float64)
26
tensor(77.4178, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(81.7358, dtype=torch.float64)
27
tensor(168.7565, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(56.7906, dtype=torch.float64)
28
tensor(273.3267, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(70.6458, dtype=torch.float64)
29
tensor(351.1928, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(80.0270, dtype=torch.float64)
30
tensor(377.4589, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(70.8689, dtype=torch.float64)
31
tensor(347.3225, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(56.0181, dtype=torch.float64)
32
tensor(274.1980, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(63.7998, dtype=torch.float64)
33
tensor(182.5879, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(122.8731, dtype=torch.float64)
34
tensor(98.7992, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(248.9609, dtype=torch.float64)
35
tensor(42.7038, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(437.4782, dtype=torch.float64)
36
tensor(22.7839, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(663.6899, dtype=torch.float64)
37
tensor(35.3935, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(889.5531, dtype=torch.float64)
38
tensor(67.8090, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(1074.4538, dtype=torch.float64)
39
tensor(103.5877, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(1186.1346, dtype=torch.float64)
40
tensor(128.2485, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(1208.4582, dtype=torch.float64)
41
tensor(133.4613, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(1144.0456, dtype=torch.float64)
42
tensor(118.6598, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(1011.6538, dtype=torch.float64)
43
tensor(89.9702, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(839.7407, dtype=torch.float64)
44
tensor(57.1818, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(658.5559, dtype=torch.float64)
45
tensor(30.0240, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(493.1534, dtype=torch.float64)
46
tensor(14.9676, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(359.1912, dtype=torch.float64)
47
tensor(13.5525, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(262.1543, dtype=torch.float64)
48
tensor(22.6027, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(199.6528, dtype=torch.float64)
49
tensor(35.9808, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(165.4151, dtype=torch.float64)
50
tensor(47.1433, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(153.2302, dtype=torch.float64)
51
tensor(51.4864, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(159.3886, dtype=torch.float64)
52
tensor(47.6688, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(182.9792, dtype=torch.float64)
53
tensor(37.5703, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(224.3592, dtype=torch.float64)
54
tensor(25.0971, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(282.7964, dtype=torch.float64)
55
tensor(14.4636, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(354.5874, dtype=torch.float64)
56
tensor(8.4934, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(432.4050, dtype=torch.float64)
57
tensor(7.9294, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(506.2545, dtype=torch.float64)
58
tensor(11.4532, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(565.6372, dtype=torch.float64)
59
tensor(16.5037, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(602.0992, dtype=torch.float64)
60
tensor(20.4334, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(611.2716, dtype=torch.float64)
61
tensor(21.5109, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(593.6971, dtype=torch.float64)
62
tensor(19.3957, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(554.3540, dtype=torch.float64)
63
tensor(15.0610, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(501.1021, dtype=torch.float64)
64
tensor(10.2256, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(442.6537, dtype=torch.float64)
65
tensor(6.5365, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(386.7516, dtype=torch.float64)
66
tensor(4.9445, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(339.0314, dtype=torch.float64)
67
tensor(5.4392, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(302.7300, dtype=torch.float64)
68
tensor(7.2045, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(279.0739, dtype=torch.float64)
69
tensor(9.0725, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(267.9772, dtype=torch.float64)
70
tensor(10.0506, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(268.6576, dtype=torch.float64)
71
tensor(9.6958, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(279.9281, dtype=torch.float64)
72
tensor(8.2055, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(300.1474, dtype=torch.float64)
73
tensor(6.2290, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(327.0052, dtype=torch.float64)
74
tensor(4.5216, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(357.3948, dtype=torch.float64)
75
tensor(3.6093, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(387.5677, dtype=torch.float64)
76
tensor(3.6047, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(413.6117, dtype=torch.float64)
77
tensor(4.2273, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(432.1324, dtype=torch.float64)
78
tensor(4.9874, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(440.9127, dtype=torch.float64)
79
tensor(5.4300, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(439.3291, dtype=torch.float64)
80
tensor(5.3236, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(428.3936, dtype=torch.float64)
81
tensor(4.7225, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(410.4254, dtype=torch.float64)
82
tensor(3.8945, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(388.4817, dtype=torch.float64)
83
tensor(3.1672, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(365.7319, dtype=torch.float64)
84
tensor(2.7754, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(344.9389, dtype=torch.float64)
85
tensor(2.7637, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(328.1711, dtype=torch.float64)
86
tensor(3.0066, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(316.7016, dtype=torch.float64)
87
tensor(3.2886, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(311.0602, dtype=torch.float64)
88
tensor(3.4187, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(311.1431, dtype=torch.float64)
89
tensor(3.3100, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(316.3060, dtype=torch.float64)
90
tensor(3.0037, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(325.4473, dtype=torch.float64)
91
tensor(2.6282, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(337.0902, dtype=torch.float64)
92
tensor(2.3241, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(349.5107, dtype=torch.float64)
93
tensor(2.1764, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(360.9323, dtype=torch.float64)
94
tensor(2.1840, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(369.7706, dtype=torch.float64)
95
tensor(2.2746, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(374.8765, dtype=torch.float64)
96
tensor(2.3510, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(375.7127, dtype=torch.float64)
97
tensor(2.3415, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(372.4134, dtype=torch.float64)
98
tensor(2.2295, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(365.7121, dtype=torch.float64)
99
tensor(2.0522, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(356.7630, dtype=torch.float64)
tensor(2.0522, dtype=torch.float64, grad_fn=<DivBackward0>) tensor(356.7630, dtype=torch.float64)