import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tabulate import tabulate
from collections import Counter
import warnings
warnings.filterwarnings('ignore')
ds = pd.read_csv('iris.data',header=None)#必须加上header=None,否则第一行数据会变成列名
ds.columns = ['sepal length','sepal width','petal length','petal width','class']
ds.head()
#ds.describe()
def chi_calculation(data,feature,threshold_intervals):
'''
function:
1. 初始化间隔(=样本总量),
2. 计算卡方,递交给上层merge函数
3. 在新的返回的intervals上重复step2,直到intervals达到threshold
'''
##去除冗余操作,unique自动排序
distinct_vals = np.unique(ds[[feature]])
labels = np.unique(ds[['class']])
##初始化intervals,等于实例数字,且区间格式为[A,A]
empty_count = {l:0 for l in labels} #返回字典,用来分类计数{'Iris-setosa': 0, 'Iris-versicolor': 0, 'Iris-virginica': 0}
intervals = [[distinct_vals[i],distinct_vals[i]] for i in range(len(distinct_vals))]
threshold_intervals = 6
##开始计算-重组递归-计算
while len(intervals) > threshold_intervals:
chi = []
for i in range(len(intervals) - 1):
##提取满足条件的数据
row1 = ds[ds[feature].between(intervals[i][0],intervals[i][1])]##在数据集中选出满足这个区间的行
row2 = ds[ds[feature].between(intervals[i+1][0],intervals[i+1][1])]##在数据集中选出满足这个区间的行
total = len(row1) + len(row2)##返回满足落在这两个区间的样本总数
##列方向的数值统计
count_0 = np.array([v for i,v in {**empty_count, **Counter(row1['class'])}.items()])#获取这个区间在三个类别上的分布
count_1 = np.array([v for i,v in {**empty_count, **Counter(row2['class'])}.items()])
count_total = count_0 + count_1 ##一维向量存储列sum
##计算期望 E=(行sum*列sum)/total
E_0 = sum(count_0)*count_total/total
E_1 = sum(count_1)*count_total/total
##计算卡方
chi2 = (count_0 - E_0)**2/E_0 + (count_1 - E_1)**2/E_1
chi2 = np.nan_to_num(chi2)##将nan转为0
chi.append(sum(chi2))
min_chi = min(chi)
intervals = re_intervals(intervals,chi,min_chi)
print('min chi is {}'.format(str(min_chi)))
return intervals
def re_intervals(intervals,chi,min_chi):
'''
function:
将拥有最小chi的区间进行合并
'''
min_chi_index = np.argwhere(chi==min_chi)[0,0]##因为argwhere返回的是二维数组,第一个0打开数组,第二个0取得首个min_chi的索引
new_intervals = []
skip = False
done = False
for i in range(len(intervals)):
if skip:##如果i是最小的,那么他与i+1合并,轮到i+1就要跳过
skip = False
continue
if i==min_chi_index and not done:
t = intervals[i] + intervals[i+1]
new_intervals.append([min(t),max(t)])
skip = True
done = True
else:
new_intervals.append(intervals[i])
return new_intervals
threshold_intervals = 6
for feature in ds.columns[:-1]:
print('Interval for {}'.format(str(feature)))
intervals = chi_calculation(ds,feature,threshold_intervals)
print(tabulate([[intervals]],tablefmt='fancy_grid')) #tablefmt='fancy_grid'))
Interval for sepal length
min chi is 5.065681444991789
╒══════════════════════════════════════════════════════════════════════════╕
│ [[4.3, 4.8], [4.9, 4.9], [5.0, 5.4], [5.5, 5.7], [5.8, 7.0], [7.1, 7.9]] │
╘══════════════════════════════════════════════════════════════════════════╛
Interval for sepal width
min chi is 1.44
╒══════════════════════════════════════════════════════════════════════════╕
│ [[2.0, 2.2], [2.3, 2.4], [2.5, 2.8], [2.9, 2.9], [3.0, 3.3], [3.4, 4.4]] │
╘══════════════════════════════════════════════════════════════════════════╛
Interval for petal length
min chi is 1.0666666666666667
╒══════════════════════════════════════════════════════════════════════════╕
│ [[1.0, 1.9], [3.0, 4.4], [4.5, 4.7], [4.8, 4.9], [5.0, 5.1], [5.2, 6.9]] │
╘══════════════════════════════════════════════════════════════════════════╛
Interval for petal width
min chi is 0.24000000000000005
╒══════════════════════════════════════════════════════════════════════════╕
│ [[0.1, 0.6], [1.0, 1.3], [1.4, 1.6], [1.7, 1.7], [1.8, 1.8], [1.9, 2.5]] │
╘══════════════════════════════════════════════════════════════════════════╛