plot_confusion_matrix: Visualize confusion matrices

Utility function for visualizing confusion matrices via matplotlib

from mlxtend.plotting import plot_confusion_matrix

Overview

Confusion Matrix

For more information on confusion matrices, please see mlxtend.evaluate.confusion_matrix.

References

  • -

Example 1 - Binary

from mlxtend.plotting import plot_confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

binary1 = np.array([[4, 1],
                    [1, 2]])

fig, ax = plot_confusion_matrix(conf_mat=binary1)
plt.show()

png

binary2 = np.array([[21, 1],
                    [3, 1]])

fig, ax = plot_confusion_matrix(conf_mat=binary2, figsize=(2, 2))
plt.show()

png

Example 2 - Binary absolute and relative with colorbar

binary = np.array([[4, 1],
                   [1, 2]])

fig, ax = plot_confusion_matrix(conf_mat=binary,
                                show_absolute=True,
                                show_normed=True,
                                colorbar=True)
plt.show()

png

Example 3 - Multiclass relative

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                show_absolute=False,
                                show_normed=True)
plt.show()

png

Example 4 - Add Class Names

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

class_names = ['class a', 'class b', 'class c', 'class d']

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                show_absolute=False,
                                show_normed=True,
                                class_names=class_names)
plt.show()

png

Example 5 - Changing Color Maps and Font Color

Matplotlib color maps can be chosen as alternative color map via the cmap argument. A list of colormaps can be found here: https://matplotlib.org/stable/tutorials/colors/colormaps.html

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                cmap='summer')

plt.show()

png

As shown above, the font color threshold may not work for certain color maps. By default all values larger than 0.5 times the maximum cell value are converted to white, and everything equal or smaller than 0.5 times the maximum cell value are converted to black.

If you want to change all values above to e.g., white, you can set the color threshold to a negative number. Or, if you want to make all the font colors black, choose ta threshold equal to or greater than 1.

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                fontcolor_threshold=1,
                                cmap='summer')

plt.show()

png

Example 6 - Normalizing Colormaps to Highlight Off-Diagonals

Suppose we have the following confusion matrix for a high-accuracy classifier:

class_dict = {0: 'airplane',
              1: 'automobile',
              2: 'bird',
              3: 'cat',
              4: 'deer',
              5: 'dog',
              6: 'frog'}

cmat = np.array([[972, 0, 1, 1, 1, 1, 3],
                 [0, 1123, 3, 1, 0, 1, 2],
                 [2, 0, 1025, 0, 0, 0, 1],
                 [0, 0, 0, 1005, 0, 2, 0],
                 [0, 1, 1, 0, 967, 0, 4],
                 [0, 0, 0, 6, 0, 881, 3],
                 [2, 3, 0, 1, 3, 4, 941]])

fig, ax = plot_confusion_matrix(
    conf_mat=cmat,
    class_names=class_dict.values(),
)

png

It can be hard to notice the cells where the models makes mistakes. With a log-normalized colormap, these mistakes off the diagonal become easier to see at a glace:

import matplotlib

fig, ax = plot_confusion_matrix(
    conf_mat=cmat,
    class_names=class_dict.values(),
    norm_colormap=matplotlib.colors.LogNorm()  
)

png

API

plot_confusion_matrix(conf_mat, hide_spines=False, hide_ticks=False, figsize=None, cmap=None, colorbar=False, show_absolute=True, show_normed=False, class_names=None, figure=None, axis=None, fontcolor_threshold=0.5)

Plot a confusion matrix via matplotlib.

Parameters

  • conf_mat : array-like, shape = [n_classes, n_classes]

    Confusion matrix from evaluate.confusion matrix.

  • hide_spines : bool (default: False)

    Hides axis spines if True.

  • hide_ticks : bool (default: False)

    Hides axis ticks if True

  • figsize : tuple (default: (2.5, 2.5))

    Height and width of the figure

  • cmap : matplotlib colormap (default: None)

    Uses matplotlib.pyplot.cm.Blues if None

  • colorbar : bool (default: False)

    Shows a colorbar if True

  • show_absolute : bool (default: True)

    Shows absolute confusion matrix coefficients if True. At least one of show_absolute or show_normed must be True.

  • show_normed : bool (default: False)

    Shows normed confusion matrix coefficients if True. The normed confusion matrix coefficients give the proportion of training examples per class that are assigned the correct label. At least one of show_absolute or show_normed must be True.

  • class_names : array-like, shape = [n_classes] (default: None)

    List of class names. If not None, ticks will be set to these values.

  • figure : None or Matplotlib figure (default: None)

    If None will create a new figure.

  • axis : None or Matplotlib figure axis (default: None)

    If None will create a new axis.

  • fontcolor_threshold : Float (default: 0.5)

    Sets a threshold for choosing black and white font colors for the cells. By default all values larger than 0.5 times the maximum cell value are converted to white, and everything equal or smaller than 0.5 times the maximum cell value are converted to black.

Returns

  • fig, ax : matplotlib.pyplot subplot objects

    Figure and axis elements of the subplot.

Examples

For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/