Customizing Colorbars
Plot legends identify discrete labels of discrete points. For continuous labels based on the color of points, lines, or regions, a labeled colorbar can be a great tool. In Matplotlib, a colorbar is a separate axes that can provide a key for the meaning of colors in a plot. We’ll start by setting up the notebook for plotting and importing the functions we will use:
- import matplotlib.pyplot as plt
- import numpy as np
- plt.style.use('classic')
- %matplotlib inline
- x = np.linspace(0, 10, 1000)
- I = np.sin(x) * np.cos(x[:, np.newaxis])
- plt.imshow(I)
- plt.colorbar();
Figure 4-49. A simple colorbar legend
We’ll now discuss a few ideas for customizing these colorbars and using them effectively in various situations.
Customizing Colorbars
We can specify the colormap using the cmap argument to the plotting function that is creating the visualization (Figure 4-50):
- plt.imshow(I, cmap='gray');
Figure 4-50. A grayscale colormap
All the available colormaps are in the plt.cm namespace; using IPython’s tabcompletion feature will give you a full list of built-in possibilities:
- plt.cm.
Choosing the colormap
A full treatment of color choice within visualization is beyond the scope of this book, but for entertaining reading on this subject and others, see the article “Ten Simple Rules for Better Figures”. Matplotlib’s online documentation also has an interesting discussion of colormap choice.
Broadly, you should be aware of three different categories of colormaps:
The jet colormap, which was the default in Matplotlib prior to version 2.0, is an example of a qualitative colormap. Its status as the default was quite unfortunate, because qualitative maps are often a poor choice for representing quantitative data. Among the problems is the fact that qualitative maps usually do not display any uniform progression in brightness as the scale increases.
We can see this by converting the jet colorbar into black and white (Figure 4-51):
- from matplotlib.colors import LinearSegmentedColormap
- def grayscale_cmap(cmap):
- """Return a grayscale version of the given colormap"""
- cmap = plt.cm.get_cmap(cmap)
- colors = cmap(np.arange(cmap.N))
- # convert RGBA to perceived grayscale luminance
- # cf. http://alienryderflex.com/hsp.html
- RGB_weight = [0.299, 0.587, 0.114]
- luminance = np.sqrt(np.dot(colors[:, :3] ** 2, RGB_weight))
- colors[:, :3] = luminance[:, np.newaxis]
- return LinearSegmentedColormap.from_list(cmap.name + "_gray", colors, cmap.N)
- def view_colormap(cmap):
- """Plot a colormap with its grayscale equivalent"""
- cmap = plt.cm.get_cmap(cmap)
- colors = cmap(np.arange(cmap.N))
- cmap = grayscale_cmap(cmap)
- grayscale = cmap(np.arange(cmap.N))
- fig, ax = plt.subplots(2, figsize=(6, 2),
- subplot_kw=dict(xticks=[], yticks=[]))
- ax[0].imshow([colors], extent=[0, 10, 0, 1])
- ax[1].imshow([grayscale], extent=[0, 10, 0, 1])
- view_colormap('jet')
Figure 4-51. The jet colormap and its uneven luminance scale
Notice the bright stripes in the grayscale image. Even in full color, this uneven brightness means that the eye will be drawn to certain portions of the color range, which will potentially emphasize unimportant parts of the dataset. It’s better to use a colormap such as viridis (the default as of Matplotlib 2.0), which is specifically constructed to have an even brightness variation across the range. Thus, it not only plays well with our color perception, but also will translate well to grayscale printing (Figure 4-52):
- view_colormap('viridis')
Figure 4-52. The viridis colormap and its even luminance scale
If you favor rainbow schemes, another good option for continuous data is the cubehelix colormap (Figure 4-53):
- view_colormap('cubehelix')
Figure 4-53. The cubehelix colormap and its luminance
For other situations, such as showing positive and negative deviations from some mean, dual-color colorbars such as RdBu (short for Red-Blue) can be useful. However, as you can see in Figure 4-54, it’s important to note that the positive-negative information will be lost upon translation to grayscale!
- view_colormap('RdBu')
Figure 4-54. The RdBu (Red-Blue) colormap and its luminance
We’ll see examples of using some of these color maps as we continue. For a more principled approach to colors in Python, you can refer to the tools and documentation within the Seaborn library (see “Visualization with Seaborn” on page 311).
Color limits and extensions
Matplotlib allows for a large range of colorbar customization. The colorbar itself is simply an instance of plt.Axes, so all of the axes and tick formatting tricks we’ve learned are applicable. The colorbar has some interesting flexibility; for example, we can narrow the color limits and indicate the out-of-bounds values with a triangular arrow at the top and bottom by setting the extend property. This might come in handy, for example, if you’re displaying an image that is subject to noise (Figure 4-55):
- # make noise in 1% of the image pixels
- speckles = (np.random.random(I.shape) < 0.01)
- I[speckles] = np.random.normal(0, 3, np.count_nonzero(speckles))
- plt.figure(figsize=(10, 3.5))
- plt.subplot(1, 2, 1)
- plt.imshow(I, cmap='RdBu')
- plt.colorbar()
- plt.subplot(1, 2, 2)
- plt.imshow(I, cmap='RdBu')
- plt.colorbar(extend='both')
- plt.clim(-1, 1);
Figure 4-55. Specifying colormap extensions
Notice that in the left panel, the default color limits respond to the noisy pixels, and the range of the noise completely washes out the pattern we are interested in. In the right panel, we manually set the color limits, and add extensions to indicate values that are above or below those limits. The result is a much more useful visualization of our data.
Discrete colorbars
Colormaps are by default continuous, but sometimes you’d like to represent discrete values. The easiest way to do this is to use the plt.cm.get_cmap() function, and pass the name of a suitable colormap along with the number of desired bins (Figure 4-56):
- plt.imshow(I, cmap=plt.cm.get_cmap('Blues', 6))
- plt.colorbar()
- plt.clim(-1, 1);
Figure 4-56. A discretized colormap
The discrete version of a colormap can be used just like any other colormap.
Example: Handwritten Digits
For an example of where this might be useful, let’s look at an interesting visualization of some handwritten digits data. This data is included in Scikit-Learn, and consists of nearly 2,000 8×8 thumbnails showing various handwritten digits. For now, let’s start by downloading the digits data and visualizing several of the example images with plt.imshow() (Figure 4-57):
- # load images of the digits 0 through 5 and visualize several of them
- from sklearn.datasets import load_digits
- digits = load_digits(n_class=6)
- fig, ax = plt.subplots(8, 8, figsize=(6, 6))
- for i, axi in enumerate(ax.flat):
- axi.imshow(digits.images[i], cmap='binary')
- axi.set(xticks=[], yticks=[])
Figure 4-57. Sample of handwritten digit data
Because each digit is defined by the hue of its 64 pixels, we can consider each digit to be a point lying in 64-dimensional space: each dimension represents the brightness of one pixel. But visualizing relationships in such high-dimensional spaces can be extremely difficult. One way to approach this is to use a dimensionality reduction technique such as manifold learning to reduce the dimensionality of the data while maintaining the relationships of interest. Dimensionality reduction is an example of unsupervised machine learning, and we will discuss it in more detail in “What Is Machine Learning?” on page 332.
Deferring the discussion of these details, let’s take a look at a two-dimensional manifold learning projection of this digits data (see “In-Depth: Manifold Learning” on page 445 for details):
- # project the digits into 2 dimensions using IsoMap
- from sklearn.manifold import Isomap
- iso = Isomap(n_components=2)
- projection = iso.fit_transform(digits.data)
- # plot the results
- plt.scatter(projection[:, 0], projection[:, 1], lw=0.1,
- c=digits.target, cmap=plt.cm.get_cmap('cubehelix', 6))
- plt.colorbar(ticks=range(6), label='digit value')
- plt.clim(-0.5, 5.5)
Figure 4-58. Manifold embedding of handwritten digit pixels
The projection also gives us some interesting insights on the relationships within the dataset: for example, the ranges of 5 and 3 nearly overlap in this projection, indicating that some handwritten fives and threes are difficult to distinguish, and therefore more likely to be confused by an automated classification algorithm. Other values, like 0 and 1, are more distantly separated, and therefore much less likely to be confused. This observation agrees with our intuition, because 5 and 3 look much more similar than do 0 and 1.
We’ll return to manifold learning and digit classification in Chapter 5.
Multiple Subplots
Sometimes it is helpful to compare different views of data side by side. To this end, Matplotlib has the concept of subplots: groups of smaller axes that can exist together within a single figure. These subplots might be insets, grids of plots, or other more complicated layouts. In this section, we’ll explore four routines for creating subplots in Matplotlib. We’ll start by setting up the notebook for plotting and importing the functions we will use:
- %matplotlib inline
- import matplotlib.pyplot as plt
- plt.style.use('seaborn-white')
- import numpy as np
The most basic method of creating an axes is to use the plt.axes function. As we’ve seen previously, by default this creates a standard axes object that fills the entire figure. plt.axes also takes an optional argument that is a list of four numbers in the figure coordinate system. These numbers represent [bottom, left, width, height] in the figure coordinate system, which ranges from 0 at the bottom left of the figure to 1 at the top right of the figure.
For example, we might create an inset axes at the top-right corner of another axes by setting the x and y position to 0.65 (that is, starting at 65% of the width and 65% of the height of the figure) and the x and yextents to 0.2 (that is, the size of the axes is 20% of the width and 20% of the height of the figure). Figure 4-59 shows the result of this code:
- ax1 = plt.axes() # standard axes
- ax2 = plt.axes([0.65, 0.65, 0.2, 0.2])
Figure 4-59. Example of an inset axes
The equivalent of this command within the object-oriented interface is fig.add_axes(). Let’s use this to create two vertically stacked axes (Figure 4-60):
- fig = plt.figure()
- ax1 = fig.add_axes([0.1, 0.5, 0.8, 0.4], xticklabels=[], ylim=(-1.2, 1.2))
- ax2 = fig.add_axes([0.1, 0.1, 0.8, 0.4], ylim=(-1.2, 1.2))
- x = np.linspace(0, 10)
- ax1.plot(np.sin(x))
- ax2.plot(np.cos(x));
Figure 4-60. Vertically stacked axes example
We now have two axes (the top with no tick labels) that are just touching: the bottom of the upper panel (at position 0.5) matches the top of the lower panel (at position 0.1 + 0.4).
plt.subplot: Simple Grids of Subplots
Aligned columns or rows of subplots are a common enough need that Matplotlib has several convenience routines that make them easy to create. The lowest level of these is plt.subplot(), which creates a single subplot within a grid. As you can see, this command takes three integer arguments—the number of rows, the number of columns, and the index of the plot to be created in this scheme, which runs from the upper left to the bottom right (Figure 4-61):
- for i in range(1, 7):
- plt.subplot(2, 3, i)
- plt.text(0.5, 0.5, str((2, 3, i)), fontsize=18, ha='center')
Figure 4-61. A plt.subplot() example
The command plt.subplots_adjust can be used to adjust the spacing between these plots. The following code (the result of which is shown in Figure 4-62) uses the equivalent object-oriented command, fig.add_subplot():
- fig = plt.figure()
- fig.subplots_adjust(hspace=0.4, wspace=0.4)
- for i in range(1, 7):
- ax = fig.add_subplot(2, 3, i)
- ax.text(0.5, 0.5, str((2, 3, i)), fontsize=18, ha='center')
Figure 4-62. plt.subplot() with adjusted margins
We’ve used the hspace and wspace arguments of plt.subplots_adjust, which specify the spacing along the height and width of the figure, in units of the subplot size (in this case, the space is 40% of the subplot width and height).
plt.subplots: The Whole Grid in One Go
The approach just described can become quite tedious when you’re creating a large grid of subplots, especially if you’d like to hide the x- and y-axis labels on the inner plots. For this purpose, plt.subplots() is the easier tool to use (note the s at the end of subplots). Rather than creating a single subplot, this function creates a full grid of subplots in a single line, returning them in a NumPy array. The arguments are the number of rows and number of columns, along with optional keywords sharex and sharey, which allow you to specify the relationships between different axes.
Here we’ll create a 2×3 grid of subplots, where all axes in the same row share their y-axis scale, and all axes in the same column share their x-axis scale (Figure 4-63):
- fig, ax = plt.subplots(2, 3, sharex='col', sharey='row')
Figure 4-63. Shared x and y axis in plt.subplots()
Note that by specifying sharex and sharey, we’ve automatically removed inner labels on the grid to make the plot cleaner. The resulting grid of axes instances is returned within a NumPy array, allowing for convenient specification of the desired axes using standard array indexing notation (Figure 4-64):
- # axes are in a two-dimensional array, indexed by [row, col]
- for i in range(2): # row
- for j in range(3): # col
- ax[i, j].text(0.5, 0.5, str((i, j)), fontsize=18, ha='center')
Figure 4-64. Identifying plots in a subplot grid
In comparison to plt.subplot(), plt.subplots() is more consistent with Python’s conventional 0-based indexing.
plt.GridSpec: More Complicated Arrangements
To go beyond a regular grid to subplots that span multiple rows and columns, plt.GridSpec() is the best tool. The plt.GridSpec() object does not create a plot by itself; it is simply a convenient interface that is recognized by the plt.subplot() command. For example, a gridspec for a grid of two rows and three columns with some specified width and height space looks like this:
- grid = plt.GridSpec(2, 3, wspace=0.4, hspace=0.3)
- plt.subplot(grid[0, 0])
- plt.subplot(grid[0, 1:])
- plt.subplot(grid[1, :2])
- plt.subplot(grid[1, 2]);
Figure 4-65. Irregular subplots with plt.GridSpec
This type of flexible grid alignment has a wide range of uses. I most often use it when creating multi-axes histogram plots like the one shown here (Figure 4-66):
- # Create some normally distributed data
- mean = [0, 0]
- cov = [[1, 1], [1, 2]]
- x, y = np.random.multivariate_normal(mean, cov, 3000).T
- # Set up the axes with gridspec
- fig = plt.figure(figsize=(6, 6))
- grid = plt.GridSpec(4, 4, hspace=0.2, wspace=0.2)
- main_ax = fig.add_subplot(grid[:-1, 1:])
- y_hist = fig.add_subplot(grid[:-1, 0], xticklabels=[], sharey=main_ax)
- x_hist = fig.add_subplot(grid[-1, 1:], yticklabels=[], sharex=main_ax)
- # scatter points on the main axes
- main_ax.plot(x, y, 'ok', markersize=3, alpha=0.2)
- # histogram on the attached axes
- x_hist.hist(x, 40, histtype='stepfilled',
- orientation='vertical', color='gray')
- x_hist.invert_yaxis()
- y_hist.hist(y, 40, histtype='stepfilled', orientation='horizontal', color='gray')
- y_hist.invert_xaxis()
Figure 4-66. Visualizing multidimensional distributions with plt.GridSpec
This type of distribution plotted alongside its margins is common enough that it has its own plotting API in the Seaborn package; see “Visualization with Seaborn” on page 311 for more details
沒有留言:
張貼留言