提问者:小点点

使用类名的graphviz树节点的颜色


展开前面的问题:更改使用export graphviz创建的决策树图的颜色

我如何根据主导类(虹膜物种)而不是二元区分来给树的节点着色?这需要虹膜的组合。目标名称、描述类的字符串和iris。目标,全班同学。

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

共2个答案

匿名用户

示例中的代码看起来非常熟悉,因此易于修改:)

对于每个节点Graphviz告诉我们从每个组中我们有多少个样本,即如果它是混合种群或树来决定。我们可以提取这些信息并使用它来获得颜色。

values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]

或者,您可以将GraphViz节点映射回skLearning节点:

values = clf.tree_.value[int(node.get_name())][0]

我们只有3个类,所以每个类都有自己的颜色(红、绿、蓝),混合种群根据它们的分布得到混合颜色。

values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])

我们现在可以很好地看到分离,二等舱越绿,我们拥有的二等舱就越多,蓝色和三等舱也是如此。

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf,
                                feature_names=iris.feature_names,
                                out_file=None,
                                filled=True,
                                rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()

for node in nodes:
    if node.get_label():
        values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
        values = [int(255 * v / sum(values)) for v in values]
        color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
        node.set_fillcolor(color)

graph.write_png('colored_tree.png')

3个以上类的通用解决方案,仅为最终节点着色。

colors =  ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')

for node in nodes:
    if node.get_name() not in ('node', 'edge'):
        values = clf.tree_.value[int(node.get_name())][0]
        #color only nodes where only one class is present
        if max(values) == sum(values):    
            node.set_fillcolor(colors[numpy.argmax(values)])
        #mixed nodes get the default color
        else:
            node.set_fillcolor(colors[-1])

匿名用户

伟大的答案,伙计们。只是为了补充马克西米利安·彼得斯的回答。识别特定颜色的叶节点的另一件事是检查split_criteria(阈值)值。由于叶节点没有子节点,因此也没有拆分标准。

https://github.com/scikit-learn/scikit-learn/blob/a24c8b464d094d2c468a16ea9f8bf8d42d949f84/sklearn/tree/_tree.pyx
TREE_UNDEFINED = -2 
thresholds = clf.tree_.threshold
for node in nodes:
    if node.get_name() not in ('node', 'edge'):
        value = clf.tree_.value[int(node.get_name())][0]
        # color only nodes where only one class is present or if it is a leaf 
        # node
        if max(values) == sum(values) or 
            thresholds[int(node.get_name())] == TREE_UNDEFINED:    
                node.set_fillcolor(colors[numpy.argmax(value)])
        # mixed nodes get the default color
        else:
            node.set_fillcolor(colors[-1])

与问题不完全相关,但添加更多信息以防对其他人有帮助。继续理解基于树的分类器的决策树桩这一想法,Skater增加了使用树代理汇总所有形式的基于树的模型的支持。看看这里的例子。

https://github.com/datascienceinc/Skater/blob/master/examples/rule_list_notebooks/explanation_using_tree_surrogate.ipynb