matplotlib

(image: matplotlib.org)

Topics

  • Basic plotting routines
  • Visualizing continuous data
  • Visualizing categorical data
  • Basic plot customizations

Workshop: Matplotlib and Data Visualization

In this workshop, we will cover using Matplotlib to create data visualizations.

By now, you have already seen Matplotlib in action in the NumPy and Pandas workshops. This workshop serves as a more structured introduction to Matplotlib.

Specifically, we'll be focusing on matplotlib.pyplot.

Installation

Windows: Start Button -> "Anaconda Prompt"

Ubuntu / MacOS: conda should be in your path

Activate the environment

conda activate mldds01

Matplotlib should already be installed. If not, install it:

conda install matplotlib

Average Daily Polyclinic Attendances for Selected Diseases

We will be practicing matplotlib concepts on this dataset.

Download Instructions

  1. Go to https://data.gov.sg/dataset/average-daily-polyclinic-attendances-selected-diseases
  2. Click on the Download button
  3. Unzip and extract the .csv file. Note the path for use below

Note: on Windows you may wish to rename the unzipped folder to something shorter.

In [ ]:
import matplotlib
matplotlib?
In [ ]:
import matplotlib.pyplot as plt
plt?

Read the data

We'll use pandas.read_csv to read the data

In [ ]:
import pandas as pd

# Use pandas to read the CSV file into a pandas.DataFrame,
#   parsing the dates for the `epi_week` column,
#   setting the 0th column as the index
#   renaming the columns (one of them `no._of_cases` is problematic with Python)

df = pd.read_csv('/tmp/polyclinic-attendance/average-daily-polyclinic-attendances-for-selected-diseases.csv',
                 parse_dates=['epi_week'],
                 names=['epi_week', 'disease', 'cases'], header=0,
                 index_col=0)
df.head(5)

Uh oh, the date format is still a string. Let's double-check its type.

In [ ]:
df.index

Hmm, looks like this date format isn't recognized.

We'll need to supply a custom date parser.

In [ ]:
# create the parser 
def parse_date(date):
    """Parses a yyyy-WNN date string
    Args:
        date: a date string in the yyyy-WNN format
    Returns:
        A pandas.datetime64 
    """
    # https://stackoverflow.com/questions/17087314/get-date-from-week-number
    return pd.datetime.strptime(date + '-0', '%Y-W%W-%w')

def parse_dates(dates):
    """Parses a list of dates
    Args:
        dates: a list of dates
    Returns:
        A list of pandas.datetime64
    """
    return [parse_date(d) for d in dates]

# test the parser
parse_dates(['2012-W01', '2012-W52'])

Re-read the CSV with custom date parser

In [ ]:
df = pd.read_csv('/tmp/polyclinic-attendance/average-daily-polyclinic-attendances-for-selected-diseases.csv',
                 parse_dates=['epi_week'], date_parser=parse_dates,
                 names=['epi_week', 'disease', 'cases'], header=0,
                 index_col=0)

df.head(5)
In [ ]:
df.index

Plots

Let's try to plot these graphs:

  1. Line plot showing total number of cases over time
  2. Overlaid line plots showing number of cases per type, over time
  3. Bar chart showing distribution of types of cases over time

Plot workflow

Before we begin, here's a generic workflow for creating a plot.

import matplotlib.pyplot as plt

# create subplots lined up as 1 row and 2 columns
# 20 x 10 "figure units"
# ax1 and ax2 are the axes for each of the subplot
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,
                               figsize=(20, 10))


# get pandas DataFrames
df1 = ...
df2 = ...

# plot the DataFrames
df1.plot(ax=ax1)
ax1.set(title='The left plot',
        ylabel='the y-axis',
        xlabel='the x-axis')

df2.plot(ax=ax2)
ax2.set(title='The right plot',
        ylabel='the y-axis',
        xlabel='the x-axis')

The workflow can be adapted to create any number of plots.

For example, to create 1 plot:

fig, (ax) = plt.subplots(figsize=(20, 10)) # default nrows=1, ncols=1

To create 2 rows of plots:

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1,
                               figsize=(20, 10))

# ax1 is the top row
# ax2 is the bottom row

To plot multiple graphs in the same plot:

fig, (ax) = plt.subplots(figsize=(20, 10))

df1.plot(ax=ax)
df2.plot(ax=ax)

Line plot showing total number of cases over time

To plot this, we need to sum up the cases for each date.

A pandas DataFrame already provides a plot() method that returns a matplotlib AxesSubplot. It's just a raw plot using the column names as the default axis labels.

In [ ]:
df.groupby(df.index)['cases'].sum().plot(marker='o')

Exercise: Plot Customization

Use the workflow to create a customized plot:

  1. Make the plot bigger by setting figsize
  2. Set the X axis to 'Year'
  3. Set the Y axis to 'Number of cases'
  4. Set the title to 'Polyclinic Cases for Selected Diseases'

You can add more customization options, such as ax.grid() to turn on the grid. See: https://matplotlib.org/api/axes_api.html#appearance

In [ ]:
# Your code here

Overlaid line plots showing number of cases per type, over time

Based on the plot, it looks like something serious happened in early 2015.

Let's find out what type of cases contributed to this spike, by plotting a line per type.

First, we need to know what types of diseases there are.

In [ ]:
# find the columns
df.columns
In [ ]:
# find unique values for the `disease` column
df.disease.unique()
In [ ]:
# we can get the Series for number of cases for one disease
diarrhoea_cases = df.loc[df.disease == 'Acute Diarrhoea', 'cases']

diarrhoea_cases
In [ ]:
# List comprehension will give us a list of Series
cases_per_disease = [
        df[df.disease == d]['cases']
    for d in df.disease.unique()
]

cases_per_disease

Exercise: Multi-line Plots

Plot each Series in cases_per_disease as a line on the SAME plot.

  1. Plot each line on the same axis. You can skip marker='o' if the plot looks too dense.
  2. Make the plot bigger by setting figsize
  3. Set the X axis to 'Year', Y axis to 'Number of cases', title to 'Polyclinic Cases for Selected Diseases'
  4. Set the legend using ax.legend(df.disease.unique())
In [ ]:
# Your code here

So the majority of the cases in 2015 are due to "Acute Upper Respiratory Tract infections."

A search of the internet reveals that there was a serious haze around 2015, but the spike was still too large to seem normal.

Just to confirm that we didn't plot things incorrectly, let's inspect the data around the first week of 2015:

In [ ]:
# using pandas' datetime helpers
reference_date = pd.to_datetime('2015-01-01')
start_date = reference_date - pd.DateOffset(weeks=1)
end_date = reference_date + pd.DateOffset(weeks=3)

# row_index, col_index
df.loc[start_date:end_date, :]

Hmm, there are duplicate entries for 2015-01-11. There could be some double-counting here.

In [ ]:
df.loc[pd.to_datetime('2015-01-11'), :]

At this point, we would inspect the raw CSV, confirm things, and contact the data source owner to figure out whether this is expected.

Visualization can spot data abnormalities

The insight here is that plotting can reveal hidden issues in the data. Before spending time creating a model, it's a good idea to plot and check that the plots make sense.

Fixing invalid rows

One way to fix the data is to replace the duplicate entries with their average value, divided by the number of entries.

In [ ]:
# compute the mean cases
date = pd.to_datetime('2015-01-11')
means = df.loc[date].groupby('disease').mean()

print(means.values.flatten())
In [ ]:
# for each disease on that date, replace the cases with the mean
#
# Note: Recall that loc must be used in order to modify the actual DataFrame
# otherwise you'll get a warning about modifying a copy of the DataFrame

for d, mean in zip(means.index, means.values):
    # loc(row_selector, column_selector)
    # the division by 2 is so that their combined sum is the average value
    df.loc[(df.index == date) & (df.disease == d), 'cases'] = mean / 2
    
df.loc[date, :]

Let's replot our graphs with the updated DataFrame:

In [ ]:
fig, ax = plt.subplots(figsize=(20, 10))

[s.plot(ax=ax) for s in [
    df[df.disease == d].groupby(
        df[df.disease == d].index)['cases'].sum()
    for d in df.disease.unique()]
]

ax.set(title='Polyclinic Cases for Selected Diseases',
      ylabel='Number of cases',
      xlabel='Week of the year')
ax.grid()
ax.legend(df.disease.unique())

Bar chart showing distribution of types of cases over time

In our final plot, we'll do a bar chart.

This should be similar to what we did in the pandas workshop.

Exercise: Pivot table and Bar chart

  1. Create a pandas.pivot_table using index=df.index.weekofyear, disease as columns, and cases as values
  2. Plot the pivot table as a stacked bar chart
  3. Customize the bar chart to your liking. For example, x-axis 'Week of year', y-axis 'Case distribution'

Hint: refer to the final exercise in the pandas worksheet on how to setup the pivot_table, or try pd.pivot_table? for help.

In [ ]:
# Your code here