Scategory_scatter: Create a scatterplot with categories in different colors

A function to quickly produce a scatter plot colored by categories from a pandas DataFrame or NumPy ndarray object.

from mlxtend.general import category_scatter

Overview

References

  • -

Example 1 - Category Scatter from Pandas DataFrames

import pandas as pd
from io import StringIO

csvfile = """label,x,y
class1,10.0,8.04
class1,10.5,7.30
class2,8.3,5.5
class2,8.1,5.9
class3,3.5,3.5
class3,3.8,5.1"""

df = pd.read_csv(StringIO(csvfile))
df
label x y
0 class1 10.0 8.04
1 class1 10.5 7.30
2 class2 8.3 5.50
3 class2 8.1 5.90
4 class3 3.5 3.50
5 class3 3.8 5.10

Plotting the data where the categories are determined by the unique values in the label column label_col. The x and y values are simply the column names of the DataFrame that we want to plot.

import matplotlib.pyplot as plt
from mlxtend.plotting import category_scatter

fig = category_scatter(x='x', y='y', label_col='label', 
                       data=df, legend_loc='upper left')

png

Example 2 - Category Scatter from NumPy Arrays

import numpy as np
from io import BytesIO

csvfile = """1,10.0,8.04
1,10.5,7.30
2,8.3,5.5
2,8.1,5.9
3,3.5,3.5
3,3.8,5.1"""

ary = np.genfromtxt(BytesIO(csvfile.encode()), delimiter=',')
ary
array([[  1.  ,  10.  ,   8.04],
       [  1.  ,  10.5 ,   7.3 ],
       [  2.  ,   8.3 ,   5.5 ],
       [  2.  ,   8.1 ,   5.9 ],
       [  3.  ,   3.5 ,   3.5 ],
       [  3.  ,   3.8 ,   5.1 ]])

Now, pretending that the first column represents the labels, and the second and third column represent the x and y values, respectively.

import matplotlib.pyplot as plt
from mlxtend.plotting import category_scatter

fix = category_scatter(x=1, y=2, label_col=0, 
                       data=ary, legend_loc='upper left')

png

API

category_scatter(x, y, label_col, data, markers='sxo^v', colors=('blue', 'green', 'red', 'purple', 'gray', 'cyan'), alpha=0.7, markersize=20.0, legend_loc='best')

Scatter plot to plot categories in different colors/markerstyles.

Parameters

  • x : str or int

    DataFrame column name of the x-axis values or integer for the numpy ndarray column index.

  • y : str

    DataFrame column name of the y-axis values or integer for the numpy ndarray column index

  • data : Pandas DataFrame object or NumPy ndarray.

  • markers : str

    Markers that are cycled through the label category.

  • colors : tuple

    Colors that are cycled through the label category.

  • alpha : float (default: 0.7)

    Parameter to control the transparency.

  • markersize : float (default` : 20.0)

    Parameter to control the marker size.

  • legend_loc : str (default: 'best')

    Location of the plot legend {best, upper left, upper right, lower left, lower right} No legend if legend_loc=False

Returns

  • fig : matplotlig.pyplot figure object

Examples

For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/