plot_confusion_matrix: Visualize confusion matrices
Utility function for visualizing confusion matrices via matplotlib
from mlxtend.plotting import plot_confusion_matrix
Overview
Confusion Matrix
For more information on confusion matrices, please see mlxtend.evaluate.confusion_matrix
.
References
- -
Example 1 - Binary
from mlxtend.plotting import plot_confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
binary1 = np.array([[4, 1],
[1, 2]])
fig, ax = plot_confusion_matrix(conf_mat=binary1)
plt.show()
binary2 = np.array([[21, 1],
[3, 1]])
fig, ax = plot_confusion_matrix(conf_mat=binary2, figsize=(2, 2))
plt.show()
Example 2 - Binary absolute and relative with colorbar
binary = np.array([[4, 1],
[1, 2]])
fig, ax = plot_confusion_matrix(conf_mat=binary,
show_absolute=True,
show_normed=True,
colorbar=True)
plt.show()
Example 3 - Multiclass relative
multiclass = np.array([[2, 1, 0, 0],
[1, 2, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
fig, ax = plot_confusion_matrix(conf_mat=multiclass,
colorbar=True,
show_absolute=False,
show_normed=True)
plt.show()
Example 4 - Add Class Names
multiclass = np.array([[2, 1, 0, 0],
[1, 2, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
class_names = ['class a', 'class b', 'class c', 'class d']
fig, ax = plot_confusion_matrix(conf_mat=multiclass,
colorbar=True,
show_absolute=False,
show_normed=True,
class_names=class_names)
plt.show()
Example 5 - Changing Color Maps and Font Color
Matplotlib color maps can be chosen as alternative color map via the cmap
argument. A list of colormaps can be found here: https://matplotlib.org/stable/tutorials/colors/colormaps.html
multiclass = np.array([[2, 1, 0, 0],
[1, 2, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
fig, ax = plot_confusion_matrix(conf_mat=multiclass,
colorbar=True,
cmap='summer')
plt.show()
As shown above, the font color threshold may not work for certain color maps. By default all values larger than 0.5 times the maximum cell value are converted to white, and everything equal or smaller than 0.5 times the maximum cell value are converted to black.
If you want to change all values above to e.g., white, you can set the color threshold to a negative number. Or, if you want to make all the font colors black, choose ta threshold equal to or greater than 1.
fig, ax = plot_confusion_matrix(conf_mat=multiclass,
colorbar=True,
fontcolor_threshold=1,
cmap='summer')
plt.show()
Example 6 - Normalizing Colormaps to Highlight Off-Diagonals
Suppose we have the following confusion matrix for a high-accuracy classifier:
class_dict = {0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog'}
cmat = np.array([[972, 0, 1, 1, 1, 1, 3],
[0, 1123, 3, 1, 0, 1, 2],
[2, 0, 1025, 0, 0, 0, 1],
[0, 0, 0, 1005, 0, 2, 0],
[0, 1, 1, 0, 967, 0, 4],
[0, 0, 0, 6, 0, 881, 3],
[2, 3, 0, 1, 3, 4, 941]])
fig, ax = plot_confusion_matrix(
conf_mat=cmat,
class_names=class_dict.values(),
)
It can be hard to notice the cells where the models makes mistakes. With a log-normalized colormap, these mistakes off the diagonal become easier to see at a glace:
import matplotlib
fig, ax = plot_confusion_matrix(
conf_mat=cmat,
class_names=class_dict.values(),
norm_colormap=matplotlib.colors.LogNorm()
)
API
plot_confusion_matrix(conf_mat, hide_spines=False, hide_ticks=False, figsize=None, cmap=None, colorbar=False, show_absolute=True, show_normed=False, class_names=None, figure=None, axis=None, fontcolor_threshold=0.5)
Plot a confusion matrix via matplotlib.
Parameters
-
conf_mat
: array-like, shape = [n_classes, n_classes]Confusion matrix from evaluate.confusion matrix.
-
hide_spines
: bool (default: False)Hides axis spines if True.
-
hide_ticks
: bool (default: False)Hides axis ticks if True
-
figsize
: tuple (default: (2.5, 2.5))Height and width of the figure
-
cmap
: matplotlib colormap (default:None
)Uses matplotlib.pyplot.cm.Blues if
None
-
colorbar
: bool (default: False)Shows a colorbar if True
-
show_absolute
: bool (default: True)Shows absolute confusion matrix coefficients if True. At least one of
show_absolute
orshow_normed
must be True. -
show_normed
: bool (default: False)Shows normed confusion matrix coefficients if True. The normed confusion matrix coefficients give the proportion of training examples per class that are assigned the correct label. At least one of
show_absolute
orshow_normed
must be True. -
class_names
: array-like, shape = [n_classes] (default: None)List of class names. If not
None
, ticks will be set to these values. -
figure
: None or Matplotlib figure (default: None)If None will create a new figure.
-
axis
: None or Matplotlib figure axis (default: None)If None will create a new axis.
-
fontcolor_threshold
: Float (default: 0.5)Sets a threshold for choosing black and white font colors for the cells. By default all values larger than 0.5 times the maximum cell value are converted to white, and everything equal or smaller than 0.5 times the maximum cell value are converted to black.
Returns
-
fig, ax
: matplotlib.pyplot subplot objectsFigure and axis elements of the subplot.
Examples
For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/