import numpy as np
import pandas as pd
import rdkit
from rdkit.Chem import Draw, Lipinski, Crippen, Descriptors
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
mol = rdkit.Chem.MolFromSmiles('Cn1cnc2n(C)c(=O)n(C)c(=O)c12')
# The following code draws the molecule in 2D - we won't worry about this utility too much moving forward.
# The important idea is that we now have a Mol object that describes one of our favorite molecules.
d2d = rdMolDraw2D.MolDraw2DSVG(350,300)
d2d.DrawMolecule(mol)
d2d.FinishDrawing()
SVG(d2d.GetDrawingText())
# We can iterate over all atoms in the molecule to generate certain atom-level properties.
# Note that the order in which the atoms are index depends on the SMILES string used to generate the molecule... 
# Luckily there are relatively few cases where we'll need to know individual atom indices.
for atom in mol.GetAtoms():
    print(atom.GetAtomicNum(), atom.GetHybridization(), atom.GetFormalCharge())
# Notice that there are no hydrogens included by default. If we want to add hydrogens, we must do so explicity:
mol_with_Hs = rdkit.Chem.AddHs(mol)
print("With added hydrogens:")
for atom in mol_with_Hs.GetAtoms():
    print(atom.GetAtomicNum(), atom.GetHybridization(), atom.GetFormalCharge())
# We can also calculate molecular properties
print(mol_with_Hs.GetNumAtoms())
print(mol_with_Hs.GetNumBonds())
print(mol_with_Hs.GetNumHeavyAtoms())
# Implement your encode_molecule() function here:
def encode_molecule(SMILES_string):
    """Given a SMILES string return a list of molecular encodings
    Arguments:
        SMILES_string: A string representing the SMILE string of the molecule
    Returns:
        mol_encoding: A list of molecule features
    """
    #********************************************************
    #******************* YOUR CODE HERE *********************
    #********************************************************
    mol = rdkit.Chem.MolFromSmiles(SMILES_string)
    mol_encoding = [Lipinski.FractionCSP3(mol),Lipinski.NOCount(mol),Lipinski.NumAliphaticRings(mol),
                   Lipinski.NumHDonors(mol),Lipinski.NumHAcceptors(mol),Lipinski.NumRotatableBonds(mol),
                    Lipinski.NumRotatableBonds(mol),Lipinski.NumHeteroatoms(mol),Lipinski.HeavyAtomCount(mol),
                    Lipinski.RingCount(mol),Lipinski.NHOHCount(mol),Lipinski.NumAliphaticCarbocycles(mol),
                    Lipinski.NumAliphaticHeterocycles(mol),Lipinski.NumAromaticCarbocycles(mol),Lipinski.NumAromaticHeterocycles(mol),
                   Lipinski.NumAromaticRings(mol)]
    return mol_encoding
print(encode_molecule('Cn1cnc2n(C)c(=O)n(C)c(=O)c12'))
delaney_processed = pd.read_csv('delaney-processed.csv', sep=',')
delaney_processed
smile_strings = delaney_processed["smiles"].to_list()
molecule_features = np.array([encode_molecule(smile_string) for smile_string in smile_strings])
esol = delaney_processed['measured log solubility in mols per litre'].to_numpy()
print(molecule_features.shape) # Some sanity checks
print(esol.shape)
num_data_points = len(esol) # This is a useful variable to have around
class Node():
    """Defines a single Node in the decision tree. Note that initializing a Node on a set of data and targets will grow an entire tree based on that data
    Attributes:
        min_size: The minimum size of a split data set that will spawn a child node. Recommend 6 (i.e. splits of size < 6 return a concrete value)
        feature_index: An int indicating the feature index containing the attributes upon which the split is decided
        threshold: A float indicating the threshold for splitting
        left_output: The output of the node if a test point falls at or below the threshold in its feature index
        right_output: The output of the node if a test_point falls above the threshold in its feature index
    """
    def __init__(self, min_size, data, targets):
        self.min_size = min_size
        self.feature_index, self.threshold, self.left_output, self.right_output = self.grow_tree_from_data(data, targets)
    def SDR(self, left_targets, right_targets):
        """ Calculates the standard deviation reduction caused by the splitting of a data set.
            This is calculated as std(all data) - sum(p(split_data)*std(split_data)). 
            Returns 0 if left_targets or rigth_targets is empty (i.e. has length 0)
            Args:
                left_targets, right_targets: The split data 
            Returns:
                SDR: The standard deviation reduction or 0 if left_targets or right_targets is 0
        """
        if len(left_targets) == 0 or len(right_targets) == 0:
            return 0
        #********************************************************
        #******************* YOUR CODE HERE *********************
        #********************************************************
        all_targets = np.hstack((left_targets,right_targets))
        p_left = len(left_targets)/len(all_targets)
        SDR = np.std(all_targets)-p_left*np.std(left_targets)-(1-p_left)*np.std(right_targets)
        return SDR
    def grow_tree_from_data(self, data, targets):
        """ Grows a random decision tree by assigning attributes to this node and spawning child nodes, if necessary.
            Args:
                data, targets: the attributes and targets of the data passed to the node
            Returns:
                A fully-attributed node with child nodes, if necessary
        """
        # Randomly choose n/3 indices to be visible to this node.
        visible_indices = np.random.choice(np.arange(np.shape(data)[1]), size = int(np.shape(data)[1]/3), replace = False)
        #Start keeping track of split performance
        best_SDR = None
        best_index = None
        best_threshold = None
        #Systematically try every possible split on the visible indices and store the best result (as measured by SDR)
        for index in visible_indices:
            for value in data[:, index]:
                left_targets = targets[np.where(data[:,index] <= value)]
                right_targets = targets[np.where(data[:,index] > value)]
                trial_SDR = self.SDR(left_targets, right_targets)
                if (best_SDR == None or best_SDR < trial_SDR):
                    best_SDR = trial_SDR
                    best_index = index
                    best_threshold = value
        # See what the data looks like after the optimal split
        best_left_data = data[np.where(data[:,best_index] <= best_threshold)]
        best_right_data = data[np.where(data[:,best_index] > best_threshold)]
        best_left_targets = targets[np.where(data[:,best_index] <= best_threshold)]
        best_right_targets = targets[np.where(data[:,best_index] > best_threshold)]
        # Return the mean of the targets if the resulting split data is small enough; otherwise generate a new node to split the data further
        if len(best_left_targets) == 0:
            left_output = np.mean(best_right_targets) # No split has occured
        elif len(best_left_targets) < self.min_size or best_SDR == 0:
            left_output = np.mean(best_left_targets)
        else:
            left_output = Node(self.min_size, best_left_data, best_left_targets)
        if len(best_right_targets) == 0:
            right_output = np.mean(best_left_targets) # No split has occured
        elif len(best_right_targets) < self.min_size or best_SDR == 0:
            right_output = np.mean(best_right_targets)
        else:
            right_output = Node(self.min_size, best_right_data, best_right_targets)
        
        return best_index, best_threshold, left_output, right_output
    def predict(self, data_point):
        """Predicts the target value of a data point passed to this node
            Args:
                data_point: The data point passed to this node
            Returns: 
                The predicted target value, either from this node or from an eventual terminal child node
        """
        if data_point[self.feature_index] <= self.threshold:
            if isinstance(self.left_output, float): # If this is a terminal node
                return self.left_output
            else:
                return self.left_output.predict(data_point)
        else:
            if isinstance(self.right_output, float): # If this is a terminal node
                return self.right_output
            else:
                return self.right_output.predict(data_point)
## Write a function to bootstrap sampling
def take_bootstrap_sample(data, targets):
    """Given a data set, takes a sample of len(data) data points, with replacement
        Args:
            data, targets: The input data points
            targets: The target values of the input data
        Returns:
            selected_indices: The list of indices selected for the bootstrapped sample
            data[selected_indices]: The bootstrapped data sample
            targets[selected_indices]: The bootstrapped target values
    """
    #********************************************************
    #******************* YOUR CODE HERE *********************
    #********************************************************
    selected_indices = np.random.choice(np.arange(np.shape(data)[0]), size = int(np.shape(data)[0]/3), replace = True)
    return selected_indices, data[selected_indices], targets[selected_indices]
# Implement Random Tree
class RandomTree():
    """A random tree grown during using the CART algorithm. A full dataset is passed to the tree and it is grown on a bootstrapped sample.
        Attributes:
            min_size: A node must be at least this big to split further
            data: The full, unboostrapped data set on which to grow the tree
            targets: The full, unboostrapped targets on which to grow the tree
    """
    def __init__(self, min_size, data, targets):
        self.selected_indices, self.data, self.targets = take_bootstrap_sample(data, targets) # Take a bootstrap sample
        self.root_node = Node(min_size, self.data, self.targets) # Grow a tree from that bootstrapped sample
    def predict(self, data_point):
        """Predicts a value for a given data point
            Args:
                data_point: The data point to be predicted
            Returns:
                prediction: The predicted value of the data point
        """
        #********************************************************
        #******************* YOUR CODE HERE *********************
        #********************************************************
        prediction = self.root_node.predict(data_point)
        return prediction
#Check out the results on the first 20 ESOL samples
sample_tree = RandomTree(6, molecule_features, esol)
for i in range(20):
    print(sample_tree.predict(molecule_features[i]), esol[i])
## Plant a random forest here:
class RandomForest():
    """A forest built out of CART trees.
        Attributes:
            n_trees: The number of trees in the forest
            data: The dataset upon which the forest is grown
            targets: The target values of the data
            tree_list: A list for holding each tree object in the forest
    """
    def __init__(self, n_trees, min_size, data, targets):
        self.n_trees = n_trees
        self.data = data
        self.targets = targets
        self.tree_list = []
        print("Planting Trees...") 
        for i in range(self.n_trees): # Grow a tree n_trees times and store it in tree_list
            print(i)
            self.tree_list.append(RandomTree(min_size, data, targets))
    def predict_point(self, data_point):
        """ Predicts the target value of a point by averaging the prediction from each tree in the forest
            Args:
                data_point: The data point whose target value is to be predicted
            Returns:
                mean_prediction: The mean prediction by all the trees
        """
        #********************************************************
        #******************* YOUR CODE HERE *********************
        #********************************************************
        predictions = [tree.predict(data_point) for tree in self.tree_list]
        mean_prediction = sum(predictions)/self.n_trees
        return mean_prediction
    def calculate_out_of_bag_error(self, data, targets):
        """Calculates the out-of-bag error for the random forest. This is the mean-squared-error for each data point as predicted only by the trees
            in the forest who did not see that data point in their bootsrapped training set.
            Args:
                data, targets: The data and targets upon which the tree was trained
            Returns:
                obg_error The out-of-bag error
        """
        #********************************************************
        #******************* YOUR CODE HERE *********************
        #********************************************************
        error = 0
        for tree in self.tree_list:
            unselected_indices = np.array(list(set(np.arange(np.shape(data)[0])) - set(tree.selected_indices)))
            predictions = [tree.predict(data_point) for data_point in data[unselected_indices]]
            error += np.sum(np.square(predictions-targets[unselected_indices]))
        obg_error = error/(self.n_trees*len(unselected_indices))
        return obg_error
my_forest = RandomForest(100, 6, molecule_features, esol)
# Check out your results here
for i in range(20):
    print(my_forest.predict_point(molecule_features[i]), esol[i])
print(my_forest.calculate_out_of_bag_error(molecule_features, esol))
def empirical_encoding(SMILES_string):
    """Given a SMILES string return an estimate of log(Solubility) based on empirical model in Delaney paper
    Arguments:
        SMILES_string: A string representing the SMILE string of the molecule
    Returns:
        esol: The estimated log(Solubility) based on Delaney Paper
    """
    #********************************************************
    #******************* YOUR CODE HERE *********************
    #********************************************************
    mol = rdkit.Chem.MolFromSmiles(SMILES_string)
    ClogP = Crippen.MolLogP(mol)
    MolWt = Descriptors.ExactMolWt(mol)
    rot_bonds = Lipinski.NumRotatableBonds(mol)
    aro_prop = len(list(mol.GetAromaticAtoms()))/Lipinski.HeavyAtomCount(mol)
    esol = 0.16 - 0.63*ClogP - 0.0062*MolWt + 0.066*rot_bonds - 0.74*aro_prop
    return esol
smile_strings = delaney_processed["smiles"].to_list()
esol_predictions = np.array([empirical_encoding(smile_string) for smile_string in smile_strings])
print(np.mean(np.square(esol_predictions - esol)))