import numpy as np
import numpy.random as random
from numpy.fft import fft
from scipy.io import wavfile
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib inline
sns.set()
sns.set(font_scale=1.5)

data_dir = './recordings/'
# determine digits of interest (0 to 9)
digits = [0,1,2,3,4,5,6,7,8,9] # change here to load more digits
# dictionary that will store our values
signals = {d:[] for d in digits}
file_names = {d:[] for d in digits}
# import files
for filename in os.listdir(data_dir):
# iterate over digits
for d in digits:
if filename.startswith(str(d)+'_'):
wav = wavfile.read(data_dir+filename)[1]
if len(wav.shape)<2:
signals[d].append(wav)
file_names[d].append(filename)
# find maximum of vector length
N = max([len(v) for d in digits for v in signals[d]])

# next we split our dataset in train and test
# we will use a 80/20 random split.
# create train/test split
ix = np.arange(100) # Return evenly spaced values within a given interval
random.shuffle(ix)
# select train entries (80 Values)
ix_train = ix[:80]
num_instances = len(ix_train) # set integer value for number of values in the training set
#select test entries (20 Values)
ix_test = ix[80:]

# next we compute the average spectrum of each spoken digit in the training set.
# we will consider a window up to 1.5 KHz
# sampling rate is 8kHz
Ts = 1.0/8000
ix_cut = int(np.ceil(1500*Ts*N)) # Return the ceiling of the input, element-wise
# initialize some plaeholder dictionaries
test1 = {}
test2 = {}
test3 = {}
test4 = {}
# initialize dictionary for storing transforms
transforms = {}
# initialize dictionary for average spectral magnitude
avg_spec_mag = {}
# initialize dictionary for storing mean transforms
mean_transforms = {}
average_spectra = {}
# compute mean transform of each digit and in the training set.
# Make sure to only keep the spectrum up to 1.5kHz

# Code Solution to Q1 Here
for d in digits:
for i in range(0,80,1): # run over the number of test entries
# compute the transform for each signal and store
transforms[i] = np.fft.fft((signals[d][i]),n=N,axis=-1,norm=None)
# pull out the array values from the transform dictionary
summary = transforms.values()
# calculate the average spectral magnitude
avg_spec_mag[d] = np.absolute(sum(summary))/num_instances
test1[d] = np.square(avg_spec_mag[d])
test2[d] = sum(test1[d])
# determine normalized average spectral magnitude
mean_transforms[d] = avg_spec_mag[d]/np.sqrt(test2[d])

# In this next part, plot the average spectral magnitude of each digit.

# Code Solution to Q2 here
centered_transforms = {} # initialize for shifted spectrum values
for d in digits:
#Shift the zero-frequency component to the center of the spectrum
centered_transforms[d] = np.fft.fftshift(mean_transforms[d])
# Generate plots
plt.plot(centered_transforms[d])
plt.title("Average Spectral Magnitude of Digit {}".format(d))
plt.xlabel("Frequency")
plt.show()
# Note: I know this should be centered around 0 Hz instead of 3k Hz, but none of the other shift functions
# I have tried have worked so far...

# classifier function
# receives a vector, computes the product with average digits, and returns the max inner product
# Input: sample x (vector)
#def mean_classifier(x):

# Code Q3a Here
for d in digits:
test3[d] = np.inner(np.abs(avg_spec_mag[d]),np.abs(mean_transforms[d]))
max_similarity = max(test3, key=test3.get)
print(test3)
print("Digit {} has the highest similarity".format(max_similarity))

# Write answer for Q3b here

# Code 3b Here
# I reused code from 3a above and increased the range of numbers accounted for in the list 'digits'

# Write answer for Q4 here
# The accuracy gets better when I classify more digits. For instance, when I run the program when testing only digits 0-5, the similarity value (while still the highest
# of all the digits) is lower than when I run the program testing digits 0-9.

# Write your answer here