Posterior Predictive Checks (PPC) in SBI¶
A common safety check performed as part of inference are Posterior Predictive Checks (PPC). A PPC compares data \(x_{\text{pp}}\) generated using the parameters \(\theta_{\text{posterior}}\) sampled from the posterior with the observed data \(x_o\). The general concept is that -if the inference is correct- the generated data \(x_{\text{pp}}\) should “look similar” the oberved data \(x_0\). Said differently, \(x_o\) should be within the support of \(x_{\text{pp}}\).
A PPC usually shouldn’t be used as a validation metric. Nonetheless a PPC is a good start for an inference diagnosis and can provide with an intuition about any bias introduced in inference: does \(x_{\text{pp}}\) systematically differ from \(x_o\)?
Main syntax¶
from sbi.analysis import pairplot
# A PPC is performed after we trained or neural posterior
posterior.set_default_x(x_o)
# We draw theta samples from the posterior. This part is not in the scope of SBI
posterior_samples = posterior.sample((5_000,))
# We use posterior theta samples to generate x data
x_pp = simulator(posterior_samples)
# We verify if the observed data falls within the support of the generated data
_ = pairplot(
samples=x_pp,
points=x_o
)
Performing a PPC over a toy example¶
Below we provide an example Posterior Predictive Check (PPC) over some toy example:
from sbi.analysis import pairplot
import torch
_ = torch.manual_seed(0)
We work on an inference problem over three parameters using any of the techniques implemented in sbi
. In this tutorial, we load the dummy posterior:
from toy_posterior_for_07_cc import ExamplePosterior
posterior = ExamplePosterior()
Let us say that we are observing the data point \(x_o\):
D = 5 # simulator output was 5-dimensional
x_o = torch.ones(1, D)
posterior.set_default_x(x_o)
The posterior can be used to draw \(\theta_{\text{posterior}}\) samples:
posterior_samples = posterior.sample((5_000,))
fig, ax = pairplot(
samples=posterior_samples,
limits=torch.tensor([[-2.5, 2.5]] * 3),
offdiag=["kde"],
diag=["kde"],
figsize=(5, 5),
labels=[rf"$\theta_{d}$" for d in range(3)],
)
Now we can use our simulator to generate some data \(x_{\text{PP}}\), using as input parameters the poterior samples \(\theta_{\text{posterior}}\). Note that the simulation part is not in the sbi
scope, so any simulator -including a non-Python one- can be used at this stage. In our case we’ll use a dummy simulator:
def dummy_simulator(posterior_samples: torch.Tensor, *args, **kwargs) -> torch.Tensor:
sample_size = posterior_samples.shape[0]
scale = 1.0
shift = torch.distributions.Gumbel(loc=torch.zeros(D), scale=scale / 2).sample()
return torch.distributions.Gumbel(loc=x_o[0] + shift, scale=scale).sample(
(sample_size,)
)
x_pp = dummy_simulator(posterior_samples)
Plotting \(x_o\) against the \(x_{\text{pp}}\), we perform a PPC that plays the role of a sanity check. In this case, the check indicates that \(x_o\) falls right within the support of \(x_{\text{pp}}\), which should make the experimenter rather confident about the estimated posterior
:
_ = pairplot(
samples=x_pp,
points=x_o[0],
limits=torch.tensor([[-2.0, 5.0]] * 5),
points_colors="red",
figsize=(8, 8),
offdiag="scatter",
scatter_offdiag=dict(marker=".", s=5),
points_offdiag=dict(marker="+", markersize=20),
labels=[rf"$x_{d}$" for d in range(D)],
)
In contrast, \(x_o\) falling well outside the support of \(x_{\text{pp}}\) is indicative of a failure to estimate the correct posterior. Here we simulate such a failure mode:
error_shift = -2.0 * torch.ones(1, 5)
_ = pairplot(
samples=x_pp,
points=x_o[0] + error_shift,
limits=torch.tensor([[-2.0, 5.0]] * 5),
points_colors="red",
figsize=(8, 8),
offdiag="scatter",
scatter_offdiag=dict(marker=".", s=5),
points_offdiag=dict(marker="+", markersize=20),
labels=[rf"$x_{d}$" for d in range(D)],
)
A typical way to investigate this issue would be to run a prior* predictive check, applying the same plotting strategy, but drawing \(\theta\) from the prior instead of the posterior. **The support for \(x_{\text{pp}}\) should be larger and should contain \(x_o\)*. If this check is successful, the “blame” can then be shifted to the inference (method used, convergence of density estimators, number of sequential rounds, etc…).