import numpy as np
import numpy.random as random
from numpy.fft import fft
from scipy.io import wavfile
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib inline
sns.set()
sns.set(font_scale=1.5)

data_dir = './recordings/'
# determine digits of interest (0 to 9)
digits = [0,1,2,3,4,5,6,7,8,9] # change here to load more digits
# dictionary that will store our values
signals = {d:[] for d in digits}
file_names = {d:[] for d in digits}
# import files
for filename in os.listdir(data_dir):
# iterate over digits
for d in digits:
if filename.startswith(str(d)+'_'):
wav = wavfile.read(data_dir+filename)[1]
if len(wav.shape)<2:
signals[d].append(wav)
file_names[d].append(filename)
# find maximum of vector length
N = max([len(v) for d in digits for v in signals[d]])
#pad the vectors with 0s if the length is less than N
for d in digits:
for i in range(len(signals[d])):
if len(signals[d][i])<N:
signals[d][i]=np.pad(signals[d][i], (0, N-len(signals[d][i])), 'constant', constant_values=(0,0))

print(N)

# next we split our dataset in train and test
# we will use a 80/20 random split.
# create train/test split
ix = np.arange(100)
random.shuffle(ix)
# select train entries
ix_train = ix[:80]
#select test entries
ix_test = ix[80:]

# next we compute the average spectrum of each spoken digit in the training
#set. We will consider a window up to 1.5 KHz
# sampling rate is 8kHz
Ts = 1.0/8000 #this is the interval between samples
ix_cut = int(np.ceil(1500*Ts*N))
# initialize dictionary for storing transforms
transforms = {}
# initialize dictionary for storing mean transforms
mean_transforms = {}
# compute mean transform of each digit and in the training set.
# Make sure to only keep the spectrum up to 1.5kHz

# Code Solution to Q1 Here
#ix_train gives the indices in signals[d] to use
#Here, transforms[d] gives a list of the abs(FFT) results for each data
#file in the ix_train index set
transforms = {d:[] for d in digits}
def fft_digits(d):
for i in ix_train:
transforms[d].append(np.abs(fft(signals[d][i])))
#transforms[d][i] now has the FFT of each d digit with each of the i training
#indexes (not stored with the training index)
for d in digits:
fft_digits(d)
#First, get the average spectral magnitude:
#avg holds the average spectral magnitude for digit d in avg[d]
avg={d:[] for d in digits}
for d in digits:
temp=[]
temp2=[]
for i in range(len(ix_train)):
#for each ith instance of digit d, cut off the signal at k=ix_cut
#which corresponds to 1500 Hz
temp.append(transforms[d][i][:ix_cut])
#temp has the values for the digit d with each ith instance being a
#separate list inside of temp list
#Now, we need to obtain a new list with the element-wise average value
temp2=np.mean(temp, axis=0)
avg[d]=temp2
#Finally, utilize avg[d] and get the normalized version stored in
#mean_transforms[d] where d is the digit:
mean_transforms = {d:[] for d in digits}
for d in digits:
mean_transforms[d]= avg[d]/np.linalg.norm(avg[d])

# In this next part, plot the average spectral magnitude of each digit.

# Code Solution to Q2 here
# x is linspace and y is the avg spectral thing
for d in digits:
y=mean_transforms[d]
x= np.linspace(0.0, 1500, len(avg[d]))
plt.plot(x, y)
plt.xlabel("Hz")
plt.grid()
plt.show()

# classifier function
# receives a vector, computes the product with average digits, and returns
#the max inner product
# Input: sample x (vector)
def mean_classifier(x):
similarities = {d:[] for d in digits}
fft_x= np.abs(fft(x)[:ix_cut])
for d in digits:
similarities[d]= np.inner(fft_x, mean_transforms[d])
max_sim= max(similarities.values())
for d in digits:
if similarities[d]==max_sim:
return d

# Code Q3a Here

# Write anser for Q3b here

# Code 3b Here
accuracy = {d:[] for d in digits}
for d in digits:
correct=0
incorrect=0
for i in ix_test:
if mean_classifier(signals[d][i])==d:
correct+=1
else:
incorrect+=1
accuracy[d]= correct/(correct+incorrect)
print("percent correct for each digit")
print(accuracy)

# Write answer for Q4 here
# Now the accuracy becomes way worse since it is now a 10 way classifier
#report a list of accuracies for each digit

#Q4 Explanation:
#When I set digits = [0,1,2,3,4,5,6,7,8,9] in the beginning of the
#coding file, the accuracy for each digit is shown as the output
#of Code 3b above. This digits list can be altered to any subset of
#digits that I want to test, but the general trend is that as more
#digits are added to the list, the accuracy goes down overall.
#This is because the classifier is now a 10-way classifier
#instead of the previous two-way classifier which introduces
#more digits to compare each sample with. One thing to notice is
#that there are some digits that seem to have a consistently higher
#accuracy rate. This is because some digits are more distinct from
#the group which results in higher classification accuracy.

# Write your answer here