Can I use a custom prior with sbi?¶
sbi
works with torch distributions only so we recommend to use those whenever possible. For example, when you are used to using scipy.stats
distributions as priors then we recommend using the corresponding torch.distributions
, most common distributions are implemented there.
In case you want to use a custom prior that is not in the set of common distributions that’s possible as well: You need to write a prior class that mimicks the behaviour of a torch.distributions.Distribution
class. Then sbi
will wrap this class to make it a fully functional torch Distribution
.
Essentially, the class needs two methods:
.sample(sample_shape)
, where sample_shape is a shape tuple, e.g.,(n,)
, and returns a batch of n samples, e.g., of shape (n, 2)` for a two dimenional prior..log_prob(value)
method that returns the “log probs” of parameters under the prior, e.g., for a batches of n parameters with shape(n, ndims)
it should return a log probs array of shape(n,)
.
For sbi > 0.17.2 this could look like the following:
class CustomUniformPrior:
"""User defined numpy uniform prior.
Custom prior with user-defined valid .sample and .log_prob methods.
"""
def __init__(self, lower: Tensor, upper: Tensor, return_numpy: bool = False):
self.lower = lower
self.upper = upper
self.dist = BoxUniform(lower, upper)
self.return_numpy = return_numpy
def sample(self, sample_shape=torch.Size([])):
samples = self.dist.sample(sample_shape)
return samples.numpy() if self.return_numpy else samples
def log_prob(self, values):
if self.return_numpy:
values = torch.as_tensor(values)
log_probs = self.dist.log_prob(values)
return log_probs.numpy() if self.return_numpy else log_probs
Once you have such a class you can wrap into a Distribution
using the process_prior
function sbi
provides:
from sbi.utils import process_prior
custom_prior = CustomUniformPrior(torch.zeros(2), torch.ones(2))
prior, *_ = process_prior(custom_prior) # Keeping only the first return.
# use this wrapped prior in sbi...
In sbi
it is sometimes required to check the support of the prior, e.g., when the prior support is bounded and one wants to reject samples from the posterior density estimator that lie outside the prior support. In torch Distributions
this is handled automatically, however, when using a custom prior it is not. Thus,
if your prior has bounded support (like the one above) it makes sense to pass the bounds to the wrapper function such that sbi
can pass them to torch Distributions
:
from sbi.utils import process_prior
custom_prior = CustomUniformPrior(torch.zeros(2), torch.ones(2))
prior = process_prior(custom_prior,
custom_prior_wrapper_kwargs=dict(lower_bound=torch.zeros(2),
upper_bound=torch.ones(2)))
# use this wrapped prior in sbi...
Note that in custom_prior_wrapper_kwargs
you can pass additinal arguments for the wrapper, e.g., validate_args
or arg_constraints
see the Distribution
documentation for more details.
If you are running sbi < 0.17.2 and use SNLE
the code above will produce a NotImplementedError
(see #581). In this case you need to update to a newer version of sbi
or use SNPE
instead.