import numpy as np
import matplotlib.pyplot as plt
class KMeans:
def __init__(self, k, max_iters, centroids=[], plot_steps=100):
self.k=k
self.max_iters=max_iters
self.plot_steps=plot_steps
self.clusters = [[] for i in range(self.k)]
self.centroids = centroids
def predict(self, x):
self.x = x
self.n_samples, self.n_features = x.shape
#Initialize centriods
if self.centroids == []:
random_sample = np.random.choice(self.n_samples,self.k,replace=False)
self.centroids = [self.x[i] for i in random_sample]
for i in range(self.max_iters):
#update clusters
self.clusters = self._createClusters(self.centroids)
if self.plot_steps:
break
#update centroids
centroids_old = self.centroids
self.centroids = self._getNewCentroids(self.clusters)
if self.plot_steps:
break
#check if converged
if self._isConverged(centroids_old,self.centroids):
break
self.plot()
return self._getClusterLabels(self.clusters)
def _createClusters(self,centroids):
clusters = [[] for i in range(self.k)]
for idx, sample in enumerate(self.x):
centroid_idx = self._closestCentroid(sample,centroids)
clusters[centroid_idx].append(idx)
return clusters
def _closestCentroid(self, sample, centroids):
distances = [self.distanceFunction(sample, point) for point in centroids]
closestPoint = np.argmin(distances)
return closestPoint
def _getNewCentroids(self,clusters):
centroids = np.zeros((self.k,self.n_features))
for(i,cluster) in enumerate(clusters):
cluster_mean = np.mean(self.x[cluster],axis=0)
centroids[i]=cluster_mean
return centroids
def _isConverged(self,centroids_old,centroids_new):
distances = [self.distanceFunction(centroids_old[i],centroids_new[i]) for i in range(self.k)]
return sum(distances) < 0.01
def _getClusterLabels(self,clusters):
labels = np.empty(self.n_samples)
for i,cluster in enumerate(clusters):
for j in cluster:
labels[j] = i
return labels
def distanceFunction(self,x1,x2):
return np.sum((x1-x2)**2)
def plot(self):
fig, ax = plt.subplots(figsize=(12,8))
for i, index in enumerate(self.clusters):
point = self.x[index].T
ax.scatter(*point)
for point in self.centroids:
ax.scatter(*point, marker="x", color="black", linewidth=2)
plt.show()
x = np.array([[3, 9],
[5, 7],
[6, 4],
[2, 8],
[7, 3],
[2, 6],
[4, 3],
[3, 8],
[5, 8]])
clusters = 3
centroids = np.array([[3, 9],
[5, 7],
[6, 4]])
K = KMeans(k=clusters, centroids=centroids, max_iters=50,plot_steps=False)
K.predict(x)