import numpy as np
import numpy.random as random
from numpy.fft import fft
from scipy.io import wavfile
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib inline
sns.set()
sns.set(font_scale=1.5)

data_dir = './recordings/'
# determine digits of interest (0 to 9) 0,1,2,3,4,5,6,7,8,9
digits = [0,1,2,3,4,5,6,7,8,9] # change here to load more digits
# dictionary that will store our values
signals = {d:[] for d in digits}
file_names = {d:[] for d in digits}
# import files
for filename in os.listdir(data_dir):
# iterate over digits
for d in digits:
if filename.startswith(str(d)+'_'):
wav = wavfile.read(data_dir+filename)[1]
if len(wav.shape)<2:
signals[d].append(wav)
file_names[d].append(filename)
# find maximum of vector length
N = max([len(v) for d in digits for v in signals[d]])
#pad vectors with 0's if len less than N
for d in digits:
for i in range(len(signals[d])):
if len(signals[d][i]) < N:
signals[d][i] = np.pad(signals[d][i], (0, N-len(signals[d][i])), 'constant')

# next we split our dataset in train and test
# we will use a 80/20 random split.
# create train/test split
ix = np.arange(100)
random.shuffle(ix)
# select train entries
ix_train = ix[:80]
#select test entries
ix_test = ix[80:]

# next we compute the average spectrum of each spoken digit in the training set.
# we will consider a window up to 1.5 KHz
# sampling rate is 8kHz
Ts = 1.0/8000
ix_cut = int(np.ceil(1500*Ts*N))
# initialize dictionary for storing transforms
transforms = {}
# initialize dictionary for storing mean transforms
mean_transforms = {}
# compute mean transform of each digit and in the training set.
# Make sure to only keep the spectrum up to 1.5kHz

# Code Solution to Q1 Here
transforms = {d:[] for d in digits} #for each digit stores the list of absolute values of FFT of data file.
average = {d:[] for d in digits} #average spectral magnitude for digit d
mean_transforms = {d:[] for d in digits} #normalized average
for d in digits:
transforms[d] = [[abs(fft(signals[d][i],norm='ortho',n=N))] for i in ix_train] #For instances i in ix_train take fft
average[d] = np.mean(transforms[d], axis=0) #element wise averages of t
mean_transforms[d] = average[d]/np.linalg.norm(average[d]) #Normalize average[d] into mean transforms

# In this next part, plot the average spectral magnitude of each digit.

# Code Solution to Q2 here
for d in digits:
y = mean_transforms[d][0]
x = np.linspace(0.0, 1.0/(2.0*Ts), int(N))
plt.plot(x,y)
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude')
plt.grid()
plt.show()

# classifier function
# receives a vector, computes the product with average digits, and returns the max inner product
# Input: sample x (vector)
def mean_classifier(x):
similar = {d:[] for d in digits} #similiarities for each digit d
xft = np.abs(fft(x)) #magnitude of FFT for signal x
for d in digits:
similar[d] = np.inner(xft/np.linalg.norm(xft), mean_transforms[d]) #inner product of normalized xft and mean transforms
most_sim = max(similar.values()) #Most similar
for d in digits: #Assign d
if similar[d] == most_sim:
return d

# Code Q3a Here

# Write anser for Q3b here

# Code 3b Here
percent = {d:[] for d in digits}
for d in digits:
total = 0
correct = 0
for i in ix_test: #Each instance in testing split
if mean_classifier(signals[d][i]) == d: #If signal belong to the right digit
correct += 1 #classified in correct bin
total += 1
percent[d] = correct/(total) #Percentage correct for that digit
print(percent)

# Write answer for Q4 here
# Generally, the accuracy of the classifer decreases when
# we increase the number of digits to 10 because
# you have more bins to classify each signal in. Therefore
# more mistakes can be made in classifying the
# signal since there are more chances that the signals being
# tested are incidentally similar to another
# digit.

# Code Q4 here

# Write your answer here