Skip to content

SBI for decision-making models

In a previous tutorial, we showed how to use SBI with trial-based iid data. Such scenarios can arise, for example, in models of perceptual decision making. In addition to trial-based iid data points, these models often come with mixed data types and varying experimental conditions. Here, we show how sbi can be used to perform inference in such models with the MNLE method.

Note, you find the original version of this notebook in the sbi repository under tutorials/Example_01_DecisionMakingModel.ipynb.

Trial-based SBI with mixed data types

In some cases, models with trial-based data additionally return data with mixed data types, e.g., continous and discrete data. For example, most computational models of decision-making have continuous reaction times and discrete choices as output.

This can induce a problem when performing trial-based SBI that relies on learning a neural likelihood: It is challenging for most density estimators to handle both, continuous and discrete data at the same time. However, there is a recent SBI method for solving this problem, it’s called Mixed Neural Likelihood Estimation (MNLE). It works just like NLE, but with mixed data types. The trick is that it learns two separate density estimators, one for the discrete part of the data, and one for the continuous part, and combines the two to obtain the final neural likelihood. Crucially, the continuous density estimator is trained conditioned on the output of the discrete one, such that statistical dependencies between the discrete and continuous data (e.g., between choices and reaction times) are modeled as well. The interested reader is referred to the original paper available here.

MNLE was recently added to sbi (see this PR and also issue) and follows the same API as SNLE.

In this tutorial we will show how to apply MNLE to mixed data, and how to deal with varying experimental conditions.

Toy problem for MNLE

To illustrate MNLE we set up a toy simulator that outputs mixed data and for which we know the likelihood such we can obtain reference posterior samples via MCMC.

Simulator: To simulate mixed data we do the following

  • Sample reaction time from inverse Gamma
  • Sample choices from Binomial
  • Return reaction time \(rt \in (0, \infty)\) and choice index \(c \in \{0, 1\}\)
\[ c \sim \text{Binomial}(\rho) \\ rt \sim \text{InverseGamma}(\alpha=2, \beta) \\ \]

Prior: The priors of the two parameters \(\rho\) and \(\beta\) are independent. We define a Beta prior over the probabilty parameter of the Binomial used in the simulator and a Gamma prior over the shape-parameter of the inverse Gamma used in the simulator:

\[ p(\beta, \rho) = p(\beta) \; p(\rho) ; \\ p(\beta) = \text{Gamma}(1, 0.5) \\ p(\text{probs}) = \text{Beta}(2, 2) \]

Because the InverseGamma and the Binomial likelihoods are well-defined we can perform MCMC on this problem and obtain reference-posterior samples.

import matplotlib.pyplot as plt
import torch
from pyro.distributions import InverseGamma
from torch import Tensor
from torch.distributions import Beta, Binomial, Gamma

from sbi.analysis import pairplot
from sbi.inference import MNLE, MCMCPosterior
from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential
from sbi.neural_nets import likelihood_nn
from sbi.utils import BoxUniform, MultipleIndependent, mcmc_transform
from sbi.utils.metrics import c2st


from example_01_utils import BinomialGammaPotential
# Toy simulator for mixed data
def mixed_simulator(theta: Tensor, concentration_scaling: float = 1.0):
    """Returns a sample from a mixed distribution given parameters theta.

    Args:
        theta: batch of parameters, shape (batch_size, 2) concentration_scaling:
        scaling factor for the concentration parameter of the InverseGamma
        distribution, mimics an experimental condition.

    """
    beta, rho = theta[:, :1], theta[:, 1:]

    choices = Binomial(probs=rho).sample()
    rts = InverseGamma(
        concentration=concentration_scaling * torch.ones_like(beta), rate=beta
    ).sample()

    return torch.cat((rts, choices), dim=1)
# Define independent prior.
prior = MultipleIndependent(
    [
        Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
        Beta(torch.tensor([2.0]), torch.tensor([2.0])),
    ],
    validate_args=False,
)

Obtain reference-posterior samples via analytical likelihood and MCMC

torch.manual_seed(42)
num_trials = 10
num_samples = 1000
theta_o = prior.sample((1,))
x_o = mixed_simulator(theta_o.repeat(num_trials, 1))
mcmc_kwargs = dict(
    num_chains=20,
    warmup_steps=50,
    method="slice_np_vectorized",
    init_strategy="proposal",
    thin=1,
)

true_posterior = MCMCPosterior(
    potential_fn=BinomialGammaPotential(prior, x_o),
    proposal=prior,
    theta_transform=mcmc_transform(prior, enable_transform=True),
    **mcmc_kwargs,
)
true_samples = true_posterior.sample((num_samples,))

Train MNLE and generate samples via MCMC

# Training data
num_simulations = 10000
# For training the MNLE emulator we need to define a proposal distribution, the prior is
# a good choice.
proposal = prior
theta = proposal.sample((num_simulations,))
x = mixed_simulator(theta)

# Train MNLE and obtain MCMC-based posterior.
trainer = MNLE()
estimator = trainer.append_simulations(theta, x).train()
 Neural network successfully converged after 75 epochs.
# Build posterior from the trained estimator and prior.
mnle_posterior = trainer.build_posterior(prior=prior)

mnle_samples = mnle_posterior.sample((num_samples,), x=x_o, **mcmc_kwargs)

Compare MNLE and reference posterior

# Plot them in one pairplot as contours (obtained via KDE on the samples).
fig, ax = pairplot(
    [
        prior.sample((1000,)),
        true_samples,
        mnle_samples,
    ],
    points=theta_o,
    diag="kde",
    upper="contour",
    upper_kwargs=dict(levels=[0.95]),
    diag_kwargs=dict(bins=100),
    fig_kwargs=dict(
        points_offdiag=dict(marker="*", markersize=10),
        points_colors=["k"],
    ),
    labels=[r"$\beta$", r"$\rho$"],
    figsize=(6, 6),
)

plt.sca(ax[1, 1])
plt.legend(
    ["Prior", "Reference", "MNLE", r"$\theta_o$"],
    frameon=False,
    fontsize=12,
);

png

We see that the inferred MNLE posterior nicely matches the reference posterior, and how both inferred a posterior that is quite different from the prior.

Because MNLE training is amortized we can obtain another posterior given a different observation with potentially a different number of trials, just by running MCMC again (without re-training MNLE):

Repeat inference with different x_o that contains more trials

num_trials = 50
x_o = mixed_simulator(theta_o.repeat(num_trials, 1))
true_samples = true_posterior.sample((num_samples,), x=x_o, **mcmc_kwargs)
mnle_samples = mnle_posterior.sample((num_samples,), x=x_o, **mcmc_kwargs)
# Plot them in one pairplot as contours (obtained via KDE on the samples).
fig, ax = pairplot(
    [
        prior.sample((1000,)),
        true_samples,
        mnle_samples,
    ],
    points=theta_o,
    diag="kde",
    upper="contour",
    diag_kwargs=dict(bins=100),
    upper_kwargs=dict(levels=[0.95]),
    fig_kwargs=dict(
        points_offdiag=dict(marker="*", markersize=10),
        points_colors=["k"],
    ),
    labels=[r"$\beta$", r"$\rho$"],
    figsize=(6, 6),
)

plt.sca(ax[1, 1])
plt.legend(
    ["Prior", "Reference", "MNLE", r"$\theta_o$"],
    frameon=False,
    fontsize=12,
);

png

print("c2st between true and MNLE posterior:", c2st(true_samples, mnle_samples).item())
c2st between true and MNLE posterior: 0.5155000000000001

Again we can see that the posteriors match nicely. In addition, we observe that the posterior’s (epistemic) uncertainty reduces as we increase the number of trials.

Note: MNLE is trained on single-trial data. Theoretically, density estimation is perfectly accurate only in the limit of infinite training data. Thus, training with a finite amount of training data naturally induces a small bias in the density estimator. As we observed above, this bias is so small that we don’t really notice it, e.g., the c2st scores were close to 0.5. However, when we increase the number of trials in x_o dramatically (on the order of 1000s) the small bias can accumulate over the trials and inference with MNLE can become less accurate.

MNLE with experimental conditions

In the perceptual decision-making research, it is common to design experiments with varying experimental decisions, e.g., to vary the difficulty of the task. During parameter inference, it can be beneficial to incorporate the experimental conditions.

In MNLE, we are learning an emulator that should be able to generate synthetic experimental data including reaction times and choices given different experimental conditions. Thus, to make MNLE work with experimental conditions, we need to include them in the training process, i.e., treat them like auxiliary parameters of the simulator:

# Define a proposal that contains both, priors for the parameters and a discrte
# prior over experimental conditions.
proposal = MultipleIndependent(
    [
        Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
        Beta(torch.tensor([2.0]), torch.tensor([2.0])),
        BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),
    ],
    validate_args=False,
)

# define a simulator wrapper in which the experimental condition are contained
# in theta and passed to the simulator.
def sim_wrapper(theta_and_conditions):
    # simulate with experiment conditions
    return mixed_simulator(
        # we assume the first two parameters are beta and rho
        theta=theta_and_conditions[:, :2],
        # we treat the third concentration parameter as an experimental condition
        concentration_scaling=theta_and_conditions[:, 2:],
    )
# Simulated data
num_simulations = 10000
num_samples = 1000
theta = proposal.sample((num_simulations,))
x = sim_wrapper(theta)
assert x.shape == (num_simulations, 2)

# simulate observed data and define ground truth parameters
num_trials = 10
# draw one ground truth parameter
theta_o = proposal.sample((1,))[:, :2]
# draw num_trials many different conditions
conditions = proposal.sample((num_trials,))[:, 2:]
# Theta is repeated for each trial, conditions are different for each trial.
theta_and_conditions_o = torch.cat((theta_o.repeat(num_trials, 1), conditions), dim=1)
x_o = sim_wrapper(theta_and_conditions_o)

Obtain ground truth posterior via MCMC

We obtain a ground-truth posterior via MCMC by using the analytical Binomial-Gamma likelihood as before.

For that, we first the define the actual prior, i.e., the distribution over the parameter we want to infer (not the proposal). (dropping the uniform prior over experimental conditions).

Additionally, we pass the entire batch of i.i.d. data x_o and matching batch of i.i.d. conditions.

prior = MultipleIndependent(
    [
        Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
        Beta(torch.tensor([2.0]), torch.tensor([2.0])),
    ],
    validate_args=False,
)
prior_transform = mcmc_transform(prior)

# We can now use the PotentialFunctionProvider to obtain a ground-truth
# posterior via MCMC.
true_posterior_samples = MCMCPosterior(
    BinomialGammaPotential(
        prior,
        x_o,
        concentration_scaling=conditions,
    ),
    theta_transform=prior_transform,
    proposal=prior,
    **mcmc_kwargs,
).sample((num_samples,), show_progress_bars=True)

Train MNLE including experimental conditions

Next, we use the combined parameters and conditions (theta) and the corresponding simulated data to train MNLE.

estimator_builder = likelihood_nn(model="mnle", z_score_x=None)  # we don't want to z-score the binary data.
trainer = MNLE(proposal, estimator_builder)
estimator = trainer.append_simulations(theta, x).train()
 Neural network successfully converged after 75 epochs.

Construct conditional potential function

We have now an emulator for the extended simulator, i.e., the one that has both the model parameters and the experimental condition as parameters.

To obtain posterior samples conditioned on a particular experimental condition (and on x_o), we need to construct a corresponding potential function that can return the log likelihood of the model parameters, but conditioned on the experimental condition.

# First, we define the potential function for the complete, unconditional MNLE-likelihood
potential_fn = LikelihoodBasedPotential(estimator, proposal)
# Then, we condition on the experimental conditions.
conditioned_potential_fn = potential_fn.condition_on_theta(
    conditions,  # pass only the conditions, must match the batch of iid data in x_o
    dims_global_theta=[0, 1]  # pass the dimensions in the original theta that correspond to beta and rho
)

# Using this potential function, we can now obtain conditional samples.
mnle_posterior = MCMCPosterior(
    potential_fn=conditioned_potential_fn,  # pass the conditioned potential function
    theta_transform=prior_transform,
    proposal=prior,  # pass the prior, not the proposal.
    **mcmc_kwargs
)
conditional_samples = mnle_posterior.sample((num_samples,), x=x_o)
# Finally, we can compare the ground truth conditional posterior with the
# MNLE-conditional posterior.
fig, ax = pairplot(
    [
        prior.sample((1000,)),
        true_posterior_samples,
        conditional_samples,
    ],
    points=theta_o,
    diag="kde",
    upper="contour",
    diag_kwargs=dict(bins=100),
    upper_kwargs=dict(levels=[0.95]),
    fig_kwargs=dict(
        points_offdiag=dict(marker="*", markersize=10),
        points_colors=["k"],

    ),
    labels=[r"$\beta$", r"$\rho$"],
    figsize=(6, 6),
)

plt.sca(ax[1, 1])
plt.legend(
    ["Prior", "Reference", "MNLE", r"$\theta_o$"],
    frameon=False,
    fontsize=12,
);
print("c2st between true and MNLE posterior:", c2st(true_posterior_samples, conditional_samples).item())
c2st between true and MNLE posterior: 0.551

png

They match accurately, showing that we can indeed post-hoc condition the trained MNLE likelihood on different experimental conditions.

Inference with multiple subjects, trials, and conditions

Note that we can also do inference for multiple x_os (e.g., subjects) with varying numbers of trails and experimental conditions - all without retraining the MNLE.

torch.manual_seed(42)
num_subjects = 3
num_trials = [10, 20, 30]
# draw one ground truth parameter
theta_o = proposal.sample((num_subjects,))[:, :2]
# Note that the trial conditions need to be the same for all subjects.

# Simulate observed data for all subjects and trials.
x_os = []
conditions = []
for i in range(num_subjects):
    conditions.append(proposal.sample((num_trials[i],))[:, 2:])
    # Theta is repeated for each trial, conditions are different for each trial.
    theta_and_condition = torch.cat((theta_o[i].repeat(num_trials[i], 1), conditions[i]), dim=-1)
    x_os.append(sim_wrapper(theta_and_condition))

# loop over subjects (vectorized batched x and batched conditions is not supported yet)
posterior_samples = []
for idx in range(num_subjects):
    # condition the potential
    conditioned_potential_fn = potential_fn.condition_on_theta(
        conditions[idx],
        dims_global_theta=[0, 1]
    )

    # pass potential to sampler
    mnle_posterior = MCMCPosterior(
        potential_fn=conditioned_potential_fn,  # pass the conditioned potential function
        theta_transform=prior_transform,
        proposal=prior,  # pass the prior, not the proposal.
        **mcmc_kwargs
    )
    posterior_samples.append(mnle_posterior.sample((num_samples,), x=x_os[idx], show_progress_bars=True))
# Plotting all three posteriors in one pairplot.

fig, ax = pairplot(
    [prior.sample((1000,))] + posterior_samples,
    # points=theta_o,
    diag="kde",
    upper="contour",
    diag_kwargs=dict(bins=100),
    upper_kwargs=dict(levels=[0.95]),
    fig_kwargs=dict(
        points_offdiag=dict(marker="*", markersize=10),
        points_colors=["k"],

    ),
    labels=[r"$\beta$", r"$\rho$"],
    figsize=(10, 10),
)

plt.sca(ax[1, 1])
plt.legend(
    ["prior"] + [f"Subject {idx+1}" for idx in range(num_subjects)],
    frameon=False,
    fontsize=12,
);

png

Note how the posteriors are becoming more narrow with increasing number of trials (subject 1: 10 trials vs. subject 3: 30 trials).