Plotting Learning Curves

A function to plot learning curves for classifiers. Learning curves are extremely useful to analyze if a model is suffering from over- or under-fitting (high variance or high bias). The function can be imported via

from mlxtend.plotting import plot_learning_curves

References

-

Example 1

from mlxtend.plotting import plot_learning_curves
import matplotlib.pyplot as plt
from mlxtend.data import iris_data
from mlxtend.preprocessing import shuffle_arrays_unison
from sklearn.neighbors import KNeighborsClassifier
import numpy as np


# Loading some example data
X, y = iris_data()
X, y = shuffle_arrays_unison(arrays=[X, y], random_seed=123)
X_train, X_test = X[:100], X[100:]
y_train, y_test = y[:100], y[100:]

clf = KNeighborsClassifier(n_neighbors=5)

plot_learning_curves(X_train, y_train, X_test, y_test, clf)
plt.show()

png

API

plot_learning_curves(X_train, y_train, X_test, y_test, clf, train_marker='o', test_marker='^', scoring='misclassification error', suppress_plot=False, print_model=True, style='fivethirtyeight', legend_loc='best')

Plots learning curves of a classifier.

Parameters

  • X_train : array-like, shape = [n_samples, n_features]

    Feature matrix of the training dataset.

  • y_train : array-like, shape = [n_samples]

    True class labels of the training dataset.

  • X_test : array-like, shape = [n_samples, n_features]

    Feature matrix of the test dataset.

  • y_test : array-like, shape = [n_samples]

    True class labels of the test dataset.

  • clf : Classifier object. Must have a .predict .fit method.

  • train_marker : str (default: 'o')

    Marker for the training set line plot.

  • test_marker : str (default: '^')

    Marker for the test set line plot.

  • scoring : str (default: 'misclassification error')

    If not 'misclassification error', accepts the following metrics (from scikit-learn): {'accuracy', 'average_precision', 'f1_micro', 'f1_macro', 'f1_weighted', 'f1_samples', 'log_loss', 'precision', 'recall', 'roc_auc', 'adjusted_rand_score', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'r2'}

  • suppress_plot=False : bool (default: False)

    Suppress matplotlib plots if True. Recommended for testing purposes.

  • print_model : bool (default: True)

    Print model parameters in plot title if True.

  • style : str (default: 'fivethirtyeight')

    Matplotlib style

  • legend_loc : str (default: 'best')

    Where to place the plot legend: {'best', 'upper left', 'upper right', 'lower left', 'lower right'}

Returns

  • errors : (training_error, test_error): tuple of lists

Examples

For usage examples, please see http://rasbt.github.io/mlxtend/user_guide/plotting/learning_curves/