Generate Color Palette From Images Using Unsupervised Learning (K-Means Clustering)
Importing Libs
import pandas as pd
import numpy as np
import requests
from urllib.request import urlretrieve
from io import BytesIO
from PIL import Image, ImageColor
import colorsys
import requests
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import plotly.express as px
import seaborn as sns
sns.set_theme()
Load Image
Img_URL
response = requests.get(Img_URL)
img = Image.open(BytesIO(response.content))
img
Get RGB Values
# convert RGBA to RGB if necessary
n_dims = np.array(img).shape[-1]
n_dims # RGB -> has 3 channels
# Convert RGBA to RGB
if n_dims == 4:
temp_img = Image.new("RGB", img.size, (255, 255, 255))
temp_img.paste(img, mask=img.split()[3]) # 3 is the alpha channel
img = temp_img
r,g,b = np.array(img).reshape(-1,n_dims).T
Store RGB Value of all Pixels to DF
df_rgb = pd.DataFrame({"R": r, "G": g, "B": b}).sample(n=100000)
df_rgb
Clustering
Set Palette Size
Number of clusters is the palette size.
palette_size = 6
Train KMeans Model
kmeans_model = KMeans(n_clusters=palette_size, random_state=0, init='k-means++',n_init=100).fit(df_rgb)
Fetch Cluster Centers
The cluster centers are the palette colours in RGB Format.
kmeans_model.cluster_centers_.astype(int).tolist()
palette = kmeans_model.cluster_centers_.astype(int).tolist()
Palette Visualization & Application
Sort Palette
palette.sort(key=lambda rgb: (lambda r, g, b : colorsys.rgb_to_hsv(r, g, b))(*rgb))
palette
Visualize Palette
img_rgb = np.array([palette], dtype=np.uint8)
fig = px.imshow(img_rgb)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
fig.show()
Use Palette in Seaborn
Convert RGB to Hex Format
def rgb_to_hex(r, g, b):
return '#{:02x}{:02x}{:02x}'.format(r, g, b)
list_hex = []
for each_rgb_col in palette:
list_hex.append(rgb_to_hex(each_rgb_col[0], each_rgb_col[1], each_rgb_col[2]))
list_hex
Set the Palette
sns.set_palette(list_hex)
sns.color_palette()
Output Comparison
display(img)
sns.color_palette(list_hex)
Cluster Analysis
clusters = kmeans_model.predict(df_rgb)
df_rgb['n_cluster'] = clusters
df_rgb[df_rgb['n_cluster'] == 0].head()
df_rgb['hex_code'] = df_rgb.apply(lambda x: rgb_to_hex(x[0], x[1], x[2]), axis = 1)
for each_cluster in range(0, palette_size):
print('Cluster ', each_cluster)
display(sns.color_palette(list(df_rgb[df_rgb['n_cluster'] == each_cluster]['hex_code'][:10])))