import numpy as np
import pandas as pd
import rdkit
from rdkit.Chem import Draw, Lipinski, Crippen, Descriptors
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
mol = rdkit.Chem.MolFromSmiles('Cn1cnc2n(C)c(=O)n(C)c(=O)c12')
# The following code draws the molecule in 2D - we won't worry about this utility too much moving forward.
# The important idea is that we now have a Mol object that describes one of our favorite molecules.
d2d = rdMolDraw2D.MolDraw2DSVG(350,300)
d2d.DrawMolecule(mol)
d2d.FinishDrawing()
SVG(d2d.GetDrawingText())
# We can iterate over all atoms in the molecule to generate certain atom-level properties.
# Note that the order in which the atoms are index depends on the SMILES string used to generate the molecule... 
# Luckily there are relatively few cases where we'll need to know individual atom indices.
for atom in mol.GetAtoms():
    print(atom.GetAtomicNum(), atom.GetHybridization(), atom.GetFormalCharge())
# Notice that there are no hydrogens included by default. If we want to add hydrogens, we must do so explicity:
mol_with_Hs = rdkit.Chem.AddHs(mol)
print("With added hydrogens:")
for atom in mol_with_Hs.GetAtoms():
    print(atom.GetAtomicNum(), atom.GetHybridization(), atom.GetFormalCharge())
# We can also calculate molecular properties
print(mol_with_Hs.GetNumAtoms())
print(mol_with_Hs.GetNumBonds())
print(mol_with_Hs.GetNumHeavyAtoms())
# Implement your encode_molecule() function here:
def encode_molecule(SMILES_string):
    """Given a SMILES string return a list of molecular encodings
    Arguments:
        SMILES_string: A string representing the SMILE string of the molecule
    Returns:
        mol_encoding: A list of molecule features
    """
    #********************************************************
    #******************* YOUR CODE HERE *********************
    #********************************************************
    mol = rdkit.Chem.MolFromSmiles(SMILES_string)
    mol_encoding = []
    mol_encoding.append(rdkit.Chem.Lipinski.HeavyAtomCount(mol))
    mol_encoding.append(rdkit.Chem.Lipinski.NHOHCount(mol)) 
    mol_encoding.append(rdkit.Chem.Lipinski.NOCount(mol)) 
    mol_encoding.append(rdkit.Chem.Lipinski.NumHAcceptors(mol))
    mol_encoding.append(rdkit.Chem.Lipinski.NumHDonors(mol)) 
    mol_encoding.append(rdkit.Chem.Lipinski.NumAliphaticHeterocycles(mol)) 
    mol_encoding.append(rdkit.Chem.Lipinski.NumAromaticHeterocycles(mol))
    mol_encoding.append(rdkit.Chem.Lipinski.FractionCSP3(mol))
    mol_encoding.append(rdkit.Chem.Lipinski.NumAliphaticCarbocycles(mol))
    mol_encoding.append(rdkit.Chem.Lipinski.NumRotatableBonds(mol))
    #add more props as needed
    mol_encoding = np.array(mol_encoding)
    return mol_encoding
print(encode_molecule('Cn1cnc2n(C)c(=O)n(C)c(=O)c12'))
delaney_processed = pd.read_csv('delaney-processed.csv', sep=',')
delaney_processed
smile_strings = delaney_processed["smiles"].to_list()
molecule_features = np.array([encode_molecule(smile_string) for smile_string in smile_strings])
esol = delaney_processed['measured log solubility in mols per litre'].to_numpy()
print(molecule_features.shape) # Some sanity checks
print(esol.shape)
num_data_points = len(esol) # This is a useful variable to have around
class Node():
    """Defines a single Node in the decision tree. Note that initializing a Node on a set of data and targets will grow an entire tree based on that data
    Attributes:
        min_size: The minimum size of a split data set that will spawn a child node. Recommend 6 (i.e. splits of size < 6 return a concrete value)
        feature_index: An int indicating the feature index containing the attributes upon which the split is decided
        threshold: A float indicating the threshold for splitting
        left_output: The output of the node if a test point falls at or below the threshold in its feature index
        right_output: The output of the node if a test_point falls above the threshold in its feature index
    """
    def __init__(self, min_size, data, targets):
        self.min_size = min_size
        self.feature_index, self.threshold, self.left_output, self.right_output = self.grow_tree_from_data(data, targets)
    def SDR(self, left_targets, right_targets):
        """ Calculates the standard deviation reduction caused by the splitting of a data set.
            This is calculated as std(all data) - sum(p(split_data)*std(split_data)). 
            Returns 0 if left_targets or rigth_targets is empty (i.e. has length 0)
            Args:
                left_targets, right_targets: The split data 
            Returns:
                SDR: The standard deviation reduction or 0 if left_targets or right_targets is 0
        """
        if len(left_targets) == 0 or len(right_targets) == 0:
            return 0
        
        all_targets = np.concatenate([left_targets, right_targets])
        all_stdev = np.std(all_targets)
        SDR = all_stdev - (1 / len(all_targets)) * (len(left_targets) * np.std(left_targets) + len(right_targets) * np.std(right_targets))
        return SDR
    def grow_tree_from_data(self, data, targets):
        """ Grows a random decision tree by assigning attributes to this node and spawning child nodes, if necessary.
            Args:
                data, targets: the attributes and targets of the data passed to the node
            Returns:
                A fully-attributed node with child nodes, if necessary
        """
        # Randomly choose n/3 indices to be visible to this node.
        visible_indices = np.random.choice(np.arange(np.shape(data)[1]), size = int(np.shape(data)[1]/3), replace = False)
        #Start keeping track of split performance
        best_SDR = None
        best_index = None
        best_threshold = None
        #Systematically try every possible split on the visible indices and store the best result (as measured by SDR)
        for index in visible_indices:
            for value in data[:, index]:
                left_targets = targets[np.where(data[:,index] <= value)]
                right_targets = targets[np.where(data[:,index] > value)]
                trial_SDR = self.SDR(left_targets, right_targets)
                if (best_SDR == None or best_SDR < trial_SDR):
                    best_SDR = trial_SDR
                    best_index = index
                    best_threshold = value
        # See what the data looks like after the optimal split
        best_left_data = data[np.where(data[:,best_index] <= best_threshold)]
        best_right_data = data[np.where(data[:,best_index] > best_threshold)]
        best_left_targets = targets[np.where(data[:,best_index] <= best_threshold)]
        best_right_targets = targets[np.where(data[:,best_index] > best_threshold)]
        # Return the mean of the targets if the resulting split data is small enough; otherwise generate a new node to split the data further
        if len(best_left_targets) == 0:
            left_output = np.mean(best_right_targets) # No split has occured
        elif len(best_left_targets) < self.min_size or best_SDR == 0:
            left_output = np.mean(best_left_targets)
        else:
            left_output = Node(self.min_size, best_left_data, best_left_targets)
        if len(best_right_targets) == 0:
            right_output = np.mean(best_left_targets) # No split has occured
        elif len(best_right_targets) < self.min_size or best_SDR == 0:
            right_output = np.mean(best_right_targets)
        else:
            right_output = Node(self.min_size, best_right_data, best_right_targets)
        
        return best_index, best_threshold, left_output, right_output
    def predict(self, data_point):
        """Predicts the target value of a data point passed to this node
            Args:
                data_point: The data point passed to this node
            Returns: 
                The predicted target value, either from this node or from an eventual terminal child node
        """
        if data_point[self.feature_index] <= self.threshold:
            if isinstance(self.left_output, float): # If this is a terminal node
                return self.left_output
            else:
                return self.left_output.predict(data_point)
        else:
            if isinstance(self.right_output, float): # If this is a terminal node
                return self.right_output
            else:
                return self.right_output.predict(data_point)
## Write a function to bootstrap sampling
def take_bootstrap_sample(data, targets):
    """Given a data set, takes a sample of len(data) data points, with replacement
        Args:
            data, targets: The input data points
            targets: The target values of the input data
        Returns:
            selected_indices: The list of indices selected for the bootstrapped sample
            data[selected_indices]: The bootstrapped data sample
            targets[selected_indices]: The bootstrapped target values
    """
    
    selected_indices = np.random.choice(np.arange(len(data)), size=len(data), replace = True) #also maybe try this with .shape w 0th dim (len should give 0th dim)
    return selected_indices, data[selected_indices], targets[selected_indices]
# Implement Random Tree
class RandomTree():
    """A random tree grown during using the CART algorithm. A full dataset is passed to the tree and it is grown on a bootstrapped sample.
        Attributes:
            min_size: A node must be at least this big to split further
            data: The full, unboostrapped data set on which to grow the tree
            targets: The full, unboostrapped targets on which to grow the tree
    """
    def __init__(self, min_size, data, targets):
        self.selected_indices, self.data, self.targets = take_bootstrap_sample(data, targets) # Take a bootstrap sample
        self.root_node = Node(min_size, self.data, self.targets) # Grow a tree from that bootstrapped sample
    def predict(self, data_point):
        """Predicts a value for a given data point
            Args:
                data_point: The data point to be predicted
            Returns:
                prediction: The predicted value of the data point
        """
        prediction = self.root_node.predict(data_point)
        return prediction
#Check out the results on the first 20 ESOL samples
sample_tree = RandomTree(6, molecule_features, esol)
for i in range(20):
    print(sample_tree.predict(molecule_features[i]), esol[i])
## Plant a random forest here:
class RandomForest():
    """A forest built out of CART trees.
        Attributes:
            n_trees: The number of trees in the forest
            data: The dataset upon which the forest is grown
            targets: The target values of the data
            tree_list: A list for holding each tree object in the forest
    """
    def __init__(self, n_trees, min_size, data, targets):
        self.n_trees = n_trees
        self.data = data
        self.targets = targets
        self.tree_list = []
        print("Planting Trees...") 
        for i in range(self.n_trees): # Grow a tree n_trees times and store it in tree_list
            print(i)
            self.tree_list.append(RandomTree(min_size, data, targets))
    def predict_point(self, data_point):
        """ Predicts the target value of a point by averaging the prediction from each tree in the forest
            Args:
                data_point: The data point whose target value is to be predicted
            Returns:
                mean_prediction: The mean prediction by all the trees
        """
        votes = []
        for tree in self.tree_list:
            votes.append(tree.predict(data_point))
        mean_prediction = np.mean(np.array(votes))
        return mean_prediction
    def calculate_out_of_bag_error(self, data, targets):
        """Calculates the out-of-bag error for the random forest. This is the mean-squared-error for each data point as predicted only by the trees
            in the forest who did not see that data point in their bootsrapped training set.
            Args:
                data, targets: The data and targets upon which the tree was trained
            Returns:
                obg_error The out-of-bag error
        """
        predictions = []
        for idx in range(len(data)): 
            votes = []
            for tree in self.tree_list:
                if idx not in tree.selected_indices:
                    votes.append(tree.predict(data[idx]))
            if len(votes) == 0:
                perdictions.append(None)
            else:
                predictions.append(np.mean(votes))
        None_indices = np.where(predictions == np.array(None))
        predictions = np.delete(predictions, None_indices)
        targets = np.delete(targets, None_indices)
        obg_error =  np.mean(np.square(predictions - targets))
        return obg_error
my_forest = RandomForest(100, 6, molecule_features, esol)
# Check out your results here
for i in range(20):
    print(my_forest.predict_point(molecule_features[i]), esol[i])
print(my_forest.calculate_out_of_bag_error(molecule_features, esol))
def empirical_encoding(SMILES_string):
    """Given a SMILES string return an estimate of log(Solubility) based on empirical model in Delaney paper
    Arguments:
        SMILES_string: A string representing the SMILE string of the molecule
    Returns:
        esol: The estimated log(Solubility) based on Delaney Paper
    """
    #********************************************************
    #******************* YOUR CODE HERE *********************
    #********************************************************
    mol = rdkit.Chem.MolFromSmiles(SMILES_string)
    aromaticProp = len(list(mol.GetAromaticAtoms())) / mol.GetNumHeavyAtoms()
    esol = 0.16 - 0.63 * (rdkit.Chem.Crippen.MolLogP(mol)) - 0.0062 * rdkit.Chem.Descriptors.MolWt(mol) + 0.066 * rdkit.Chem.Lipinski.NumRotatableBonds(mol) - 0.74 * aromaticProp
    return esol
print(empirical_encoding('OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O'))
smile_strings = delaney_processed["smiles"].to_list()
esol_predictions = np.array([empirical_encoding(smile_string) for smile_string in smile_strings])
print(np.mean(np.square(esol_predictions - esol)))