from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)
export_graphviz(
tree_clf,
out_file='iris_tree.dot',
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
filled=True
)
# Must have graphviz installed for this
from subprocess import check_call
import os
check_call(['dot','-Tpng','iris_tree.dot','-o','iris_dot.png'])
os.remove('iris_tree.dot')
print(tree_clf.predict_proba([[5, 1.5]]))
print(tree_clf.predict([[5, 1.5]]))
[[0. 0.90740741 0.09259259]]
[1]
from decision_boundary import plot_decision_boundaries
plot_decision_boundaries(X, y, DecisionTreeClassifier, max_depth=2)