我从“scikit learn/decision tree/export graphviz”的决策树中得到了错误的类名。程序如下所示:
import matplotlib.pyplot as plt
import matplotlib.image as img
import pydot
from sklearn import tree
digital_table = [[0, 0], [0, 1], [1, 0], [1, 1]]
digital_label = ['zero', 'one', 'two', 'three']
digital_name = ['idx-1', 'idx-2']
digital_tree = tree.DecisionTreeClassifier()
digital_tree.fit(digital_table, digital_label)
with open("digital.dot", 'w') as f:
f = tree.export_graphviz(digital_tree,
feature_names=digital_name,
class_names=digital_label,
filled=True, rounded=True,
out_file=f)
(graph,) = pydot.graph_from_dot_file("digital.dot")
graph.write_png("digital.png")
plt.imshow(img.imread('digital.png'))
plt.show()
产出如下:
问题在于叶中显示的类名。例如,如果idx-1为1,idx-2为1,则绿色框应标记为“三”。但是,图像显示标签为“一”。有人能发表你的意见吗?
当您使用DecisionTreeClassifier时,应该将类标签更改为数字,如0,1,2
然后使用:
classe_names = decision_tree_classifier.classes_
它将按升序为您提供类的标签。然后按相同的顺序指定class_标签。它可以是字符串。
在将类标签传递到export\u graphviz