Matplotlib Reference
Matplotlib Reference
However, matplotlib is also a massive library, and getting a plot to look just
right is often achieved through trial and error. Using one-liners to generate
basic plots in matplotlib is fairly simple, but skillfully commanding the
remaining 98% of the library can be daunting.
This article assumes the user knows a tiny bit of NumPy. We’ll mainly use
the numpy.randommodule to generate “toy” data, drawing samples from
different statistical distributions.
If you don’t already have matplotlib installed, see here for a walkthrough
before proceeding.
The library itself is huge, at something like 70,000 total lines of code.
Matplotlib is home to several different interfaces (ways of constructing
a figure) and capable of interacting with a handful of different backends.
(Backends deal with the process of how charts are actually rendered,
not just structured internally.)
While it is comprehensive, some of matplotlib’s own public
documentation is seriously out-of-date. The library is still evolving, and
many older examples floating around online may take 70% fewer lines
of code in their modern version.
So, before we get to any glitzy examples, it’s useful to grasp the core concepts
of matplotlib’s design.
One relevant feature of MATLAB is its global style. The Python concept of
importing is not heavily used in MATLAB, and most of MATLAB’s functions are
readily available to the user at the top level.
Knowing that matplotlib has its roots in MATLAB helps to explain why pylab
exists. pylab is a module within the matplotlib library that was built to mimic
MATLAB’s global style. It exists only to bring a number of functions and
classes from both NumPy and matplotlib into the namespace, making for an
easy transition for former MATLAB users who were not used to
needing import statements.
Ex-MATLAB converts (who are all fine people, I promise!) liked this
functionality, because with from pylab import *, they could simply
call plot() or array() directly, as they would in MATLAB.
The issue here may be apparent to some Python users: using from pylab
import * in a session or script is generally bad practice. Matplotlib now
directly advises against this in its own tutorials:
“[pylab] still exists for historical reasons, but it is highly advised not to use. It
pollutes namespaces with functions that will shadow Python built-ins and can
lead to hard-to-track bugs. To get IPython integration without imports the use
of the %matplotlibmagic is preferred.” [Source]
Internally, there are a ton of potentially conflicting imports being masked
within the short pylab source. In fact, using ipython --pylab (from the
terminal/command line) or %pylab(from IPython/Jupyter tools) simply
calls from pylab import * under the hood.
The bottom line is that matplotlib has abandoned this convenience module and
now explicitly recommends against using pylab, bringing things more in line
with one of Python’s key notions: explicit is better than implicit.
Without the need for pylab, we can usually get away with just one canonical
import:
>>>
>>> import matplotlib.pyplot as plt
While we’re at it, let’s also import NumPy, which we’ll use for generating data
later on, and call np.random.seed() to make examples with (pseudo)random
data reproducible:
>>>
>>> import numpy as np
>>> np.random.seed(444)
A Figure object is the outermost container for a matplotlib graphic, which can
contain multiple Axes objects. One source of confusion is the name:
an Axes actually translates into what we think of as an individual plot or graph
(rather than the plural of “axis,” as we might expect).
You can think of the Figure object as a box-like container holding one or
more Axes (actual plots). Below the Axes in the hierarchy are smaller objects
such as tick marks, individual lines, legends, and text boxes. Almost every
“element” of a chart is its own manipulable Python object, all the way down to
the ticks and labels:
Here’s an illustration of this hierarchy in action. Don’t worry if you’re not
completely familiar with this notation, which we’ll cover later on:
>>>
>>> fig, _ = plt.subplots()
>>> type(fig)
<class 'matplotlib.figure.Figure'>
Above, we created two variables with plt.subplots(). The first is a top-
level Figure object. The second is a “throwaway” variable that we don’t need
just yet, denoted with an underscore. Using attribute notation, it is easy to
traverse down the figure hierarchy and see the first tick of the y axis of the
first Axes object:
>>>
>>> one_tick = fig.axes[0].yaxis.get_major_ticks()[0]
>>> type(one_tick)
<class 'matplotlib.axis.YTick'>
Above, fig (a Figure class instance) has multiple Axes (a list, for which we take
the first element). Each Axes has a yaxis and xaxis, each of which have a
collection of “major ticks,” and we grab the first one.
Almost all functions from pyplot, such as plt.plot(), are implicitly either
referring to an existing current Figure and current Axes, or creating them
anew if none exist. Hidden in the matplotlib docs is this helpful snippet:
“[With pyplot], simple functions are used to add plot elements (lines, images,
text, etc.) to the current axes in the current figure.” [emphasis added]
Hardcore ex-MATLAB users may choose to word this by saying something
like, “plt.plot() is a state-machine interface that implicitly tracks the current
figure!” In English, this means that:
The stateful interface makes its calls with plt.plot() and other top-level
pyplot functions. There is only ever one Figure or Axes that you’re
manipulating at a given time, and you don’t need to explicitly refer to it.
Modifying the underlying objects directly is the object-oriented
approach. We usually do this by calling methods of an Axes object, which
is the object that represents a plot itself.
Tying these together, most of the functions from pyplot also exist as methods
of the matplotlib.axes.Axes class.
This is easier to see by peeking under the hood. plt.plot() can be boiled down
to five or so lines of code:
>>>
# matplotlib/pyplot.py
>>> def plot(*args, **kwargs):
... """An abridged version of plt.plot()."""
... ax = plt.gca()
... return ax.plot(*args, **kwargs)
pyplot is home to a batch of functions that are really just wrappers around
matplotlib’s object-oriented interface. For example, with plt.title(), there
are corresponding setter and getter methods within the OO
approach, ax.set_title() and ax.get_title(). (Use of getters and setters tends
to be more popular in languages such as Java but is a key feature of
matplotlib’s OO approach.)
Similarly, if you take a few moments to look at the source for top-level
functions like plt.grid(), plt.legend(), and plt.ylabels(), you’ll notice that all
of them follow the same structure of delegating to the current Axes
with gca() and then calling some method of the current Axes. (This is the
underlying object-oriented approach!)
The prescribed way to create a Figure with a single Axes under the OO
approach is (not too intuitively) with plt.subplots(). This is really the only
time that the OO approach uses pyplot, to create a Figure and Axes:
>>>
>>> fig, ax = plt.subplots()
Above, we took advantage of iterable unpacking to assign a separate variable
to each of the two results of plt.subplots(). Notice that we didn’t pass
arguments to subplots() here. The default call is subplots(nrows=1, ncols=1).
Consequently, ax is a single AxesSubplot object:
>>>
>>> type(ax)
<class 'matplotlib.axes._subplots.AxesSubplot'>
We can call its instance methods to manipulate the plot similarly to how we
call pyplots functions. Let’s illustrate with a stacked area graph of three time
series:
>>>
>>> rng = np.arange(50)
>>> rnd = np.random.randint(0, 10, size=(3, rng.size))
>>> yrs = 1950 + rng
After creating three random time series, we defined one Figure (fig)
containing one Axes (a plot, ax).
We call methods of ax directly to create a stacked area chart and to add
a legend, title, and y-axis label. Under the object-oriented approach, it’s
clear that all of these are attributes of ax.
tight_layout() applies to the Figure object as a whole to clean up
whitespace padding.
Let’s look at an example with multiple subplots (Axes) within one Figure,
plotting two correlated arrays that are drawn from the discrete uniform
distribution:
>>>
>>> x = np.random.randint(low=1, high=11, size=50)
>>> y = x + np.random.randint(1, 5, size=x.size)
>>> data = np.column_stack((x, y))
>>>
>>> (fig.axes[0] is ax1, fig.axes[1] is ax2)
(True, True)
(fig.axes is lowercase, not uppercase. There’s no denying the terminology is a
bit confusing.)
Taking this one step further, we could alternatively create a figure that holds a
2x2 grid of Axes objects:
>>>
>>> fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(7, 7))
Now, what is ax? It’s no longer a single Axes, but a two-dimensional NumPy
array of them:
>>>
>>> type(ax)
numpy.ndarray
>>> ax
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x1106daf98>,
<matplotlib.axes._subplots.AxesSubplot object at 0x113045c88>],
[<matplotlib.axes._subplots.AxesSubplot object at 0x11d573cf8>,
<matplotlib.axes._subplots.AxesSubplot object at 0x1130117f0>]],
dtype=object)
>>> ax.shape
(2, 2)
This is reaffirmed by the docstring:
>>>
>>> fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(7, 7))
>>> ax1, ax2, ax3, ax4 = ax.flatten() # flatten a 2d NumPy array to 1d
We could’ve also done this with ((ax1, ax2), (ax3, ax4)) = ax, but the first
approach tends to be more flexible.
>>>
>>> from io import BytesIO
>>> import tarfile
>>> from urllib.request import urlopen
>>>
>>> y = housing[:, -1]
>>> pop, age = housing[:, [4, 7]].T
Next let’s define a “helper function” that places a text box inside of a plot and
acts as an “in-plot title”:
>>>
>>> def add_titlebox(ax, text):
... ax.text(.55, .8, text,
... horizontalalignment='center',
... transform=ax.transAxes,
... bbox=dict(facecolor='white', alpha=0.6),
... fontsize=12.5)
... return ax
We’re ready to do some plotting. Matplotlib’s gridspec module allows for more
subplot customization. pyplot’s subplot2grid() interacts with this module
nicely. Let’s say we want to create a layout like this:
Above, what we actually have is a 3x2 grid. ax1 is twice the height and width
of ax2/ax3, meaning that it takes up two columns and two rows.
The second argument to subplot2grid() is the (row, column) location of the
Axes within the grid:
>>>
>>> gridsize = (3, 2)
>>> fig = plt.figure(figsize=(12, 8))
>>> ax1 = plt.subplot2grid(gridsize, (0, 0), colspan=2, rowspan=2)
>>> ax2 = plt.subplot2grid(gridsize, (2, 0))
>>> ax3 = plt.subplot2grid(gridsize, (2, 1))
Now, we can proceed as normal, modifying each Axes individually:
>>>
>>> ax1.set_title('Home value as a function of home age & area
population',
... fontsize=14)
>>> sctr = ax1.scatter(x=age, y=pop, c=y, cmap='RdYlGn')
>>> plt.colorbar(sctr, ax=ax1, format='$%d')
>>> ax1.set_yscale('log')
>>> ax2.hist(age, bins='auto')
>>> ax3.hist(pop, bins='auto', log=True)
>>>
>>> fig1, ax1 = plt.subplots()
>>> id(fig1)
4525567840
After the above routine, the current figure is fig2, the most recently created
figure. However, both figures are still hanging around in memory, each with a
corresponding ID number (1-indexed, in MATLAB style):
>>>
>>> plt.get_fignums()
[1, 2]
A useful way to get all of the Figures themselves is with a mapping
of plt.figure() to each of these integers:
>>>
>>> def get_all_figures():
... return [plt.figure(i) for i in plt.get_fignums()]
>>> get_all_figures()
[<matplotlib.figure.Figure at 0x10dbeaf60>,
<matplotlib.figure.Figure at 0x1234cb6d8>]
Be cognizant of this if running a script where you’re creating a group of
figures. You’ll want to explicitly close each of them after use to avoid
a MemoryError. By itself, plt.close() closes the current
figure, plt.close(num) closes the figure number num,
and plt.close('all')closes all the figure windows:
>>>
>>> plt.close('all')
>>> get_all_figures()
[]
Methods that get heavy use are imshow() and matshow(), with the latter being a
wrapper around the former. These are useful anytime that a raw numerical
array can be visualized as a colored grid.
First, let’s create two distinct grids with some fancy NumPy indexing:
>>>
>>> x = np.diag(np.arange(2, 12))[::-1]
>>> x[np.diag_indices_from(x[::-1])] = np.arange(2, 12)
>>> x2 = np.arange(x.size).reshape(x.shape)
Next, we can map these to their image representations. In this specific case,
we toggle “off” all axis labels and ticks by using a dictionary comprehension
and passing the result to ax.tick_params():
>>>
>>> sides = ('left', 'right', 'top', 'bottom')
>>> nolabels = {s: False for s in sides}
>>> nolabels.update({'label%s' % s: False for s in sides})
>>> print(nolabels)
{'left': False, 'right': False, 'top': False, 'bottom': False,
'labelleft': False,
'labelright': False, 'labeltop': False, 'labelbottom': False}
Then, we can use a context manager to disable the grid, and call matshow() on
each Axes. Lastly, we need to put the colorbar in what is technically a new
Axes within fig. For this, we can use a bit of an esoteric function from deep
within matplotlib:
>>>
>>> from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
That is, the plot() method on pandas’ Series and DataFrame is a wrapper
around plt.plot(). One convenience provided, for example, is that if the
DataFrame’s Index consists of dates, gcf().autofmt_xdate() is called internally
by pandas to get the current Figure and nicely auto-format the x-axis.
We can prove this “chain” of function calls with a bit of introspection. First,
let’s construct a plain-vanilla pandas Series, assuming we’re starting out in a
fresh interpreter session:
>>>
>>> import pandas as pd
>>> type(ax)
<matplotlib.axes._subplots.AxesSubplot at 0x121083eb8>
>>>
>>> import pandas as pd
>>> import matplotlib.transforms as mtransforms
>>> url = 'https://fred.stlouisfed.org/graph/fredgraph.csv?id=VIXCLS'
>>> vix = pd.read_csv(url, index_col=0, parse_dates=True, na_values='.',
... infer_datetime_format=True,
... squeeze=True).dropna()
>>> ma = vix.rolling('90d').mean()
>>> state = pd.cut(ma, bins=[-np.inf, 14, 18, 24, np.inf],
... labels=range(4))
Pandas also comes built-out with a smattering of more advanced plots (which
could take up an entire tutorial all on their own). However, all of these, like
their simpler counterparts, rely on matplotlib machinery internally.
Wrapping Up
As shown by some of the examples above, there’s no getting around the fact
that matplotlib can be a technical, syntax-heavy library. Creating a production-
ready chart sometimes requires a half hour of Googling and combining a
hodgepodge of lines in order to fine-tune a plot.
More Resources
From the matplotlib documentation:
Free Bonus: Click here to download 5 Python + Matplotlib examples with full
source code that you can use as a basis for making your own plots and
graphics.
Third-party resources:
Matplotlib offers two ways to configure style in a uniform way across different
plots:
Quick Tip: GitHub is a great place to keep configuration files. I keep mine here.
Just make sure that they don’t contain personally identifiable or private
information, such as passwords or SSH private keys!
Alternatively, you can change your configuration parameters interactively
(Option #2 above). When you import matplotlib.pyplot as plt, you get access
to an rcParams object that resembles a Python dictionary of settings. All of the
module objects starting with “rc” are a means to interact with your plot styles
and settings:
>>>
>>> [attr for attr in dir(plt) if attr.startswith('rc')]
['rc', 'rcParams', 'rcParamsDefault', 'rc_context', 'rcdefaults']
Of these:
With plt.rc() and plt.rcParams, these two syntaxes are equivalent for
adjusting settings:
>>>
>>> plt.rc('lines', linewidth=2, color='r') # Syntax 1
>>>
>>> plt.style.available
['seaborn-dark', 'seaborn-darkgrid', 'seaborn-ticks', 'fivethirtyeight',
'seaborn-whitegrid', 'classic', '_classic_test', 'fast', 'seaborn-talk',
'seaborn-dark-palette', 'seaborn-bright', 'seaborn-pastel', 'grayscale',
'seaborn-notebook', 'ggplot', 'seaborn-colorblind', 'seaborn-muted',
'seaborn', 'Solarize_Light2', 'seaborn-paper', 'bmh', 'seaborn-white',
'dark_background', 'seaborn-poster', 'seaborn-deep']
To set a style, make this call:
>>>
>>> plt.style.use('fivethirtyeight')
Your plots will now take on a new look:
For inspiration, matplotlib keeps some style sheet displays for reference as
well.
While interactive mode is off by default, you can check its status
with plt.rcParams['interactive'] or plt.isinteractive(), and toggle it on and
off with plt.ion() and plt.ioff(), respectively:
>>>
>>> plt.rcParams['interactive'] # or: plt.isinteractive()
True
>>>
>>> plt.ioff()
>>> plt.rcParams['interactive']
False
In some code examples, you may notice the presence of plt.show() at the end
of a chunk of code. The main purpose of plt.show(), as the name implies, is to
actually “show” (open) the figure when you’re running with interactive mode
turned off. In other words:
If interactive mode is on, you don’t need plt.show(), and images will
automatically pop-up and be updated as you reference them.
If interactive mode is off, you’ll need plt.show() to display a figure
and plt.draw() to update a plot.
Below, we make sure that interactive mode is off, which requires that we
call plt.show()after building the plot itself:
>>>
>>> plt.ioff()
>>> x = np.arange(-4, 5)
>>> y1 = x ** 2
>>> y2 = 10 / (x ** 2 + 1)
>>> fig, ax = plt.subplots()
>>> ax.plot(x, y1, 'rx', x, y2, 'b+', linestyle='solid')
>>> ax.fill_between(x, y1, y2, where=y2>y1, interpolate=True,
... color='green', alpha=0.3)
>>> lgnd = ax.legend(['y1', 'y2'], loc='upper center', shadow=True)
>>> lgnd.get_frame().set_facecolor('#ffb19a')
>>> plt.show()
Notably, interactive mode has nothing to do with what IDE you’re using, or
whether you’ve enable inline plotting with something like jupyter notebook --
matplotlib inline or %matplotlib.