Skip to content

Visualization

There are a lot of different libraries for plotting and visualizing data, but the one we will focus on here is Matplotlib. Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python.

More in-depth information is available at the official documentation and cheatsheets.

It is standard practice to import Matplotlib into the plt namespace:

import matplotlib.pyplot as plt

For the examples on this page we also need to import Polars and two classes from Matplotlib that we will use:

import polars as pl
from matplotlib.ticker import MultipleLocator
from polars import col

We will be working with the patient data from the previous chapter:

patients = pl.read_csv("../../data/hosp/patients.csv.gz")

Let’s start by exploring the age in the data as a histogram. First we need to create a figure with one or more axes (areas for plotting), then we can draw our histogram on the axis.

fig, ax = plt.subplots()
ax.hist(patients["anchor_age"])

Great! Now let’s increase the number of bins.

fig, ax = plt.subplots()
ax.hist(patients["anchor_age"], bins=20)

It looks ok, but it could definitely be prettier. Let’s change the color and make the bars stand out more with a slight border.

fig, ax = plt.subplots()
ax.hist(patients["anchor_age"], bins=20, facecolor="#3b82f6", edgecolor="white")

Much better! Continuing from the previous code, let’s set the limits and labels of our x and y axes. We will also tweak where the ticks are located on the y axis.

ax.set_xlim(15, 95)
ax.set_ylim(0, 20)
ax.set_xlabel("Age in years")
ax.set_ylabel("Number of subjects")
ax.yaxis.set_major_locator(MultipleLocator(2))

Finally, let’s add a grid underneath the figure to make it easier to see the scale of the bins.

ax.grid(color="#f3f4f6")
ax.set_axisbelow(True) # Puts the grid below the figure

Now, isn’t that a beautiful figure!

If we want to show multiple plots on the same axis, we can plot them one after the other using the same ax object. We better also add a legend for clarity.

fig, ax = plt.subplots()
males = patients.filter(col("gender").eq("M"))["anchor_age"]
females = patients.filter(col("gender").eq("F"))["anchor_age"]
ax.hist(males, facecolor="#3b82f6", alpha=0.5, label="Male")
ax.hist(females, facecolor="#ef4444", alpha=0.5, label="Female")
ax.legend()

If we instead want to split the stratification into separate axes/panels, we can tell the subplots function that we want 2 rows of axes and then plot the stratified data on the respective axis.

fig, axs = plt.subplots(2, 1, constrained_layout=True)
axs[0].hist(males, facecolor="#3b82f6", alpha=0.5)
axs[0].set_title("Male")
axs[1].hist(females, facecolor="#ef4444", alpha=0.5)
axs[1].set_title("Female")

Finally, we can save the figure for use in our manuscript (change the file path as needed). Many different file formats are supported (eg. png, svg, pdf, tiff).

fig.savefig("../../assets/img/visualization/figure-Pyw3e0.svg")

Below are basic examples of different types of plots.

fig, ax = plt.subplots()
ax.plot([0, 1, 2, 3], [0.5, 12, 15, 3])

fig, ax = plt.subplots()
ax.bar(*patients["gender"].value_counts())
ax.set_xlabel("Sex")
ax.set_ylabel("n")

fig, ax = plt.subplots()
ax.hist(patients["anchor_age"])

fig, ax = plt.subplots()
strat_by_age = patients.group_by("gender").agg("anchor_age")
ax.boxplot(strat_by_age["anchor_age"], tick_labels=strat_by_age["gender"].to_list())
ax.set_xlabel("Sex")
ax.set_ylabel("anchor_age")

fig, ax = plt.subplots()
ax.scatter(patients["anchor_age"], patients["anchor_year"])
ax.set_xlabel("anchor_age")
ax.set_ylabel("anchor_year")

We have already seen an example of a figure with two panels when we stratified the age histogram by sex above. Here we’ll look into more detail how to create multiple panels within one figure.

fig, axs = plt.subplots(2, 2, constrained_layout=True)
axs[0][0].set_title("Histogram")
axs[0][0].hist(patients["anchor_age"])
axs[0][0].set_xlabel("Age in years")
axs[0][0].set_ylabel("Number of subjects")
axs[0][1].set_title("Bar")
axs[0][1].bar(*patients["gender"].value_counts())
axs[0][1].set_xlabel("Sex")
axs[0][1].set_ylabel("n")
strat_by_age = patients.group_by("gender").agg("anchor_age")
axs[1][0].set_title("Box")
axs[1][0].boxplot(
strat_by_age["anchor_age"], tick_labels=strat_by_age["gender"].to_list()
)
axs[1][0].set_xlabel("Sex")
axs[1][0].set_ylabel("anchor_age")
axs[1][1].set_title("Scatter")
axs[1][1].scatter(patients["anchor_age"], patients["anchor_year"])
axs[1][1].set_xlabel("anchor_age")
axs[1][1].set_ylabel("anchor_year")