提问者:小点点

为什么第二行没有显示范围?


我有以下代码:

def plot_learning_curve(estimator,X, y, para, axes=None, cv=None,
                        n_jobs=None, train_sizes=np.linspace(.1, 1.0, 5)):
    
    _, axes = plt.subplots(1, 1, figsize=(5, 5))

    axes.set_title(f'Curvas de error de aprendizaje (cv=10) para: {para}')
    axes.set_xlabel("Nº Ejemplos de entrenamiento")
    axes.set_ylabel("Error: 1 - F1")

    train_sizes, train_scores, test_scores, fit_times, _ = \
        learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs,
                       train_sizes=train_sizes,
                       return_times=True)
    train_scores_mean = np.mean([1] - train_scores, axis=1)
    train_scores_std = np.std([1] - train_scores, axis=1)
    test_scores_mean = np.mean([1] - test_scores, axis=1)
    test_scores_std = np.std([1] - test_scores, axis=1)
   

    # Plot learning curve
    axes.grid()
    axes.fill_between(train_sizes, train_scores_mean - train_scores_std,
                         train_scores_mean + train_scores_std, alpha=0.1,
                         color="r")
    
    print(train_scores_mean)
    print(train_scores_std)
    
    axes.fill_between(train_sizes, test_scores_mean - test_scores_std,
                         test_scores_mean + test_scores_std, alpha=0.1,
                         color="g")
    
    
    axes.plot(train_sizes, train_scores_mean, 'o-', color="r",
                 label="Error de entrenamiento medio")
    axes.plot(train_sizes, test_scores_mean, 'o-', color="g",
                 label="Error de validación medio")
    axes.legend(loc="best")

    return plt

我可以绘制以下图像:

然而,正如你所看到的,簧片线并不像绿线那样填满空格。

但数据是:

[0.         0.         0.00037093 0.0053362  0.01481688]
[0.         0.         0.000383   0.000422   0.00081868]

错误在哪里?


共1个答案

匿名用户

您可能需要添加一个命令来告诉库应该有第二个刻度。

axes.plot(train_sizes, train_scores_mean, 'o-', color="r",
             label="Error de entrenamiento medio")
axes2 = axes.twinx()
axes2.plot(train_sizes, test_scores_mean, 'o-', color="g",
             label="Error de validación medio")
axes2.legend(loc="best")

return plt

https://matplotlib.org/gallery/api/two_scales.html