我有以下代码:
def plot_learning_curve(estimator,X, y, para, axes=None, cv=None,
n_jobs=None, train_sizes=np.linspace(.1, 1.0, 5)):
_, axes = plt.subplots(1, 1, figsize=(5, 5))
axes.set_title(f'Curvas de error de aprendizaje (cv=10) para: {para}')
axes.set_xlabel("Nº Ejemplos de entrenamiento")
axes.set_ylabel("Error: 1 - F1")
train_sizes, train_scores, test_scores, fit_times, _ = \
learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs,
train_sizes=train_sizes,
return_times=True)
train_scores_mean = np.mean([1] - train_scores, axis=1)
train_scores_std = np.std([1] - train_scores, axis=1)
test_scores_mean = np.mean([1] - test_scores, axis=1)
test_scores_std = np.std([1] - test_scores, axis=1)
# Plot learning curve
axes.grid()
axes.fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color="r")
print(train_scores_mean)
print(train_scores_std)
axes.fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1,
color="g")
axes.plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Error de entrenamiento medio")
axes.plot(train_sizes, test_scores_mean, 'o-', color="g",
label="Error de validación medio")
axes.legend(loc="best")
return plt
我可以绘制以下图像:
然而,正如你所看到的,簧片线并不像绿线那样填满空格。
但数据是:
[0. 0. 0.00037093 0.0053362 0.01481688]
[0. 0. 0.000383 0.000422 0.00081868]
错误在哪里?
您可能需要添加一个命令来告诉库应该有第二个刻度。
axes.plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Error de entrenamiento medio")
axes2 = axes.twinx()
axes2.plot(train_sizes, test_scores_mean, 'o-', color="g",
label="Error de validación medio")
axes2.legend(loc="best")
return plt
https://matplotlib.org/gallery/api/two_scales.html