Handson-Ml - Tools - Matplotlib - Ipynb at 265099f9 Ageron - Handson-Ml GitHub
Handson-Ml - Tools - Matplotlib - Ipynb at 265099f9 Ageron - Handson-Ml GitHub
Dismiss
Join GitHub today
GitHub is home to over 50 million developers working together to host and
review code, manage projects, and build software together.
Sign up
ageron fixed typo in tools_matplotlib.ipynb Latest commit 32571a7 on Mar 4, 2016 History
1 contributor
1.09 MB Download
/
Tools - matplotlib
This notebook demonstrates how to use the matplotlib library to plot beautiful graphs.
Table of Contents
1 Plotting your first graph
2 Line style and color
3 Saving a figure
4 Subplots
5 Multiple figures
6 Pyplot's state machine: implicit vs explicit
7 Pylab vs Pyplot vs Matplotlib
8 Drawing text
9 Legends
10 Non linear scales
11 Ticks and tickers
12 Polar projection
13 3D projection
14 Scatter plot
15 Lines
16 Histograms
17 Images
18 Animations
19 Saving animations to video files
20 What next?
Matplotlib can output graphs using various backend graphics libraries, such as Tk, wxPython, etc. When
running python using the command line, the graphs are typically shown in a separate window. In a Jupyter
notebook, we can simply output the graphs within the notebook itself by running the %matplotlib inline
magic command.
Yep, it's as simple as calling the plot function with some data, and then calling the show function!
If the plot function is given one array of data, it will use it as the coordinates on the vertical axis, and it will
just use each data point's index in the array as the horizontal coordinate. You can also provide two arrays:
one for the horizontal axis x, and the second for the vertical axis y:
The axes automatically match the extent of the data. We would like to give the graph a bit more room, so
let's call the axis function to change the extent of each axis [xmin, xmax, ymin, ymax].
Now, let's plot a mathematical function. We use NumPy's linspace function to create an array x
containing 500 floats ranging from -2 to 2, then we create a second array y computed as the square of x
(to learn about NumPy, read the NumPy tutorial (tools_numpy.ipynb)). /
In [7]: import numpy as np
x = np.linspace(-2, 2, 500)
y = x**2
plt.plot(x, y)
plt.show()
That's a bit dry, let's add a title, and x and y labels, and draw a grid.
In [8]: plt.plot(x, y)
plt.title("Square function")
plt.xlabel("x")
plt.ylabel("y = x**2")
plt.grid(True)
plt.show()
/
Line style and color
By default, matplotlib draws a line between consecutive points.
In [9]: plt.plot([0, 100, 100, 0, 0, 100, 50, 0, 100], [0, 0, 100, 100, 0, 100, 1
30, 100, 0])
plt.axis([-10, 110, -10, 140])
plt.show()
You can pass a 3rd argument to change the line's style and color. For example "g--" means "green
dashed line".
In [10]: plt.plot([0, 100, 100, 0, 0, 100, 50, 0, 100], [0, 0, 100, 100, 0, 100, 1
30, 100, 0], "g--")
plt.axis([-10, 110, -10, 140])
plt.show()
/
You can plot multiple lines on one graph very simply: just pass x1, y1, [style1], x2, y2, [style2],
...
For example:
In [11]: plt.plot([0, 100, 100, 0, 0], [0, 0, 100, 100, 0], "r-", [0, 100, 50, 0,
100], [0, 100, 130, 100, 0], "g--")
plt.axis([-10, 110, -10, 140])
plt.show()
In [12]: plt.plot([0, 100, 100, 0, 0], [0, 0, 100, 100, 0], "r-")
plt.plot([0, 100, 50, 0, 100], [0, 100, 130, 100, 0], "g--")
plt.axis([-10, 110, -10, 140])
/
plt.show()
plt.show()
You can also draw simple points instead of lines. Here's an example with green dashes, red dotted line
and blue triangles. Check out the documentation
(http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.plot) for the full list of style & color options.
The plot function returns a list of Line2D objects (one for each line). You can set extra attributes on these
lines, such as the line width, the dash style or the alpha level. See the full list of attributes in the
/
documentation (http://matplotlib org/users/pyplot tutorial html#controlling line properties)
documentation (http://matplotlib.org/users/pyplot_tutorial.html#controlling-line-properties).
Saving a figure
Saving a figure to disk is as simple as calling savefig
(http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.savefig) with the name of the file (or a file
object). The available image formats depend on the graphics backend you use.
/
Subplots
A matplotlib figure may contain multiple subplots. These subplots are organized in a grid. To create a
subplot, just call the subplot function, and specify the number of rows and columns in the figure, and the
index of the subplot you want to draw on (starting from 1, then left to right, and top to bottom). Note that
pyplot keeps track of the currently active subplot (which you can get a reference to by calling plt.gca()),
so when you call the plot function, it draws on the active subplot.
/
Note that b l t(223) is a shorthand for b l t(2 2 3)
Note that subplot(223) is a shorthand for subplot(2, 2, 3).
It is easy to create subplots that span across multiple grid cells like so:
If you need more complex subplot positionning, you can use subplot2grid instead of subplot. You
specify the number of rows and columns in the grid, then your subplot's position in that grid (top-left =
(0,0)), and optionally how many rows and/or columns it spans. For example:
/
If you need even more flexibility in subplot positioning, check out the GridSpec documentation
(http://matplotlib.org/users/gridspec.html)
Multiple figures
It is also possible to draw multiple figures. Each figure may contain one or more subplots. By default,
matplotlib creates figure(1) automatically. When you switch figure, pyplot keeps track of the currently
active figure (which you can get a reference to by calling plt.gcf()), and the active subplot of that figure
becomes the current subplot.
plt.figure(1)
plt.subplot(211)
plt.plot(x, x**2)
plt.title("Square and Cube")
plt.subplot(212)
plt.plot(x, x**3)
plt.show()
But when you are writing a program, explicit is better than implicit. Explicit code is usually easier to debug
and maintain, and if you don't believe me just read the 2nd rule in the Zen of Python:
Fortunately, Pyplot allows you to ignore the state machine entirely, so you can write beautifully explicit
code. Simply call the subplots function and use the figure object and the list of axes objects that are
returned. No more magic! For example:
fig2, ax = plt.subplots(1, 1)
ax.plot(x, x**2)
plt.show() /
For consistency, we will continue to use pyplot's state machine in the rest of this tutorial, but we
recommend using the object-oriented interface in your programs.
Pyplot provides a number of tools to plot graphs, including the state-machine interface to the underlying /
bj i d l i lib
object-oriented plotting library.
Pylab is a convenience module that imports matplotlib.pyplot and NumPy in a single name space. You will
find many examples using pylab, but it is no longer recommended (because explicit imports are better than
implicit ones).
Drawing text
You can call text to add text at any location in the graph. Just specify the horizontal and vertical
coordinates and the text, and optionally some extra attributes. Any text in matplotlib may contain TeX
equation expressions, see the documentation (http://matplotlib.org/users/mathtext.html) for more details.
plt.show()
/
Note: ha is an alias for horizontalalignment
Note: ha is an alias for horizontalalignment
It is quite frequent to annotate elements of a graph, such as the beautiful point above. The annotate
function makes this easy: just indicate the location of the point of interest, and the position of the text, plus
optionally some extra attributes for the text and the arrow.
You can also add a bounding box around your text by using the bbox attribute:
/
plt.show()
plt.show()
Just for fun, if you want an xkcd (http://xkcd.com)-style plot, just draw within a with plt.xkcd() section:
plt.show()
/
Legends
The simplest way to add a legend is to set a label on all lines, then just call the legend function.
plt.figure(1)
plt.plot(x, y)
plt.yscale('linear') /
lt titl ('li ')
plt.title('linear')
plt.grid(True)
plt.figure(2)
plt.plot(x, y)
plt.yscale('log')
plt.title('log')
plt.grid(True)
plt.figure(3)
plt.plot(x, y)
plt.yscale('logit')
plt.title('logit')
plt.grid(True)
plt.figure(4)
plt.plot(x, y - y.mean())
plt.yscale('symlog', linthreshy=0.05)
plt.title('symlog')
plt.grid(True)
plt.show()
/
Ticks and tickers
The axes have little marks called "ticks". To be precise, "ticks" are the locations of the marks (eg. (-1, 0, /
The axes have little marks called ticks . To be precise, ticks are the locations of the marks (eg. ( 1, 0,
1)), "tick lines" are the small lines drawn at those locations, "tick labels" are the labels drawn next to the
tick lines, and "tickers" are objects that are capable of deciding where to place ticks. The default tickers
typically do a pretty good job at placing ~5 to 8 ticks at a reasonable distance from one another.
But sometimes you need more control (eg. there are too many tick labels on the logit graph above).
Fortunately, matplotlib gives you full control over ticks. You can even activate minor ticks.
plt.figure(1, figsize=(15,10))
plt.subplot(131)
plt.plot(x, x**3)
plt.grid(True)
plt.title("Default ticks")
ax = plt.subplot(132)
plt.plot(x, x**3)
ax.xaxis.set_ticks(np.arange(-2, 2, 1))
plt.grid(True)
plt.title("Manual ticks on the x-axis")
ax = plt.subplot(133)
plt.plot(x, x**3)
plt.minorticks_on()
ax.tick_params(axis='x', which='minor', bottom='off')
ax.xaxis.set_ticks([-2, 0, 1, 2])
ax.yaxis.set_ticks(np.arange(-5, 5, 1))
ax.yaxis.set_ticklabels(["min", -4, -3, -2, -1, 0, 1, 2, 3, "max"])
plt.title("Manual ticks and tick labels\n(plus minor ticks) on the y-axi
s")
plt.grid(True)
plt.show()
/
Polar projection
Drawing a polar graph is as easy as setting the projection attribute to "polar" when creating the
subplot.
In [29]: radius = 1
theta = np.linspace(0, 2*np.pi*radius, 1000)
plt.subplot(111, projection='polar')
plt.plot(theta, np.sin(5*theta), "g-")
plt.plot(theta, 0.5*np.cos(20*theta), "b-")
plt.show()
/
3D projection
Plotting 3D graphs is quite straightforward. You need to import Axes3D, which registers the "3d"
projection. Then create a subplot setting the projection to "3d". This returns an Axes3DSubplot object,
which you can use to call plot_surface, giving x, y, and z coordinates, plus optional attributes.
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)
/
Scatter plot
To draw a scatter plot, simply provide the x and y coordinates of the points.
And as usual there are a number of other attributes you can set, such as the fill and edge colors and the
alpha level.
plt.grid(True)
plt.show()
/
Lines
You can draw lines simply using the plot function, as we have done so far. However, it is often convenient
to create a utility function that plots a (seemingly) infinite line across the graph, given a slope and an
intercept. You can also use the hlines and vlines functions that plot horizontal and vertical line
segments. For example:
x = randn(1000)
y = 0.5*x + 5 + randn(1000)*2
plt.axis([-2.5, 2.5, -5, 15])
plt.scatter(x, y, alpha=0.2)
plt.plot(1, 0, "ro")
plt.vlines(1, -5, 0, color="red")
plt.hlines(0, -2.5, 1, color="red")
plot_line(axis=plt.gca(), slope=0.5, intercept=5, color="magenta")
plt.grid(True)
plt.show()
/
Histograms
In [36]: data = [1, 1.1, 1.8, 2, 2.1, 3.2, 3, 3, 3, 3]
plt.subplot(211)
plt.hist(data, bins = 10, rwidth=0.8)
plt.subplot(212)
plt.hist(data, bins = [1, 1.5, 2, 2.5, 3], rwidth=0.95)
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.show()
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.legend()
plt.grid(True)
plt.show()
Images
Reading, generating and plotting images in matplotlib is quite straightforward.
To read an image, just import the matplotlib.image module, and call its imread function, passing it the
file name (or file object). This returns the image data, as a NumPy array. Let's try this with the
my_square_function.png image we saved earlier.
img = mpimg.imread('my_square_function.png')
print(img.shape, img.dtype)
We have loaded a 288x432 image. Each pixel is represented by a 4-element array: red, green, blue, and
alpha levels, stored as 32-bit floats between 0 and 1. Now all we need to do is to call imshow:
/
In [39]: plt imshow(img)
In [39]: plt.imshow(img)
plt.show()
Tadaaa! You may want to hide the axes when you are displaying an image:
In [40]: plt.imshow(img)
plt.axis('off')
plt.show()
[[ 0 1 2 ..., 97 98 99]
[ 100 101 102 ..., 197 198 199]
[ 200 201 202 ..., 297 298 299]
...,
[9700 9701 9702 ..., 9797 9798 9799]
[9800 9801 9802 ..., 9897 9898 9899]
[9900 9901 9902 ..., 9997 9998 9999]]
As we did not provide RGB levels, the imshow function automatically maps values to a color gradient. By
default, the color gradient goes from blue (for low values) to red (for high values), but you can select
another color map. For example:
/
You can also generate an RGB image directly:
Since the img array is just quite small (20x30), when the imshow function displays it, it grows the image to
the figure's size. By default it uses bilinear interpolation (https://en.wikipedia.org/wiki/Bilinear_interpolation)
to fill the added pixels. This is why the edges look blurry. You can select another interpolation algorithm,
such as copying the color of the nearest pixel:
/
Animations
Although matplotlib is mostly used to generate images, it is also capable of displaying animations,
depending on the Backend you use. In a Jupyter notebook, we need to use the nbagg backend to use
interactive matplotlib features, including animations. We also need to import matplotlib.animation.