我试图建立一个决策树使用gridsearch和管道,但我得到一个错误,当我试图导出图像使用图形。我在网上找了找,什么也没找到;一个潜在的问题是如果我不使用best_estimator_
实例,但在这种情况下我使用了。
除导出图形部分外,所有操作都正常(获得准确性和其他指标)。
def TreeOpt(X, y):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
std_scl = StandardScaler()
dec_tree = tree.DecisionTreeClassifier()
pipe = Pipeline(steps=[('std_slc', std_scl),
('dec_tree', dec_tree)])
criterion = ['gini', 'entropy']
max_depth = list(range(1,15))
parameters = dict(dec_tree__criterion=criterion,
dec_tree__max_depth=max_depth)
tree_gs = GridSearchCV(pipe, parameters)
tree_gs.fit(X_train, y_train)
export_graphviz(
tree_gs.best_estimator_,
out_file=("dec_tree.dot"),
feature_names=None,
class_names=None,
filled=True)
但我明白了
<ipython-input-2-bb91ec6ba0d9> in <module>
37 filled=True)
38
---> 39 DecTreeOptimizer(X = df.drop(['quality'], axis=1), y = df.quality)
40
<ipython-input-2-bb91ec6ba0d9> in DecTreeOptimizer(X, y)
30 print("Best score: " + str(tree_GS.best_score_))
31
---> 32 export_graphviz(
33 tree_GS.best_estimator_,
34 out_file=("dec_tree.dot"),
~\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
61 extra_args = len(args) - len(all_args)
62 if extra_args <= 0:
---> 63 return f(*args, **kwargs)
64
65 # extra_args > 0
~\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\tree\_export.py in export_graphviz(decision_tree, out_file, max_depth, feature_names, class_names, label, filled, leaves_parallel, impurity, node_ids, proportion, rotate, rounded, special_characters, precision)
767 """
768
--> 769 check_is_fitted(decision_tree)
770 own_file = False
771 return_string = False
~\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
61 extra_args = len(args) - len(all_args)
62 if extra_args <= 0:
---> 63 return f(*args, **kwargs)
64
65 # extra_args > 0
~\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\utils\validation.py in check_is_fitted(estimator, attributes, msg, all_or_any)
1096
1097 if not attrs:
-> 1098 raise NotFittedError(msg % {'name': type(estimator).__name__})
1099
1100
NotFittedError: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.```
经过长时间的搜索,终于在这里找到了答案:使用pipeline和GridsearchCV绘制最佳决策树
best_estimator_属性返回管道而不是对象,所以我只能这样查询:best_estimator_[1](然后我发现我的代码还有很多其他问题,但这是第2部分)。
我会把这个留在这里,以防其他人需要它。干杯!