Customizing the density estimator¶
sbi
allows to specify a custom density estimator for each of the implemented methods. For all options, check the API reference here.
Changing the type of density estimator¶
One option is using one of the preconfigured density estimators by passing a string in the density_estimator
keyword argument to the inference object (NPE
or NLE
), e.g., “maf” for a Masked Autoregressive Flow, of “nsf” for a Neural Spline Flow with default hyperparameters.
import torch
from sbi.inference import NPE, NRE
from sbi.utils import BoxUniform
prior = BoxUniform(torch.zeros(2), torch.ones(2))
inference = NPE(prior=prior, density_estimator="maf")
In the case of NRE
, the argument is called classifier
:
inference = NRE(prior=prior, classifier="resnet")
Changing hyperparameters of density estimators¶
Alternatively, you can use a set of utils functions to configure a density estimator yourself, e.g., use a MAF with hyperparameters chosen for your problem at hand.
Here, because we want to use N*P*E, we specifiy a neural network targeting the posterior (using the utils function posterior_nn
). In this example, we will create a neural spline flow ('nsf'
) with 60
hidden units and 3
transform layers:
# For SNLE: likelihood_nn(). For SNRE: classifier_nn()
from sbi.neural_nets import posterior_nn
density_estimator_build_fun = posterior_nn(
model="nsf", hidden_features=60, num_transforms=3
)
inference = NPE(prior=prior, density_estimator=density_estimator_build_fun)
It is also possible to pass an embedding_net
to posterior_nn()
to automatically
learn summary statistics from high-dimensional simulation outputs. You can find a more
detailed tutorial on this in 04_embedding_networks.
Building new density estimators from scratch¶
Finally, it is also possible to implement your own density estimator from scratch, e.g., including embedding nets to preprocess data, or to a density estimator architecture of your choice.
For this, the density_estimator
argument needs to be a function that takes theta
and x
batches as arguments to then construct the density estimator after the first set of simulations was generated. Our factory functions in sbi/neural_nets/factory.py
return such a function.
The returned density_estimator
object needs to be a subclass of DensityEstimator
, which requires to implement three methods:
log_prob(input, condition, **kwargs)
: Return the log probabilities of the inputs given a condition or multiple i.e. batched conditions.loss(input, condition, **kwargs)
: Return the loss for training the density estimator.sample(sample_shape, condition, **kwargs)
: Return samples from the density estimator.
See more information on the Reference API page.