(image: matplotlib.org)
In this workshop, we will cover using Matplotlib to create data visualizations.
By now, you have already seen Matplotlib in action in the NumPy and Pandas workshops. This workshop serves as a more structured introduction to Matplotlib.
Specifically, we'll be focusing on matplotlib.pyplot
.
Windows: Start Button -> "Anaconda Prompt"
Ubuntu / MacOS: conda should be in your path
Activate the environment
conda activate mldds01
Matplotlib should already be installed. If not, install it:
conda install matplotlib
We will be practicing matplotlib
concepts on this dataset.
Note: on Windows you may wish to rename the unzipped folder to something shorter.
import matplotlib
matplotlib?
import matplotlib.pyplot as plt
plt?
We'll use pandas.read_csv
to read the data
import pandas as pd
# Use pandas to read the CSV file into a pandas.DataFrame,
# parsing the dates for the `epi_week` column,
# setting the 0th column as the index
# renaming the columns (one of them `no._of_cases` is problematic with Python)
df = pd.read_csv('/tmp/polyclinic-attendance/average-daily-polyclinic-attendances-for-selected-diseases.csv',
parse_dates=['epi_week'],
names=['epi_week', 'disease', 'cases'], header=0,
index_col=0)
df.head(5)
Uh oh, the date format is still a string. Let's double-check its type.
df.index
Hmm, looks like this date format isn't recognized.
We'll need to supply a custom date parser.
# create the parser
def parse_date(date):
"""Parses a yyyy-WNN date string
Args:
date: a date string in the yyyy-WNN format
Returns:
A pandas.datetime64
"""
# https://stackoverflow.com/questions/17087314/get-date-from-week-number
return pd.datetime.strptime(date + '-0', '%Y-W%W-%w')
def parse_dates(dates):
"""Parses a list of dates
Args:
dates: a list of dates
Returns:
A list of pandas.datetime64
"""
return [parse_date(d) for d in dates]
# test the parser
parse_dates(['2012-W01', '2012-W52'])
df = pd.read_csv('/tmp/polyclinic-attendance/average-daily-polyclinic-attendances-for-selected-diseases.csv',
parse_dates=['epi_week'], date_parser=parse_dates,
names=['epi_week', 'disease', 'cases'], header=0,
index_col=0)
df.head(5)
df.index
Let's try to plot these graphs:
Before we begin, here's a generic workflow for creating a plot.
import matplotlib.pyplot as plt
# create subplots lined up as 1 row and 2 columns
# 20 x 10 "figure units"
# ax1 and ax2 are the axes for each of the subplot
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,
figsize=(20, 10))
# get pandas DataFrames
df1 = ...
df2 = ...
# plot the DataFrames
df1.plot(ax=ax1)
ax1.set(title='The left plot',
ylabel='the y-axis',
xlabel='the x-axis')
df2.plot(ax=ax2)
ax2.set(title='The right plot',
ylabel='the y-axis',
xlabel='the x-axis')
The workflow can be adapted to create any number of plots.
For example, to create 1 plot:
fig, (ax) = plt.subplots(figsize=(20, 10)) # default nrows=1, ncols=1
To create 2 rows of plots:
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1,
figsize=(20, 10))
# ax1 is the top row
# ax2 is the bottom row
To plot multiple graphs in the same plot:
fig, (ax) = plt.subplots(figsize=(20, 10))
df1.plot(ax=ax)
df2.plot(ax=ax)
To plot this, we need to sum up the cases for each date.
A pandas DataFrame already provides a plot()
method that returns a matplotlib
AxesSubplot
. It's just a raw plot using the column names as the default axis labels.
df.groupby(df.index)['cases'].sum().plot(marker='o')
Use the workflow to create a customized plot:
You can add more customization options, such as ax.grid()
to turn on the grid.
See: https://matplotlib.org/api/axes_api.html#appearance
# Your code here
Based on the plot, it looks like something serious happened in early 2015.
Let's find out what type of cases contributed to this spike, by plotting a line per type.
First, we need to know what types of diseases there are.
# find the columns
df.columns
# find unique values for the `disease` column
df.disease.unique()
# we can get the Series for number of cases for one disease
diarrhoea_cases = df.loc[df.disease == 'Acute Diarrhoea', 'cases']
diarrhoea_cases
# List comprehension will give us a list of Series
cases_per_disease = [
df[df.disease == d]['cases']
for d in df.disease.unique()
]
cases_per_disease
Plot each Series in cases_per_disease
as a line on the SAME plot.
ax.legend(df.disease.unique())
# Your code here
So the majority of the cases in 2015 are due to "Acute Upper Respiratory Tract infections."
A search of the internet reveals that there was a serious haze around 2015, but the spike was still too large to seem normal.
Just to confirm that we didn't plot things incorrectly, let's inspect the data around the first week of 2015:
# using pandas' datetime helpers
reference_date = pd.to_datetime('2015-01-01')
start_date = reference_date - pd.DateOffset(weeks=1)
end_date = reference_date + pd.DateOffset(weeks=3)
# row_index, col_index
df.loc[start_date:end_date, :]
Hmm, there are duplicate entries for 2015-01-11
. There could be some double-counting here.
df.loc[pd.to_datetime('2015-01-11'), :]
At this point, we would inspect the raw CSV, confirm things, and contact the data source owner to figure out whether this is expected.
The insight here is that plotting can reveal hidden issues in the data. Before spending time creating a model, it's a good idea to plot and check that the plots make sense.
One way to fix the data is to replace the duplicate entries with their average value, divided by the number of entries.
# compute the mean cases
date = pd.to_datetime('2015-01-11')
means = df.loc[date].groupby('disease').mean()
print(means.values.flatten())
# for each disease on that date, replace the cases with the mean
#
# Note: Recall that loc must be used in order to modify the actual DataFrame
# otherwise you'll get a warning about modifying a copy of the DataFrame
for d, mean in zip(means.index, means.values):
# loc(row_selector, column_selector)
# the division by 2 is so that their combined sum is the average value
df.loc[(df.index == date) & (df.disease == d), 'cases'] = mean / 2
df.loc[date, :]
Let's replot our graphs with the updated DataFrame:
fig, ax = plt.subplots(figsize=(20, 10))
[s.plot(ax=ax) for s in [
df[df.disease == d].groupby(
df[df.disease == d].index)['cases'].sum()
for d in df.disease.unique()]
]
ax.set(title='Polyclinic Cases for Selected Diseases',
ylabel='Number of cases',
xlabel='Week of the year')
ax.grid()
ax.legend(df.disease.unique())
In our final plot, we'll do a bar chart.
This should be similar to what we did in the pandas workshop.
pandas.pivot_table
using index=df.index.weekofyear
, disease
as columns, and cases
as valuesHint: refer to the final exercise in the pandas worksheet on how to setup the pivot_table
, or try pd.pivot_table?
for help.
# Your code here