# import libraries and K-Means function
import numpy as np
import matplotlib.pyplot as plt
from pandas import DataFrame
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
# Create dataframe for data
df = pd.read_csv('cereal.csv')
df.head()
# drop features we aren't using or create a new dataframe with only the features we want
df = df.drop(['type', 'calories', 'protein', 'fat', 'sodium', 'fiber', 'potass', 'vitamins', 'shelf', 'weight', 'cups'], axis=1)
df.head()
# Check to see if any instances have NaN for entries using .info
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 77 entries, 0 to 76
Data columns (total 5 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 name 77 non-null object
1 mfr 77 non-null object
2 carbo 77 non-null float64
3 sugars 77 non-null int64
4 rating 77 non-null float64
dtypes: float64(2), int64(1), object(2)
memory usage: 3.1+ KB
# Check to see if any cereals have negative values for sugars or carbohydrates; if 2 or less cereals, drop those
# data instances; otherwise replace negative values with 0
x = df[(df.sugars < 0) | (df.carbo < 0)]
if len(x) < 2:
df = df[(df.sugars >= 0) & (df.carbo >= 0)]
else:
throw_an_error()
df.head()
# Address negative values
# set background grid for plots
sns.set_style( 'whitegrid')
# Plot number of products per manufacturer
sns.countplot(x='mfr', data=df)
# There is only one cereal from American Home Foods Company so we drop that data sample
#
# Type df[df.mfr == 'A'] and it will give you all data from this manufacturer.
# Make sure you set df = drop command; use command to make sure it was deleted
#
df = df[df.mfr != 'A']
# For plots we would like the name of manufacturer instead of just "N" or "Q"
# Use .map or .apply
#
# Then drop 'mfr' column
#
df['Manufacturer'] = df.mfr.map({
'A': 'American Home Food Products',
'G': 'General Mills',
'K': 'Kelloggs',
'N': 'Nabisco',
'P': 'Post',
'Q': 'Quaker Oats',
'R': 'Ralston Purina'
})
df = df.drop(['mfr'], axis=1)
df.head()
# Plot to see which manufacturer has highest rated cereals
sns.displot(x='rating', hue='Manufacturer', multiple='stack', data=df)
# Which cereal is rated highest?
print(df.iloc[df.rating.idxmax()])
name All-Bran with Extra Fiber
carbo 8.0
sugars 0
rating 93.704912
Manufacturer Kelloggs
Name: 3, dtype: object
# Look at sugars per brand by plotting
sns.displot(x='sugars', hue='Manufacturer', multiple='stack', data=df)
# Get data for clustering
sugars = df.sugars.to_numpy().reshape(-1,1)
# Form model, fit data and print out cluster centers
km = KMeans(n_clusters=3, init='random', random_state = 0)
y_km = km.fit(sugars)
print(km.cluster_centers_)
[[ 1.95833333]
[ 6.84 ]
[12.03846154]]
# Add column to dataframe for this clusters, say sugars_clusters
df['sugar_clusters'] = km.labels_
df.head()
# Plot clusters
sns.displot(x='sugar_clusters', hue='Manufacturer', multiple='stack', data=df)
name carbo sugars rating Manufacturer \
0 100% Bran 5.0 6 68.402973 Nabisco
1 100% Natural Bran 8.0 8 33.983679 Quaker Oats
2 All-Bran 7.0 5 59.425505 Kelloggs
3 All-Bran with Extra Fiber 8.0 0 93.704912 Kelloggs
4 Almond Delight 14.0 8 34.384843 Ralston Purina
.. ... ... ... ... ...
72 Triples 21.0 3 39.106174 General Mills
73 Trix 13.0 12 27.753301 General Mills
74 Wheat Chex 17.0 3 49.787445 Ralston Purina
75 Wheaties 17.0 3 51.592193 General Mills
76 Wheaties Honey Gold 16.0 8 36.187559 General Mills
sugar_clusters Cluster names
0 1 Medium
1 1 Medium
2 1 Medium
3 0 Low
4 1 Medium
.. ... ...
72 0 Low
73 2 High
74 0 Low
75 0 Low
76 1 Medium
[75 rows x 7 columns]
# Determine which cluster number corresponds to lowest, middle and highest level and create a new
# column in dataframe using .map
#
df['Cluster names'] = df['sugar_clusters'].map({
0: 'Low',
1: 'Medium',
2: 'High'
})
#
# How are cereals distributed among the 3 levels?
sns.displot(x='Cluster names', hue='Manufacturer', multiple='stack', data=df)
# They are roughly evenly distributed among the levels.
# Which cereals have the highest sugar levels
print(df[df['Cluster names'] == "High"].name)
5 Apple Cinnamon Cheerios
6 Apple Jacks
10 Cap'n'Crunch
14 Cocoa Puffs
17 Corn Pops
18 Count Chocula
22 Crispy Wheat & Raisins
24 Froot Loops
25 Frosted Flakes
27 Fruit & Fibre Dates; Walnuts; and Oats
28 Fruitful Bran
29 Fruity Pebbles
30 Golden Crisp
35 Honey Graham Ohs
36 Honey Nut Cheerios
37 Honey-comb
42 Lucky Charms
44 Muesli Raisins; Dates; & Almonds
45 Muesli Raisins; Peaches; & Pecans
46 Mueslix Crispy Blend
51 Oatmeal Raisin Crisp
52 Post Nat. Raisin Bran
58 Raisin Bran
66 Smacks
70 Total Raisin Bran
73 Trix
Name: name, dtype: object
# Which cereals have the lowest sugar levels
print(df[df['Cluster names'] == "Low"].name)
3 All-Bran with Extra Fiber
11 Cheerios
15 Corn Chex
16 Corn Flakes
20 Cream of Wheat (Quick)
21 Crispix
33 Grape-Nuts
34 Great Grains Pecan
40 Kix
50 Nutri-grain Wheat
53 Product 19
54 Puffed Rice
55 Puffed Wheat
61 Rice Chex
62 Rice Krispies
63 Shredded Wheat
64 Shredded Wheat 'n'Bran
65 Shredded Wheat spoon size
67 Special K
69 Total Corn Flakes
71 Total Whole Grain
72 Triples
74 Wheat Chex
75 Wheaties
Name: name, dtype: object
# If you eat a particular cereal like Apple Jacks, Froot Loops, etc. what cluster is it in?
my_cereal = 'Apple Jacks'
print (f"The data instance and sugar cluster for {my_cereal} is ", \
df[df.name == my_cereal]['Cluster names'] )
The data instance and sugar cluster for Apple Jacks is 6 High
Name: Cluster names, dtype: object
carbo = df.carbo.to_numpy().reshape(-1,1)
km = KMeans(n_clusters=3, init='random', random_state = 0)
y_km = km.fit(carbo)
print(km.cluster_centers_)
[[15.59677419]
[11.08333333]
[20.92857143]]
df['carbo_labels'] = km.labels_
df['Carbohydrate cluster'] = df['carbo_labels'].map({
1: 'Low',
0: 'Medium',
2: 'High'
})
df.head()
print('High carbs, low sugar')
print(df[(df['Carbohydrate cluster'] == 'High') & (df['Cluster names'] == 'Low')].name)
print('Low carbs, high sugar')
print(df[(df['Carbohydrate cluster'] == 'Low') & (df['Cluster names'] == 'High')].name)
High carbs, low sugar
15 Corn Chex
16 Corn Flakes
20 Cream of Wheat (Quick)
21 Crispix
40 Kix
53 Product 19
61 Rice Chex
62 Rice Krispies
64 Shredded Wheat 'n'Bran
65 Shredded Wheat spoon size
69 Total Corn Flakes
72 Triples
Name: name, dtype: object
Low carbs, high sugar
5 Apple Cinnamon Cheerios
6 Apple Jacks
10 Cap'n'Crunch
14 Cocoa Puffs
17 Corn Pops
18 Count Chocula
22 Crispy Wheat & Raisins
24 Froot Loops
27 Fruit & Fibre Dates; Walnuts; and Oats
29 Fruity Pebbles
30 Golden Crisp
35 Honey Graham Ohs
36 Honey Nut Cheerios
42 Lucky Charms
52 Post Nat. Raisin Bran
66 Smacks
73 Trix
Name: name, dtype: object
sns.displot(x='carbo_labels', hue='Manufacturer', multiple='stack', data=df)
sns.displot(x='Carbohydrate cluster', hue='Manufacturer', multiple='stack', data=df)
# There are a roughly equal amount of low and medium carbohydrate cereals,
# with comparatively little high carbohydrate cereals.
print("High carbohydrates")
print(df[df['Carbohydrate cluster'] == "High"].name)
print("Low carbohydrates")
print(df[df['Carbohydrate cluster'] == "Low"].name)
High carbohydrates
15 Corn Chex
16 Corn Flakes
20 Cream of Wheat (Quick)
21 Crispix
39 Just Right Fruit & Nut
40 Kix
49 Nutri-Grain Almond-Raisin
53 Product 19
61 Rice Chex
62 Rice Krispies
64 Shredded Wheat 'n'Bran
65 Shredded Wheat spoon size
69 Total Corn Flakes
72 Triples
Name: name, dtype: object
Low carbohydrates
0 100% Bran
1 100% Natural Bran
2 All-Bran
3 All-Bran with Extra Fiber
5 Apple Cinnamon Cheerios
6 Apple Jacks
9 Bran Flakes
10 Cap'n'Crunch
12 Cinnamon Toast Crunch
13 Clusters
14 Cocoa Puffs
17 Corn Pops
18 Count Chocula
19 Cracklin' Oat Bran
22 Crispy Wheat & Raisins
24 Froot Loops
27 Fruit & Fibre Dates; Walnuts; and Oats
29 Fruity Pebbles
30 Golden Crisp
34 Great Grains Pecan
35 Honey Graham Ohs
36 Honey Nut Cheerios
41 Life
42 Lucky Charms
52 Post Nat. Raisin Bran
54 Puffed Rice
55 Puffed Wheat
59 Raisin Nut Bran
66 Smacks
73 Trix
Name: name, dtype: object
# If you eat a particular cereal like Apple Jacks, Froot Loops, etc. what cluster is it in?
my_cereal = 'Apple Jacks'
print (f"The data instance and carbohydrate cluster for {my_cereal} is ", \
df[df.name == my_cereal]['Carbohydrate cluster'] )
The data instance and carbohydrate cluster for Apple Jacks is 6 Low
Name: Carbohydrate cluster, dtype: object