Plotting functionality¶
Here we will have a look at the different options for finetuning pairplots and marginal_plots.
Lets first draw some samples from the posterior used in a previous tutorial.
import torch
from toy_posterior_for_07_cc import ExamplePosterior
from sbi.analysis import pairplot
posterior = ExamplePosterior()
posterior_samples = posterior.sample((100,))
We will start with the default plot and gradually make it prettier
_ = pairplot(
posterior_samples,
)

Customisation¶
The pairplots are split into three regions, the diagonal (diag) and the upper and lower off-diagonal regions(upper and lower). We can pass separate arguments (e.g. hist, kde, scatter) for each region, as well as corresponding style keywords in a dictionary (by using e.g. upper_kwargs). For overall figure stylisation one can use fig_kwargs.
To get a closer look at the potential options, have a look at the following dataclasses.
* FigOptions dataclass for figure stylisation.
* ContourOffDiagOptions,HistOffDiagOptions, KdeOffDiagOptions, PlotOffDiagOptions, ScatterOffDiagOptions dataclasses for styling the upper and lower off-diagonal regions.
* HistDiagOptions, KdeDiagOptions, ScatterDiagOptions for styling the diagonal region.
You can find the dataclasses in analysis/plotting_classes.py.
As illustrated below, we can directly use any matplotlib keywords (such as cmap for images) by passing them in the mpl_kwargs entry of upper_kwargs or diag_kwargs.
Migration Note¶
Previously you would pass nested dictionaries to diag_kwargs, upper_kwargs, lower_kwargs, and fig_kwargs arguments. This is still supported for backward compatability, but we recommend using the dataclasses listed above for clarity and autocompletion.
Let’s now make a scatter plot for the upper diagonal, a histogram for the diagonal, and pass the respective dataclasses for both.
from sbi.analysis.plotting_classes import HistDiagOptions, ScatterOffDiagOptions
_ = pairplot(
posterior_samples,
limits=[[-3, 3] * 3],
figsize=(5, 5),
diag="hist",
upper="scatter",
diag_kwargs=HistDiagOptions(
mpl_kwargs={
"color": 'tab:blue',
"histtype": "bar",
"bins": 10,
"edgecolor": 'white',
"linewidth": 1,
"alpha": 0.6,
"fill": True,
}
),
upper_kwargs=ScatterOffDiagOptions(mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8}),
labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
)

Compare two sets of samples¶
By passing a list of samples, we can plot two sets of samples on top of each other.
# draw two different subsets of samples to plot
posterior_samples1 = posterior.sample((20,))
posterior_samples2 = posterior.sample((20,))
_ = pairplot(
[posterior_samples1, posterior_samples2],
limits=[[-3, 3] * 3],
figsize=(5, 5),
diag=["hist", "hist"],
upper=["scatter", "scatter"],
diag_kwargs=HistDiagOptions(
mpl_kwargs={
"bins": 10,
"edgecolor": 'white',
"linewidth": 1,
"alpha": 0.6,
"histtype": "bar",
"fill": True,
}
),
upper_kwargs=ScatterOffDiagOptions(mpl_kwargs={"s": 50, "alpha": 0.8}),
labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
)

Multi-layered plots¶
We can use the same functionality to make a multi-layered plot using the same set of samples, e.g. a kernel-density estimate on top of a scatter plot.
from sbi.analysis.plotting_classes import FigOptions
_ = pairplot(
[posterior_samples, posterior_samples],
limits=[[-3, 3] * 3],
figsize=(5, 5),
diag=["hist", None],
upper=["scatter", "contour"],
diag_kwargs=HistDiagOptions(
mpl_kwargs= {
"bins": 10,
"color": 'tab:blue',
"edgecolor": 'white',
"linewidth": 1,
"alpha": 0.6,
"histtype": "bar",
"fill": True,
},
),
upper_kwargs=[
ScatterOffDiagOptions(
mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8},
),
ScatterOffDiagOptions(mpl_kwargs={"cmap": 'Blues_r', "alpha": 0.8, "colors": None}),
],
labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
fig_kwargs=FigOptions(despine=dict(offset=0)),
)

Lower diagonal¶
We can add something in the lower off-diagonal as well.
from sbi.analysis.plotting_classes import KdeOffDiagOptions
_ = pairplot(
[posterior_samples, posterior_samples],
limits=[[-3, 3] * 3],
figsize=(5, 5),
diag=["hist", None],
upper=["scatter", "contour"],
lower=["kde", None],
diag_kwargs=HistDiagOptions(
mpl_kwargs={
"bins": 10,
"color": 'tab:blue',
"edgecolor": 'white',
"linewidth": 1,
"alpha": 0.6,
"histtype": "bar",
"fill": True,
}
),
upper_kwargs=[
ScatterOffDiagOptions(mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8}),
ScatterOffDiagOptions(mpl_kwargs={"cmap": 'Blues_r', "alpha": 0.8, "colors": None}),
],
lower_kwargs=KdeOffDiagOptions(mpl_kwargs={"cmap": "Blues_r"}),
labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
)

Adding observed data¶
We can also add points, e.g., our observed data \(\theta_o\) to the plot.
# fake observed data:
theta_o = torch.ones(1, 3)
_ = pairplot(
[posterior_samples, posterior_samples],
limits=[[-3, 3] * 3],
figsize=(5, 5),
diag=["hist", None],
upper=["scatter", "contour"],
diag_kwargs=HistDiagOptions(
mpl_kwargs={
"bins": 10,
"color": 'tab:blue',
"edgecolor": 'white',
"linewidth": 1,
"alpha": 0.6,
"histtype": "bar",
"fill": True,
}
),
upper_kwargs=[
ScatterOffDiagOptions(mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8}),
ScatterOffDiagOptions(mpl_kwargs={"cmap": 'Blues_r', "alpha": 0.8, "colors": None}),
],
labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
points=theta_o,
fig_kwargs=FigOptions(
points_labels=[r"$\theta_o$"],
legend=True,
points_colors=["purple"],
points_offdiag=dict(marker="+", markersize=20),
despine=dict(offset=0),
),
)

Subsetting the plot¶
For high-dimensional posteriors, we might only want to visualise a subset of the marginals. This can be done by passing a list of entries to plot to the subset argument of the pairplot function.
_ = pairplot(
[posterior_samples, posterior_samples],
limits=[[-3, 3] * 3],
figsize=(5, 5),
subset=[0, 2],
diag=["hist", None],
upper=["scatter", "contour"],
diag_kwargs=HistDiagOptions(
mpl_kwargs={
"bins": 10,
"color": 'tab:blue',
"edgecolor": 'white',
"linewidth": 1,
"alpha": 0.6,
"histtype": "bar",
"fill": True,
}
),
upper_kwargs=[
ScatterOffDiagOptions(mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8}),
ScatterOffDiagOptions(mpl_kwargs={"cmap": 'Blues_r', "alpha": 0.8, "colors": None}),
],
labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
points=theta_o,
fig_kwargs=FigOptions(
points_labels=[r"$\theta_o$"],
legend=True,
points_colors=["purple"],
points_offdiag=dict(marker="+", markersize=20),
despine=dict(offset=0),
),
)

Plot just the marginals¶
1D Marginals can also be visualised using the marginal_plot function
from sbi.analysis import marginal_plot
# plot posterior samples
_ = marginal_plot(
[posterior_samples, posterior_samples],
limits=[[-3, 3] * 3],
subset=[0, 1],
diag=["hist", None],
diag_kwargs=HistDiagOptions(
mpl_kwargs={
"bins": 10,
"color": 'tab:blue',
"edgecolor": 'white',
"linewidth": 1,
"alpha": 0.6,
"histtype": "bar",
"fill": True,
}
),
labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
points=[torch.ones(1, 3)],
figsize=(4, 2),
fig_kwargs=FigOptions(
points_labels=[r"$\theta_o$"],
legend=True,
points_colors=["purple"],
points_offdiag=dict(marker="+", markersize=20),
despine=dict(offset=0),
),
)
