Skip to content

Posteriors

DirectPosterior

Bases: NeuralPosterior

Posterior based on neural networks that directly estimate the posterior (NPE).

NPE trains a neural network to directly approximate the posterior distribution. However, for bounded priors, the neural network can have leakage: it puts non-zero mass in regions where the prior is zero. The DirectPosterior class wraps the trained network to deal with these cases.

Specifically, this class offers the following functionality:

  • correct the calculation of the log probability such that it compensates for the leakage.
  • reject samples that lie outside of the prior bounds.

This class can not be used in combination with NLE or NRE.

Source code in sbi/inference/posteriors/direct_posterior.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
class DirectPosterior(NeuralPosterior):
    r"""Posterior based on neural networks that directly estimate the posterior (NPE).

    NPE trains a neural network to directly approximate the posterior distribution.
    However, for bounded priors, the neural network can have leakage: it puts non-zero
    mass in regions where the prior is zero. The `DirectPosterior` class wraps the
    trained network to deal with these cases.

    Specifically, this class offers the following functionality:

    - correct the calculation of the log probability such that it compensates for the
      leakage.
    - reject samples that lie outside of the prior bounds.

    This class can not be used in combination with NLE or NRE.
    """

    def __init__(
        self,
        posterior_estimator: ConditionalDensityEstimator,
        prior: Distribution,
        max_sampling_batch_size: int = 10_000,
        device: Optional[Union[str, torch.device]] = None,
        x_shape: Optional[torch.Size] = None,
        enable_transform: bool = True,
    ):
        """
        Args:
            prior: Prior distribution with `.log_prob()` and `.sample()`.
            posterior_estimator: The trained neural posterior.
            max_sampling_batch_size: Batchsize of samples being drawn from
                the proposal at every iteration.
            device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
                `potential_fn.device` is used.
            x_shape: Deprecated, should not be passed.
            enable_transform: Whether to transform parameters to unconstrained space
                during MAP optimization. When False, an identity transform will be
                returned for `theta_transform`.
        """
        # Because `DirectPosterior` does not take the `potential_fn` as input, it
        # builds it itself. The `potential_fn` and `theta_transform` are used only for
        # obtaining the MAP.
        check_prior(prior)
        self.enable_transform = enable_transform
        self.x_shape = x_shape
        potential_fn, theta_transform = posterior_estimator_based_potential(
            posterior_estimator,
            prior,
            x_o=None,
            enable_transform=enable_transform,
        )

        super().__init__(
            potential_fn=potential_fn,
            theta_transform=theta_transform,
            device=device,
            x_shape=x_shape,
        )

        self.device = device
        self.prior = prior
        self.posterior_estimator = posterior_estimator

        self.max_sampling_batch_size = max_sampling_batch_size
        self._leakage_density_correction_factor = None

        self._purpose = """It samples the posterior network and rejects samples that
            lie outside of the prior bounds."""

    def to(self, device: Union[str, torch.device]) -> None:
        """Move posterior_estimator, prior and x_o to device.

        Changes the device attribute, reinstanciates the
        posterior, and resets the default x.

        Args:
            device: device where to move the posterior to.
        """
        self.device = device
        if hasattr(self.prior, "to"):
            self.prior.to(device)  # type: ignore
        else:
            raise ValueError("""Prior has no attribute to(device).""")
        if hasattr(self.posterior_estimator, "to"):
            self.posterior_estimator.to(device)
        else:
            raise ValueError("""Posterior estimator has no attribute to(device).""")

        potential_fn, theta_transform = posterior_estimator_based_potential(
            self.posterior_estimator,
            self.prior,
            x_o=None,
            enable_transform=self.enable_transform,
        )
        x_o = None
        if hasattr(self, "_x") and (self._x is not None):
            x_o = self._x.to(device)

        super().__init__(
            potential_fn=potential_fn,
            theta_transform=theta_transform,
            device=device,
            x_shape=self.x_shape,
        )
        # super().__init__ erases the self._x, so we need to set it again
        if x_o is not None:
            self.set_default_x(x_o)

    def sample(
        self,
        sample_shape: Shape = torch.Size(),
        x: Optional[Tensor] = None,
        max_sampling_batch_size: int = 10_000,
        show_progress_bars: bool = True,
        reject_outside_prior: bool = True,
        max_sampling_time: Optional[float] = None,
        return_partial_on_timeout: bool = False,
    ) -> Tensor:
        r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.

        Args:
            sample_shape: Desired shape of samples that are drawn from posterior. If
                sample_shape is multidimensional we simply draw `sample_shape.numel()`
                samples and then reshape into the desired shape.
            x: Conditioning observation $x_o$. If not provided, uses the default `x`
                set via `.set_default_x()`.
            max_sampling_batch_size: Maximum batch size for rejection sampling.
            show_progress_bars: Whether to show sampling progress monitor.
            reject_outside_prior: If True (default), rejection sampling is used to
                ensure samples lie within the prior support. If False, samples are drawn
                directly from the neural density estimator without rejection, which is
                faster but may include samples outside the prior support.
            max_sampling_time: Optional maximum allowed sampling time in seconds.
                If exceeded, sampling is aborted and a RuntimeError is raised. Only
                applies when `reject_outside_prior=True` (no effect otherwise since
                direct sampling is fast).
            return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
                return the samples collected so far instead of raising a RuntimeError.
                A warning will be issued. Only applies when `reject_outside_prior=True`
                (default).
        """
        num_samples = torch.Size(sample_shape).numel()
        x = self._x_else_default_x(x)
        x = reshape_to_batch_event(
            x, event_shape=self.posterior_estimator.condition_shape
        )
        if x.shape[0] > 1:
            raise ValueError(
                ".sample() supports only `batchsize == 1`. If you intend "
                "to sample multiple observations, use `.sample_batched()`. "
                "If you intend to sample i.i.d. observations, set up the "
                "posterior density estimator with an appropriate permutation "
                "invariant embedding net."
            )

        max_sampling_batch_size = (
            self.max_sampling_batch_size
            if max_sampling_batch_size is None
            else max_sampling_batch_size
        )

        if reject_outside_prior:
            # Normal rejection behavior.
            samples = rejection.accept_reject_sample(
                proposal=self.posterior_estimator.sample,
                accept_reject_fn=lambda theta: within_support(self.prior, theta),
                num_samples=num_samples,
                show_progress_bars=show_progress_bars,
                max_sampling_batch_size=max_sampling_batch_size,
                proposal_sampling_kwargs={"condition": x},
                alternative_method="build_posterior(..., sample_with='mcmc')",
                max_sampling_time=max_sampling_time,
                return_partial_on_timeout=return_partial_on_timeout,
            )[0]
        else:
            # Bypass rejection sampling entirely.
            samples = self.posterior_estimator.sample(
                torch.Size([num_samples]),
                condition=x,
            )
            warn_if_outside_prior_support(self.prior, samples[:, 0])

        return samples[:, 0]  # Remove batch dimension.

    def sample_batched(
        self,
        sample_shape: Shape,
        x: Tensor,
        max_sampling_batch_size: int = 10_000,
        show_progress_bars: bool = True,
        reject_outside_prior: bool = True,
        max_sampling_time: Optional[float] = None,
        return_partial_on_timeout: bool = False,
    ) -> Tensor:
        r"""Draw samples from the posteriors for a batch of different xs.

        Given a batch of observations `[x_1, ..., x_B]`, this method samples from
        posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.

        Args:
            sample_shape: Desired shape of samples that are drawn from the posterior
                given every observation.
            x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
                `batch_dim` corresponds to the number of observations to be drawn.
            max_sampling_batch_size: Maximum batch size for rejection sampling.
            show_progress_bars: Whether to show sampling progress monitor.
            reject_outside_prior: If True (default), rejection sampling is used to
                ensure samples lie within the prior support. If False, samples are drawn
                directly from the neural density estimator without rejection, which is
                faster but may include samples outside the prior support.
            max_sampling_time: Optional maximum allowed sampling time in seconds.
                If exceeded, sampling is aborted and a RuntimeError is raised. Only
                applies when `reject_outside_prior=True`.
            return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
                return the samples collected so far instead of raising a RuntimeError.
                A warning will be issued. Only applies when `reject_outside_prior=True`.

        Returns:
            Samples from the posteriors of shape (*sample_shape, B, *input_shape)
        """
        num_samples = torch.Size(sample_shape).numel()
        condition_shape = self.posterior_estimator.condition_shape
        x = reshape_to_batch_event(x, event_shape=condition_shape)
        num_xos = x.shape[0]

        # throw warning if num_x * num_samples is too large
        if num_xos * num_samples > 2**21:  # 2 million-ish
            warnings.warn(
                f"Note that for batched sampling, the direct posterior sampling "
                f"generates {num_xos} * {num_samples} = {num_xos * num_samples} "
                "samples. This can be slow and memory-intensive. Consider "
                "reducing the number of samples or batch size.",
                stacklevel=2,
            )

        max_sampling_batch_size = (
            self.max_sampling_batch_size
            if max_sampling_batch_size is None
            else max_sampling_batch_size
        )

        # Adjust max_sampling_batch_size to avoid excessive memory usage
        if max_sampling_batch_size * num_xos > 100_000:
            capped = max(1, 100_000 // num_xos)
            warnings.warn(
                f"Capping max_sampling_batch_size from {max_sampling_batch_size} "
                f"to {capped} to avoid excessive memory usage.",
                stacklevel=2,
            )
            max_sampling_batch_size = capped

        if reject_outside_prior:
            # Normal rejection behavior.
            samples = rejection.accept_reject_sample(
                proposal=self.posterior_estimator.sample,
                accept_reject_fn=lambda theta: within_support(self.prior, theta),
                num_samples=num_samples,
                show_progress_bars=show_progress_bars,
                max_sampling_batch_size=max_sampling_batch_size,
                proposal_sampling_kwargs={"condition": x},
                alternative_method="build_posterior(..., sample_with='mcmc')",
                max_sampling_time=max_sampling_time,
                return_partial_on_timeout=return_partial_on_timeout,
            )[0]
        else:
            # Bypass rejection sampling entirely.
            samples = self.posterior_estimator.sample(
                torch.Size([num_samples]),
                condition=x,
            )
            warn_if_outside_prior_support(self.prior, samples)

        return samples

    def log_prob(
        self,
        theta: Tensor,
        x: Optional[Tensor] = None,
        norm_posterior: bool = True,
        track_gradients: bool = False,
        leakage_correction_params: Optional[dict] = None,
    ) -> Tensor:
        r"""Returns the log-probability of the posterior $p(\theta|x)$.

        Args:
            theta: Parameters $\theta$.
            norm_posterior: Whether to enforce a normalized posterior density.
                Renormalization of the posterior is useful when some
                probability falls out or leaks out of the prescribed prior support.
                The normalizing factor is calculated via rejection sampling, so if you
                need speedier but unnormalized log posterior estimates set here
                `norm_posterior=False`. The returned log posterior is set to
                -∞ outside of the prior support regardless of this setting.
            track_gradients: Whether the returned tensor supports tracking gradients.
                This can be helpful for e.g. sensitivity analysis, but increases memory
                consumption.
            leakage_correction_params: A `dict` of keyword arguments to override the
                default values of `leakage_correction()`. Possible options are:
                `num_rejection_samples`, `force_update`, `show_progress_bars`, and
                `rejection_sampling_batch_size`.
                These parameters only have an effect if `norm_posterior=True`.

        Returns:
            `(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
            support of the prior, -∞ (corresponding to 0 probability) outside.
        """
        x = self._x_else_default_x(x)

        theta = ensure_theta_batched(torch.as_tensor(theta))
        theta_density_estimator = reshape_to_sample_batch_event(
            theta, theta.shape[1:], leading_is_sample=True
        )
        x_density_estimator = reshape_to_batch_event(
            x, event_shape=self.posterior_estimator.condition_shape
        )
        if x_density_estimator.shape[0] > 1:
            raise ValueError(
                ".log_prob() supports only `batchsize == 1`. If you intend "
                "to evaluate given multiple observations, use `.log_prob_batched()`. "
                "If you intend to evaluate given i.i.d. observations, set up the "
                "posterior density estimator with an appropriate permutation "
                "invariant embedding net."
            )

        self.posterior_estimator.eval()

        with torch.set_grad_enabled(track_gradients):
            # Evaluate on device, move back to cpu for comparison with prior.
            unnorm_log_prob = self.posterior_estimator.log_prob(
                theta_density_estimator, condition=x_density_estimator
            )
            # `log_prob` supports only a single observation (i.e. `batchsize==1`).
            # We now remove this additional dimension.
            unnorm_log_prob = unnorm_log_prob.squeeze(dim=1)

            # Force probability to be zero outside prior support.
            in_prior_support = within_support(self.prior, theta)

            masked_log_prob = torch.where(
                in_prior_support,
                unnorm_log_prob,
                torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
            )

            if leakage_correction_params is None:
                leakage_correction_params = dict()  # use defaults
            log_factor = (
                log(self.leakage_correction(x=x, **leakage_correction_params))
                if norm_posterior
                else 0
            )

            return masked_log_prob - log_factor

    def log_prob_batched(
        self,
        theta: Tensor,
        x: Tensor,
        norm_posterior: bool = True,
        track_gradients: bool = False,
        leakage_correction_params: Optional[dict] = None,
    ) -> Tensor:
        """Given a batch of observations [x_1, ..., x_B] and a batch of parameters \
            [$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \
            of the posteriors $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \
            (i.e. vectorized) manner.

        Args:
            theta: Batch of parameters $\theta$ of shape \
                `(*sample_shape, batch_dim, *theta_shape)`.
            x: Batch of observations $x$ of shape \
                `(batch_dim, *condition_shape)`.
            norm_posterior: Whether to enforce a normalized posterior density.
                Renormalization of the posterior is useful when some
                probability falls out or leaks out of the prescribed prior support.
                The normalizing factor is calculated via rejection sampling, so if you
                need speedier but unnormalized log posterior estimates set here
                `norm_posterior=False`. The returned log posterior is set to
                -∞ outside of the prior support regardless of this setting.
            track_gradients: Whether the returned tensor supports tracking gradients.
                This can be helpful for e.g. sensitivity analysis, but increases memory
                consumption.
            leakage_correction_params: A `dict` of keyword arguments to override the
                default values of `leakage_correction()`. Possible options are:
                `num_rejection_samples`, `force_update`, `show_progress_bars`, and
                `rejection_sampling_batch_size`.
                These parameters only have an effect if `norm_posterior=True`.

        Returns:
            `(len(θ), B)`-shaped log posterior probability $\\log p(\theta|x)$\\ for θ \
            in the support of the prior, -∞ (corresponding to 0 probability) outside.
        """

        theta = ensure_theta_batched(torch.as_tensor(theta))
        event_shape = self.posterior_estimator.input_shape
        theta_density_estimator = reshape_to_sample_batch_event(
            theta, event_shape, leading_is_sample=True
        )
        x_density_estimator = reshape_to_batch_event(
            x, event_shape=self.posterior_estimator.condition_shape
        )

        self.posterior_estimator.eval()

        with torch.set_grad_enabled(track_gradients):
            # Evaluate on device, move back to cpu for comparison with prior.
            unnorm_log_prob = self.posterior_estimator.log_prob(
                theta_density_estimator, condition=x_density_estimator
            )

            # Force probability to be zero outside prior support.
            in_prior_support = within_support(self.prior, theta)

            masked_log_prob = torch.where(
                in_prior_support,
                unnorm_log_prob,
                torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
            )

            if leakage_correction_params is None:
                leakage_correction_params = dict()  # use defaults
            log_factor = (
                log(self.leakage_correction(x=x, **leakage_correction_params))
                if norm_posterior
                else 0
            )

            return masked_log_prob - log_factor

    @torch.no_grad()
    def leakage_correction(
        self,
        x: Tensor,
        num_rejection_samples: int = 10_000,
        force_update: bool = False,
        show_progress_bars: bool = False,
        rejection_sampling_batch_size: int = 10_000,
    ) -> Tensor:
        r"""Return leakage correction factor for a leaky posterior density estimate.

        The factor is estimated from the acceptance probability during rejection
        sampling from the posterior.

        This is to avoid re-estimating the acceptance probability from scratch
        whenever `log_prob` is called and `norm_posterior=True`. Here, it
        is estimated only once for `self.default_x` and saved for later. We
        re-evaluate only whenever a new `x` is passed.

        Arguments:
            num_rejection_samples: Number of samples used to estimate correction factor.
            show_progress_bars: Whether to show a progress bar during sampling.
            rejection_sampling_batch_size: Batch size for rejection sampling.

        Returns:
            Saved or newly-estimated correction factor (as a scalar `Tensor`).
        """

        def acceptance_at(x: Tensor) -> Tensor:
            # [1:] to remove batch-dimension for `reshape_to_batch_event`.
            return rejection.accept_reject_sample(
                proposal=self.posterior_estimator.sample,
                accept_reject_fn=lambda theta: within_support(self.prior, theta),
                num_samples=num_rejection_samples,
                show_progress_bars=show_progress_bars,
                sample_for_correction_factor=True,
                max_sampling_batch_size=rejection_sampling_batch_size,
                proposal_sampling_kwargs={
                    "condition": reshape_to_batch_event(
                        x, event_shape=self.posterior_estimator.condition_shape
                    )
                },
            )[1]

        # Check if the provided x matches the default x (short-circuit on identity).
        is_new_x = self.default_x is None or (
            x is not self.default_x and (x != self.default_x).any()
        )

        not_saved_at_default_x = self._leakage_density_correction_factor is None

        if is_new_x:  # Calculate at x; don't save.
            return acceptance_at(x)
        elif not_saved_at_default_x or force_update:  # Calculate at default_x; save.
            assert self.default_x is not None
            self._leakage_density_correction_factor = acceptance_at(self.default_x)

        return self._leakage_density_correction_factor  # type: ignore

    def map(
        self,
        x: Optional[Tensor] = None,
        num_iter: int = 1_000,
        num_to_optimize: int = 100,
        learning_rate: float = 0.01,
        init_method: Union[str, Tensor] = "posterior",
        num_init_samples: int = 1_000,
        save_best_every: int = 10,
        show_progress_bars: bool = False,
        force_update: bool = False,
    ) -> Tensor:
        r"""Returns the maximum-a-posteriori estimate (MAP).

        The method can be interrupted (Ctrl-C) when the user sees that the
        log-probability converges. The best estimate will be saved in `self._map` and
        can be accessed with `self.map()`. The MAP is obtained by running gradient
        ascent from a given number of starting positions (samples from the posterior
        with the highest log-probability). After the optimization is done, we select the
        parameter set that has the highest log-probability after the optimization.

        Warning: The default values used by this function are not well-tested. They
        might require hand-tuning for the problem at hand.

        For developers: if the prior is a `BoxUniform`, we carry out the optimization
        in unbounded space and transform the result back into bounded space.

        Args:
            x: Deprecated - use `.set_default_x()` prior to `.map()`.
            num_iter: Number of optimization steps that the algorithm takes
                to find the MAP.
            learning_rate: Learning rate of the optimizer.
            init_method: How to select the starting parameters for the optimization. If
                it is a string, it can be either [`posterior`, `prior`], which samples
                the respective distribution `num_init_samples` times. If it is a
                tensor, the tensor will be used as init locations.
            num_init_samples: Draw this number of samples from the posterior and
                evaluate the log-probability of all of them.
            num_to_optimize: From the drawn `num_init_samples`, use the
                `num_to_optimize` with highest log-probability as the initial points
                for the optimization.
            save_best_every: The best log-probability is computed, saved in the
                `map`-attribute, and printed every `save_best_every`-th iteration.
                Computing the best log-probability creates a significant overhead
                (thus, the default is `10`.)
            show_progress_bars: Whether to show a progressbar during sampling from the
                posterior.
            force_update: Whether to re-calculate the MAP when x is unchanged and
                have a cached value.
            log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
                {'norm_posterior': True} for SNPE.

        Returns:
            The MAP estimate.
        """
        return super().map(
            x=x,
            num_iter=num_iter,
            num_to_optimize=num_to_optimize,
            learning_rate=learning_rate,
            init_method=init_method,
            num_init_samples=num_init_samples,
            save_best_every=save_best_every,
            show_progress_bars=show_progress_bars,
            force_update=force_update,
        )

__init__(posterior_estimator, prior, max_sampling_batch_size=10000, device=None, x_shape=None, enable_transform=True)

Parameters:

Name Type Description Default
prior Distribution

Prior distribution with .log_prob() and .sample().

required
posterior_estimator ConditionalDensityEstimator

The trained neural posterior.

required
max_sampling_batch_size int

Batchsize of samples being drawn from the proposal at every iteration.

10000
device Optional[Union[str, device]]

Training device, e.g., “cpu”, “cuda” or “cuda:0”. If None, potential_fn.device is used.

None
x_shape Optional[Size]

Deprecated, should not be passed.

None
enable_transform bool

Whether to transform parameters to unconstrained space during MAP optimization. When False, an identity transform will be returned for theta_transform.

True
Source code in sbi/inference/posteriors/direct_posterior.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(
    self,
    posterior_estimator: ConditionalDensityEstimator,
    prior: Distribution,
    max_sampling_batch_size: int = 10_000,
    device: Optional[Union[str, torch.device]] = None,
    x_shape: Optional[torch.Size] = None,
    enable_transform: bool = True,
):
    """
    Args:
        prior: Prior distribution with `.log_prob()` and `.sample()`.
        posterior_estimator: The trained neural posterior.
        max_sampling_batch_size: Batchsize of samples being drawn from
            the proposal at every iteration.
        device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
            `potential_fn.device` is used.
        x_shape: Deprecated, should not be passed.
        enable_transform: Whether to transform parameters to unconstrained space
            during MAP optimization. When False, an identity transform will be
            returned for `theta_transform`.
    """
    # Because `DirectPosterior` does not take the `potential_fn` as input, it
    # builds it itself. The `potential_fn` and `theta_transform` are used only for
    # obtaining the MAP.
    check_prior(prior)
    self.enable_transform = enable_transform
    self.x_shape = x_shape
    potential_fn, theta_transform = posterior_estimator_based_potential(
        posterior_estimator,
        prior,
        x_o=None,
        enable_transform=enable_transform,
    )

    super().__init__(
        potential_fn=potential_fn,
        theta_transform=theta_transform,
        device=device,
        x_shape=x_shape,
    )

    self.device = device
    self.prior = prior
    self.posterior_estimator = posterior_estimator

    self.max_sampling_batch_size = max_sampling_batch_size
    self._leakage_density_correction_factor = None

    self._purpose = """It samples the posterior network and rejects samples that
        lie outside of the prior bounds."""

leakage_correction(x, num_rejection_samples=10000, force_update=False, show_progress_bars=False, rejection_sampling_batch_size=10000)

Return leakage correction factor for a leaky posterior density estimate.

The factor is estimated from the acceptance probability during rejection sampling from the posterior.

This is to avoid re-estimating the acceptance probability from scratch whenever log_prob is called and norm_posterior=True. Here, it is estimated only once for self.default_x and saved for later. We re-evaluate only whenever a new x is passed.

Parameters:

Name Type Description Default
num_rejection_samples int

Number of samples used to estimate correction factor.

10000
show_progress_bars bool

Whether to show a progress bar during sampling.

False
rejection_sampling_batch_size int

Batch size for rejection sampling.

10000

Returns:

Type Description
Tensor

Saved or newly-estimated correction factor (as a scalar Tensor).

Source code in sbi/inference/posteriors/direct_posterior.py
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
@torch.no_grad()
def leakage_correction(
    self,
    x: Tensor,
    num_rejection_samples: int = 10_000,
    force_update: bool = False,
    show_progress_bars: bool = False,
    rejection_sampling_batch_size: int = 10_000,
) -> Tensor:
    r"""Return leakage correction factor for a leaky posterior density estimate.

    The factor is estimated from the acceptance probability during rejection
    sampling from the posterior.

    This is to avoid re-estimating the acceptance probability from scratch
    whenever `log_prob` is called and `norm_posterior=True`. Here, it
    is estimated only once for `self.default_x` and saved for later. We
    re-evaluate only whenever a new `x` is passed.

    Arguments:
        num_rejection_samples: Number of samples used to estimate correction factor.
        show_progress_bars: Whether to show a progress bar during sampling.
        rejection_sampling_batch_size: Batch size for rejection sampling.

    Returns:
        Saved or newly-estimated correction factor (as a scalar `Tensor`).
    """

    def acceptance_at(x: Tensor) -> Tensor:
        # [1:] to remove batch-dimension for `reshape_to_batch_event`.
        return rejection.accept_reject_sample(
            proposal=self.posterior_estimator.sample,
            accept_reject_fn=lambda theta: within_support(self.prior, theta),
            num_samples=num_rejection_samples,
            show_progress_bars=show_progress_bars,
            sample_for_correction_factor=True,
            max_sampling_batch_size=rejection_sampling_batch_size,
            proposal_sampling_kwargs={
                "condition": reshape_to_batch_event(
                    x, event_shape=self.posterior_estimator.condition_shape
                )
            },
        )[1]

    # Check if the provided x matches the default x (short-circuit on identity).
    is_new_x = self.default_x is None or (
        x is not self.default_x and (x != self.default_x).any()
    )

    not_saved_at_default_x = self._leakage_density_correction_factor is None

    if is_new_x:  # Calculate at x; don't save.
        return acceptance_at(x)
    elif not_saved_at_default_x or force_update:  # Calculate at default_x; save.
        assert self.default_x is not None
        self._leakage_density_correction_factor = acceptance_at(self.default_x)

    return self._leakage_density_correction_factor  # type: ignore

log_prob(theta, x=None, norm_posterior=True, track_gradients=False, leakage_correction_params=None)

Returns the log-probability of the posterior \(p(\theta|x)\).

Parameters:

Name Type Description Default
theta Tensor

Parameters \(\theta\).

required
norm_posterior bool

Whether to enforce a normalized posterior density. Renormalization of the posterior is useful when some probability falls out or leaks out of the prescribed prior support. The normalizing factor is calculated via rejection sampling, so if you need speedier but unnormalized log posterior estimates set here norm_posterior=False. The returned log posterior is set to -∞ outside of the prior support regardless of this setting.

True
track_gradients bool

Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis, but increases memory consumption.

False
leakage_correction_params Optional[dict]

A dict of keyword arguments to override the default values of leakage_correction(). Possible options are: num_rejection_samples, force_update, show_progress_bars, and rejection_sampling_batch_size. These parameters only have an effect if norm_posterior=True.

None

Returns:

Type Description
Tensor

(len(θ),)-shaped log posterior probability \(\log p(\theta|x)\) for θ in the

Tensor

support of the prior, -∞ (corresponding to 0 probability) outside.

Source code in sbi/inference/posteriors/direct_posterior.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def log_prob(
    self,
    theta: Tensor,
    x: Optional[Tensor] = None,
    norm_posterior: bool = True,
    track_gradients: bool = False,
    leakage_correction_params: Optional[dict] = None,
) -> Tensor:
    r"""Returns the log-probability of the posterior $p(\theta|x)$.

    Args:
        theta: Parameters $\theta$.
        norm_posterior: Whether to enforce a normalized posterior density.
            Renormalization of the posterior is useful when some
            probability falls out or leaks out of the prescribed prior support.
            The normalizing factor is calculated via rejection sampling, so if you
            need speedier but unnormalized log posterior estimates set here
            `norm_posterior=False`. The returned log posterior is set to
            -∞ outside of the prior support regardless of this setting.
        track_gradients: Whether the returned tensor supports tracking gradients.
            This can be helpful for e.g. sensitivity analysis, but increases memory
            consumption.
        leakage_correction_params: A `dict` of keyword arguments to override the
            default values of `leakage_correction()`. Possible options are:
            `num_rejection_samples`, `force_update`, `show_progress_bars`, and
            `rejection_sampling_batch_size`.
            These parameters only have an effect if `norm_posterior=True`.

    Returns:
        `(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
        support of the prior, -∞ (corresponding to 0 probability) outside.
    """
    x = self._x_else_default_x(x)

    theta = ensure_theta_batched(torch.as_tensor(theta))
    theta_density_estimator = reshape_to_sample_batch_event(
        theta, theta.shape[1:], leading_is_sample=True
    )
    x_density_estimator = reshape_to_batch_event(
        x, event_shape=self.posterior_estimator.condition_shape
    )
    if x_density_estimator.shape[0] > 1:
        raise ValueError(
            ".log_prob() supports only `batchsize == 1`. If you intend "
            "to evaluate given multiple observations, use `.log_prob_batched()`. "
            "If you intend to evaluate given i.i.d. observations, set up the "
            "posterior density estimator with an appropriate permutation "
            "invariant embedding net."
        )

    self.posterior_estimator.eval()

    with torch.set_grad_enabled(track_gradients):
        # Evaluate on device, move back to cpu for comparison with prior.
        unnorm_log_prob = self.posterior_estimator.log_prob(
            theta_density_estimator, condition=x_density_estimator
        )
        # `log_prob` supports only a single observation (i.e. `batchsize==1`).
        # We now remove this additional dimension.
        unnorm_log_prob = unnorm_log_prob.squeeze(dim=1)

        # Force probability to be zero outside prior support.
        in_prior_support = within_support(self.prior, theta)

        masked_log_prob = torch.where(
            in_prior_support,
            unnorm_log_prob,
            torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
        )

        if leakage_correction_params is None:
            leakage_correction_params = dict()  # use defaults
        log_factor = (
            log(self.leakage_correction(x=x, **leakage_correction_params))
            if norm_posterior
            else 0
        )

        return masked_log_prob - log_factor

log_prob_batched(theta, x, norm_posterior=True, track_gradients=False, leakage_correction_params=None)

Given a batch of observations [x_1, …, x_B] and a batch of parameters [$ heta_1$,…, $ heta_B$] this function evalautes the log-probabilities of the posteriors \(p( heta_1|x_1)\), …, \(p( heta_B|x_B)\) in a batched (i.e. vectorized) manner.

Parameters:

Name Type Description Default
theta Tensor

Batch of parameters $ heta$ of shape (*sample_shape, batch_dim, *theta_shape).

required
x Tensor

Batch of observations \(x\) of shape (batch_dim, *condition_shape).

required
norm_posterior bool

Whether to enforce a normalized posterior density. Renormalization of the posterior is useful when some probability falls out or leaks out of the prescribed prior support. The normalizing factor is calculated via rejection sampling, so if you need speedier but unnormalized log posterior estimates set here norm_posterior=False. The returned log posterior is set to -∞ outside of the prior support regardless of this setting.

True
track_gradients bool

Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis, but increases memory consumption.

False
leakage_correction_params Optional[dict]

A dict of keyword arguments to override the default values of leakage_correction(). Possible options are: num_rejection_samples, force_update, show_progress_bars, and rejection_sampling_batch_size. These parameters only have an effect if norm_posterior=True.

None

Returns:

Type Description
Tensor

(len(θ), B)-shaped log posterior probability \(\log p( heta|x)\) for θ in the support of the prior, -∞ (corresponding to 0 probability) outside.

Source code in sbi/inference/posteriors/direct_posterior.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def log_prob_batched(
    self,
    theta: Tensor,
    x: Tensor,
    norm_posterior: bool = True,
    track_gradients: bool = False,
    leakage_correction_params: Optional[dict] = None,
) -> Tensor:
    """Given a batch of observations [x_1, ..., x_B] and a batch of parameters \
        [$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \
        of the posteriors $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \
        (i.e. vectorized) manner.

    Args:
        theta: Batch of parameters $\theta$ of shape \
            `(*sample_shape, batch_dim, *theta_shape)`.
        x: Batch of observations $x$ of shape \
            `(batch_dim, *condition_shape)`.
        norm_posterior: Whether to enforce a normalized posterior density.
            Renormalization of the posterior is useful when some
            probability falls out or leaks out of the prescribed prior support.
            The normalizing factor is calculated via rejection sampling, so if you
            need speedier but unnormalized log posterior estimates set here
            `norm_posterior=False`. The returned log posterior is set to
            -∞ outside of the prior support regardless of this setting.
        track_gradients: Whether the returned tensor supports tracking gradients.
            This can be helpful for e.g. sensitivity analysis, but increases memory
            consumption.
        leakage_correction_params: A `dict` of keyword arguments to override the
            default values of `leakage_correction()`. Possible options are:
            `num_rejection_samples`, `force_update`, `show_progress_bars`, and
            `rejection_sampling_batch_size`.
            These parameters only have an effect if `norm_posterior=True`.

    Returns:
        `(len(θ), B)`-shaped log posterior probability $\\log p(\theta|x)$\\ for θ \
        in the support of the prior, -∞ (corresponding to 0 probability) outside.
    """

    theta = ensure_theta_batched(torch.as_tensor(theta))
    event_shape = self.posterior_estimator.input_shape
    theta_density_estimator = reshape_to_sample_batch_event(
        theta, event_shape, leading_is_sample=True
    )
    x_density_estimator = reshape_to_batch_event(
        x, event_shape=self.posterior_estimator.condition_shape
    )

    self.posterior_estimator.eval()

    with torch.set_grad_enabled(track_gradients):
        # Evaluate on device, move back to cpu for comparison with prior.
        unnorm_log_prob = self.posterior_estimator.log_prob(
            theta_density_estimator, condition=x_density_estimator
        )

        # Force probability to be zero outside prior support.
        in_prior_support = within_support(self.prior, theta)

        masked_log_prob = torch.where(
            in_prior_support,
            unnorm_log_prob,
            torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
        )

        if leakage_correction_params is None:
            leakage_correction_params = dict()  # use defaults
        log_factor = (
            log(self.leakage_correction(x=x, **leakage_correction_params))
            if norm_posterior
            else 0
        )

        return masked_log_prob - log_factor

map(x=None, num_iter=1000, num_to_optimize=100, learning_rate=0.01, init_method='posterior', num_init_samples=1000, save_best_every=10, show_progress_bars=False, force_update=False)

Returns the maximum-a-posteriori estimate (MAP).

The method can be interrupted (Ctrl-C) when the user sees that the log-probability converges. The best estimate will be saved in self._map and can be accessed with self.map(). The MAP is obtained by running gradient ascent from a given number of starting positions (samples from the posterior with the highest log-probability). After the optimization is done, we select the parameter set that has the highest log-probability after the optimization.

Warning: The default values used by this function are not well-tested. They might require hand-tuning for the problem at hand.

For developers: if the prior is a BoxUniform, we carry out the optimization in unbounded space and transform the result back into bounded space.

Parameters:

Name Type Description Default
x Optional[Tensor]

Deprecated - use .set_default_x() prior to .map().

None
num_iter int

Number of optimization steps that the algorithm takes to find the MAP.

1000
learning_rate float

Learning rate of the optimizer.

0.01
init_method Union[str, Tensor]

How to select the starting parameters for the optimization. If it is a string, it can be either [posterior, prior], which samples the respective distribution num_init_samples times. If it is a tensor, the tensor will be used as init locations.

'posterior'
num_init_samples int

Draw this number of samples from the posterior and evaluate the log-probability of all of them.

1000
num_to_optimize int

From the drawn num_init_samples, use the num_to_optimize with highest log-probability as the initial points for the optimization.

100
save_best_every int

The best log-probability is computed, saved in the map-attribute, and printed every save_best_every-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is 10.)

10
show_progress_bars bool

Whether to show a progressbar during sampling from the posterior.

False
force_update bool

Whether to re-calculate the MAP when x is unchanged and have a cached value.

False
log_prob_kwargs

Will be empty for SNLE and SNRE. Will contain {‘norm_posterior’: True} for SNPE.

required

Returns:

Type Description
Tensor

The MAP estimate.

Source code in sbi/inference/posteriors/direct_posterior.py
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def map(
    self,
    x: Optional[Tensor] = None,
    num_iter: int = 1_000,
    num_to_optimize: int = 100,
    learning_rate: float = 0.01,
    init_method: Union[str, Tensor] = "posterior",
    num_init_samples: int = 1_000,
    save_best_every: int = 10,
    show_progress_bars: bool = False,
    force_update: bool = False,
) -> Tensor:
    r"""Returns the maximum-a-posteriori estimate (MAP).

    The method can be interrupted (Ctrl-C) when the user sees that the
    log-probability converges. The best estimate will be saved in `self._map` and
    can be accessed with `self.map()`. The MAP is obtained by running gradient
    ascent from a given number of starting positions (samples from the posterior
    with the highest log-probability). After the optimization is done, we select the
    parameter set that has the highest log-probability after the optimization.

    Warning: The default values used by this function are not well-tested. They
    might require hand-tuning for the problem at hand.

    For developers: if the prior is a `BoxUniform`, we carry out the optimization
    in unbounded space and transform the result back into bounded space.

    Args:
        x: Deprecated - use `.set_default_x()` prior to `.map()`.
        num_iter: Number of optimization steps that the algorithm takes
            to find the MAP.
        learning_rate: Learning rate of the optimizer.
        init_method: How to select the starting parameters for the optimization. If
            it is a string, it can be either [`posterior`, `prior`], which samples
            the respective distribution `num_init_samples` times. If it is a
            tensor, the tensor will be used as init locations.
        num_init_samples: Draw this number of samples from the posterior and
            evaluate the log-probability of all of them.
        num_to_optimize: From the drawn `num_init_samples`, use the
            `num_to_optimize` with highest log-probability as the initial points
            for the optimization.
        save_best_every: The best log-probability is computed, saved in the
            `map`-attribute, and printed every `save_best_every`-th iteration.
            Computing the best log-probability creates a significant overhead
            (thus, the default is `10`.)
        show_progress_bars: Whether to show a progressbar during sampling from the
            posterior.
        force_update: Whether to re-calculate the MAP when x is unchanged and
            have a cached value.
        log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
            {'norm_posterior': True} for SNPE.

    Returns:
        The MAP estimate.
    """
    return super().map(
        x=x,
        num_iter=num_iter,
        num_to_optimize=num_to_optimize,
        learning_rate=learning_rate,
        init_method=init_method,
        num_init_samples=num_init_samples,
        save_best_every=save_best_every,
        show_progress_bars=show_progress_bars,
        force_update=force_update,
    )

sample(sample_shape=torch.Size(), x=None, max_sampling_batch_size=10000, show_progress_bars=True, reject_outside_prior=True, max_sampling_time=None, return_partial_on_timeout=False)

Draw samples from the approximate posterior distribution \(p(\theta|x)\).

Parameters:

Name Type Description Default
sample_shape Shape

Desired shape of samples that are drawn from posterior. If sample_shape is multidimensional we simply draw sample_shape.numel() samples and then reshape into the desired shape.

Size()
x Optional[Tensor]

Conditioning observation \(x_o\). If not provided, uses the default x set via .set_default_x().

None
max_sampling_batch_size int

Maximum batch size for rejection sampling.

10000
show_progress_bars bool

Whether to show sampling progress monitor.

True
reject_outside_prior bool

If True (default), rejection sampling is used to ensure samples lie within the prior support. If False, samples are drawn directly from the neural density estimator without rejection, which is faster but may include samples outside the prior support.

True
max_sampling_time Optional[float]

Optional maximum allowed sampling time in seconds. If exceeded, sampling is aborted and a RuntimeError is raised. Only applies when reject_outside_prior=True (no effect otherwise since direct sampling is fast).

None
return_partial_on_timeout bool

If True and max_sampling_time is exceeded, return the samples collected so far instead of raising a RuntimeError. A warning will be issued. Only applies when reject_outside_prior=True (default).

False
Source code in sbi/inference/posteriors/direct_posterior.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def sample(
    self,
    sample_shape: Shape = torch.Size(),
    x: Optional[Tensor] = None,
    max_sampling_batch_size: int = 10_000,
    show_progress_bars: bool = True,
    reject_outside_prior: bool = True,
    max_sampling_time: Optional[float] = None,
    return_partial_on_timeout: bool = False,
) -> Tensor:
    r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.

    Args:
        sample_shape: Desired shape of samples that are drawn from posterior. If
            sample_shape is multidimensional we simply draw `sample_shape.numel()`
            samples and then reshape into the desired shape.
        x: Conditioning observation $x_o$. If not provided, uses the default `x`
            set via `.set_default_x()`.
        max_sampling_batch_size: Maximum batch size for rejection sampling.
        show_progress_bars: Whether to show sampling progress monitor.
        reject_outside_prior: If True (default), rejection sampling is used to
            ensure samples lie within the prior support. If False, samples are drawn
            directly from the neural density estimator without rejection, which is
            faster but may include samples outside the prior support.
        max_sampling_time: Optional maximum allowed sampling time in seconds.
            If exceeded, sampling is aborted and a RuntimeError is raised. Only
            applies when `reject_outside_prior=True` (no effect otherwise since
            direct sampling is fast).
        return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
            return the samples collected so far instead of raising a RuntimeError.
            A warning will be issued. Only applies when `reject_outside_prior=True`
            (default).
    """
    num_samples = torch.Size(sample_shape).numel()
    x = self._x_else_default_x(x)
    x = reshape_to_batch_event(
        x, event_shape=self.posterior_estimator.condition_shape
    )
    if x.shape[0] > 1:
        raise ValueError(
            ".sample() supports only `batchsize == 1`. If you intend "
            "to sample multiple observations, use `.sample_batched()`. "
            "If you intend to sample i.i.d. observations, set up the "
            "posterior density estimator with an appropriate permutation "
            "invariant embedding net."
        )

    max_sampling_batch_size = (
        self.max_sampling_batch_size
        if max_sampling_batch_size is None
        else max_sampling_batch_size
    )

    if reject_outside_prior:
        # Normal rejection behavior.
        samples = rejection.accept_reject_sample(
            proposal=self.posterior_estimator.sample,
            accept_reject_fn=lambda theta: within_support(self.prior, theta),
            num_samples=num_samples,
            show_progress_bars=show_progress_bars,
            max_sampling_batch_size=max_sampling_batch_size,
            proposal_sampling_kwargs={"condition": x},
            alternative_method="build_posterior(..., sample_with='mcmc')",
            max_sampling_time=max_sampling_time,
            return_partial_on_timeout=return_partial_on_timeout,
        )[0]
    else:
        # Bypass rejection sampling entirely.
        samples = self.posterior_estimator.sample(
            torch.Size([num_samples]),
            condition=x,
        )
        warn_if_outside_prior_support(self.prior, samples[:, 0])

    return samples[:, 0]  # Remove batch dimension.

sample_batched(sample_shape, x, max_sampling_batch_size=10000, show_progress_bars=True, reject_outside_prior=True, max_sampling_time=None, return_partial_on_timeout=False)

Draw samples from the posteriors for a batch of different xs.

Given a batch of observations [x_1, ..., x_B], this method samples from posteriors \(p(\theta|x_1), \ldots, p(\theta|x_B)\) in a vectorized manner.

Parameters:

Name Type Description Default
sample_shape Shape

Desired shape of samples that are drawn from the posterior given every observation.

required
x Tensor

A batch of observations, of shape (batch_dim, event_shape_x). batch_dim corresponds to the number of observations to be drawn.

required
max_sampling_batch_size int

Maximum batch size for rejection sampling.

10000
show_progress_bars bool

Whether to show sampling progress monitor.

True
reject_outside_prior bool

If True (default), rejection sampling is used to ensure samples lie within the prior support. If False, samples are drawn directly from the neural density estimator without rejection, which is faster but may include samples outside the prior support.

True
max_sampling_time Optional[float]

Optional maximum allowed sampling time in seconds. If exceeded, sampling is aborted and a RuntimeError is raised. Only applies when reject_outside_prior=True.

None
return_partial_on_timeout bool

If True and max_sampling_time is exceeded, return the samples collected so far instead of raising a RuntimeError. A warning will be issued. Only applies when reject_outside_prior=True.

False

Returns:

Type Description
Tensor

Samples from the posteriors of shape (*sample_shape, B, *input_shape)

Source code in sbi/inference/posteriors/direct_posterior.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def sample_batched(
    self,
    sample_shape: Shape,
    x: Tensor,
    max_sampling_batch_size: int = 10_000,
    show_progress_bars: bool = True,
    reject_outside_prior: bool = True,
    max_sampling_time: Optional[float] = None,
    return_partial_on_timeout: bool = False,
) -> Tensor:
    r"""Draw samples from the posteriors for a batch of different xs.

    Given a batch of observations `[x_1, ..., x_B]`, this method samples from
    posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.

    Args:
        sample_shape: Desired shape of samples that are drawn from the posterior
            given every observation.
        x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
            `batch_dim` corresponds to the number of observations to be drawn.
        max_sampling_batch_size: Maximum batch size for rejection sampling.
        show_progress_bars: Whether to show sampling progress monitor.
        reject_outside_prior: If True (default), rejection sampling is used to
            ensure samples lie within the prior support. If False, samples are drawn
            directly from the neural density estimator without rejection, which is
            faster but may include samples outside the prior support.
        max_sampling_time: Optional maximum allowed sampling time in seconds.
            If exceeded, sampling is aborted and a RuntimeError is raised. Only
            applies when `reject_outside_prior=True`.
        return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
            return the samples collected so far instead of raising a RuntimeError.
            A warning will be issued. Only applies when `reject_outside_prior=True`.

    Returns:
        Samples from the posteriors of shape (*sample_shape, B, *input_shape)
    """
    num_samples = torch.Size(sample_shape).numel()
    condition_shape = self.posterior_estimator.condition_shape
    x = reshape_to_batch_event(x, event_shape=condition_shape)
    num_xos = x.shape[0]

    # throw warning if num_x * num_samples is too large
    if num_xos * num_samples > 2**21:  # 2 million-ish
        warnings.warn(
            f"Note that for batched sampling, the direct posterior sampling "
            f"generates {num_xos} * {num_samples} = {num_xos * num_samples} "
            "samples. This can be slow and memory-intensive. Consider "
            "reducing the number of samples or batch size.",
            stacklevel=2,
        )

    max_sampling_batch_size = (
        self.max_sampling_batch_size
        if max_sampling_batch_size is None
        else max_sampling_batch_size
    )

    # Adjust max_sampling_batch_size to avoid excessive memory usage
    if max_sampling_batch_size * num_xos > 100_000:
        capped = max(1, 100_000 // num_xos)
        warnings.warn(
            f"Capping max_sampling_batch_size from {max_sampling_batch_size} "
            f"to {capped} to avoid excessive memory usage.",
            stacklevel=2,
        )
        max_sampling_batch_size = capped

    if reject_outside_prior:
        # Normal rejection behavior.
        samples = rejection.accept_reject_sample(
            proposal=self.posterior_estimator.sample,
            accept_reject_fn=lambda theta: within_support(self.prior, theta),
            num_samples=num_samples,
            show_progress_bars=show_progress_bars,
            max_sampling_batch_size=max_sampling_batch_size,
            proposal_sampling_kwargs={"condition": x},
            alternative_method="build_posterior(..., sample_with='mcmc')",
            max_sampling_time=max_sampling_time,
            return_partial_on_timeout=return_partial_on_timeout,
        )[0]
    else:
        # Bypass rejection sampling entirely.
        samples = self.posterior_estimator.sample(
            torch.Size([num_samples]),
            condition=x,
        )
        warn_if_outside_prior_support(self.prior, samples)

    return samples

to(device)

Move posterior_estimator, prior and x_o to device.

Changes the device attribute, reinstanciates the posterior, and resets the default x.

Parameters:

Name Type Description Default
device Union[str, device]

device where to move the posterior to.

required
Source code in sbi/inference/posteriors/direct_posterior.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def to(self, device: Union[str, torch.device]) -> None:
    """Move posterior_estimator, prior and x_o to device.

    Changes the device attribute, reinstanciates the
    posterior, and resets the default x.

    Args:
        device: device where to move the posterior to.
    """
    self.device = device
    if hasattr(self.prior, "to"):
        self.prior.to(device)  # type: ignore
    else:
        raise ValueError("""Prior has no attribute to(device).""")
    if hasattr(self.posterior_estimator, "to"):
        self.posterior_estimator.to(device)
    else:
        raise ValueError("""Posterior estimator has no attribute to(device).""")

    potential_fn, theta_transform = posterior_estimator_based_potential(
        self.posterior_estimator,
        self.prior,
        x_o=None,
        enable_transform=self.enable_transform,
    )
    x_o = None
    if hasattr(self, "_x") and (self._x is not None):
        x_o = self._x.to(device)

    super().__init__(
        potential_fn=potential_fn,
        theta_transform=theta_transform,
        device=device,
        x_shape=self.x_shape,
    )
    # super().__init__ erases the self._x, so we need to set it again
    if x_o is not None:
        self.set_default_x(x_o)

ImportanceSamplingPosterior

Bases: NeuralPosterior

Provides importance sampling to sample from the posterior.

SNLE or SNRE train neural networks to approximate the likelihood(-ratios). ImportanceSamplingPosterior allows to estimate the posterior log-probability by estimating the normlalization constant with importance sampling. It also allows to perform importance sampling (with .sample()) and to draw approximate samples with sampling-importance-resampling (SIR) (with .sir_sample())

Source code in sbi/inference/posteriors/importance_posterior.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
class ImportanceSamplingPosterior(NeuralPosterior):
    r"""Provides importance sampling to sample from the posterior.

    SNLE or SNRE train neural networks to approximate the likelihood(-ratios).
    `ImportanceSamplingPosterior` allows to estimate the posterior log-probability by
    estimating the normlalization constant with importance sampling. It also allows to
    perform importance sampling (with `.sample()`) and to draw approximate samples with
    sampling-importance-resampling (SIR) (with `.sir_sample()`)
    """

    def __init__(
        self,
        potential_fn: Union[Callable, BasePotential],
        proposal: Any,
        theta_transform: Optional[TorchTransform] = None,
        method: Literal["sir", "importance"] = "sir",
        oversampling_factor: int = 32,
        max_sampling_batch_size: int = 10_000,
        device: Optional[Union[str, torch.device]] = None,
        x_shape: Optional[torch.Size] = None,
    ):
        """
        Args:
            potential_fn: The potential function from which to draw samples. Must be a
                `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
            proposal: The proposal distribution.
            theta_transform: Transformation that is applied to parameters. Is not used
                during but only when calling `.map()`.
            method: Either of [`sir`|`importance`]. This sets the behavior of the
                `.sample()` method. With `sir`, approximate posterior samples are
                generated with sampling importance resampling (SIR). With
                `importance`, the `.sample()` method returns a tuple of samples and
                corresponding importance weights.
            oversampling_factor: Number of proposed samples from which only one is
                selected based on its importance weight.
            max_sampling_batch_size: The batch size of samples being drawn from the
                proposal at every iteration.
            device: Device on which to sample, e.g., "cpu", "cuda" or "cuda:0". If
                None, `potential_fn.device` is used.
            x_shape: Deprecated, should not be passed.
        """
        super().__init__(
            potential_fn,
            theta_transform=theta_transform,
            device=device,
            x_shape=x_shape,
        )

        self.proposal = proposal
        self._normalization_constant = None
        self.method = method
        self.theta_transform = theta_transform

        self.oversampling_factor = oversampling_factor
        self.max_sampling_batch_size = max_sampling_batch_size

        self._purpose = (
            "It provides sampling-importance resampling (SIR) to .sample() from the "
            "posterior and can evaluate the _unnormalized_ posterior density with "
            ".log_prob()."
        )
        self.x_shape = x_shape

    def to(self, device: Union[str, torch.device]) -> None:
        """
        Move the potential, the proposal and x_o to a new device.

        It also reinstantiates the posterior with the new device.

        Args:
            device: Device on which to move the posterior to.
        """
        self.device = device
        self.potential_fn.to(device)  # type: ignore
        self.proposal.to(device)
        x_o = None
        if hasattr(self, "_x") and (self._x is not None):
            x_o = self._x.to(device)

        self.theta_transform = mcmc_transform(self.proposal, device=device)
        super().__init__(
            self.potential_fn,
            theta_transform=self.theta_transform,
            device=device,
            x_shape=self.x_shape,
        )
        # super().__init__ erases the self._x, so we need to set it again
        if x_o is not None:
            self.set_default_x(x_o)

    def log_prob(
        self,
        theta: Tensor,
        x: Optional[Tensor] = None,
        track_gradients: bool = False,
        normalization_constant_params: Optional[dict] = None,
    ) -> Tensor:
        r"""Returns the log-probability of theta under the posterior.

        The normalization constant is estimated with importance sampling.

        Args:
            theta: Parameters $\theta$.
            track_gradients: Whether the returned tensor supports tracking gradients.
                This can be helpful for e.g. sensitivity analysis, but increases memory
                consumption.
            normalization_constant_params: Parameters passed on to
                `estimate_normalization_constant()`.

        Returns:
            `len($\theta$)`-shaped log-probability.
        """
        x = self._x_else_default_x(x)
        self.potential_fn.set_x(x)

        theta = ensure_theta_batched(torch.as_tensor(theta))

        with torch.set_grad_enabled(track_gradients):
            potential_values = self.potential_fn(
                theta.to(self._device), track_gradients=track_gradients
            )

            if normalization_constant_params is None:
                normalization_constant_params = dict()  # use defaults
            normalization_constant = self.estimate_normalization_constant(
                x, **normalization_constant_params
            )

            return (potential_values - torch.log(normalization_constant)).to(
                self._device
            )

    @torch.no_grad()
    def estimate_normalization_constant(
        self, x: Tensor, num_samples: int = 10_000, force_update: bool = False
    ) -> Tensor:
        """Returns the normalization constant via importance sampling.

        Args:
            num_samples: Number of importance samples used for the estimate.
            force_update: Whether to re-calculate the normlization constant when x is
                unchanged and have a cached value.
        """
        # Check if the provided x matches the default x (short-circuit on identity).
        is_new_x = self.default_x is None or (
            x is not self.default_x and (x != self.default_x).any()
        )

        not_saved_at_default_x = self._normalization_constant is None

        if is_new_x:  # Calculate at x; don't save.
            _, log_importance_weights = importance_sample(
                self.potential_fn,
                proposal=self.proposal,
                num_samples=num_samples,
            )
            return torch.mean(torch.exp(log_importance_weights))
        elif not_saved_at_default_x or force_update:  # Calculate at default_x; save.
            assert self.default_x is not None
            _, log_importance_weights = importance_sample(
                self.potential_fn,
                proposal=self.proposal,
                num_samples=num_samples,
            )
            self._normalization_constant = torch.mean(torch.exp(log_importance_weights))

        return self._normalization_constant.to(self._device)  # type: ignore

    def sample(
        self,
        sample_shape: Shape = torch.Size(),
        x: Optional[Tensor] = None,
        method: Optional[str] = None,
        oversampling_factor: int = 32,
        max_sampling_batch_size: int = 10_000,
        show_progress_bars: bool = False,
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        """Draw samples from the approximate posterior distribution $p(\theta|x)$.

        Args:
            sample_shape: Shape of samples that are drawn from posterior.
            x: Conditioning observation $x_o$. If not provided, uses the default `x`
                set via `.set_default_x()`.
            method: Either of [`sir`|`importance`]. This sets the behavior of the
                `.sample()` method. With `sir`, approximate posterior samples are
                generated with sampling importance resampling (SIR). With
                `importance`, the `.sample()` method returns a tuple of samples and
                corresponding importance weights.
            oversampling_factor: Number of proposed samples from which only one is
                selected based on its importance weight.
            max_sampling_batch_size: The batch size of samples being drawn from the
                proposal at every iteration.
            show_progress_bars: Whether to show a progressbar during sampling.
        """

        method = self.method if method is None else method

        self.potential_fn.set_x(self._x_else_default_x(x))

        if method == "sir":
            return self._sir_sample(
                sample_shape,
                oversampling_factor=oversampling_factor,
                max_sampling_batch_size=max_sampling_batch_size,
                show_progress_bars=show_progress_bars,
            )
        elif method == "importance":
            return self._importance_sample(sample_shape)
        else:
            raise NameError

    def sample_batched(
        self,
        sample_shape: Shape,
        x: Tensor,
        max_sampling_batch_size: int = 10000,
        show_progress_bars: bool = True,
    ) -> Tensor:
        raise NotImplementedError(
            "Batched sampling is not implemented for ImportanceSamplingPosterior. \
           Alternatively you can use `sample` in a loop \
           [posterior.sample(theta, x_o) for x_o in x]."
        )

    def _importance_sample(
        self,
        sample_shape: Shape = torch.Size(),
        show_progress_bars: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        """Returns samples from the proposal and log of their importance weights.

        Args:
            sample_shape: Desired shape of samples that are drawn from posterior.
            show_progress_bars: Whether to show sampling progress monitor.

        Returns:
            Samples and logarithm of corresponding importance weights.
        """
        num_samples = torch.Size(sample_shape).numel()
        samples, log_importance_weights = importance_sample(
            self.potential_fn,
            proposal=self.proposal,
            num_samples=num_samples,
            show_progress_bars=show_progress_bars,
        )

        samples = samples.reshape((*sample_shape, -1)).to(self._device)
        return samples, log_importance_weights.to(self._device)

    def _sir_sample(
        self,
        sample_shape: Shape = torch.Size(),
        oversampling_factor: int = 32,
        max_sampling_batch_size: int = 10_000,
        show_progress_bars: bool = False,
    ):
        r"""Returns approximate samples from posterior $p(\theta|x)$ via SIR.

        Args:
            sample_shape: Desired shape of samples that are drawn from posterior. If
                sample_shape is multidimensional we simply draw `sample_shape.numel()`
                samples and then reshape into the desired shape.
            oversampling_factor: Number of proposed samples from which only one is
                selected based on its importance weight.
            max_sampling_batch_size: The batch size of samples being drawn from
                the proposal at every iteration.
            show_progress_bars: Whether to show sampling progress monitor.

        Returns:
            Samples from posterior.
        """
        # Replace arguments that were not passed with their default.
        oversampling_factor = (
            self.oversampling_factor
            if oversampling_factor is None
            else oversampling_factor
        )
        max_sampling_batch_size = (
            self.max_sampling_batch_size
            if max_sampling_batch_size is None
            else max_sampling_batch_size
        )

        num_samples = torch.Size(sample_shape).numel()
        samples = sampling_importance_resampling(
            self.potential_fn,
            proposal=self.proposal,
            num_samples=num_samples,
            num_candidate_samples=oversampling_factor,
            show_progress_bars=show_progress_bars,
            max_sampling_batch_size=max_sampling_batch_size,
            device=self._device,
        )

        return samples.reshape((*sample_shape, -1)).to(self._device)

    def map(
        self,
        x: Optional[Tensor] = None,
        num_iter: int = 1_000,
        num_to_optimize: int = 100,
        learning_rate: float = 0.01,
        init_method: Union[str, Tensor] = "proposal",
        num_init_samples: int = 1_000,
        save_best_every: int = 10,
        show_progress_bars: bool = False,
        force_update: bool = False,
    ) -> Tensor:
        r"""Returns the maximum-a-posteriori estimate (MAP).

        The method can be interrupted (Ctrl-C) when the user sees that the
        log-probability converges. The best estimate will be saved in `self._map` and
        can be accessed with `self.map()`. The MAP is obtained by running gradient
        ascent from a given number of starting positions (samples from the posterior
        with the highest log-probability). After the optimization is done, we select the
        parameter set that has the highest log-probability after the optimization.

        Warning: The default values used by this function are not well-tested. They
        might require hand-tuning for the problem at hand.

        For developers: if the prior is a `BoxUniform`, we carry out the optimization
        in unbounded space and transform the result back into bounded space.

        Args:
            x: Deprecated - use `.set_default_x()` prior to `.map()`.
            num_iter: Number of optimization steps that the algorithm takes
                to find the MAP.
            learning_rate: Learning rate of the optimizer.
            init_method: How to select the starting parameters for the optimization. If
                it is a string, it can be either [`posterior`, `prior`], which samples
                the respective distribution `num_init_samples` times. If it is a
                tensor, the tensor will be used as init locations.
            num_init_samples: Draw this number of samples from the posterior and
                evaluate the log-probability of all of them.
            num_to_optimize: From the drawn `num_init_samples`, use the
                `num_to_optimize` with highest log-probability as the initial points
                for the optimization.
            save_best_every: The best log-probability is computed, saved in the
                `map`-attribute, and printed every `save_best_every`-th iteration.
                Computing the best log-probability creates a significant overhead
                (thus, the default is `10`.)
            show_progress_bars: Whether to show a progressbar during sampling from the
                posterior.
            force_update: Whether to re-calculate the MAP when x is unchanged and
                have a cached value.
            log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
                {'norm_posterior': True} for SNPE.

        Returns:
            The MAP estimate.
        """
        return super().map(
            x=x,
            num_iter=num_iter,
            num_to_optimize=num_to_optimize,
            learning_rate=learning_rate,
            init_method=init_method,
            num_init_samples=num_init_samples,
            save_best_every=save_best_every,
            show_progress_bars=show_progress_bars,
            force_update=force_update,
        )

__init__(potential_fn, proposal, theta_transform=None, method='sir', oversampling_factor=32, max_sampling_batch_size=10000, device=None, x_shape=None)

Parameters:

Name Type Description Default
potential_fn Union[Callable, BasePotential]

The potential function from which to draw samples. Must be a BasePotential or a Callable which takes theta and x_o as inputs.

required
proposal Any

The proposal distribution.

required
theta_transform Optional[TorchTransform]

Transformation that is applied to parameters. Is not used during but only when calling .map().

None
method Literal['sir', 'importance']

Either of [sir|importance]. This sets the behavior of the .sample() method. With sir, approximate posterior samples are generated with sampling importance resampling (SIR). With importance, the .sample() method returns a tuple of samples and corresponding importance weights.

'sir'
oversampling_factor int

Number of proposed samples from which only one is selected based on its importance weight.

32
max_sampling_batch_size int

The batch size of samples being drawn from the proposal at every iteration.

10000
device Optional[Union[str, device]]

Device on which to sample, e.g., “cpu”, “cuda” or “cuda:0”. If None, potential_fn.device is used.

None
x_shape Optional[Size]

Deprecated, should not be passed.

None
Source code in sbi/inference/posteriors/importance_posterior.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    potential_fn: Union[Callable, BasePotential],
    proposal: Any,
    theta_transform: Optional[TorchTransform] = None,
    method: Literal["sir", "importance"] = "sir",
    oversampling_factor: int = 32,
    max_sampling_batch_size: int = 10_000,
    device: Optional[Union[str, torch.device]] = None,
    x_shape: Optional[torch.Size] = None,
):
    """
    Args:
        potential_fn: The potential function from which to draw samples. Must be a
            `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
        proposal: The proposal distribution.
        theta_transform: Transformation that is applied to parameters. Is not used
            during but only when calling `.map()`.
        method: Either of [`sir`|`importance`]. This sets the behavior of the
            `.sample()` method. With `sir`, approximate posterior samples are
            generated with sampling importance resampling (SIR). With
            `importance`, the `.sample()` method returns a tuple of samples and
            corresponding importance weights.
        oversampling_factor: Number of proposed samples from which only one is
            selected based on its importance weight.
        max_sampling_batch_size: The batch size of samples being drawn from the
            proposal at every iteration.
        device: Device on which to sample, e.g., "cpu", "cuda" or "cuda:0". If
            None, `potential_fn.device` is used.
        x_shape: Deprecated, should not be passed.
    """
    super().__init__(
        potential_fn,
        theta_transform=theta_transform,
        device=device,
        x_shape=x_shape,
    )

    self.proposal = proposal
    self._normalization_constant = None
    self.method = method
    self.theta_transform = theta_transform

    self.oversampling_factor = oversampling_factor
    self.max_sampling_batch_size = max_sampling_batch_size

    self._purpose = (
        "It provides sampling-importance resampling (SIR) to .sample() from the "
        "posterior and can evaluate the _unnormalized_ posterior density with "
        ".log_prob()."
    )
    self.x_shape = x_shape

estimate_normalization_constant(x, num_samples=10000, force_update=False)

Returns the normalization constant via importance sampling.

Parameters:

Name Type Description Default
num_samples int

Number of importance samples used for the estimate.

10000
force_update bool

Whether to re-calculate the normlization constant when x is unchanged and have a cached value.

False
Source code in sbi/inference/posteriors/importance_posterior.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
@torch.no_grad()
def estimate_normalization_constant(
    self, x: Tensor, num_samples: int = 10_000, force_update: bool = False
) -> Tensor:
    """Returns the normalization constant via importance sampling.

    Args:
        num_samples: Number of importance samples used for the estimate.
        force_update: Whether to re-calculate the normlization constant when x is
            unchanged and have a cached value.
    """
    # Check if the provided x matches the default x (short-circuit on identity).
    is_new_x = self.default_x is None or (
        x is not self.default_x and (x != self.default_x).any()
    )

    not_saved_at_default_x = self._normalization_constant is None

    if is_new_x:  # Calculate at x; don't save.
        _, log_importance_weights = importance_sample(
            self.potential_fn,
            proposal=self.proposal,
            num_samples=num_samples,
        )
        return torch.mean(torch.exp(log_importance_weights))
    elif not_saved_at_default_x or force_update:  # Calculate at default_x; save.
        assert self.default_x is not None
        _, log_importance_weights = importance_sample(
            self.potential_fn,
            proposal=self.proposal,
            num_samples=num_samples,
        )
        self._normalization_constant = torch.mean(torch.exp(log_importance_weights))

    return self._normalization_constant.to(self._device)  # type: ignore

log_prob(theta, x=None, track_gradients=False, normalization_constant_params=None)

Returns the log-probability of theta under the posterior.

The normalization constant is estimated with importance sampling.

Parameters:

Name Type Description Default
theta Tensor

Parameters \(\theta\).

required
track_gradients bool

Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis, but increases memory consumption.

False
normalization_constant_params Optional[dict]

Parameters passed on to estimate_normalization_constant().

None

Returns:

Type Description
Tensor

len($\theta$)-shaped log-probability.

Source code in sbi/inference/posteriors/importance_posterior.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def log_prob(
    self,
    theta: Tensor,
    x: Optional[Tensor] = None,
    track_gradients: bool = False,
    normalization_constant_params: Optional[dict] = None,
) -> Tensor:
    r"""Returns the log-probability of theta under the posterior.

    The normalization constant is estimated with importance sampling.

    Args:
        theta: Parameters $\theta$.
        track_gradients: Whether the returned tensor supports tracking gradients.
            This can be helpful for e.g. sensitivity analysis, but increases memory
            consumption.
        normalization_constant_params: Parameters passed on to
            `estimate_normalization_constant()`.

    Returns:
        `len($\theta$)`-shaped log-probability.
    """
    x = self._x_else_default_x(x)
    self.potential_fn.set_x(x)

    theta = ensure_theta_batched(torch.as_tensor(theta))

    with torch.set_grad_enabled(track_gradients):
        potential_values = self.potential_fn(
            theta.to(self._device), track_gradients=track_gradients
        )

        if normalization_constant_params is None:
            normalization_constant_params = dict()  # use defaults
        normalization_constant = self.estimate_normalization_constant(
            x, **normalization_constant_params
        )

        return (potential_values - torch.log(normalization_constant)).to(
            self._device
        )

map(x=None, num_iter=1000, num_to_optimize=100, learning_rate=0.01, init_method='proposal', num_init_samples=1000, save_best_every=10, show_progress_bars=False, force_update=False)

Returns the maximum-a-posteriori estimate (MAP).

The method can be interrupted (Ctrl-C) when the user sees that the log-probability converges. The best estimate will be saved in self._map and can be accessed with self.map(). The MAP is obtained by running gradient ascent from a given number of starting positions (samples from the posterior with the highest log-probability). After the optimization is done, we select the parameter set that has the highest log-probability after the optimization.

Warning: The default values used by this function are not well-tested. They might require hand-tuning for the problem at hand.

For developers: if the prior is a BoxUniform, we carry out the optimization in unbounded space and transform the result back into bounded space.

Parameters:

Name Type Description Default
x Optional[Tensor]

Deprecated - use .set_default_x() prior to .map().

None
num_iter int

Number of optimization steps that the algorithm takes to find the MAP.

1000
learning_rate float

Learning rate of the optimizer.

0.01
init_method Union[str, Tensor]

How to select the starting parameters for the optimization. If it is a string, it can be either [posterior, prior], which samples the respective distribution num_init_samples times. If it is a tensor, the tensor will be used as init locations.

'proposal'
num_init_samples int

Draw this number of samples from the posterior and evaluate the log-probability of all of them.

1000
num_to_optimize int

From the drawn num_init_samples, use the num_to_optimize with highest log-probability as the initial points for the optimization.

100
save_best_every int

The best log-probability is computed, saved in the map-attribute, and printed every save_best_every-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is 10.)

10
show_progress_bars bool

Whether to show a progressbar during sampling from the posterior.

False
force_update bool

Whether to re-calculate the MAP when x is unchanged and have a cached value.

False
log_prob_kwargs

Will be empty for SNLE and SNRE. Will contain {‘norm_posterior’: True} for SNPE.

required

Returns:

Type Description
Tensor

The MAP estimate.

Source code in sbi/inference/posteriors/importance_posterior.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def map(
    self,
    x: Optional[Tensor] = None,
    num_iter: int = 1_000,
    num_to_optimize: int = 100,
    learning_rate: float = 0.01,
    init_method: Union[str, Tensor] = "proposal",
    num_init_samples: int = 1_000,
    save_best_every: int = 10,
    show_progress_bars: bool = False,
    force_update: bool = False,
) -> Tensor:
    r"""Returns the maximum-a-posteriori estimate (MAP).

    The method can be interrupted (Ctrl-C) when the user sees that the
    log-probability converges. The best estimate will be saved in `self._map` and
    can be accessed with `self.map()`. The MAP is obtained by running gradient
    ascent from a given number of starting positions (samples from the posterior
    with the highest log-probability). After the optimization is done, we select the
    parameter set that has the highest log-probability after the optimization.

    Warning: The default values used by this function are not well-tested. They
    might require hand-tuning for the problem at hand.

    For developers: if the prior is a `BoxUniform`, we carry out the optimization
    in unbounded space and transform the result back into bounded space.

    Args:
        x: Deprecated - use `.set_default_x()` prior to `.map()`.
        num_iter: Number of optimization steps that the algorithm takes
            to find the MAP.
        learning_rate: Learning rate of the optimizer.
        init_method: How to select the starting parameters for the optimization. If
            it is a string, it can be either [`posterior`, `prior`], which samples
            the respective distribution `num_init_samples` times. If it is a
            tensor, the tensor will be used as init locations.
        num_init_samples: Draw this number of samples from the posterior and
            evaluate the log-probability of all of them.
        num_to_optimize: From the drawn `num_init_samples`, use the
            `num_to_optimize` with highest log-probability as the initial points
            for the optimization.
        save_best_every: The best log-probability is computed, saved in the
            `map`-attribute, and printed every `save_best_every`-th iteration.
            Computing the best log-probability creates a significant overhead
            (thus, the default is `10`.)
        show_progress_bars: Whether to show a progressbar during sampling from the
            posterior.
        force_update: Whether to re-calculate the MAP when x is unchanged and
            have a cached value.
        log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
            {'norm_posterior': True} for SNPE.

    Returns:
        The MAP estimate.
    """
    return super().map(
        x=x,
        num_iter=num_iter,
        num_to_optimize=num_to_optimize,
        learning_rate=learning_rate,
        init_method=init_method,
        num_init_samples=num_init_samples,
        save_best_every=save_best_every,
        show_progress_bars=show_progress_bars,
        force_update=force_update,
    )

sample(sample_shape=torch.Size(), x=None, method=None, oversampling_factor=32, max_sampling_batch_size=10000, show_progress_bars=False)

Draw samples from the approximate posterior distribution \(p( heta|x)\).

Parameters:

Name Type Description Default
sample_shape Shape

Shape of samples that are drawn from posterior.

Size()
x Optional[Tensor]

Conditioning observation \(x_o\). If not provided, uses the default x set via .set_default_x().

None
method Optional[str]

Either of [sir|importance]. This sets the behavior of the .sample() method. With sir, approximate posterior samples are generated with sampling importance resampling (SIR). With importance, the .sample() method returns a tuple of samples and corresponding importance weights.

None
oversampling_factor int

Number of proposed samples from which only one is selected based on its importance weight.

32
max_sampling_batch_size int

The batch size of samples being drawn from the proposal at every iteration.

10000
show_progress_bars bool

Whether to show a progressbar during sampling.

False
Source code in sbi/inference/posteriors/importance_posterior.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def sample(
    self,
    sample_shape: Shape = torch.Size(),
    x: Optional[Tensor] = None,
    method: Optional[str] = None,
    oversampling_factor: int = 32,
    max_sampling_batch_size: int = 10_000,
    show_progress_bars: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
    """Draw samples from the approximate posterior distribution $p(\theta|x)$.

    Args:
        sample_shape: Shape of samples that are drawn from posterior.
        x: Conditioning observation $x_o$. If not provided, uses the default `x`
            set via `.set_default_x()`.
        method: Either of [`sir`|`importance`]. This sets the behavior of the
            `.sample()` method. With `sir`, approximate posterior samples are
            generated with sampling importance resampling (SIR). With
            `importance`, the `.sample()` method returns a tuple of samples and
            corresponding importance weights.
        oversampling_factor: Number of proposed samples from which only one is
            selected based on its importance weight.
        max_sampling_batch_size: The batch size of samples being drawn from the
            proposal at every iteration.
        show_progress_bars: Whether to show a progressbar during sampling.
    """

    method = self.method if method is None else method

    self.potential_fn.set_x(self._x_else_default_x(x))

    if method == "sir":
        return self._sir_sample(
            sample_shape,
            oversampling_factor=oversampling_factor,
            max_sampling_batch_size=max_sampling_batch_size,
            show_progress_bars=show_progress_bars,
        )
    elif method == "importance":
        return self._importance_sample(sample_shape)
    else:
        raise NameError

to(device)

Move the potential, the proposal and x_o to a new device.

It also reinstantiates the posterior with the new device.

Parameters:

Name Type Description Default
device Union[str, device]

Device on which to move the posterior to.

required
Source code in sbi/inference/posteriors/importance_posterior.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def to(self, device: Union[str, torch.device]) -> None:
    """
    Move the potential, the proposal and x_o to a new device.

    It also reinstantiates the posterior with the new device.

    Args:
        device: Device on which to move the posterior to.
    """
    self.device = device
    self.potential_fn.to(device)  # type: ignore
    self.proposal.to(device)
    x_o = None
    if hasattr(self, "_x") and (self._x is not None):
        x_o = self._x.to(device)

    self.theta_transform = mcmc_transform(self.proposal, device=device)
    super().__init__(
        self.potential_fn,
        theta_transform=self.theta_transform,
        device=device,
        x_shape=self.x_shape,
    )
    # super().__init__ erases the self._x, so we need to set it again
    if x_o is not None:
        self.set_default_x(x_o)

MCMCPosterior

Bases: NeuralPosterior

Provides MCMC to sample from the posterior.

SNLE or SNRE train neural networks to approximate the likelihood(-ratios). MCMCPosterior allows to sample from the posterior with MCMC.

Source code in sbi/inference/posteriors/mcmc_posterior.py
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
class MCMCPosterior(NeuralPosterior):
    r"""Provides MCMC to sample from the posterior.

    SNLE or SNRE train neural networks to approximate the likelihood(-ratios).
    `MCMCPosterior` allows to sample from the posterior with MCMC.
    """

    def __init__(
        self,
        potential_fn: Union[Callable, BasePotential],
        proposal: Any,
        theta_transform: Optional[TorchTransform] = None,
        method: Literal[
            "slice_np",
            "slice_np_vectorized",
            "hmc_pyro",
            "nuts_pyro",
            "slice_pymc",
            "hmc_pymc",
            "nuts_pymc",
        ] = "slice_np_vectorized",
        thin: int = -1,
        warmup_steps: int = 200,
        num_chains: int = 20,
        init_strategy: Literal["proposal", "sir", "resample"] = "resample",
        init_strategy_parameters: Optional[Dict[str, Any]] = None,
        init_strategy_num_candidates: Optional[int] = None,
        num_workers: int = 1,
        mp_context: Literal["fork", "spawn"] = "spawn",
        device: Optional[Union[str, torch.device]] = None,
        x_shape: Optional[torch.Size] = None,
    ):
        """
        Args:
            potential_fn: The potential function from which to draw samples. Must be a
                `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
            proposal: Proposal distribution that is used to initialize the MCMC chain.
            theta_transform: Transformation that will be applied during sampling.
                Allows to perform MCMC in unconstrained space.
            method: Method used for MCMC sampling, one of `slice_np`,
                `slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
                `hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
                numpy implementation of slice sampling. `slice_np_vectorized` is
                identical to `slice_np`, but if `num_chains>1`, the chains are
                vectorized for `slice_np_vectorized` whereas they are run sequentially
                for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
                likewise the samplers ending on `_pymc` are using PyMC.
            thin: The thinning factor for the chain, default 1 (no thinning).
            warmup_steps: The initial number of samples to discard.
            num_chains: The number of chains. Should generally be at most
                `num_workers - 1`.
            init_strategy: The initialisation strategy for chains; `proposal` will draw
                init locations from `proposal`, whereas `sir` will use Sequential-
                Importance-Resampling (SIR). SIR initially samples
                `init_strategy_num_candidates` from the `proposal`, evaluates all of
                them under the `potential_fn` and `proposal`, and then resamples the
                initial locations with weights proportional to `exp(potential_fn -
                proposal.log_prob`. `resample` is the same as `sir` but
                uses `exp(potential_fn)` as weights.
            init_strategy_parameters: Dictionary of keyword arguments passed to the
                init strategy, e.g., for `init_strategy=sir` this could be
                `num_candidate_samples`, i.e., the number of candidates to find init
                locations (internal default is `1000`), or `device`.
            init_strategy_num_candidates: Number of candidates to find init
                 locations in `init_strategy=sir` (deprecated, use
                 init_strategy_parameters instead).
            num_workers: number of cpu cores used to parallelize mcmc
            mp_context: Multiprocessing start method, either `"fork"` or `"spawn"`
                (default), used by Pyro and PyMC samplers. `"fork"` can be significantly
                faster than `"spawn"` but is only supported on POSIX-based systems
                (e.g. Linux and macOS, not Windows).
            device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
                `potential_fn.device` is used.
            x_shape: Deprecated, should not be passed.
        """
        if method == "slice":
            warn(
                "The Pyro-based slice sampler is deprecated, and the method `slice` "
                "has been changed to `slice_np`, i.e., the custom "
                "numpy-based slice sampler.",
                DeprecationWarning,
                stacklevel=2,
            )
            method = "slice_np"

        thin = _process_thin_default(thin)

        super().__init__(
            potential_fn,
            theta_transform=theta_transform,
            device=device,
            x_shape=x_shape,
        )

        self.proposal = proposal
        self.method = method
        self.thin = thin
        self.warmup_steps = warmup_steps
        self.num_chains = num_chains
        self.init_strategy = init_strategy
        self.init_strategy_parameters = init_strategy_parameters or {}
        self.num_workers = num_workers
        self.mp_context = mp_context
        self._posterior_sampler = None
        # Hardcode parameter name to reduce clutter kwargs.
        self.param_name = "theta"
        self.x_shape = x_shape

        if init_strategy_num_candidates is not None:
            warn(
                "Passing `init_strategy_num_candidates` is deprecated as of sbi "
                "v0.19.0. Instead, use e.g., `init_strategy_parameters "
                f"={'num_candidate_samples': 1000}`",
                stacklevel=2,
            )
            self.init_strategy_parameters["num_candidate_samples"] = (
                init_strategy_num_candidates
            )

        self.potential_ = self._prepare_potential(method)

        self._purpose = (
            "It provides MCMC to .sample() from the posterior and "
            "can evaluate the _unnormalized_ posterior density with .log_prob()."
        )

    def to(self, device: Union[str, torch.device]) -> None:
        """Moves potential_fn, proposal, x_o and theta_transform to the

        specified device. Reinstantiates the posterior and resets the default x_o.

        Args:
            device: Device to move the posterior to.
        """
        self.device = device
        self.potential_fn.to(device)  # type: ignore
        self.proposal.to(device)
        x_o = None
        if hasattr(self, "_x") and (self._x is not None):
            x_o = self._x.to(device)

        self.theta_transform = mcmc_transform(self.proposal, device=device)

        super().__init__(
            self.potential_fn,
            theta_transform=self.theta_transform,
            device=device,
            x_shape=self.x_shape,
        )
        # super().__init__ erases the self._x, so we need to set it again
        if x_o is not None:
            self.set_default_x(x_o)
        self.potential_ = self._prepare_potential(self.method)

    @property
    def mcmc_method(self) -> str:
        """Returns MCMC method."""
        return self._mcmc_method

    @mcmc_method.setter
    def mcmc_method(self, method: str) -> None:
        """See `set_mcmc_method`."""
        self.set_mcmc_method(method)

    @property
    def posterior_sampler(self):
        """Returns sampler created by `sample`."""
        return self._posterior_sampler

    def set_mcmc_method(self, method: str) -> "NeuralPosterior":
        """Sets sampling method to for MCMC and returns `NeuralPosterior`.

        Args:
            method: Method to use.

        Returns:
            `NeuralPosterior` for chainable calls.
        """
        self._mcmc_method = method
        return self

    def log_prob(
        self, theta: Tensor, x: Optional[Tensor] = None, track_gradients: bool = False
    ) -> Tensor:
        r"""Returns the log-probability of theta under the posterior.

        Args:
            theta: Parameters $\theta$.
            track_gradients: Whether the returned tensor supports tracking gradients.
                This can be helpful for e.g. sensitivity analysis, but increases memory
                consumption.

        Returns:
            `len($\theta$)`-shaped log-probability.
        """
        warn(
            "`.log_prob()` is deprecated for methods that can only evaluate the "
            "log-probability up to a normalizing constant. Use `.potential()` instead.",
            stacklevel=2,
        )
        warn("The log-probability is unnormalized!", stacklevel=2)

        self.potential_fn.set_x(self._x_else_default_x(x))

        theta = ensure_theta_batched(torch.as_tensor(theta))
        return self.potential_fn(
            theta.to(self._device), track_gradients=track_gradients
        )

    def sample(
        self,
        sample_shape: Shape = torch.Size(),
        x: Optional[Tensor] = None,
        method: Optional[str] = None,
        thin: Optional[int] = None,
        warmup_steps: Optional[int] = None,
        num_chains: Optional[int] = None,
        init_strategy: Optional[str] = None,
        init_strategy_parameters: Optional[Dict[str, Any]] = None,
        num_workers: Optional[int] = None,
        mp_context: Optional[str] = None,
        show_progress_bars: bool = True,
    ) -> Tensor:
        r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.

        Args:
            sample_shape: Desired shape of samples that are drawn from posterior. If
                sample_shape is multidimensional we simply draw `sample_shape.numel()`
                samples and then reshape into the desired shape.
            x: Conditioning observation $x_o$. If not provided, uses the default `x`
                set via `.set_default_x()`.
            method: MCMC method to use. One of `slice_np`, `slice_np_vectorized`,
                `hmc_pyro`, `nuts_pyro`, `slice_pymc`, `hmc_pymc`, `nuts_pymc`.
                If not provided, uses the method specified at initialization.
            thin: Thinning factor for the chain. If not provided, uses the value
                specified at initialization.
            warmup_steps: Number of warmup steps to discard. If not provided, uses
                the value specified at initialization.
            num_chains: Number of MCMC chains to run. If not provided, uses the
                value specified at initialization.
            init_strategy: Initialization strategy for chains (`proposal`, `sir`,
                or `resample`). If not provided, uses the value specified at
                initialization.
            init_strategy_parameters: Parameters for the initialization strategy.
                If not provided, uses the value specified at initialization.
            num_workers: Number of CPU cores for parallelization. If not provided,
                uses the value specified at initialization.
            mp_context: Multiprocessing context (`fork` or `spawn`). If not provided,
                uses the value specified at initialization.
            show_progress_bars: Whether to show sampling progress monitor.

        Returns:
            Samples from posterior.
        """

        self.potential_fn.set_x(self._x_else_default_x(x))

        # Replace arguments that were not passed with their default.
        method = self.method if method is None else method
        thin = self.thin if thin is None else thin
        warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps
        num_chains = self.num_chains if num_chains is None else num_chains
        init_strategy = self.init_strategy if init_strategy is None else init_strategy
        num_workers = self.num_workers if num_workers is None else num_workers
        mp_context = self.mp_context if mp_context is None else mp_context
        init_strategy_parameters = (
            self.init_strategy_parameters
            if init_strategy_parameters is None
            else init_strategy_parameters
        )
        self.potential_ = self._prepare_potential(method)  # type: ignore

        initial_params = self._get_initial_params(
            init_strategy,  # type: ignore
            num_chains,  # type: ignore
            num_workers,
            show_progress_bars,
            **init_strategy_parameters,
        )
        num_samples = torch.Size(sample_shape).numel()

        track_gradients = method in ("hmc_pyro", "nuts_pyro", "hmc_pymc", "nuts_pymc")
        with torch.set_grad_enabled(track_gradients):
            if method in ("slice_np", "slice_np_vectorized"):
                transformed_samples = self._slice_np_mcmc(
                    num_samples=num_samples,
                    potential_function=self.potential_,
                    initial_params=initial_params,
                    thin=thin,  # type: ignore
                    warmup_steps=warmup_steps,  # type: ignore
                    vectorized=(method == "slice_np_vectorized"),
                    interchangeable_chains=True,
                    num_workers=num_workers,
                    show_progress_bars=show_progress_bars,
                )
            elif method in ("hmc_pyro", "nuts_pyro"):
                transformed_samples = self._pyro_mcmc(
                    num_samples=num_samples,
                    potential_function=self.potential_,
                    initial_params=initial_params,
                    mcmc_method=method,  # type: ignore
                    thin=thin,  # type: ignore
                    warmup_steps=warmup_steps,  # type: ignore
                    num_chains=num_chains,
                    show_progress_bars=show_progress_bars,
                    mp_context=mp_context,
                )
            elif method in ("hmc_pymc", "nuts_pymc", "slice_pymc"):
                transformed_samples = self._pymc_mcmc(
                    num_samples=num_samples,
                    potential_function=self.potential_,
                    initial_params=initial_params,
                    mcmc_method=method,  # type: ignore
                    thin=thin,  # type: ignore
                    warmup_steps=warmup_steps,  # type: ignore
                    num_chains=num_chains,
                    show_progress_bars=show_progress_bars,
                    mp_context=mp_context,
                )
            else:
                raise NameError(f"The sampling method {method} is not implemented!")

        samples = self.theta_transform.inv(transformed_samples)
        # NOTE: Currently MCMCPosteriors will require a single dimension for the
        # parameter dimension. With recent ConditionalDensity(Ratio) estimators, we
        # can have multiple dimensions for the parameter dimension.
        samples = samples.reshape((*sample_shape, -1))  # type: ignore

        return samples

    def sample_batched(
        self,
        sample_shape: Shape,
        x: Tensor,
        method: Optional[str] = None,
        thin: Optional[int] = None,
        warmup_steps: Optional[int] = None,
        num_chains: Optional[int] = None,
        init_strategy: Optional[str] = None,
        init_strategy_parameters: Optional[Dict[str, Any]] = None,
        num_workers: Optional[int] = None,
        mp_context: Optional[str] = None,
        show_progress_bars: bool = True,
    ) -> Tensor:
        r"""Draw samples from the posteriors for a batch of different xs.

        Given a batch of observations `[x_1, ..., x_B]`, this method samples from
        posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.

        Check the `__init__()` method for a description of all arguments as well as
        their default values.

        Args:
            sample_shape: Desired shape of samples that are drawn from the posterior
                given every observation.
            x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
                `batch_dim` corresponds to the number of observations to be
                drawn.
            method: Method used for MCMC sampling, e.g., "slice_np_vectorized".
            thin: The thinning factor for the chain, default 1 (no thinning).
            warmup_steps: The initial number of samples to discard.
            num_chains: The number of chains used for each `x` passed in the batch.
            init_strategy: The initialisation strategy for chains.
            init_strategy_parameters: Dictionary of keyword arguments passed to
                the init strategy.
            num_workers: number of cpu cores used to parallelize initial
                parameter generation and mcmc sampling.
            mp_context: Multiprocessing start method, either `"fork"` or `"spawn"`
            show_progress_bars: Whether to show sampling progress monitor.

        Returns:
            Samples from the posteriors of shape (*sample_shape, B, *input_shape)
        """

        # Replace arguments that were not passed with their default.
        method = self.method if method is None else method
        thin = self.thin if thin is None else thin
        warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps
        num_chains = self.num_chains if num_chains is None else num_chains
        init_strategy = self.init_strategy if init_strategy is None else init_strategy
        num_workers = self.num_workers if num_workers is None else num_workers
        mp_context = self.mp_context if mp_context is None else mp_context
        init_strategy_parameters = (
            self.init_strategy_parameters
            if init_strategy_parameters is None
            else init_strategy_parameters
        )

        assert method == "slice_np_vectorized", (
            "Batched sampling only supported for vectorized samplers!"
        )

        # warn if num_chains is larger than num requested samples
        if num_chains > torch.Size(sample_shape).numel():
            warnings.warn(
                "The passed number of MCMC chains is larger than the number of "
                f"requested samples: {num_chains} > {torch.Size(sample_shape).numel()},"
                f" resetting it to {torch.Size(sample_shape).numel()}.",
                stacklevel=2,
            )
            num_chains = torch.Size(sample_shape).numel()

        # custom shape handling to make sure to match the batch size of x and theta
        # without unnecessary combinations.
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        batch_size = x.shape[0]

        x = reshape_to_batch_event(x, event_shape=x.shape[1:])

        # For batched sampling, we want `num_chains` for each observation in the batch.
        # Here we repeat the observations ABC -> AAABBBCCC, so that the chains are
        # in the order of the observations.
        x_ = x.repeat_interleave(num_chains, dim=0)

        self.potential_fn.set_x(x_, x_is_iid=False)
        self.potential_ = self._prepare_potential(method)  # type: ignore

        # For each observation in the batch, we have num_chains independent chains.
        num_chains_extended = batch_size * num_chains
        if num_chains_extended > 100:
            warnings.warn(
                "Note that for batched sampling, we use num_chains many chains "
                "for each x in the batch. With the given settings, this results "
                f"in a large number of chains ({num_chains_extended}), which can "
                "be slow and memory-intensive for vectorized MCMC. Consider "
                "reducing the number of chains or batch size.",
                stacklevel=2,
            )
        init_strategy_parameters["num_return_samples"] = num_chains_extended
        initial_params = self._get_initial_params_batched(
            x,
            init_strategy,  # type: ignore
            num_chains,  # type: ignore
            num_workers,
            show_progress_bars,
            **init_strategy_parameters,
        )
        # We need num_samples from each posterior in the batch
        num_samples = torch.Size(sample_shape).numel() * batch_size

        with torch.set_grad_enabled(False):
            transformed_samples = self._slice_np_mcmc(
                num_samples=num_samples,
                potential_function=self.potential_,
                initial_params=initial_params,
                thin=thin,  # type: ignore
                warmup_steps=warmup_steps,  # type: ignore
                vectorized=(method == "slice_np_vectorized"),
                interchangeable_chains=False,
                num_workers=num_workers,
                show_progress_bars=show_progress_bars,
            )

        # (num_chains_extended, samples_per_chain, *input_shape)
        samples_per_chain: Tensor = self.theta_transform.inv(transformed_samples)  # type: ignore
        dim_theta = samples_per_chain.shape[-1]
        # We need to collect samples for each x from the respective chains.
        # However, using samples.reshape(*sample_shape, batch_size, dim_theta)
        # does not combine the samples in the right order, since this mixes
        # samples that belong to different `x`. The following permute is a
        # workaround to reshape the samples in the right order.
        samples_per_x = samples_per_chain.reshape((
            batch_size,
            # We are flattening the sample shape here using -1 because we might have
            # generated more samples than requested (more chains, or multiple of
            # chains not matching sample_shape)
            -1,
            dim_theta,
        )).permute(1, 0, -1)

        # Shape is now (-1, batch_size, dim_theta)
        # We can now select the number of requested samples
        samples = samples_per_x[: torch.Size(sample_shape).numel()]
        # and reshape into (*sample_shape, batch_size, dim_theta)
        samples = samples.reshape((*sample_shape, batch_size, dim_theta))
        return samples

    def _build_mcmc_init_fn(
        self,
        proposal: Any,
        potential_fn: Callable,
        transform: torch_tf.Transform,
        init_strategy: str,
        **kwargs,
    ) -> Callable:
        """Return function that, when called, creates an initial parameter set for MCMC.

        Args:
            proposal: Proposal distribution.
            potential_fn: Potential function that the candidate samples are weighted
                with.
            init_strategy: Specifies the initialization method. Either of
                [`proposal`|`sir`|`resample`|`latest_sample`].
            kwargs: Passed on to init function. This way, init specific keywords can
                be set through `mcmc_parameters`. Unused arguments will be absorbed by
                the intitialization method.

        Returns: Initialization function.
        """
        if init_strategy == "proposal" or init_strategy == "prior":
            if init_strategy == "prior":
                warn(
                    "You set `init_strategy=prior`. As of sbi v0.18.0, this is "
                    "deprecated and it will be removed in a future release. Use "
                    "`init_strategy=proposal` instead.",
                    stacklevel=2,
                )
            return lambda: proposal_init(proposal, transform=transform, **kwargs)
        elif init_strategy == "sir":
            warn(
                "As of sbi v0.19.0, the behavior of the SIR initialization for MCMC "
                "has changed. If you wish to restore the behavior of sbi v0.18.0, set "
                "`init_strategy='resample'.`",
                stacklevel=2,
            )
            return lambda: sir_init(
                proposal, potential_fn, transform=transform, **kwargs
            )
        elif init_strategy == "resample":
            return lambda: resample_given_potential_fn(
                proposal, potential_fn, transform=transform, **kwargs
            )
        elif init_strategy == "latest_sample":
            latest_sample = IterateParameters(self._mcmc_init_params, **kwargs)
            return latest_sample
        else:
            raise NotImplementedError

    def _get_initial_params(
        self,
        init_strategy: str,
        num_chains: int,
        num_workers: int,
        show_progress_bars: bool,
        **kwargs,
    ) -> Tensor:
        """Return initial parameters for MCMC obtained with given init strategy.

        Parallelizes across CPU cores only for resample and SIR.

        Args:
            init_strategy: Specifies the initialization method. Either of
                [`proposal`|`sir`|`resample`|`latest_sample`].
            num_chains: number of MCMC chains, generates initial params for each
            num_workers: number of CPU cores for parallization
            show_progress_bars: whether to show progress bars for SIR init
            kwargs: Passed on to `_build_mcmc_init_fn`.

        Returns:
            Tensor: initial parameters, one for each chain
        """
        # Build init function
        init_fn = self._build_mcmc_init_fn(
            self.proposal,
            self.potential_fn,
            transform=self.theta_transform,
            init_strategy=init_strategy,  # type: ignore
            **kwargs,
        )

        # Parallelize inits for resampling only.
        if num_workers > 1 and (init_strategy == "resample" or init_strategy == "sir"):

            def seeded_init_fn(seed):
                torch.manual_seed(seed)
                return init_fn()

            seeds = torch.randint(high=2**31, size=(num_chains,))

            # Generate initial params parallelized over num_workers.
            initial_params = list(
                tqdm(
                    Parallel(return_as="generator", n_jobs=num_workers)(
                        delayed(seeded_init_fn)(seed) for seed in seeds
                    ),
                    total=len(seeds),
                    desc=f"Generating {num_chains} MCMC inits via {init_strategy} "
                    "strategy",
                    disable=not show_progress_bars,
                )
            )
            initial_params = torch.cat(initial_params)  # type: ignore
        else:
            initial_params = torch.cat(
                [
                    init_fn()
                    for _ in tqdm(
                        range(num_chains),
                        desc=f"Generating {num_chains} MCMC inits via {init_strategy} "
                        "strategy",
                        disable=not show_progress_bars,
                    )
                ]  # type: ignore
            )
        assert initial_params.shape[0] == num_chains, "Initial params shape mismatch."
        return initial_params

    def _get_initial_params_batched(
        self,
        x: torch.Tensor,
        init_strategy: str,
        num_chains_per_x: int,
        num_workers: int,
        show_progress_bars: bool,
        **kwargs,
    ) -> Tensor:
        """Return initial parameters for MCMC for a batch of `x`, obtained with given
           init strategy.

        Parallelizes across CPU cores only for resample and SIR.

        Args:
            x: Batch of observations to create different initial parameters for.
            init_strategy: Specifies the initialization method. Either of
                [`proposal`|`sir`|`resample`|`latest_sample`].
            num_chains_per_x: number of MCMC chains for each x, generates initial params
                for each x
            num_workers: number of CPU cores for parallization
            show_progress_bars: whether to show progress bars for SIR init
            kwargs: Passed on to `_build_mcmc_init_fn`.

        Returns:
            Tensor: initial parameters, one for each chain
        """

        potential_ = deepcopy(self.potential_fn)
        initial_params = []
        init_fn = self._build_mcmc_init_fn(
            self.proposal,
            potential_fn=potential_,
            transform=self.theta_transform,
            init_strategy=init_strategy,  # type: ignore
            **kwargs,
        )
        for xi in x:
            # Build init function
            potential_.set_x(xi)

            # Parallelize inits for resampling or sir.
            if num_workers > 1 and (
                init_strategy == "resample" or init_strategy == "sir"
            ):

                def seeded_init_fn(seed):
                    torch.manual_seed(seed)
                    return init_fn()

                seeds = torch.randint(high=2**31, size=(num_chains_per_x,))

                # Generate initial params parallelized over num_workers.
                initial_params = initial_params + list(
                    tqdm(
                        Parallel(return_as="generator", n_jobs=num_workers)(
                            delayed(seeded_init_fn)(seed) for seed in seeds
                        ),
                        total=len(seeds),
                        desc=f"""Generating {num_chains_per_x} MCMC inits with
                                {num_workers} workers.""",
                        disable=not show_progress_bars,
                    )
                )

            else:
                initial_params = initial_params + [
                    init_fn() for _ in range(num_chains_per_x)
                ]  # type: ignore

        initial_params = torch.cat(initial_params)
        return initial_params

    def _slice_np_mcmc(
        self,
        num_samples: int,
        potential_function: Callable,
        initial_params: Tensor,
        thin: int,
        warmup_steps: int,
        vectorized: bool = False,
        interchangeable_chains=True,
        num_workers: int = 1,
        init_width: Union[float, ndarray] = 0.01,
        show_progress_bars: bool = True,
    ) -> Tensor:
        """Custom implementation of slice sampling using Numpy.

        Args:
            num_samples: Desired number of samples.
            potential_function: A callable **class**.
            initial_params: Initial parameters for MCMC chain.
            thin: Thinning (subsampling) factor, default 1 (no thinning).
            warmup_steps: Initial number of samples to discard.
            vectorized: Whether to use a vectorized implementation of the
                `SliceSampler`.
            interchangeable_chains: Whether chains are interchangeable, i.e., whether
                we can mix samples between chains.
            num_workers: Number of CPU cores to use.
            init_width: Inital width of brackets.
            show_progress_bars: Whether to show a progressbar during sampling;
                can only be turned off for vectorized sampler.

        Returns:
            Tensor of shape (num_samples, shape_of_single_theta).
        """

        num_chains, dim_samples = initial_params.shape

        if not vectorized:
            SliceSamplerMultiChain = SliceSamplerSerial
        else:
            SliceSamplerMultiChain = SliceSamplerVectorized

        def multi_obs_potential(params):
            # Params are of shape (num_chains * num_obs, event).
            all_potentials = potential_function(params)  # Shape: (num_chains, num_obs)
            return all_potentials.flatten()

        posterior_sampler = SliceSamplerMultiChain(
            init_params=tensor2numpy(initial_params),
            log_prob_fn=multi_obs_potential,
            num_chains=num_chains,
            thin=thin,
            verbose=show_progress_bars,
            num_workers=num_workers,
            init_width=init_width,
        )
        warmup_ = warmup_steps * thin
        num_samples_ = ceil((num_samples * thin) / num_chains)
        # Run mcmc including warmup
        samples = posterior_sampler.run(warmup_ + num_samples_)
        samples = samples[:, warmup_steps:, :]  # discard warmup steps
        samples = torch.from_numpy(samples)  # chains x samples x dim

        # Save posterior sampler.
        self._posterior_sampler = posterior_sampler

        # Save sample as potential next init (if init_strategy == 'latest_sample').
        self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples)

        # Update: If chains are interchangeable, return concatenated samples. Otherwise
        # return samples per chain.
        if interchangeable_chains:
            # Collect samples from all chains.
            samples = samples.reshape(-1, dim_samples)[:num_samples]

        return samples.type(torch.float32).to(self._device)

    def _pyro_mcmc(
        self,
        num_samples: int,
        potential_function: Callable,
        initial_params: Tensor,
        mcmc_method: str = "nuts_pyro",
        thin: int = -1,
        warmup_steps: int = 200,
        num_chains: Optional[int] = 1,
        show_progress_bars: bool = True,
        mp_context: str = "spawn",
    ) -> Tensor:
        r"""Return samples obtained using Pyro's HMC or NUTS sampler.

        Args:
            num_samples: Desired number of samples.
            potential_function: A callable **class**. A class, but not a function,
                is picklable for Pyro MCMC to use it across chains in parallel,
                even when the potential function requires evaluating a neural network.
            initial_params: Initial parameters for MCMC chain.
            mcmc_method: Pyro MCMC method to use, either `"hmc_pyro"` or
                `"nuts_pyro"` (default).
            thin: Thinning (subsampling) factor, default 1 (no thinning).
            warmup_steps: Initial number of samples to discard.
            num_chains: Whether to sample in parallel. If None, use all but one CPU.
            show_progress_bars: Whether to show a progressbar during sampling.

        Returns:
            Tensor of shape (num_samples, shape_of_single_theta).
        """
        thin = _process_thin_default(thin)
        num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains
        kernels = dict(hmc_pyro=HMC, nuts_pyro=NUTS)

        sampler = MCMC(
            kernel=kernels[mcmc_method](potential_fn=potential_function),
            num_samples=ceil((thin * num_samples) / num_chains),
            warmup_steps=warmup_steps,
            initial_params={self.param_name: initial_params},
            num_chains=num_chains,
            mp_context=mp_context,
            disable_progbar=not show_progress_bars,
            transforms={},
        )
        sampler.run()
        samples = next(iter(sampler.get_samples().values())).reshape(
            -1,
            initial_params.shape[1],  # .shape[1] = dim of theta
        )

        # Save posterior sampler.
        self._posterior_sampler = sampler

        samples = samples[::thin][:num_samples]

        return samples.detach()

    def _pymc_mcmc(
        self,
        num_samples: int,
        potential_function: Callable,
        initial_params: Tensor,
        mcmc_method: str = "nuts_pymc",
        thin: int = -1,
        warmup_steps: int = 200,
        num_chains: Optional[int] = 1,
        show_progress_bars: bool = True,
        mp_context: str = "spawn",
    ) -> Tensor:
        r"""Return samples obtained using PyMC's HMC, NUTS or slice samplers.

        Args:
            num_samples: Desired number of samples.
            potential_function: A callable **class**. A class, but not a function,
                is picklable for PyMC MCMC to use it across chains in parallel,
                even when the potential function requires evaluating a neural network.
            initial_params: Initial parameters for MCMC chain.
            mcmc_method: mcmc_method: Pyro MCMC method to use, either `"hmc_pymc"` or
                `"slice_pymc"`, or `"nuts_pymc"` (default).
            thin: Thinning (subsampling) factor, default 1 (no thinning).
            warmup_steps: Initial number of samples to discard.
            num_chains: Whether to sample in parallel. If None, use all but one CPU.
            show_progress_bars: Whether to show a progressbar during sampling.

        Returns:
            Tensor of shape (num_samples, shape_of_single_theta).
        """
        thin = _process_thin_default(thin)
        num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains
        steps = dict(slice_pymc="slice", hmc_pymc="hmc", nuts_pymc="nuts")

        sampler = PyMCSampler(
            potential_fn=potential_function,
            step=steps[mcmc_method],
            initvals=tensor2numpy(initial_params),
            draws=ceil((thin * num_samples) / num_chains),
            tune=warmup_steps,
            chains=num_chains,
            mp_ctx=mp_context,
            progressbar=show_progress_bars,
            param_name=self.param_name,
            device=self._device,
        )
        samples = sampler.run()
        samples = torch.from_numpy(samples).to(dtype=torch.float32, device=self._device)
        samples = samples.reshape(-1, initial_params.shape[1])

        # Save posterior sampler.
        self._posterior_sampler = sampler

        samples = samples[::thin][:num_samples]

        return samples

    def _prepare_potential(self, method: str) -> Callable:
        """Combines potential and transform and takes care of gradients and pyro.

        Args:
            method: Which MCMC method to use.

        Returns:
            A potential function that is ready to be used in MCMC.
        """
        if method in ("hmc_pyro", "nuts_pyro"):
            track_gradients = True
            pyro = True
        elif method in ("hmc_pymc", "nuts_pymc"):
            track_gradients = True
            pyro = False
        elif method in ("slice_np", "slice_np_vectorized", "slice_pymc"):
            track_gradients = False
            pyro = False
        else:
            if "hmc" in method or "nuts" in method:
                warn(
                    "The kwargs 'hmc' and 'nuts' are deprecated. Use 'hmc_pyro', "
                    "'nuts_pyro', 'hmc_pymc', or 'nuts_pymc' instead.",
                    DeprecationWarning,
                    stacklevel=2,
                )
            raise NotImplementedError(f"MCMC method {method} is not implemented.")

        prepared_potential = partial(
            transformed_potential,
            potential_fn=self.potential_fn,
            theta_transform=self.theta_transform,
            device=self._device,
            track_gradients=track_gradients,
        )
        if pyro:
            prepared_potential = partial(
                pyro_potential_wrapper, potential=prepared_potential
            )

        return prepared_potential

    def map(
        self,
        x: Optional[Tensor] = None,
        num_iter: int = 1_000,
        num_to_optimize: int = 100,
        learning_rate: float = 0.01,
        init_method: Union[str, Tensor] = "proposal",
        num_init_samples: int = 1_000,
        save_best_every: int = 10,
        show_progress_bars: bool = False,
        force_update: bool = False,
    ) -> Tensor:
        r"""Returns the maximum-a-posteriori estimate (MAP).

        The method can be interrupted (Ctrl-C) when the user sees that the
        log-probability converges. The best estimate will be saved in `self._map` and
        can be accessed with `self.map()`. The MAP is obtained by running gradient
        ascent from a given number of starting positions (samples from the posterior
        with the highest log-probability). After the optimization is done, we select the
        parameter set that has the highest log-probability after the optimization.

        Warning: The default values used by this function are not well-tested. They
        might require hand-tuning for the problem at hand.

        For developers: if the prior is a `BoxUniform`, we carry out the optimization
        in unbounded space and transform the result back into bounded space.

        Args:
            x: Deprecated - use `.set_default_x()` prior to `.map()`.
            num_iter: Number of optimization steps that the algorithm takes
                to find the MAP.
            learning_rate: Learning rate of the optimizer.
            init_method: How to select the starting parameters for the optimization. If
                it is a string, it can be either [`posterior`, `prior`], which samples
                the respective distribution `num_init_samples` times. If it is a
                tensor, the tensor will be used as init locations.
            num_init_samples: Draw this number of samples from the posterior and
                evaluate the log-probability of all of them.
            num_to_optimize: From the drawn `num_init_samples`, use the
                `num_to_optimize` with highest log-probability as the initial points
                for the optimization.
            save_best_every: The best log-probability is computed, saved in the
                `map`-attribute, and printed every `save_best_every`-th iteration.
                Computing the best log-probability creates a significant overhead
                (thus, the default is `10`.)
            show_progress_bars: Whether to show a progressbar during sampling from
                the posterior.
            force_update: Whether to re-calculate the MAP when x is unchanged and
                have a cached value.
            log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
                {'norm_posterior': True} for SNPE.

        Returns:
            The MAP estimate.
        """
        return super().map(
            x=x,
            num_iter=num_iter,
            num_to_optimize=num_to_optimize,
            learning_rate=learning_rate,
            init_method=init_method,
            num_init_samples=num_init_samples,
            save_best_every=save_best_every,
            show_progress_bars=show_progress_bars,
            force_update=force_update,
        )

    def __getstate__(self) -> Dict:
        """Get state of MCMCPosterior.

        Removes the posterior sampler from the state, as it may not be picklable.

        Returns:
            Dict: State of MCMCPosterior.
        """
        state = self.__dict__.copy()
        state["_posterior_sampler"] = None

        return state

mcmc_method property writable

Returns MCMC method.

posterior_sampler property

Returns sampler created by sample.

__getstate__()

Get state of MCMCPosterior.

Removes the posterior sampler from the state, as it may not be picklable.

Returns:

Name Type Description
Dict Dict

State of MCMCPosterior.

Source code in sbi/inference/posteriors/mcmc_posterior.py
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
def __getstate__(self) -> Dict:
    """Get state of MCMCPosterior.

    Removes the posterior sampler from the state, as it may not be picklable.

    Returns:
        Dict: State of MCMCPosterior.
    """
    state = self.__dict__.copy()
    state["_posterior_sampler"] = None

    return state

__init__(potential_fn, proposal, theta_transform=None, method='slice_np_vectorized', thin=-1, warmup_steps=200, num_chains=20, init_strategy='resample', init_strategy_parameters=None, init_strategy_num_candidates=None, num_workers=1, mp_context='spawn', device=None, x_shape=None)

Parameters:

Name Type Description Default
potential_fn Union[Callable, BasePotential]

The potential function from which to draw samples. Must be a BasePotential or a Callable which takes theta and x_o as inputs.

required
proposal Any

Proposal distribution that is used to initialize the MCMC chain.

required
theta_transform Optional[TorchTransform]

Transformation that will be applied during sampling. Allows to perform MCMC in unconstrained space.

None
method Literal['slice_np', 'slice_np_vectorized', 'hmc_pyro', 'nuts_pyro', 'slice_pymc', 'hmc_pymc', 'nuts_pymc']

Method used for MCMC sampling, one of slice_np, slice_np_vectorized, hmc_pyro, nuts_pyro, slice_pymc, hmc_pymc, nuts_pymc. slice_np is a custom numpy implementation of slice sampling. slice_np_vectorized is identical to slice_np, but if num_chains>1, the chains are vectorized for slice_np_vectorized whereas they are run sequentially for slice_np. The samplers ending on _pyro are using Pyro, and likewise the samplers ending on _pymc are using PyMC.

'slice_np_vectorized'
thin int

The thinning factor for the chain, default 1 (no thinning).

-1
warmup_steps int

The initial number of samples to discard.

200
num_chains int

The number of chains. Should generally be at most num_workers - 1.

20
init_strategy Literal['proposal', 'sir', 'resample']

The initialisation strategy for chains; proposal will draw init locations from proposal, whereas sir will use Sequential- Importance-Resampling (SIR). SIR initially samples init_strategy_num_candidates from the proposal, evaluates all of them under the potential_fn and proposal, and then resamples the initial locations with weights proportional to exp(potential_fn - proposal.log_prob. resample is the same as sir but uses exp(potential_fn) as weights.

'resample'
init_strategy_parameters Optional[Dict[str, Any]]

Dictionary of keyword arguments passed to the init strategy, e.g., for init_strategy=sir this could be num_candidate_samples, i.e., the number of candidates to find init locations (internal default is 1000), or device.

None
init_strategy_num_candidates Optional[int]

Number of candidates to find init locations in init_strategy=sir (deprecated, use init_strategy_parameters instead).

None
num_workers int

number of cpu cores used to parallelize mcmc

1
mp_context Literal['fork', 'spawn']

Multiprocessing start method, either "fork" or "spawn" (default), used by Pyro and PyMC samplers. "fork" can be significantly faster than "spawn" but is only supported on POSIX-based systems (e.g. Linux and macOS, not Windows).

'spawn'
device Optional[Union[str, device]]

Training device, e.g., “cpu”, “cuda” or “cuda:0”. If None, potential_fn.device is used.

None
x_shape Optional[Size]

Deprecated, should not be passed.

None
Source code in sbi/inference/posteriors/mcmc_posterior.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def __init__(
    self,
    potential_fn: Union[Callable, BasePotential],
    proposal: Any,
    theta_transform: Optional[TorchTransform] = None,
    method: Literal[
        "slice_np",
        "slice_np_vectorized",
        "hmc_pyro",
        "nuts_pyro",
        "slice_pymc",
        "hmc_pymc",
        "nuts_pymc",
    ] = "slice_np_vectorized",
    thin: int = -1,
    warmup_steps: int = 200,
    num_chains: int = 20,
    init_strategy: Literal["proposal", "sir", "resample"] = "resample",
    init_strategy_parameters: Optional[Dict[str, Any]] = None,
    init_strategy_num_candidates: Optional[int] = None,
    num_workers: int = 1,
    mp_context: Literal["fork", "spawn"] = "spawn",
    device: Optional[Union[str, torch.device]] = None,
    x_shape: Optional[torch.Size] = None,
):
    """
    Args:
        potential_fn: The potential function from which to draw samples. Must be a
            `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
        proposal: Proposal distribution that is used to initialize the MCMC chain.
        theta_transform: Transformation that will be applied during sampling.
            Allows to perform MCMC in unconstrained space.
        method: Method used for MCMC sampling, one of `slice_np`,
            `slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
            `hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
            numpy implementation of slice sampling. `slice_np_vectorized` is
            identical to `slice_np`, but if `num_chains>1`, the chains are
            vectorized for `slice_np_vectorized` whereas they are run sequentially
            for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
            likewise the samplers ending on `_pymc` are using PyMC.
        thin: The thinning factor for the chain, default 1 (no thinning).
        warmup_steps: The initial number of samples to discard.
        num_chains: The number of chains. Should generally be at most
            `num_workers - 1`.
        init_strategy: The initialisation strategy for chains; `proposal` will draw
            init locations from `proposal`, whereas `sir` will use Sequential-
            Importance-Resampling (SIR). SIR initially samples
            `init_strategy_num_candidates` from the `proposal`, evaluates all of
            them under the `potential_fn` and `proposal`, and then resamples the
            initial locations with weights proportional to `exp(potential_fn -
            proposal.log_prob`. `resample` is the same as `sir` but
            uses `exp(potential_fn)` as weights.
        init_strategy_parameters: Dictionary of keyword arguments passed to the
            init strategy, e.g., for `init_strategy=sir` this could be
            `num_candidate_samples`, i.e., the number of candidates to find init
            locations (internal default is `1000`), or `device`.
        init_strategy_num_candidates: Number of candidates to find init
             locations in `init_strategy=sir` (deprecated, use
             init_strategy_parameters instead).
        num_workers: number of cpu cores used to parallelize mcmc
        mp_context: Multiprocessing start method, either `"fork"` or `"spawn"`
            (default), used by Pyro and PyMC samplers. `"fork"` can be significantly
            faster than `"spawn"` but is only supported on POSIX-based systems
            (e.g. Linux and macOS, not Windows).
        device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
            `potential_fn.device` is used.
        x_shape: Deprecated, should not be passed.
    """
    if method == "slice":
        warn(
            "The Pyro-based slice sampler is deprecated, and the method `slice` "
            "has been changed to `slice_np`, i.e., the custom "
            "numpy-based slice sampler.",
            DeprecationWarning,
            stacklevel=2,
        )
        method = "slice_np"

    thin = _process_thin_default(thin)

    super().__init__(
        potential_fn,
        theta_transform=theta_transform,
        device=device,
        x_shape=x_shape,
    )

    self.proposal = proposal
    self.method = method
    self.thin = thin
    self.warmup_steps = warmup_steps
    self.num_chains = num_chains
    self.init_strategy = init_strategy
    self.init_strategy_parameters = init_strategy_parameters or {}
    self.num_workers = num_workers
    self.mp_context = mp_context
    self._posterior_sampler = None
    # Hardcode parameter name to reduce clutter kwargs.
    self.param_name = "theta"
    self.x_shape = x_shape

    if init_strategy_num_candidates is not None:
        warn(
            "Passing `init_strategy_num_candidates` is deprecated as of sbi "
            "v0.19.0. Instead, use e.g., `init_strategy_parameters "
            f"={'num_candidate_samples': 1000}`",
            stacklevel=2,
        )
        self.init_strategy_parameters["num_candidate_samples"] = (
            init_strategy_num_candidates
        )

    self.potential_ = self._prepare_potential(method)

    self._purpose = (
        "It provides MCMC to .sample() from the posterior and "
        "can evaluate the _unnormalized_ posterior density with .log_prob()."
    )

log_prob(theta, x=None, track_gradients=False)

Returns the log-probability of theta under the posterior.

Parameters:

Name Type Description Default
theta Tensor

Parameters \(\theta\).

required
track_gradients bool

Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis, but increases memory consumption.

False

Returns:

Type Description
Tensor

len($\theta$)-shaped log-probability.

Source code in sbi/inference/posteriors/mcmc_posterior.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def log_prob(
    self, theta: Tensor, x: Optional[Tensor] = None, track_gradients: bool = False
) -> Tensor:
    r"""Returns the log-probability of theta under the posterior.

    Args:
        theta: Parameters $\theta$.
        track_gradients: Whether the returned tensor supports tracking gradients.
            This can be helpful for e.g. sensitivity analysis, but increases memory
            consumption.

    Returns:
        `len($\theta$)`-shaped log-probability.
    """
    warn(
        "`.log_prob()` is deprecated for methods that can only evaluate the "
        "log-probability up to a normalizing constant. Use `.potential()` instead.",
        stacklevel=2,
    )
    warn("The log-probability is unnormalized!", stacklevel=2)

    self.potential_fn.set_x(self._x_else_default_x(x))

    theta = ensure_theta_batched(torch.as_tensor(theta))
    return self.potential_fn(
        theta.to(self._device), track_gradients=track_gradients
    )

map(x=None, num_iter=1000, num_to_optimize=100, learning_rate=0.01, init_method='proposal', num_init_samples=1000, save_best_every=10, show_progress_bars=False, force_update=False)

Returns the maximum-a-posteriori estimate (MAP).

The method can be interrupted (Ctrl-C) when the user sees that the log-probability converges. The best estimate will be saved in self._map and can be accessed with self.map(). The MAP is obtained by running gradient ascent from a given number of starting positions (samples from the posterior with the highest log-probability). After the optimization is done, we select the parameter set that has the highest log-probability after the optimization.

Warning: The default values used by this function are not well-tested. They might require hand-tuning for the problem at hand.

For developers: if the prior is a BoxUniform, we carry out the optimization in unbounded space and transform the result back into bounded space.

Parameters:

Name Type Description Default
x Optional[Tensor]

Deprecated - use .set_default_x() prior to .map().

None
num_iter int

Number of optimization steps that the algorithm takes to find the MAP.

1000
learning_rate float

Learning rate of the optimizer.

0.01
init_method Union[str, Tensor]

How to select the starting parameters for the optimization. If it is a string, it can be either [posterior, prior], which samples the respective distribution num_init_samples times. If it is a tensor, the tensor will be used as init locations.

'proposal'
num_init_samples int

Draw this number of samples from the posterior and evaluate the log-probability of all of them.

1000
num_to_optimize int

From the drawn num_init_samples, use the num_to_optimize with highest log-probability as the initial points for the optimization.

100
save_best_every int

The best log-probability is computed, saved in the map-attribute, and printed every save_best_every-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is 10.)

10
show_progress_bars bool

Whether to show a progressbar during sampling from the posterior.

False
force_update bool

Whether to re-calculate the MAP when x is unchanged and have a cached value.

False
log_prob_kwargs

Will be empty for SNLE and SNRE. Will contain {‘norm_posterior’: True} for SNPE.

required

Returns:

Type Description
Tensor

The MAP estimate.

Source code in sbi/inference/posteriors/mcmc_posterior.py
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
def map(
    self,
    x: Optional[Tensor] = None,
    num_iter: int = 1_000,
    num_to_optimize: int = 100,
    learning_rate: float = 0.01,
    init_method: Union[str, Tensor] = "proposal",
    num_init_samples: int = 1_000,
    save_best_every: int = 10,
    show_progress_bars: bool = False,
    force_update: bool = False,
) -> Tensor:
    r"""Returns the maximum-a-posteriori estimate (MAP).

    The method can be interrupted (Ctrl-C) when the user sees that the
    log-probability converges. The best estimate will be saved in `self._map` and
    can be accessed with `self.map()`. The MAP is obtained by running gradient
    ascent from a given number of starting positions (samples from the posterior
    with the highest log-probability). After the optimization is done, we select the
    parameter set that has the highest log-probability after the optimization.

    Warning: The default values used by this function are not well-tested. They
    might require hand-tuning for the problem at hand.

    For developers: if the prior is a `BoxUniform`, we carry out the optimization
    in unbounded space and transform the result back into bounded space.

    Args:
        x: Deprecated - use `.set_default_x()` prior to `.map()`.
        num_iter: Number of optimization steps that the algorithm takes
            to find the MAP.
        learning_rate: Learning rate of the optimizer.
        init_method: How to select the starting parameters for the optimization. If
            it is a string, it can be either [`posterior`, `prior`], which samples
            the respective distribution `num_init_samples` times. If it is a
            tensor, the tensor will be used as init locations.
        num_init_samples: Draw this number of samples from the posterior and
            evaluate the log-probability of all of them.
        num_to_optimize: From the drawn `num_init_samples`, use the
            `num_to_optimize` with highest log-probability as the initial points
            for the optimization.
        save_best_every: The best log-probability is computed, saved in the
            `map`-attribute, and printed every `save_best_every`-th iteration.
            Computing the best log-probability creates a significant overhead
            (thus, the default is `10`.)
        show_progress_bars: Whether to show a progressbar during sampling from
            the posterior.
        force_update: Whether to re-calculate the MAP when x is unchanged and
            have a cached value.
        log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
            {'norm_posterior': True} for SNPE.

    Returns:
        The MAP estimate.
    """
    return super().map(
        x=x,
        num_iter=num_iter,
        num_to_optimize=num_to_optimize,
        learning_rate=learning_rate,
        init_method=init_method,
        num_init_samples=num_init_samples,
        save_best_every=save_best_every,
        show_progress_bars=show_progress_bars,
        force_update=force_update,
    )

sample(sample_shape=torch.Size(), x=None, method=None, thin=None, warmup_steps=None, num_chains=None, init_strategy=None, init_strategy_parameters=None, num_workers=None, mp_context=None, show_progress_bars=True)

Draw samples from the approximate posterior distribution \(p(\theta|x)\).

Parameters:

Name Type Description Default
sample_shape Shape

Desired shape of samples that are drawn from posterior. If sample_shape is multidimensional we simply draw sample_shape.numel() samples and then reshape into the desired shape.

Size()
x Optional[Tensor]

Conditioning observation \(x_o\). If not provided, uses the default x set via .set_default_x().

None
method Optional[str]

MCMC method to use. One of slice_np, slice_np_vectorized, hmc_pyro, nuts_pyro, slice_pymc, hmc_pymc, nuts_pymc. If not provided, uses the method specified at initialization.

None
thin Optional[int]

Thinning factor for the chain. If not provided, uses the value specified at initialization.

None
warmup_steps Optional[int]

Number of warmup steps to discard. If not provided, uses the value specified at initialization.

None
num_chains Optional[int]

Number of MCMC chains to run. If not provided, uses the value specified at initialization.

None
init_strategy Optional[str]

Initialization strategy for chains (proposal, sir, or resample). If not provided, uses the value specified at initialization.

None
init_strategy_parameters Optional[Dict[str, Any]]

Parameters for the initialization strategy. If not provided, uses the value specified at initialization.

None
num_workers Optional[int]

Number of CPU cores for parallelization. If not provided, uses the value specified at initialization.

None
mp_context Optional[str]

Multiprocessing context (fork or spawn). If not provided, uses the value specified at initialization.

None
show_progress_bars bool

Whether to show sampling progress monitor.

True

Returns:

Type Description
Tensor

Samples from posterior.

Source code in sbi/inference/posteriors/mcmc_posterior.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def sample(
    self,
    sample_shape: Shape = torch.Size(),
    x: Optional[Tensor] = None,
    method: Optional[str] = None,
    thin: Optional[int] = None,
    warmup_steps: Optional[int] = None,
    num_chains: Optional[int] = None,
    init_strategy: Optional[str] = None,
    init_strategy_parameters: Optional[Dict[str, Any]] = None,
    num_workers: Optional[int] = None,
    mp_context: Optional[str] = None,
    show_progress_bars: bool = True,
) -> Tensor:
    r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.

    Args:
        sample_shape: Desired shape of samples that are drawn from posterior. If
            sample_shape is multidimensional we simply draw `sample_shape.numel()`
            samples and then reshape into the desired shape.
        x: Conditioning observation $x_o$. If not provided, uses the default `x`
            set via `.set_default_x()`.
        method: MCMC method to use. One of `slice_np`, `slice_np_vectorized`,
            `hmc_pyro`, `nuts_pyro`, `slice_pymc`, `hmc_pymc`, `nuts_pymc`.
            If not provided, uses the method specified at initialization.
        thin: Thinning factor for the chain. If not provided, uses the value
            specified at initialization.
        warmup_steps: Number of warmup steps to discard. If not provided, uses
            the value specified at initialization.
        num_chains: Number of MCMC chains to run. If not provided, uses the
            value specified at initialization.
        init_strategy: Initialization strategy for chains (`proposal`, `sir`,
            or `resample`). If not provided, uses the value specified at
            initialization.
        init_strategy_parameters: Parameters for the initialization strategy.
            If not provided, uses the value specified at initialization.
        num_workers: Number of CPU cores for parallelization. If not provided,
            uses the value specified at initialization.
        mp_context: Multiprocessing context (`fork` or `spawn`). If not provided,
            uses the value specified at initialization.
        show_progress_bars: Whether to show sampling progress monitor.

    Returns:
        Samples from posterior.
    """

    self.potential_fn.set_x(self._x_else_default_x(x))

    # Replace arguments that were not passed with their default.
    method = self.method if method is None else method
    thin = self.thin if thin is None else thin
    warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps
    num_chains = self.num_chains if num_chains is None else num_chains
    init_strategy = self.init_strategy if init_strategy is None else init_strategy
    num_workers = self.num_workers if num_workers is None else num_workers
    mp_context = self.mp_context if mp_context is None else mp_context
    init_strategy_parameters = (
        self.init_strategy_parameters
        if init_strategy_parameters is None
        else init_strategy_parameters
    )
    self.potential_ = self._prepare_potential(method)  # type: ignore

    initial_params = self._get_initial_params(
        init_strategy,  # type: ignore
        num_chains,  # type: ignore
        num_workers,
        show_progress_bars,
        **init_strategy_parameters,
    )
    num_samples = torch.Size(sample_shape).numel()

    track_gradients = method in ("hmc_pyro", "nuts_pyro", "hmc_pymc", "nuts_pymc")
    with torch.set_grad_enabled(track_gradients):
        if method in ("slice_np", "slice_np_vectorized"):
            transformed_samples = self._slice_np_mcmc(
                num_samples=num_samples,
                potential_function=self.potential_,
                initial_params=initial_params,
                thin=thin,  # type: ignore
                warmup_steps=warmup_steps,  # type: ignore
                vectorized=(method == "slice_np_vectorized"),
                interchangeable_chains=True,
                num_workers=num_workers,
                show_progress_bars=show_progress_bars,
            )
        elif method in ("hmc_pyro", "nuts_pyro"):
            transformed_samples = self._pyro_mcmc(
                num_samples=num_samples,
                potential_function=self.potential_,
                initial_params=initial_params,
                mcmc_method=method,  # type: ignore
                thin=thin,  # type: ignore
                warmup_steps=warmup_steps,  # type: ignore
                num_chains=num_chains,
                show_progress_bars=show_progress_bars,
                mp_context=mp_context,
            )
        elif method in ("hmc_pymc", "nuts_pymc", "slice_pymc"):
            transformed_samples = self._pymc_mcmc(
                num_samples=num_samples,
                potential_function=self.potential_,
                initial_params=initial_params,
                mcmc_method=method,  # type: ignore
                thin=thin,  # type: ignore
                warmup_steps=warmup_steps,  # type: ignore
                num_chains=num_chains,
                show_progress_bars=show_progress_bars,
                mp_context=mp_context,
            )
        else:
            raise NameError(f"The sampling method {method} is not implemented!")

    samples = self.theta_transform.inv(transformed_samples)
    # NOTE: Currently MCMCPosteriors will require a single dimension for the
    # parameter dimension. With recent ConditionalDensity(Ratio) estimators, we
    # can have multiple dimensions for the parameter dimension.
    samples = samples.reshape((*sample_shape, -1))  # type: ignore

    return samples

sample_batched(sample_shape, x, method=None, thin=None, warmup_steps=None, num_chains=None, init_strategy=None, init_strategy_parameters=None, num_workers=None, mp_context=None, show_progress_bars=True)

Draw samples from the posteriors for a batch of different xs.

Given a batch of observations [x_1, ..., x_B], this method samples from posteriors \(p(\theta|x_1), \ldots, p(\theta|x_B)\) in a vectorized manner.

Check the __init__() method for a description of all arguments as well as their default values.

Parameters:

Name Type Description Default
sample_shape Shape

Desired shape of samples that are drawn from the posterior given every observation.

required
x Tensor

A batch of observations, of shape (batch_dim, event_shape_x). batch_dim corresponds to the number of observations to be drawn.

required
method Optional[str]

Method used for MCMC sampling, e.g., “slice_np_vectorized”.

None
thin Optional[int]

The thinning factor for the chain, default 1 (no thinning).

None
warmup_steps Optional[int]

The initial number of samples to discard.

None
num_chains Optional[int]

The number of chains used for each x passed in the batch.

None
init_strategy Optional[str]

The initialisation strategy for chains.

None
init_strategy_parameters Optional[Dict[str, Any]]

Dictionary of keyword arguments passed to the init strategy.

None
num_workers Optional[int]

number of cpu cores used to parallelize initial parameter generation and mcmc sampling.

None
mp_context Optional[str]

Multiprocessing start method, either "fork" or "spawn"

None
show_progress_bars bool

Whether to show sampling progress monitor.

True

Returns:

Type Description
Tensor

Samples from the posteriors of shape (*sample_shape, B, *input_shape)

Source code in sbi/inference/posteriors/mcmc_posterior.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def sample_batched(
    self,
    sample_shape: Shape,
    x: Tensor,
    method: Optional[str] = None,
    thin: Optional[int] = None,
    warmup_steps: Optional[int] = None,
    num_chains: Optional[int] = None,
    init_strategy: Optional[str] = None,
    init_strategy_parameters: Optional[Dict[str, Any]] = None,
    num_workers: Optional[int] = None,
    mp_context: Optional[str] = None,
    show_progress_bars: bool = True,
) -> Tensor:
    r"""Draw samples from the posteriors for a batch of different xs.

    Given a batch of observations `[x_1, ..., x_B]`, this method samples from
    posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.

    Check the `__init__()` method for a description of all arguments as well as
    their default values.

    Args:
        sample_shape: Desired shape of samples that are drawn from the posterior
            given every observation.
        x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
            `batch_dim` corresponds to the number of observations to be
            drawn.
        method: Method used for MCMC sampling, e.g., "slice_np_vectorized".
        thin: The thinning factor for the chain, default 1 (no thinning).
        warmup_steps: The initial number of samples to discard.
        num_chains: The number of chains used for each `x` passed in the batch.
        init_strategy: The initialisation strategy for chains.
        init_strategy_parameters: Dictionary of keyword arguments passed to
            the init strategy.
        num_workers: number of cpu cores used to parallelize initial
            parameter generation and mcmc sampling.
        mp_context: Multiprocessing start method, either `"fork"` or `"spawn"`
        show_progress_bars: Whether to show sampling progress monitor.

    Returns:
        Samples from the posteriors of shape (*sample_shape, B, *input_shape)
    """

    # Replace arguments that were not passed with their default.
    method = self.method if method is None else method
    thin = self.thin if thin is None else thin
    warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps
    num_chains = self.num_chains if num_chains is None else num_chains
    init_strategy = self.init_strategy if init_strategy is None else init_strategy
    num_workers = self.num_workers if num_workers is None else num_workers
    mp_context = self.mp_context if mp_context is None else mp_context
    init_strategy_parameters = (
        self.init_strategy_parameters
        if init_strategy_parameters is None
        else init_strategy_parameters
    )

    assert method == "slice_np_vectorized", (
        "Batched sampling only supported for vectorized samplers!"
    )

    # warn if num_chains is larger than num requested samples
    if num_chains > torch.Size(sample_shape).numel():
        warnings.warn(
            "The passed number of MCMC chains is larger than the number of "
            f"requested samples: {num_chains} > {torch.Size(sample_shape).numel()},"
            f" resetting it to {torch.Size(sample_shape).numel()}.",
            stacklevel=2,
        )
        num_chains = torch.Size(sample_shape).numel()

    # custom shape handling to make sure to match the batch size of x and theta
    # without unnecessary combinations.
    if len(x.shape) == 1:
        x = x.unsqueeze(0)
    batch_size = x.shape[0]

    x = reshape_to_batch_event(x, event_shape=x.shape[1:])

    # For batched sampling, we want `num_chains` for each observation in the batch.
    # Here we repeat the observations ABC -> AAABBBCCC, so that the chains are
    # in the order of the observations.
    x_ = x.repeat_interleave(num_chains, dim=0)

    self.potential_fn.set_x(x_, x_is_iid=False)
    self.potential_ = self._prepare_potential(method)  # type: ignore

    # For each observation in the batch, we have num_chains independent chains.
    num_chains_extended = batch_size * num_chains
    if num_chains_extended > 100:
        warnings.warn(
            "Note that for batched sampling, we use num_chains many chains "
            "for each x in the batch. With the given settings, this results "
            f"in a large number of chains ({num_chains_extended}), which can "
            "be slow and memory-intensive for vectorized MCMC. Consider "
            "reducing the number of chains or batch size.",
            stacklevel=2,
        )
    init_strategy_parameters["num_return_samples"] = num_chains_extended
    initial_params = self._get_initial_params_batched(
        x,
        init_strategy,  # type: ignore
        num_chains,  # type: ignore
        num_workers,
        show_progress_bars,
        **init_strategy_parameters,
    )
    # We need num_samples from each posterior in the batch
    num_samples = torch.Size(sample_shape).numel() * batch_size

    with torch.set_grad_enabled(False):
        transformed_samples = self._slice_np_mcmc(
            num_samples=num_samples,
            potential_function=self.potential_,
            initial_params=initial_params,
            thin=thin,  # type: ignore
            warmup_steps=warmup_steps,  # type: ignore
            vectorized=(method == "slice_np_vectorized"),
            interchangeable_chains=False,
            num_workers=num_workers,
            show_progress_bars=show_progress_bars,
        )

    # (num_chains_extended, samples_per_chain, *input_shape)
    samples_per_chain: Tensor = self.theta_transform.inv(transformed_samples)  # type: ignore
    dim_theta = samples_per_chain.shape[-1]
    # We need to collect samples for each x from the respective chains.
    # However, using samples.reshape(*sample_shape, batch_size, dim_theta)
    # does not combine the samples in the right order, since this mixes
    # samples that belong to different `x`. The following permute is a
    # workaround to reshape the samples in the right order.
    samples_per_x = samples_per_chain.reshape((
        batch_size,
        # We are flattening the sample shape here using -1 because we might have
        # generated more samples than requested (more chains, or multiple of
        # chains not matching sample_shape)
        -1,
        dim_theta,
    )).permute(1, 0, -1)

    # Shape is now (-1, batch_size, dim_theta)
    # We can now select the number of requested samples
    samples = samples_per_x[: torch.Size(sample_shape).numel()]
    # and reshape into (*sample_shape, batch_size, dim_theta)
    samples = samples.reshape((*sample_shape, batch_size, dim_theta))
    return samples

set_mcmc_method(method)

Sets sampling method to for MCMC and returns NeuralPosterior.

Parameters:

Name Type Description Default
method str

Method to use.

required

Returns:

Type Description
NeuralPosterior

NeuralPosterior for chainable calls.

Source code in sbi/inference/posteriors/mcmc_posterior.py
208
209
210
211
212
213
214
215
216
217
218
def set_mcmc_method(self, method: str) -> "NeuralPosterior":
    """Sets sampling method to for MCMC and returns `NeuralPosterior`.

    Args:
        method: Method to use.

    Returns:
        `NeuralPosterior` for chainable calls.
    """
    self._mcmc_method = method
    return self

to(device)

Moves potential_fn, proposal, x_o and theta_transform to the

specified device. Reinstantiates the posterior and resets the default x_o.

Parameters:

Name Type Description Default
device Union[str, device]

Device to move the posterior to.

required
Source code in sbi/inference/posteriors/mcmc_posterior.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def to(self, device: Union[str, torch.device]) -> None:
    """Moves potential_fn, proposal, x_o and theta_transform to the

    specified device. Reinstantiates the posterior and resets the default x_o.

    Args:
        device: Device to move the posterior to.
    """
    self.device = device
    self.potential_fn.to(device)  # type: ignore
    self.proposal.to(device)
    x_o = None
    if hasattr(self, "_x") and (self._x is not None):
        x_o = self._x.to(device)

    self.theta_transform = mcmc_transform(self.proposal, device=device)

    super().__init__(
        self.potential_fn,
        theta_transform=self.theta_transform,
        device=device,
        x_shape=self.x_shape,
    )
    # super().__init__ erases the self._x, so we need to set it again
    if x_o is not None:
        self.set_default_x(x_o)
    self.potential_ = self._prepare_potential(self.method)

RejectionPosterior

Bases: NeuralPosterior

Provides rejection sampling to sample from the posterior.

SNLE or SNRE train neural networks to approximate the likelihood(-ratios). RejectionPosterior allows to sample from the posterior with rejection sampling.

Source code in sbi/inference/posteriors/rejection_posterior.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
class RejectionPosterior(NeuralPosterior):
    r"""Provides rejection sampling to sample from the posterior.

    SNLE or SNRE train neural networks to approximate the likelihood(-ratios).
    `RejectionPosterior` allows to sample from the posterior with rejection sampling.
    """

    def __init__(
        self,
        potential_fn: Union[BasePotential, CustomPotential],
        proposal: Any,
        theta_transform: Optional[TorchTransform] = None,
        max_sampling_batch_size: int = 10_000,
        num_samples_to_find_max: int = 10_000,
        num_iter_to_find_max: int = 100,
        m: float = 1.2,
        device: Optional[Union[str, torch.device]] = None,
        x_shape: Optional[torch.Size] = None,
    ):
        """
        Args:
            potential_fn: The potential function from which to draw samples. Must be a
                `BasePotential` or a `CustomPotential`.
            proposal: The proposal distribution.
            theta_transform: Transformation that is applied to parameters. Is not used
                during but only when calling `.map()`.
            max_sampling_batch_size: The batchsize of samples being drawn from
                the proposal at every iteration.
            num_samples_to_find_max: The number of samples that are used to find the
                maximum of the `potential_fn / proposal` ratio.
            num_iter_to_find_max: The number of gradient ascent iterations to find the
                maximum of the `potential_fn / proposal` ratio.
            m: Multiplier to the `potential_fn / proposal` ratio.
            device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
                `potential_fn.device` is used.
            x_shape: Deprecated, should not be passed.
        """
        super().__init__(
            potential_fn,
            theta_transform=theta_transform,
            device=device,
            x_shape=x_shape,
        )

        self.proposal = proposal
        self.max_sampling_batch_size = max_sampling_batch_size
        self.num_samples_to_find_max = num_samples_to_find_max
        self.num_iter_to_find_max = num_iter_to_find_max
        self.m = m
        self.x_shape = x_shape

        self._purpose = (
            "It provides rejection sampling to .sample() from the posterior and "
            "can evaluate the _unnormalized_ posterior density with .log_prob()."
        )

    def to(self, device: Union[str, torch.device]) -> None:
        """
        Move potential fucntion, proposal and x_o to the device.

        This method reinstantiates the posterior and resets the default x_o

        Args:
            device: The device to move the posterior to.
        """
        self.device = device
        self.potential_fn.to(device)  # type: ignore
        self.proposal.to(device)
        x_o = None
        if hasattr(self, "_x") and (self._x is not None):
            x_o = self._x.to(device)

        self.theta_transform = mcmc_transform(self.proposal, device=device)
        super().__init__(
            self.potential_fn,
            theta_transform=self.theta_transform,
            device=device,
            x_shape=self.x_shape,
        )
        # super().__init__ erases the self._x, so we need to set it again
        if x_o is not None:
            self.set_default_x(x_o)

    def log_prob(
        self, theta: Tensor, x: Optional[Tensor] = None, track_gradients: bool = False
    ) -> Tensor:
        r"""Returns the log-probability of theta under the posterior.

        Args:
            theta: Parameters $\theta$.
            track_gradients: Whether the returned tensor supports tracking gradients.
                This can be helpful for e.g. sensitivity analysis, but increases memory
                consumption.

        Returns:
            `len($\theta$)`-shaped log-probability.
        """
        warn(
            "`.log_prob()` is deprecated for methods that can only evaluate the "
            "log-probability up to a normalizing constant. Use `.potential()` instead.",
            stacklevel=2,
        )
        warn("The log-probability is unnormalized!", stacklevel=2)

        self.potential_fn.set_x(self._x_else_default_x(x))

        theta = ensure_theta_batched(torch.as_tensor(theta))
        return self.potential_fn(
            theta.to(self._device), track_gradients=track_gradients
        )

    def sample(
        self,
        sample_shape: Shape = torch.Size(),
        x: Optional[Tensor] = None,
        max_sampling_batch_size: Optional[int] = None,
        num_samples_to_find_max: Optional[int] = None,
        num_iter_to_find_max: Optional[int] = None,
        m: Optional[float] = None,
        show_progress_bars: bool = True,
        reject_outside_prior: bool = True,
        max_sampling_time: Optional[float] = None,
        return_partial_on_timeout: bool = False,
    ):
        r"""Draw samples from the approximate posterior via rejection sampling.

        Args:
            sample_shape: Desired shape of samples that are drawn from posterior. If
                sample_shape is multidimensional we simply draw `sample_shape.numel()`
                samples and then reshape into the desired shape.
            x: Conditioning observation $x_o$. If not provided, uses the default `x`
                set via `.set_default_x()`.
            max_sampling_batch_size: Maximum batch size for rejection sampling.
                If not provided, uses the value specified at initialization.
            num_samples_to_find_max: Number of samples to find the maximum of the
                potential function. If not provided, uses the value from initialization.
            num_iter_to_find_max: Number of optimization iterations to find the
                maximum. If not provided, uses the value from initialization.
            m: Multiplier for the proposal distribution. If not provided, uses the
                value from initialization.
            show_progress_bars: Whether to show sampling progress monitor.
            reject_outside_prior: If True (default), rejection sampling is used to
                ensure samples lie within the prior support. If False, samples are drawn
                directly from the proposal without rejection, which is faster but may
                include samples outside the prior support.
            max_sampling_time: Optional maximum allowed sampling time in seconds.
                If exceeded, sampling is aborted and a RuntimeError is raised. Only
                applies when `reject_outside_prior=True` (no effect otherwise since
                direct sampling from the proposal is fast).
            return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
                return the samples collected so far instead of raising a RuntimeError.
                A warning will be issued. Only applies when `reject_outside_prior=True`
                (default).

        Returns:
            Samples from posterior.
        """
        num_samples = torch.Size(sample_shape).numel()
        self.potential_fn.set_x(self._x_else_default_x(x))

        potential = partial(self.potential_fn, track_gradients=True)

        # Replace arguments that were not passed with their default.
        max_sampling_batch_size = (
            self.max_sampling_batch_size
            if max_sampling_batch_size is None
            else max_sampling_batch_size
        )
        num_samples_to_find_max = (
            self.num_samples_to_find_max
            if num_samples_to_find_max is None
            else num_samples_to_find_max
        )
        num_iter_to_find_max = (
            self.num_iter_to_find_max
            if num_iter_to_find_max is None
            else num_iter_to_find_max
        )
        m = self.m if m is None else m

        if reject_outside_prior:
            samples, _ = rejection_sample(
                potential,
                proposal=self.proposal,
                num_samples=num_samples,
                show_progress_bars=show_progress_bars,
                warn_acceptance=0.01,
                max_sampling_batch_size=max_sampling_batch_size,
                num_samples_to_find_max=num_samples_to_find_max,
                num_iter_to_find_max=num_iter_to_find_max,
                m=m,
                max_sampling_time=max_sampling_time,
                return_partial_on_timeout=return_partial_on_timeout,
                device=self._device,
            )
        else:
            # Bypass rejection sampling entirely.
            samples = self.proposal.sample((num_samples,))
            warn(
                "Samples drawn with reject_outside_prior=False are taken directly "
                "from the proposal without rejection sampling. These samples may lie "
                "outside the prior support, which could lead to incorrect inference.",
                stacklevel=2,
            )

        return samples.reshape((*sample_shape, -1))

    def sample_batched(
        self,
        sample_shape: Shape,
        x: Tensor,
        max_sampling_batch_size: int = 10000,
        show_progress_bars: bool = True,
    ) -> Tensor:
        raise NotImplementedError(
            "Batched sampling is not implemented for RejectionPosterior. \
            Alternatively you can use `sample` in a loop \
            [posterior.sample(theta, x_o) for x_o in x]."
        )

    def map(
        self,
        x: Optional[Tensor] = None,
        num_iter: int = 1_000,
        num_to_optimize: int = 100,
        learning_rate: float = 0.01,
        init_method: Union[str, Tensor] = "proposal",
        num_init_samples: int = 1_000,
        save_best_every: int = 10,
        show_progress_bars: bool = False,
        force_update: bool = False,
    ) -> Tensor:
        r"""Returns the maximum-a-posteriori estimate (MAP).

        The method can be interrupted (Ctrl-C) when the user sees that the
        log-probability converges. The best estimate will be saved in `self._map` and
        can be accessed with `self.map()`. The MAP is obtained by running gradient
        ascent from a given number of starting positions (samples from the posterior
        with the highest log-probability). After the optimization is done, we select the
        parameter set that has the highest log-probability after the optimization.

        Warning: The default values used by this function are not well-tested. They
        might require hand-tuning for the problem at hand.

        For developers: if the prior is a `BoxUniform`, we carry out the optimization
        in unbounded space and transform the result back into bounded space.

        Args:
            x: Deprecated - use `.set_default_x()` prior to `.map()`.
            num_iter: Number of optimization steps that the algorithm takes
                to find the MAP.
            learning_rate: Learning rate of the optimizer.
            init_method: How to select the starting parameters for the optimization. If
                it is a string, it can be either [`posterior`, `prior`], which samples
                the respective distribution `num_init_samples` times. If it is a
                tensor, the tensor will be used as init locations.
            num_init_samples: Draw this number of samples from the posterior and
                evaluate the log-probability of all of them.
            num_to_optimize: From the drawn `num_init_samples`, use the
                `num_to_optimize` with highest log-probability as the initial points
                for the optimization.
            save_best_every: The best log-probability is computed, saved in the
                `map`-attribute, and printed every `save_best_every`-th iteration.
                Computing the best log-probability creates a significant overhead
                (thus, the default is `10`.)
            show_progress_bars: Whether to show a progressbar during sampling from
                the posterior.
            force_update: Whether to re-calculate the MAP when x is unchanged and
                have a cached value.
            log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
                {'norm_posterior': True} for SNPE.

        Returns:
            The MAP estimate.
        """
        return super().map(
            x=x,
            num_iter=num_iter,
            num_to_optimize=num_to_optimize,
            learning_rate=learning_rate,
            init_method=init_method,
            num_init_samples=num_init_samples,
            save_best_every=save_best_every,
            show_progress_bars=show_progress_bars,
            force_update=force_update,
        )

__init__(potential_fn, proposal, theta_transform=None, max_sampling_batch_size=10000, num_samples_to_find_max=10000, num_iter_to_find_max=100, m=1.2, device=None, x_shape=None)

Parameters:

Name Type Description Default
potential_fn Union[BasePotential, CustomPotential]

The potential function from which to draw samples. Must be a BasePotential or a CustomPotential.

required
proposal Any

The proposal distribution.

required
theta_transform Optional[TorchTransform]

Transformation that is applied to parameters. Is not used during but only when calling .map().

None
max_sampling_batch_size int

The batchsize of samples being drawn from the proposal at every iteration.

10000
num_samples_to_find_max int

The number of samples that are used to find the maximum of the potential_fn / proposal ratio.

10000
num_iter_to_find_max int

The number of gradient ascent iterations to find the maximum of the potential_fn / proposal ratio.

100
m float

Multiplier to the potential_fn / proposal ratio.

1.2
device Optional[Union[str, device]]

Training device, e.g., “cpu”, “cuda” or “cuda:0”. If None, potential_fn.device is used.

None
x_shape Optional[Size]

Deprecated, should not be passed.

None
Source code in sbi/inference/posteriors/rejection_posterior.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    potential_fn: Union[BasePotential, CustomPotential],
    proposal: Any,
    theta_transform: Optional[TorchTransform] = None,
    max_sampling_batch_size: int = 10_000,
    num_samples_to_find_max: int = 10_000,
    num_iter_to_find_max: int = 100,
    m: float = 1.2,
    device: Optional[Union[str, torch.device]] = None,
    x_shape: Optional[torch.Size] = None,
):
    """
    Args:
        potential_fn: The potential function from which to draw samples. Must be a
            `BasePotential` or a `CustomPotential`.
        proposal: The proposal distribution.
        theta_transform: Transformation that is applied to parameters. Is not used
            during but only when calling `.map()`.
        max_sampling_batch_size: The batchsize of samples being drawn from
            the proposal at every iteration.
        num_samples_to_find_max: The number of samples that are used to find the
            maximum of the `potential_fn / proposal` ratio.
        num_iter_to_find_max: The number of gradient ascent iterations to find the
            maximum of the `potential_fn / proposal` ratio.
        m: Multiplier to the `potential_fn / proposal` ratio.
        device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
            `potential_fn.device` is used.
        x_shape: Deprecated, should not be passed.
    """
    super().__init__(
        potential_fn,
        theta_transform=theta_transform,
        device=device,
        x_shape=x_shape,
    )

    self.proposal = proposal
    self.max_sampling_batch_size = max_sampling_batch_size
    self.num_samples_to_find_max = num_samples_to_find_max
    self.num_iter_to_find_max = num_iter_to_find_max
    self.m = m
    self.x_shape = x_shape

    self._purpose = (
        "It provides rejection sampling to .sample() from the posterior and "
        "can evaluate the _unnormalized_ posterior density with .log_prob()."
    )

log_prob(theta, x=None, track_gradients=False)

Returns the log-probability of theta under the posterior.

Parameters:

Name Type Description Default
theta Tensor

Parameters \(\theta\).

required
track_gradients bool

Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis, but increases memory consumption.

False

Returns:

Type Description
Tensor

len($\theta$)-shaped log-probability.

Source code in sbi/inference/posteriors/rejection_posterior.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def log_prob(
    self, theta: Tensor, x: Optional[Tensor] = None, track_gradients: bool = False
) -> Tensor:
    r"""Returns the log-probability of theta under the posterior.

    Args:
        theta: Parameters $\theta$.
        track_gradients: Whether the returned tensor supports tracking gradients.
            This can be helpful for e.g. sensitivity analysis, but increases memory
            consumption.

    Returns:
        `len($\theta$)`-shaped log-probability.
    """
    warn(
        "`.log_prob()` is deprecated for methods that can only evaluate the "
        "log-probability up to a normalizing constant. Use `.potential()` instead.",
        stacklevel=2,
    )
    warn("The log-probability is unnormalized!", stacklevel=2)

    self.potential_fn.set_x(self._x_else_default_x(x))

    theta = ensure_theta_batched(torch.as_tensor(theta))
    return self.potential_fn(
        theta.to(self._device), track_gradients=track_gradients
    )

map(x=None, num_iter=1000, num_to_optimize=100, learning_rate=0.01, init_method='proposal', num_init_samples=1000, save_best_every=10, show_progress_bars=False, force_update=False)

Returns the maximum-a-posteriori estimate (MAP).

The method can be interrupted (Ctrl-C) when the user sees that the log-probability converges. The best estimate will be saved in self._map and can be accessed with self.map(). The MAP is obtained by running gradient ascent from a given number of starting positions (samples from the posterior with the highest log-probability). After the optimization is done, we select the parameter set that has the highest log-probability after the optimization.

Warning: The default values used by this function are not well-tested. They might require hand-tuning for the problem at hand.

For developers: if the prior is a BoxUniform, we carry out the optimization in unbounded space and transform the result back into bounded space.

Parameters:

Name Type Description Default
x Optional[Tensor]

Deprecated - use .set_default_x() prior to .map().

None
num_iter int

Number of optimization steps that the algorithm takes to find the MAP.

1000
learning_rate float

Learning rate of the optimizer.

0.01
init_method Union[str, Tensor]

How to select the starting parameters for the optimization. If it is a string, it can be either [posterior, prior], which samples the respective distribution num_init_samples times. If it is a tensor, the tensor will be used as init locations.

'proposal'
num_init_samples int

Draw this number of samples from the posterior and evaluate the log-probability of all of them.

1000
num_to_optimize int

From the drawn num_init_samples, use the num_to_optimize with highest log-probability as the initial points for the optimization.

100
save_best_every int

The best log-probability is computed, saved in the map-attribute, and printed every save_best_every-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is 10.)

10
show_progress_bars bool

Whether to show a progressbar during sampling from the posterior.

False
force_update bool

Whether to re-calculate the MAP when x is unchanged and have a cached value.

False
log_prob_kwargs

Will be empty for SNLE and SNRE. Will contain {‘norm_posterior’: True} for SNPE.

required

Returns:

Type Description
Tensor

The MAP estimate.

Source code in sbi/inference/posteriors/rejection_posterior.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def map(
    self,
    x: Optional[Tensor] = None,
    num_iter: int = 1_000,
    num_to_optimize: int = 100,
    learning_rate: float = 0.01,
    init_method: Union[str, Tensor] = "proposal",
    num_init_samples: int = 1_000,
    save_best_every: int = 10,
    show_progress_bars: bool = False,
    force_update: bool = False,
) -> Tensor:
    r"""Returns the maximum-a-posteriori estimate (MAP).

    The method can be interrupted (Ctrl-C) when the user sees that the
    log-probability converges. The best estimate will be saved in `self._map` and
    can be accessed with `self.map()`. The MAP is obtained by running gradient
    ascent from a given number of starting positions (samples from the posterior
    with the highest log-probability). After the optimization is done, we select the
    parameter set that has the highest log-probability after the optimization.

    Warning: The default values used by this function are not well-tested. They
    might require hand-tuning for the problem at hand.

    For developers: if the prior is a `BoxUniform`, we carry out the optimization
    in unbounded space and transform the result back into bounded space.

    Args:
        x: Deprecated - use `.set_default_x()` prior to `.map()`.
        num_iter: Number of optimization steps that the algorithm takes
            to find the MAP.
        learning_rate: Learning rate of the optimizer.
        init_method: How to select the starting parameters for the optimization. If
            it is a string, it can be either [`posterior`, `prior`], which samples
            the respective distribution `num_init_samples` times. If it is a
            tensor, the tensor will be used as init locations.
        num_init_samples: Draw this number of samples from the posterior and
            evaluate the log-probability of all of them.
        num_to_optimize: From the drawn `num_init_samples`, use the
            `num_to_optimize` with highest log-probability as the initial points
            for the optimization.
        save_best_every: The best log-probability is computed, saved in the
            `map`-attribute, and printed every `save_best_every`-th iteration.
            Computing the best log-probability creates a significant overhead
            (thus, the default is `10`.)
        show_progress_bars: Whether to show a progressbar during sampling from
            the posterior.
        force_update: Whether to re-calculate the MAP when x is unchanged and
            have a cached value.
        log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
            {'norm_posterior': True} for SNPE.

    Returns:
        The MAP estimate.
    """
    return super().map(
        x=x,
        num_iter=num_iter,
        num_to_optimize=num_to_optimize,
        learning_rate=learning_rate,
        init_method=init_method,
        num_init_samples=num_init_samples,
        save_best_every=save_best_every,
        show_progress_bars=show_progress_bars,
        force_update=force_update,
    )

sample(sample_shape=torch.Size(), x=None, max_sampling_batch_size=None, num_samples_to_find_max=None, num_iter_to_find_max=None, m=None, show_progress_bars=True, reject_outside_prior=True, max_sampling_time=None, return_partial_on_timeout=False)

Draw samples from the approximate posterior via rejection sampling.

Parameters:

Name Type Description Default
sample_shape Shape

Desired shape of samples that are drawn from posterior. If sample_shape is multidimensional we simply draw sample_shape.numel() samples and then reshape into the desired shape.

Size()
x Optional[Tensor]

Conditioning observation \(x_o\). If not provided, uses the default x set via .set_default_x().

None
max_sampling_batch_size Optional[int]

Maximum batch size for rejection sampling. If not provided, uses the value specified at initialization.

None
num_samples_to_find_max Optional[int]

Number of samples to find the maximum of the potential function. If not provided, uses the value from initialization.

None
num_iter_to_find_max Optional[int]

Number of optimization iterations to find the maximum. If not provided, uses the value from initialization.

None
m Optional[float]

Multiplier for the proposal distribution. If not provided, uses the value from initialization.

None
show_progress_bars bool

Whether to show sampling progress monitor.

True
reject_outside_prior bool

If True (default), rejection sampling is used to ensure samples lie within the prior support. If False, samples are drawn directly from the proposal without rejection, which is faster but may include samples outside the prior support.

True
max_sampling_time Optional[float]

Optional maximum allowed sampling time in seconds. If exceeded, sampling is aborted and a RuntimeError is raised. Only applies when reject_outside_prior=True (no effect otherwise since direct sampling from the proposal is fast).

None
return_partial_on_timeout bool

If True and max_sampling_time is exceeded, return the samples collected so far instead of raising a RuntimeError. A warning will be issued. Only applies when reject_outside_prior=True (default).

False

Returns:

Type Description

Samples from posterior.

Source code in sbi/inference/posteriors/rejection_posterior.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def sample(
    self,
    sample_shape: Shape = torch.Size(),
    x: Optional[Tensor] = None,
    max_sampling_batch_size: Optional[int] = None,
    num_samples_to_find_max: Optional[int] = None,
    num_iter_to_find_max: Optional[int] = None,
    m: Optional[float] = None,
    show_progress_bars: bool = True,
    reject_outside_prior: bool = True,
    max_sampling_time: Optional[float] = None,
    return_partial_on_timeout: bool = False,
):
    r"""Draw samples from the approximate posterior via rejection sampling.

    Args:
        sample_shape: Desired shape of samples that are drawn from posterior. If
            sample_shape is multidimensional we simply draw `sample_shape.numel()`
            samples and then reshape into the desired shape.
        x: Conditioning observation $x_o$. If not provided, uses the default `x`
            set via `.set_default_x()`.
        max_sampling_batch_size: Maximum batch size for rejection sampling.
            If not provided, uses the value specified at initialization.
        num_samples_to_find_max: Number of samples to find the maximum of the
            potential function. If not provided, uses the value from initialization.
        num_iter_to_find_max: Number of optimization iterations to find the
            maximum. If not provided, uses the value from initialization.
        m: Multiplier for the proposal distribution. If not provided, uses the
            value from initialization.
        show_progress_bars: Whether to show sampling progress monitor.
        reject_outside_prior: If True (default), rejection sampling is used to
            ensure samples lie within the prior support. If False, samples are drawn
            directly from the proposal without rejection, which is faster but may
            include samples outside the prior support.
        max_sampling_time: Optional maximum allowed sampling time in seconds.
            If exceeded, sampling is aborted and a RuntimeError is raised. Only
            applies when `reject_outside_prior=True` (no effect otherwise since
            direct sampling from the proposal is fast).
        return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
            return the samples collected so far instead of raising a RuntimeError.
            A warning will be issued. Only applies when `reject_outside_prior=True`
            (default).

    Returns:
        Samples from posterior.
    """
    num_samples = torch.Size(sample_shape).numel()
    self.potential_fn.set_x(self._x_else_default_x(x))

    potential = partial(self.potential_fn, track_gradients=True)

    # Replace arguments that were not passed with their default.
    max_sampling_batch_size = (
        self.max_sampling_batch_size
        if max_sampling_batch_size is None
        else max_sampling_batch_size
    )
    num_samples_to_find_max = (
        self.num_samples_to_find_max
        if num_samples_to_find_max is None
        else num_samples_to_find_max
    )
    num_iter_to_find_max = (
        self.num_iter_to_find_max
        if num_iter_to_find_max is None
        else num_iter_to_find_max
    )
    m = self.m if m is None else m

    if reject_outside_prior:
        samples, _ = rejection_sample(
            potential,
            proposal=self.proposal,
            num_samples=num_samples,
            show_progress_bars=show_progress_bars,
            warn_acceptance=0.01,
            max_sampling_batch_size=max_sampling_batch_size,
            num_samples_to_find_max=num_samples_to_find_max,
            num_iter_to_find_max=num_iter_to_find_max,
            m=m,
            max_sampling_time=max_sampling_time,
            return_partial_on_timeout=return_partial_on_timeout,
            device=self._device,
        )
    else:
        # Bypass rejection sampling entirely.
        samples = self.proposal.sample((num_samples,))
        warn(
            "Samples drawn with reject_outside_prior=False are taken directly "
            "from the proposal without rejection sampling. These samples may lie "
            "outside the prior support, which could lead to incorrect inference.",
            stacklevel=2,
        )

    return samples.reshape((*sample_shape, -1))

to(device)

Move potential fucntion, proposal and x_o to the device.

This method reinstantiates the posterior and resets the default x_o

Parameters:

Name Type Description Default
device Union[str, device]

The device to move the posterior to.

required
Source code in sbi/inference/posteriors/rejection_posterior.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def to(self, device: Union[str, torch.device]) -> None:
    """
    Move potential fucntion, proposal and x_o to the device.

    This method reinstantiates the posterior and resets the default x_o

    Args:
        device: The device to move the posterior to.
    """
    self.device = device
    self.potential_fn.to(device)  # type: ignore
    self.proposal.to(device)
    x_o = None
    if hasattr(self, "_x") and (self._x is not None):
        x_o = self._x.to(device)

    self.theta_transform = mcmc_transform(self.proposal, device=device)
    super().__init__(
        self.potential_fn,
        theta_transform=self.theta_transform,
        device=device,
        x_shape=self.x_shape,
    )
    # super().__init__ erases the self._x, so we need to set it again
    if x_o is not None:
        self.set_default_x(x_o)

VectorFieldPosterior

Bases: NeuralPosterior

Posterior based on flow- or score-matching estimators.

This posterior samples from the vector field model - typically a score-based or a flow matching model - given the vector_field_estimator and rejects samples that lie outside of the prior bounds.

The posterior is defined by a vector field estimator and a prior. The vector field estimator defines a continuous transformation from a base distribution to the approximated posterior distribution. Sampling is done by running either an ordinary differential equation (ODE) or a stochastic differential equation (SDE) defined by the vector field estimator with the starting points sampled from the base distribution.

Log probabilities are obtained by calling the potential function, which in turn uses the ODE to compute the log-probability.

Source code in sbi/inference/posteriors/vector_field_posterior.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
class VectorFieldPosterior(NeuralPosterior):
    r"""Posterior based on flow- or score-matching estimators.

    This posterior samples from the vector field model - typically a score-based or a
    flow matching model - given the `vector_field_estimator` and rejects samples that
    lie outside of the prior bounds.

    The posterior is defined by a vector field estimator and a prior. The vector field
    estimator defines a continuous transformation from a base distribution to the
    approximated posterior distribution. Sampling is done by running either
    an ordinary differential equation (ODE) or a stochastic differential equation
    (SDE) defined by the vector field estimator with the starting points sampled from
    the base distribution.

    Log probabilities are obtained by calling the potential function, which in turn uses
    the ODE to compute the log-probability.
    """

    def __init__(
        self,
        vector_field_estimator: ConditionalVectorFieldEstimator,
        prior: Distribution,  # type: ignore
        max_sampling_batch_size: int = 10_000,
        device: Optional[Union[str, torch.device]] = None,
        enable_transform: bool = True,
        sample_with: Literal["ode", "sde"] = "sde",
        **kwargs,
    ):
        """
        Args:
            prior: Prior distribution with `.log_prob()` and `.sample()`.
            vector_field_estimator: The trained vector field estimator.
            max_sampling_batch_size: Batchsize of samples being drawn from
                the proposal at every iteration.
            device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
                `potential_fn.device` is used.
            enable_transform: Whether to transform parameters to unconstrained space
                during MAP optimization. When False, an identity transform will be
                returned for `theta_transform`. True is not supported yet.
            sample_with: Whether to sample from the posterior using the ODE-based
                sampler or the SDE-based sampler.
            **kwargs: Additional keyword arguments passed to
                `VectorFieldBasedPotential`.
        """

        check_prior(prior)
        potential_fn, theta_transform = vector_field_estimator_based_potential(
            vector_field_estimator,
            prior,
            x_o=None,
            enable_transform=enable_transform,
            **kwargs,
        )
        super().__init__(
            potential_fn=potential_fn,
            theta_transform=theta_transform,
            device=device,
        )
        # Set the potential function type.
        self.potential_fn: VectorFieldBasedPotential = potential_fn

        self.prior = prior
        self.enable_transform = enable_transform
        self.vector_field_estimator = vector_field_estimator
        self.device = device

        self.sample_with = sample_with
        assert self.sample_with in [
            "ode",
            "sde",
        ], f"sample_with must be 'ode' or 'sde', but is {self.sample_with}."
        self.max_sampling_batch_size = max_sampling_batch_size

        self._purpose = """It samples from the vector field model given the \
            vector_field_estimator."""

    def to(self, device: Union[str, torch.device]) -> None:
        """Move posterior to device.

        Args:
            device: device where to move the posterior to.
        """
        self.device = device
        if hasattr(self.prior, "to"):
            self.prior.to(device)  # type: ignore
        else:
            raise ValueError("""Prior has no attribute to(device).""")
        if hasattr(self.vector_field_estimator, "to"):
            self.vector_field_estimator.to(device)
        else:
            raise ValueError("""Posterior estimator has no attribute to(device).""")

        potential_fn, theta_transform = vector_field_estimator_based_potential(
            self.vector_field_estimator,
            self.prior,
            x_o=None,
            enable_transform=self.enable_transform,
        )
        x_o = None
        if hasattr(self, "_x") and (self._x is not None):
            x_o = self._x.to(device)
        super().__init__(
            potential_fn=potential_fn,
            theta_transform=theta_transform,
            device=device,
        )
        # super().__init__ erases the self._x, so we need to set it again
        if x_o is not None:
            self.set_default_x(x_o)

        self.potential_fn: VectorFieldBasedPotential = potential_fn

    def sample(
        self,
        sample_shape: Shape = torch.Size(),
        x: Optional[Tensor] = None,
        predictor: Union[str, Predictor] = "euler_maruyama",
        corrector: Optional[Union[str, Corrector]] = None,
        predictor_params: Optional[Dict] = None,
        corrector_params: Optional[Dict] = None,
        steps: int = 500,
        ts: Optional[Tensor] = None,
        iid_method: Optional[
            Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"]
        ] = None,
        iid_params: Optional[Dict] = None,
        max_sampling_batch_size: int = 10_000,
        sample_with: Optional[str] = None,
        show_progress_bars: bool = True,
        reject_outside_prior: bool = True,
        max_sampling_time: Optional[float] = None,
        return_partial_on_timeout: bool = False,
    ) -> Tensor:
        r"""Return samples from posterior distribution $p(\theta|x)$.

        Args:
            sample_shape: Shape of the samples to be drawn.
            predictor: The predictor for the vector field sampler. Can be a string or
                a custom predictor following the API in `sbi.samplers.score.predictors`.
                Currently, only `euler_maruyama` is implemented.
            corrector: The corrector for the vector field sampler. Either of
                [None].
            predictor_params: Additional parameters passed to predictor.
            corrector_params: Additional parameters passed to corrector.
            steps: Number of steps to take for the Euler-Maruyama method.
                If `sample_with` is "ode", this is ignored.
            ts: Time points at which to evaluate the vector field process. If None, a
                linear grid between t_max and t_min is used. If `sample_with` is "ode",
                this is ignored.
            iid_method: Which method to use for computing the score in the iid setting.
                We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss". The
                fnpe method is simple and generally applicable. However, it can become
                inaccurate already for quite a few iid samples (as it based on heuristic
                approximations), and should be used at best only with a `corrector`. The
                "gauss" methods are more accurate, by aiming for an efficient
                approximation of the correct marginal score in the iid case. This
                however requires estimating some hyperparamters, which is done in a
                systematic way in the "auto_gauss" (initial overhead) and "jac_gauss"
                (iterative jacobian computations are expensive). We default to
                "auto_gauss" for these reasons. Note that in order to use the iid
                method, the vector field estimator must support it and have
                SCORE_DEFINED and MARGINALS_DEFINED class attributes set to True.
            iid_params: Additional parameters passed to the iid method. See the specific
                `IIDScoreFunction` child class for details.
            max_sampling_batch_size: Maximum batch size for sampling.
            sample_with: Sampling method to use - 'ode' or 'sde'. Note that in order to
                use the 'sde' sampling method, the vector field estimator must support
                it and have the SCORE_DEFINED class attribute set to True.
            show_progress_bars: Whether to show a progress bar during sampling.
            reject_outside_prior: If True (default), rejection sampling is used to
                ensure samples lie within the prior support. If False, samples are drawn
                directly from the ODE/SDE sampler without rejection, which is faster but
                may include samples outside the prior support.
            max_sampling_time: Optional maximum allowed sampling time in seconds.
                If exceeded, sampling is aborted and a RuntimeError is raised. Only
                applies when `reject_outside_prior=True` (no effect otherwise since
                direct sampling does not use rejection).
            return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
                return the samples collected so far instead of raising a RuntimeError.
                A warning will be issued. Only applies when `reject_outside_prior=True`
                (default).
        """

        if sample_with is None:
            sample_with = self.sample_with

        x = self._x_else_default_x(x)
        x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
        is_iid = x.shape[0] > 1
        self.potential_fn.set_x(
            x,
            x_is_iid=is_iid,
            iid_method=iid_method or self.potential_fn.iid_method,
            iid_params=iid_params,
        )

        num_samples = torch.Size(sample_shape).numel()

        if sample_with == "ode":
            if reject_outside_prior:
                samples, _ = rejection.accept_reject_sample(
                    proposal=self.sample_via_ode,
                    accept_reject_fn=lambda theta: within_support(self.prior, theta),
                    num_samples=num_samples,
                    show_progress_bars=show_progress_bars,
                    max_sampling_batch_size=max_sampling_batch_size,
                    max_sampling_time=max_sampling_time,
                    return_partial_on_timeout=return_partial_on_timeout,
                )
            else:
                # Bypass rejection sampling entirely.
                samples = self.sample_via_ode(torch.Size([num_samples]))
        elif sample_with == "sde":
            proposal_sampling_kwargs = {
                "predictor": predictor,
                "corrector": corrector,
                "predictor_params": predictor_params,
                "corrector_params": corrector_params,
                "steps": steps,
                "ts": ts,
                "max_sampling_batch_size": max_sampling_batch_size,
                "show_progress_bars": show_progress_bars,
            }
            if reject_outside_prior:
                samples, _ = rejection.accept_reject_sample(
                    proposal=self._sample_via_diffusion,
                    accept_reject_fn=lambda theta: within_support(self.prior, theta),
                    num_samples=num_samples,
                    show_progress_bars=show_progress_bars,
                    max_sampling_batch_size=max_sampling_batch_size,
                    proposal_sampling_kwargs=proposal_sampling_kwargs,
                    max_sampling_time=max_sampling_time,
                    return_partial_on_timeout=return_partial_on_timeout,
                )
            else:
                # Bypass rejection sampling entirely.
                samples = self._sample_via_diffusion(
                    (num_samples,),
                    **proposal_sampling_kwargs,
                )
        else:
            raise ValueError(
                f"Expected sample_with to be 'ode' or 'sde', but got {sample_with}."
            )

        if not reject_outside_prior:
            warn_if_outside_prior_support(self.prior, samples)

        samples = samples.reshape(
            sample_shape + self.vector_field_estimator.input_shape
        )
        return samples

    def _sample_via_diffusion(
        self,
        sample_shape: Shape = torch.Size(),
        predictor: Union[str, Predictor] = "euler_maruyama",
        corrector: Optional[Union[str, Corrector]] = None,
        predictor_params: Optional[Dict] = None,
        corrector_params: Optional[Dict] = None,
        steps: int = 500,
        ts: Optional[Tensor] = None,
        max_sampling_batch_size: int = 10_000,
        show_progress_bars: bool = True,
        save_intermediate: bool = False,
        **kwargs,
    ) -> Tensor:
        r"""Return samples from posterior distribution $p(\theta|x)$.

        NOTE: this method can be unsupported for some vector field estimators, e.g.,
        if the vector field estimator was trained with a custom flow matching routine
        for which the corresponding score is not defined.

        Args:
            sample_shape: Shape of the samples to be drawn.
            predictor: The predictor for the diffusion-based sampler. Can be a string or
                a custom predictor following the API in `sbi.samplers.score.predictors`.
                Currently, only `euler_maruyama` is implemented.
            corrector: The corrector for the diffusion-based sampler. Either of
                [None].
            steps: Number of steps to take for the Euler-Maruyama method.
            ts: Time points at which to evaluate the diffusion process. If None,
                uses the solve_schedule() specific to the estimator.
            max_sampling_batch_size: Maximum batch size for sampling.
            sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
                `.sample()`.
            show_progress_bars: Whether to show a progress bar during sampling.
            save_intermediate: Whether to save intermediate results of the diffusion
                process. If True, the returned tensor has shape
                `(*sample_shape, steps, *input_shape)`.
        """

        if not self.vector_field_estimator.SCORE_DEFINED:
            raise ValueError(
                "The vector field estimator does not support the 'sde' sampling method."
            )

        total_samples_needed = torch.Size(sample_shape).numel()

        # Determine effective batch size for sampling
        effective_batch_size = (
            self.max_sampling_batch_size
            if max_sampling_batch_size is None
            else max_sampling_batch_size
        )
        # Ensure we don't use larger batches than total samples needed
        effective_batch_size = min(effective_batch_size, total_samples_needed)

        if ts is None:
            ts = self.vector_field_estimator.solve_schedule(steps)
        ts = ts.to(self.device)

        # Initialize the diffusion sampler
        diffuser = Diffuser(
            self.potential_fn,
            predictor=predictor,
            corrector=corrector,
            predictor_params=predictor_params,
            corrector_params=corrector_params,
        )

        # Calculate how many batches we need
        num_batches = math.ceil(total_samples_needed / effective_batch_size)

        # Generate samples in batches
        all_samples = []
        samples_generated = 0

        for _ in range(num_batches):
            # Calculate how many samples to generate in this batch
            remaining_samples = total_samples_needed - samples_generated
            current_batch_size = min(effective_batch_size, remaining_samples)

            # Generate samples for this batch
            batch_samples = diffuser.run(
                num_samples=current_batch_size,
                ts=ts,
                show_progress_bars=show_progress_bars,
                save_intermediate=save_intermediate,
            )

            all_samples.append(batch_samples)
            samples_generated += current_batch_size

        # Concatenate all batches and ensure we return exactly the requested number
        samples = torch.cat(all_samples, dim=0)[:total_samples_needed]

        if torch.isnan(samples).all():
            raise RuntimeError(
                "All samples NaN after diffusion sampling. "
                "This may indicate numerical instability in the vector field."
            )

        return samples

    def sample_via_ode(
        self,
        sample_shape: Shape = torch.Size(),
        **kwargs,
    ) -> Tensor:
        r"""
        Return samples from posterior distribution with probability flow ODE.

        This builds the probability flow ODE and then samples from the corresponding
        flow.

        Args:
            sample_shape: The shape of the samples to be returned.
            **kwargs: Additional keyword arguments for the ODE solver that
                depend on the used ODE backend.

        Returns:
            Samples from the approximated posterior distribution
                :math:`\theta \sim p(\theta|x)`.
        """
        num_samples = torch.Size(sample_shape).numel()

        samples = self.potential_fn.neural_ode(self.potential_fn.x_o, **kwargs).sample(
            torch.Size((num_samples,))
        )

        return samples

    def log_prob(
        self,
        theta: Tensor,
        x: Optional[Tensor] = None,
        track_gradients: bool = False,
        ode_kwargs: Optional[Dict] = None,
    ) -> Tensor:
        r"""Returns the log-probability of the posterior $p(\theta|x)$.

        This requires building and evaluating the probability flow ODE.

        Args:
            theta: Parameters $\theta$.
            x: Observed data $x_o$. If None, the default $x_o$ is used.
            track_gradients: Whether the returned tensor supports tracking gradients.
                This can be helpful for e.g. sensitivity analysis, but increases memory
                consumption.
            ode_kwargs: Additional keyword arguments for the ODE solver.

        Returns:
            `(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
            support of the prior, -∞ (corresponding to 0 probability) outside.
        """
        x = self._x_else_default_x(x)
        x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
        is_iid = x.shape[0] > 1
        self.potential_fn.set_x(x, x_is_iid=is_iid, **(ode_kwargs or {}))

        theta = ensure_theta_batched(torch.as_tensor(theta))
        return self.potential_fn(
            theta.to(self._device),
            track_gradients=track_gradients,
        )

    def sample_batched(
        self,
        sample_shape: torch.Size,
        x: Tensor,
        predictor: Union[str, Predictor] = "euler_maruyama",
        corrector: Optional[Union[str, Corrector]] = None,
        predictor_params: Optional[Dict] = None,
        corrector_params: Optional[Dict] = None,
        steps: int = 500,
        ts: Optional[Tensor] = None,
        max_sampling_batch_size: int = 10000,
        show_progress_bars: bool = True,
        reject_outside_prior: bool = True,
        max_sampling_time: Optional[float] = None,
        return_partial_on_timeout: bool = False,
    ) -> Tensor:
        r"""Given a batch of observations [x_1, ..., x_B] this function samples from
        posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
        manner.

        Args:
            sample_shape: Desired shape of samples that are drawn from the posterior
                given every observation.
            x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
                `batch_dim` corresponds to the number of observations to be
                drawn.
            predictor: The predictor for the diffusion-based sampler. Can be a string or
                a custom predictor following the API in `sbi.samplers.score.predictors`.
                Currently, only `euler_maruyama` is implemented.
            corrector: The corrector for the diffusion-based sampler.
            predictor_params: Additional parameters passed to predictor.
            corrector_params: Additional parameters passed to corrector.
            steps: Number of steps to take for the Euler-Maruyama method.
            ts: Time points at which to evaluate the diffusion process. If None, a
                linear grid between t_max and t_min is used.
            max_sampling_batch_size: Maximum batch size for sampling.
            show_progress_bars: Whether to show sampling progress monitor.
            reject_outside_prior: If True (default), rejection sampling is used to
                ensure samples lie within the prior support. If False, samples are drawn
                directly from the ODE/SDE sampler without rejection, which is faster but
                may include samples outside the prior support.
            max_sampling_time: Optional maximum allowed sampling time in seconds.
                If exceeded, sampling is aborted and a RuntimeError is raised. Only
                applies when `reject_outside_prior=True`.
            return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
                return the samples collected so far instead of raising a RuntimeError.
                A warning will be issued. Only applies when `reject_outside_prior=True`.

        Returns:
            Samples from the posteriors of shape (*sample_shape, B, *input_shape)
        """
        num_samples = torch.Size(sample_shape).numel()
        x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
        condition_dim = len(self.vector_field_estimator.condition_shape)
        batch_shape = x.shape[:-condition_dim]
        batch_size = batch_shape.numel()
        self.potential_fn.set_x(x)

        max_sampling_batch_size = (
            self.max_sampling_batch_size
            if max_sampling_batch_size is None
            else max_sampling_batch_size
        )

        # Adjust max_sampling_batch_size to avoid excessive memory usage
        if max_sampling_batch_size * batch_size > 100_000:
            capped = max(1, 100_000 // batch_size)
            warnings.warn(
                f"Capping max_sampling_batch_size from {max_sampling_batch_size} "
                f"to {capped} to avoid excessive memory usage.",
                stacklevel=2,
            )
            max_sampling_batch_size = capped

        if self.sample_with == "ode":
            if reject_outside_prior:
                samples, _ = rejection.accept_reject_sample(
                    proposal=self.sample_via_ode,
                    accept_reject_fn=lambda theta: within_support(self.prior, theta),
                    num_samples=num_samples,
                    num_xos=batch_size,
                    show_progress_bars=show_progress_bars,
                    max_sampling_batch_size=max_sampling_batch_size,
                    max_sampling_time=max_sampling_time,
                    return_partial_on_timeout=return_partial_on_timeout,
                )
            else:
                # Bypass rejection sampling.
                samples = self.sample_via_ode(torch.Size([num_samples]))
            samples = samples.reshape(
                sample_shape + batch_shape + self.vector_field_estimator.input_shape
            )
        elif self.sample_with == "sde":
            proposal_sampling_kwargs = {
                "predictor": predictor,
                "corrector": corrector,
                "predictor_params": predictor_params,
                "corrector_params": corrector_params,
                "steps": steps,
                "ts": ts,
                "max_sampling_batch_size": max_sampling_batch_size,
                "show_progress_bars": show_progress_bars,
            }
            if reject_outside_prior:
                samples, _ = rejection.accept_reject_sample(
                    proposal=self._sample_via_diffusion,
                    accept_reject_fn=lambda theta: within_support(self.prior, theta),
                    num_samples=num_samples,
                    num_xos=batch_size,
                    show_progress_bars=show_progress_bars,
                    max_sampling_batch_size=max_sampling_batch_size,
                    proposal_sampling_kwargs=proposal_sampling_kwargs,
                    max_sampling_time=max_sampling_time,
                    return_partial_on_timeout=return_partial_on_timeout,
                )
            else:
                # Bypass rejection sampling.
                samples = self._sample_via_diffusion(
                    (num_samples,), **proposal_sampling_kwargs
                )
            samples = samples.reshape(
                sample_shape + batch_shape + self.vector_field_estimator.input_shape
            )

        if not reject_outside_prior:
            warn_if_outside_prior_support(self.prior, samples)

        return samples

    def map(
        self,
        x: Optional[Tensor] = None,
        num_iter: int = 1000,
        num_to_optimize: int = 1000,
        learning_rate: float = 0.01,
        init_method: Union[str, Tensor] = "posterior",
        num_init_samples: int = 1000,
        save_best_every: int = 1000,
        show_progress_bars: bool = False,
        force_update: bool = False,
    ) -> Tensor:
        r"""Returns the maximum-a-posteriori estimate (MAP).

        The method can be interrupted (Ctrl-C) when the user sees that the
        log-probability converges. The best estimate will be saved in `self._map` and
        can be accessed with `self.map()`. The MAP is obtained by running gradient
        ascent from a given number of starting positions (samples from the posterior
        with the highest log-probability). After the optimization is done, we select the
        parameter set that has the highest log-probability after the optimization.

        Warning: The default values used by this function are not well-tested. They
        might require hand-tuning for the problem at hand.

        For developers: if the prior is a `BoxUniform`, we carry out the optimization
        in unbounded space and transform the result back into bounded space.

        Args:
            x: Deprecated - use `.set_default_x()` prior to `.map()`.
            num_iter: Number of optimization steps that the algorithm takes
                to find the MAP.
            num_to_optimize: From the drawn `num_init_samples`, use the
                `num_to_optimize` with highest log-probability as the initial points
                for the optimization.
            learning_rate: Learning rate of the optimizer.
            init_method: How to select the starting parameters for the optimization. If
                it is a string, it can be either [`posterior`, `prior`], which samples
                the respective distribution `num_init_samples` times. If it is a
                tensor, the tensor will be used as init locations.
            num_init_samples: Draw this number of samples from the posterior and
                evaluate the log-probability of all of them.
            save_best_every: The best log-probability is computed, saved in the
                `map`-attribute, and printed every `save_best_every`-th iteration.
                Computing the best log-probability creates a significant overhead
                (thus, the default is `10`.)
            show_progress_bars: Whether to show a progressbar during sampling from
                the posterior.
            force_update: Whether to re-calculate the MAP when x is unchanged and
                have a cached value.

        Returns:
            The MAP estimate.
        """
        if x is not None:
            raise ValueError(
                "Passing `x` directly to `.map()` has been deprecated."
                "Use `.self_default_x()` to set `x`, and then run `.map()` "
            )

        if self.default_x is None:
            raise ValueError(
                "Default `x` has not been set."
                "To set the default, use the `.set_default_x()` method."
            )

        if self._map is None or force_update:
            # rebuild coarse flow fast for MAP optimization.
            self.potential_fn.set_x(self.default_x, atol=1e-2, rtol=1e-3, exact=True)
            callable_potential_fn = CallableDifferentiablePotentialFunction(
                self.potential_fn
            )
            if init_method == "posterior":
                inits = self.sample((num_init_samples,))
            elif init_method == "proposal":
                inits = self.proposal.sample((num_init_samples,))  # type: ignore
            elif isinstance(init_method, Tensor):
                inits = init_method
            else:
                raise ValueError

            self._map = gradient_ascent(
                potential_fn=callable_potential_fn,
                inits=inits,
                theta_transform=self.theta_transform,
                num_iter=num_iter,
                num_to_optimize=num_to_optimize,
                learning_rate=learning_rate,
                save_best_every=save_best_every,
                show_progress_bars=show_progress_bars,
            )[0]

        return self._map

__init__(vector_field_estimator, prior, max_sampling_batch_size=10000, device=None, enable_transform=True, sample_with='sde', **kwargs)

Parameters:

Name Type Description Default
prior Distribution

Prior distribution with .log_prob() and .sample().

required
vector_field_estimator ConditionalVectorFieldEstimator

The trained vector field estimator.

required
max_sampling_batch_size int

Batchsize of samples being drawn from the proposal at every iteration.

10000
device Optional[Union[str, device]]

Training device, e.g., “cpu”, “cuda” or “cuda:0”. If None, potential_fn.device is used.

None
enable_transform bool

Whether to transform parameters to unconstrained space during MAP optimization. When False, an identity transform will be returned for theta_transform. True is not supported yet.

True
sample_with Literal['ode', 'sde']

Whether to sample from the posterior using the ODE-based sampler or the SDE-based sampler.

'sde'
**kwargs

Additional keyword arguments passed to VectorFieldBasedPotential.

{}
Source code in sbi/inference/posteriors/vector_field_posterior.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __init__(
    self,
    vector_field_estimator: ConditionalVectorFieldEstimator,
    prior: Distribution,  # type: ignore
    max_sampling_batch_size: int = 10_000,
    device: Optional[Union[str, torch.device]] = None,
    enable_transform: bool = True,
    sample_with: Literal["ode", "sde"] = "sde",
    **kwargs,
):
    """
    Args:
        prior: Prior distribution with `.log_prob()` and `.sample()`.
        vector_field_estimator: The trained vector field estimator.
        max_sampling_batch_size: Batchsize of samples being drawn from
            the proposal at every iteration.
        device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
            `potential_fn.device` is used.
        enable_transform: Whether to transform parameters to unconstrained space
            during MAP optimization. When False, an identity transform will be
            returned for `theta_transform`. True is not supported yet.
        sample_with: Whether to sample from the posterior using the ODE-based
            sampler or the SDE-based sampler.
        **kwargs: Additional keyword arguments passed to
            `VectorFieldBasedPotential`.
    """

    check_prior(prior)
    potential_fn, theta_transform = vector_field_estimator_based_potential(
        vector_field_estimator,
        prior,
        x_o=None,
        enable_transform=enable_transform,
        **kwargs,
    )
    super().__init__(
        potential_fn=potential_fn,
        theta_transform=theta_transform,
        device=device,
    )
    # Set the potential function type.
    self.potential_fn: VectorFieldBasedPotential = potential_fn

    self.prior = prior
    self.enable_transform = enable_transform
    self.vector_field_estimator = vector_field_estimator
    self.device = device

    self.sample_with = sample_with
    assert self.sample_with in [
        "ode",
        "sde",
    ], f"sample_with must be 'ode' or 'sde', but is {self.sample_with}."
    self.max_sampling_batch_size = max_sampling_batch_size

    self._purpose = """It samples from the vector field model given the \
        vector_field_estimator."""

log_prob(theta, x=None, track_gradients=False, ode_kwargs=None)

Returns the log-probability of the posterior \(p(\theta|x)\).

This requires building and evaluating the probability flow ODE.

Parameters:

Name Type Description Default
theta Tensor

Parameters \(\theta\).

required
x Optional[Tensor]

Observed data \(x_o\). If None, the default \(x_o\) is used.

None
track_gradients bool

Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis, but increases memory consumption.

False
ode_kwargs Optional[Dict]

Additional keyword arguments for the ODE solver.

None

Returns:

Type Description
Tensor

(len(θ),)-shaped log posterior probability \(\log p(\theta|x)\) for θ in the

Tensor

support of the prior, -∞ (corresponding to 0 probability) outside.

Source code in sbi/inference/posteriors/vector_field_posterior.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def log_prob(
    self,
    theta: Tensor,
    x: Optional[Tensor] = None,
    track_gradients: bool = False,
    ode_kwargs: Optional[Dict] = None,
) -> Tensor:
    r"""Returns the log-probability of the posterior $p(\theta|x)$.

    This requires building and evaluating the probability flow ODE.

    Args:
        theta: Parameters $\theta$.
        x: Observed data $x_o$. If None, the default $x_o$ is used.
        track_gradients: Whether the returned tensor supports tracking gradients.
            This can be helpful for e.g. sensitivity analysis, but increases memory
            consumption.
        ode_kwargs: Additional keyword arguments for the ODE solver.

    Returns:
        `(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
        support of the prior, -∞ (corresponding to 0 probability) outside.
    """
    x = self._x_else_default_x(x)
    x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
    is_iid = x.shape[0] > 1
    self.potential_fn.set_x(x, x_is_iid=is_iid, **(ode_kwargs or {}))

    theta = ensure_theta_batched(torch.as_tensor(theta))
    return self.potential_fn(
        theta.to(self._device),
        track_gradients=track_gradients,
    )

map(x=None, num_iter=1000, num_to_optimize=1000, learning_rate=0.01, init_method='posterior', num_init_samples=1000, save_best_every=1000, show_progress_bars=False, force_update=False)

Returns the maximum-a-posteriori estimate (MAP).

The method can be interrupted (Ctrl-C) when the user sees that the log-probability converges. The best estimate will be saved in self._map and can be accessed with self.map(). The MAP is obtained by running gradient ascent from a given number of starting positions (samples from the posterior with the highest log-probability). After the optimization is done, we select the parameter set that has the highest log-probability after the optimization.

Warning: The default values used by this function are not well-tested. They might require hand-tuning for the problem at hand.

For developers: if the prior is a BoxUniform, we carry out the optimization in unbounded space and transform the result back into bounded space.

Parameters:

Name Type Description Default
x Optional[Tensor]

Deprecated - use .set_default_x() prior to .map().

None
num_iter int

Number of optimization steps that the algorithm takes to find the MAP.

1000
num_to_optimize int

From the drawn num_init_samples, use the num_to_optimize with highest log-probability as the initial points for the optimization.

1000
learning_rate float

Learning rate of the optimizer.

0.01
init_method Union[str, Tensor]

How to select the starting parameters for the optimization. If it is a string, it can be either [posterior, prior], which samples the respective distribution num_init_samples times. If it is a tensor, the tensor will be used as init locations.

'posterior'
num_init_samples int

Draw this number of samples from the posterior and evaluate the log-probability of all of them.

1000
save_best_every int

The best log-probability is computed, saved in the map-attribute, and printed every save_best_every-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is 10.)

1000
show_progress_bars bool

Whether to show a progressbar during sampling from the posterior.

False
force_update bool

Whether to re-calculate the MAP when x is unchanged and have a cached value.

False

Returns:

Type Description
Tensor

The MAP estimate.

Source code in sbi/inference/posteriors/vector_field_posterior.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
def map(
    self,
    x: Optional[Tensor] = None,
    num_iter: int = 1000,
    num_to_optimize: int = 1000,
    learning_rate: float = 0.01,
    init_method: Union[str, Tensor] = "posterior",
    num_init_samples: int = 1000,
    save_best_every: int = 1000,
    show_progress_bars: bool = False,
    force_update: bool = False,
) -> Tensor:
    r"""Returns the maximum-a-posteriori estimate (MAP).

    The method can be interrupted (Ctrl-C) when the user sees that the
    log-probability converges. The best estimate will be saved in `self._map` and
    can be accessed with `self.map()`. The MAP is obtained by running gradient
    ascent from a given number of starting positions (samples from the posterior
    with the highest log-probability). After the optimization is done, we select the
    parameter set that has the highest log-probability after the optimization.

    Warning: The default values used by this function are not well-tested. They
    might require hand-tuning for the problem at hand.

    For developers: if the prior is a `BoxUniform`, we carry out the optimization
    in unbounded space and transform the result back into bounded space.

    Args:
        x: Deprecated - use `.set_default_x()` prior to `.map()`.
        num_iter: Number of optimization steps that the algorithm takes
            to find the MAP.
        num_to_optimize: From the drawn `num_init_samples`, use the
            `num_to_optimize` with highest log-probability as the initial points
            for the optimization.
        learning_rate: Learning rate of the optimizer.
        init_method: How to select the starting parameters for the optimization. If
            it is a string, it can be either [`posterior`, `prior`], which samples
            the respective distribution `num_init_samples` times. If it is a
            tensor, the tensor will be used as init locations.
        num_init_samples: Draw this number of samples from the posterior and
            evaluate the log-probability of all of them.
        save_best_every: The best log-probability is computed, saved in the
            `map`-attribute, and printed every `save_best_every`-th iteration.
            Computing the best log-probability creates a significant overhead
            (thus, the default is `10`.)
        show_progress_bars: Whether to show a progressbar during sampling from
            the posterior.
        force_update: Whether to re-calculate the MAP when x is unchanged and
            have a cached value.

    Returns:
        The MAP estimate.
    """
    if x is not None:
        raise ValueError(
            "Passing `x` directly to `.map()` has been deprecated."
            "Use `.self_default_x()` to set `x`, and then run `.map()` "
        )

    if self.default_x is None:
        raise ValueError(
            "Default `x` has not been set."
            "To set the default, use the `.set_default_x()` method."
        )

    if self._map is None or force_update:
        # rebuild coarse flow fast for MAP optimization.
        self.potential_fn.set_x(self.default_x, atol=1e-2, rtol=1e-3, exact=True)
        callable_potential_fn = CallableDifferentiablePotentialFunction(
            self.potential_fn
        )
        if init_method == "posterior":
            inits = self.sample((num_init_samples,))
        elif init_method == "proposal":
            inits = self.proposal.sample((num_init_samples,))  # type: ignore
        elif isinstance(init_method, Tensor):
            inits = init_method
        else:
            raise ValueError

        self._map = gradient_ascent(
            potential_fn=callable_potential_fn,
            inits=inits,
            theta_transform=self.theta_transform,
            num_iter=num_iter,
            num_to_optimize=num_to_optimize,
            learning_rate=learning_rate,
            save_best_every=save_best_every,
            show_progress_bars=show_progress_bars,
        )[0]

    return self._map

sample(sample_shape=torch.Size(), x=None, predictor='euler_maruyama', corrector=None, predictor_params=None, corrector_params=None, steps=500, ts=None, iid_method=None, iid_params=None, max_sampling_batch_size=10000, sample_with=None, show_progress_bars=True, reject_outside_prior=True, max_sampling_time=None, return_partial_on_timeout=False)

Return samples from posterior distribution \(p(\theta|x)\).

Parameters:

Name Type Description Default
sample_shape Shape

Shape of the samples to be drawn.

Size()
predictor Union[str, Predictor]

The predictor for the vector field sampler. Can be a string or a custom predictor following the API in sbi.samplers.score.predictors. Currently, only euler_maruyama is implemented.

'euler_maruyama'
corrector Optional[Union[str, Corrector]]

The corrector for the vector field sampler. Either of [None].

None
predictor_params Optional[Dict]

Additional parameters passed to predictor.

None
corrector_params Optional[Dict]

Additional parameters passed to corrector.

None
steps int

Number of steps to take for the Euler-Maruyama method. If sample_with is “ode”, this is ignored.

500
ts Optional[Tensor]

Time points at which to evaluate the vector field process. If None, a linear grid between t_max and t_min is used. If sample_with is “ode”, this is ignored.

None
iid_method Optional[Literal['fnpe', 'gauss', 'auto_gauss', 'jac_gauss']]

Which method to use for computing the score in the iid setting. We currently support “fnpe”, “gauss”, “auto_gauss”, “jac_gauss”. The fnpe method is simple and generally applicable. However, it can become inaccurate already for quite a few iid samples (as it based on heuristic approximations), and should be used at best only with a corrector. The “gauss” methods are more accurate, by aiming for an efficient approximation of the correct marginal score in the iid case. This however requires estimating some hyperparamters, which is done in a systematic way in the “auto_gauss” (initial overhead) and “jac_gauss” (iterative jacobian computations are expensive). We default to “auto_gauss” for these reasons. Note that in order to use the iid method, the vector field estimator must support it and have SCORE_DEFINED and MARGINALS_DEFINED class attributes set to True.

None
iid_params Optional[Dict]

Additional parameters passed to the iid method. See the specific IIDScoreFunction child class for details.

None
max_sampling_batch_size int

Maximum batch size for sampling.

10000
sample_with Optional[str]

Sampling method to use - ‘ode’ or ‘sde’. Note that in order to use the ‘sde’ sampling method, the vector field estimator must support it and have the SCORE_DEFINED class attribute set to True.

None
show_progress_bars bool

Whether to show a progress bar during sampling.

True
reject_outside_prior bool

If True (default), rejection sampling is used to ensure samples lie within the prior support. If False, samples are drawn directly from the ODE/SDE sampler without rejection, which is faster but may include samples outside the prior support.

True
max_sampling_time Optional[float]

Optional maximum allowed sampling time in seconds. If exceeded, sampling is aborted and a RuntimeError is raised. Only applies when reject_outside_prior=True (no effect otherwise since direct sampling does not use rejection).

None
return_partial_on_timeout bool

If True and max_sampling_time is exceeded, return the samples collected so far instead of raising a RuntimeError. A warning will be issued. Only applies when reject_outside_prior=True (default).

False
Source code in sbi/inference/posteriors/vector_field_posterior.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def sample(
    self,
    sample_shape: Shape = torch.Size(),
    x: Optional[Tensor] = None,
    predictor: Union[str, Predictor] = "euler_maruyama",
    corrector: Optional[Union[str, Corrector]] = None,
    predictor_params: Optional[Dict] = None,
    corrector_params: Optional[Dict] = None,
    steps: int = 500,
    ts: Optional[Tensor] = None,
    iid_method: Optional[
        Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"]
    ] = None,
    iid_params: Optional[Dict] = None,
    max_sampling_batch_size: int = 10_000,
    sample_with: Optional[str] = None,
    show_progress_bars: bool = True,
    reject_outside_prior: bool = True,
    max_sampling_time: Optional[float] = None,
    return_partial_on_timeout: bool = False,
) -> Tensor:
    r"""Return samples from posterior distribution $p(\theta|x)$.

    Args:
        sample_shape: Shape of the samples to be drawn.
        predictor: The predictor for the vector field sampler. Can be a string or
            a custom predictor following the API in `sbi.samplers.score.predictors`.
            Currently, only `euler_maruyama` is implemented.
        corrector: The corrector for the vector field sampler. Either of
            [None].
        predictor_params: Additional parameters passed to predictor.
        corrector_params: Additional parameters passed to corrector.
        steps: Number of steps to take for the Euler-Maruyama method.
            If `sample_with` is "ode", this is ignored.
        ts: Time points at which to evaluate the vector field process. If None, a
            linear grid between t_max and t_min is used. If `sample_with` is "ode",
            this is ignored.
        iid_method: Which method to use for computing the score in the iid setting.
            We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss". The
            fnpe method is simple and generally applicable. However, it can become
            inaccurate already for quite a few iid samples (as it based on heuristic
            approximations), and should be used at best only with a `corrector`. The
            "gauss" methods are more accurate, by aiming for an efficient
            approximation of the correct marginal score in the iid case. This
            however requires estimating some hyperparamters, which is done in a
            systematic way in the "auto_gauss" (initial overhead) and "jac_gauss"
            (iterative jacobian computations are expensive). We default to
            "auto_gauss" for these reasons. Note that in order to use the iid
            method, the vector field estimator must support it and have
            SCORE_DEFINED and MARGINALS_DEFINED class attributes set to True.
        iid_params: Additional parameters passed to the iid method. See the specific
            `IIDScoreFunction` child class for details.
        max_sampling_batch_size: Maximum batch size for sampling.
        sample_with: Sampling method to use - 'ode' or 'sde'. Note that in order to
            use the 'sde' sampling method, the vector field estimator must support
            it and have the SCORE_DEFINED class attribute set to True.
        show_progress_bars: Whether to show a progress bar during sampling.
        reject_outside_prior: If True (default), rejection sampling is used to
            ensure samples lie within the prior support. If False, samples are drawn
            directly from the ODE/SDE sampler without rejection, which is faster but
            may include samples outside the prior support.
        max_sampling_time: Optional maximum allowed sampling time in seconds.
            If exceeded, sampling is aborted and a RuntimeError is raised. Only
            applies when `reject_outside_prior=True` (no effect otherwise since
            direct sampling does not use rejection).
        return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
            return the samples collected so far instead of raising a RuntimeError.
            A warning will be issued. Only applies when `reject_outside_prior=True`
            (default).
    """

    if sample_with is None:
        sample_with = self.sample_with

    x = self._x_else_default_x(x)
    x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
    is_iid = x.shape[0] > 1
    self.potential_fn.set_x(
        x,
        x_is_iid=is_iid,
        iid_method=iid_method or self.potential_fn.iid_method,
        iid_params=iid_params,
    )

    num_samples = torch.Size(sample_shape).numel()

    if sample_with == "ode":
        if reject_outside_prior:
            samples, _ = rejection.accept_reject_sample(
                proposal=self.sample_via_ode,
                accept_reject_fn=lambda theta: within_support(self.prior, theta),
                num_samples=num_samples,
                show_progress_bars=show_progress_bars,
                max_sampling_batch_size=max_sampling_batch_size,
                max_sampling_time=max_sampling_time,
                return_partial_on_timeout=return_partial_on_timeout,
            )
        else:
            # Bypass rejection sampling entirely.
            samples = self.sample_via_ode(torch.Size([num_samples]))
    elif sample_with == "sde":
        proposal_sampling_kwargs = {
            "predictor": predictor,
            "corrector": corrector,
            "predictor_params": predictor_params,
            "corrector_params": corrector_params,
            "steps": steps,
            "ts": ts,
            "max_sampling_batch_size": max_sampling_batch_size,
            "show_progress_bars": show_progress_bars,
        }
        if reject_outside_prior:
            samples, _ = rejection.accept_reject_sample(
                proposal=self._sample_via_diffusion,
                accept_reject_fn=lambda theta: within_support(self.prior, theta),
                num_samples=num_samples,
                show_progress_bars=show_progress_bars,
                max_sampling_batch_size=max_sampling_batch_size,
                proposal_sampling_kwargs=proposal_sampling_kwargs,
                max_sampling_time=max_sampling_time,
                return_partial_on_timeout=return_partial_on_timeout,
            )
        else:
            # Bypass rejection sampling entirely.
            samples = self._sample_via_diffusion(
                (num_samples,),
                **proposal_sampling_kwargs,
            )
    else:
        raise ValueError(
            f"Expected sample_with to be 'ode' or 'sde', but got {sample_with}."
        )

    if not reject_outside_prior:
        warn_if_outside_prior_support(self.prior, samples)

    samples = samples.reshape(
        sample_shape + self.vector_field_estimator.input_shape
    )
    return samples

sample_batched(sample_shape, x, predictor='euler_maruyama', corrector=None, predictor_params=None, corrector_params=None, steps=500, ts=None, max_sampling_batch_size=10000, show_progress_bars=True, reject_outside_prior=True, max_sampling_time=None, return_partial_on_timeout=False)

Given a batch of observations [x_1, …, x_B] this function samples from posteriors \(p(\theta|x_1)\), … ,\(p(\theta|x_B)\), in a batched (i.e. vectorized) manner.

Parameters:

Name Type Description Default
sample_shape Size

Desired shape of samples that are drawn from the posterior given every observation.

required
x Tensor

A batch of observations, of shape (batch_dim, event_shape_x). batch_dim corresponds to the number of observations to be drawn.

required
predictor Union[str, Predictor]

The predictor for the diffusion-based sampler. Can be a string or a custom predictor following the API in sbi.samplers.score.predictors. Currently, only euler_maruyama is implemented.

'euler_maruyama'
corrector Optional[Union[str, Corrector]]

The corrector for the diffusion-based sampler.

None
predictor_params Optional[Dict]

Additional parameters passed to predictor.

None
corrector_params Optional[Dict]

Additional parameters passed to corrector.

None
steps int

Number of steps to take for the Euler-Maruyama method.

500
ts Optional[Tensor]

Time points at which to evaluate the diffusion process. If None, a linear grid between t_max and t_min is used.

None
max_sampling_batch_size int

Maximum batch size for sampling.

10000
show_progress_bars bool

Whether to show sampling progress monitor.

True
reject_outside_prior bool

If True (default), rejection sampling is used to ensure samples lie within the prior support. If False, samples are drawn directly from the ODE/SDE sampler without rejection, which is faster but may include samples outside the prior support.

True
max_sampling_time Optional[float]

Optional maximum allowed sampling time in seconds. If exceeded, sampling is aborted and a RuntimeError is raised. Only applies when reject_outside_prior=True.

None
return_partial_on_timeout bool

If True and max_sampling_time is exceeded, return the samples collected so far instead of raising a RuntimeError. A warning will be issued. Only applies when reject_outside_prior=True.

False

Returns:

Type Description
Tensor

Samples from the posteriors of shape (*sample_shape, B, *input_shape)

Source code in sbi/inference/posteriors/vector_field_posterior.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def sample_batched(
    self,
    sample_shape: torch.Size,
    x: Tensor,
    predictor: Union[str, Predictor] = "euler_maruyama",
    corrector: Optional[Union[str, Corrector]] = None,
    predictor_params: Optional[Dict] = None,
    corrector_params: Optional[Dict] = None,
    steps: int = 500,
    ts: Optional[Tensor] = None,
    max_sampling_batch_size: int = 10000,
    show_progress_bars: bool = True,
    reject_outside_prior: bool = True,
    max_sampling_time: Optional[float] = None,
    return_partial_on_timeout: bool = False,
) -> Tensor:
    r"""Given a batch of observations [x_1, ..., x_B] this function samples from
    posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
    manner.

    Args:
        sample_shape: Desired shape of samples that are drawn from the posterior
            given every observation.
        x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
            `batch_dim` corresponds to the number of observations to be
            drawn.
        predictor: The predictor for the diffusion-based sampler. Can be a string or
            a custom predictor following the API in `sbi.samplers.score.predictors`.
            Currently, only `euler_maruyama` is implemented.
        corrector: The corrector for the diffusion-based sampler.
        predictor_params: Additional parameters passed to predictor.
        corrector_params: Additional parameters passed to corrector.
        steps: Number of steps to take for the Euler-Maruyama method.
        ts: Time points at which to evaluate the diffusion process. If None, a
            linear grid between t_max and t_min is used.
        max_sampling_batch_size: Maximum batch size for sampling.
        show_progress_bars: Whether to show sampling progress monitor.
        reject_outside_prior: If True (default), rejection sampling is used to
            ensure samples lie within the prior support. If False, samples are drawn
            directly from the ODE/SDE sampler without rejection, which is faster but
            may include samples outside the prior support.
        max_sampling_time: Optional maximum allowed sampling time in seconds.
            If exceeded, sampling is aborted and a RuntimeError is raised. Only
            applies when `reject_outside_prior=True`.
        return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
            return the samples collected so far instead of raising a RuntimeError.
            A warning will be issued. Only applies when `reject_outside_prior=True`.

    Returns:
        Samples from the posteriors of shape (*sample_shape, B, *input_shape)
    """
    num_samples = torch.Size(sample_shape).numel()
    x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
    condition_dim = len(self.vector_field_estimator.condition_shape)
    batch_shape = x.shape[:-condition_dim]
    batch_size = batch_shape.numel()
    self.potential_fn.set_x(x)

    max_sampling_batch_size = (
        self.max_sampling_batch_size
        if max_sampling_batch_size is None
        else max_sampling_batch_size
    )

    # Adjust max_sampling_batch_size to avoid excessive memory usage
    if max_sampling_batch_size * batch_size > 100_000:
        capped = max(1, 100_000 // batch_size)
        warnings.warn(
            f"Capping max_sampling_batch_size from {max_sampling_batch_size} "
            f"to {capped} to avoid excessive memory usage.",
            stacklevel=2,
        )
        max_sampling_batch_size = capped

    if self.sample_with == "ode":
        if reject_outside_prior:
            samples, _ = rejection.accept_reject_sample(
                proposal=self.sample_via_ode,
                accept_reject_fn=lambda theta: within_support(self.prior, theta),
                num_samples=num_samples,
                num_xos=batch_size,
                show_progress_bars=show_progress_bars,
                max_sampling_batch_size=max_sampling_batch_size,
                max_sampling_time=max_sampling_time,
                return_partial_on_timeout=return_partial_on_timeout,
            )
        else:
            # Bypass rejection sampling.
            samples = self.sample_via_ode(torch.Size([num_samples]))
        samples = samples.reshape(
            sample_shape + batch_shape + self.vector_field_estimator.input_shape
        )
    elif self.sample_with == "sde":
        proposal_sampling_kwargs = {
            "predictor": predictor,
            "corrector": corrector,
            "predictor_params": predictor_params,
            "corrector_params": corrector_params,
            "steps": steps,
            "ts": ts,
            "max_sampling_batch_size": max_sampling_batch_size,
            "show_progress_bars": show_progress_bars,
        }
        if reject_outside_prior:
            samples, _ = rejection.accept_reject_sample(
                proposal=self._sample_via_diffusion,
                accept_reject_fn=lambda theta: within_support(self.prior, theta),
                num_samples=num_samples,
                num_xos=batch_size,
                show_progress_bars=show_progress_bars,
                max_sampling_batch_size=max_sampling_batch_size,
                proposal_sampling_kwargs=proposal_sampling_kwargs,
                max_sampling_time=max_sampling_time,
                return_partial_on_timeout=return_partial_on_timeout,
            )
        else:
            # Bypass rejection sampling.
            samples = self._sample_via_diffusion(
                (num_samples,), **proposal_sampling_kwargs
            )
        samples = samples.reshape(
            sample_shape + batch_shape + self.vector_field_estimator.input_shape
        )

    if not reject_outside_prior:
        warn_if_outside_prior_support(self.prior, samples)

    return samples

sample_via_ode(sample_shape=torch.Size(), **kwargs)

Return samples from posterior distribution with probability flow ODE.

This builds the probability flow ODE and then samples from the corresponding flow.

Parameters:

Name Type Description Default
sample_shape Shape

The shape of the samples to be returned.

Size()
**kwargs

Additional keyword arguments for the ODE solver that depend on the used ODE backend.

{}

Returns:

Type Description
Tensor

Samples from the approximated posterior distribution :math:\theta \sim p(\theta|x).

Source code in sbi/inference/posteriors/vector_field_posterior.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def sample_via_ode(
    self,
    sample_shape: Shape = torch.Size(),
    **kwargs,
) -> Tensor:
    r"""
    Return samples from posterior distribution with probability flow ODE.

    This builds the probability flow ODE and then samples from the corresponding
    flow.

    Args:
        sample_shape: The shape of the samples to be returned.
        **kwargs: Additional keyword arguments for the ODE solver that
            depend on the used ODE backend.

    Returns:
        Samples from the approximated posterior distribution
            :math:`\theta \sim p(\theta|x)`.
    """
    num_samples = torch.Size(sample_shape).numel()

    samples = self.potential_fn.neural_ode(self.potential_fn.x_o, **kwargs).sample(
        torch.Size((num_samples,))
    )

    return samples

to(device)

Move posterior to device.

Parameters:

Name Type Description Default
device Union[str, device]

device where to move the posterior to.

required
Source code in sbi/inference/posteriors/vector_field_posterior.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def to(self, device: Union[str, torch.device]) -> None:
    """Move posterior to device.

    Args:
        device: device where to move the posterior to.
    """
    self.device = device
    if hasattr(self.prior, "to"):
        self.prior.to(device)  # type: ignore
    else:
        raise ValueError("""Prior has no attribute to(device).""")
    if hasattr(self.vector_field_estimator, "to"):
        self.vector_field_estimator.to(device)
    else:
        raise ValueError("""Posterior estimator has no attribute to(device).""")

    potential_fn, theta_transform = vector_field_estimator_based_potential(
        self.vector_field_estimator,
        self.prior,
        x_o=None,
        enable_transform=self.enable_transform,
    )
    x_o = None
    if hasattr(self, "_x") and (self._x is not None):
        x_o = self._x.to(device)
    super().__init__(
        potential_fn=potential_fn,
        theta_transform=theta_transform,
        device=device,
    )
    # super().__init__ erases the self._x, so we need to set it again
    if x_o is not None:
        self.set_default_x(x_o)

    self.potential_fn: VectorFieldBasedPotential = potential_fn

VIPosterior

Bases: NeuralPosterior

Provides VI (Variational Inference) to sample from the posterior.

SNLE or SNRE train neural networks to approximate the likelihood (or likelihood ratios). VIPosterior allows learning a tractable variational posterior :math:q(\theta) which approximates the true posterior :math:p(\theta|x_o). After this second training stage, we can produce approximate posterior samples by sampling from :math:q at no additional cost.

For additional information, see [1]_ and [2]_.

References

.. [1] Glöckler, M., Deistler, M., & Macke, J. (2022). Variational methods for simulation-based inference. https://openreview.net/forum?id=kZ0UYdhqkNY

.. [2] Wiqvist, S., Frellsen, J., & Picchini, U. (2021). Sequential Neural Posterior and Likelihood Approximation. https://arxiv.org/abs/2102.06522

Source code in sbi/inference/posteriors/vi_posterior.py
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
class VIPosterior(NeuralPosterior):
    r"""Provides VI (Variational Inference) to sample from the posterior.

    SNLE or SNRE train neural networks to approximate the likelihood (or likelihood
    ratios). ``VIPosterior`` allows learning a tractable variational posterior
    :math:`q(\theta)` which approximates the true posterior
    :math:`p(\theta|x_o)`. After this second training stage, we can produce
    approximate posterior samples by sampling from :math:`q` at no additional cost.

    For additional information, see [1]_ and [2]_.

    References
    ----------

    .. [1] Glöckler, M., Deistler, M., & Macke, J. (2022).
        Variational methods for simulation-based inference.
        https://openreview.net/forum?id=kZ0UYdhqkNY

    .. [2] Wiqvist, S., Frellsen, J., & Picchini, U. (2021).
        Sequential Neural Posterior and Likelihood Approximation.
        https://arxiv.org/abs/2102.06522
    """

    def __init__(
        self,
        potential_fn: Union[BasePotential, CustomPotential],
        prior: Optional[TorchDistribution] = None,  # type: ignore
        q: QType = "maf",
        theta_transform: Optional[TorchTransform] = None,
        vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL",
        device: Union[str, torch.device] = "cpu",
        x_shape: Optional[torch.Size] = None,
        parameters: Optional[Iterable] = None,
        modules: Optional[Iterable] = None,
        num_transforms: int = 5,
        hidden_features: int = 50,
        z_score_theta: Literal["none", "independent", "structured"] = "independent",
        z_score_x: Literal["none", "independent", "structured"] = "independent",
    ):
        """
        Args:
            potential_fn: The potential function from which to draw samples. Must be a
                `BasePotential` or a `CustomPotential`.
            prior: This is the prior distribution. Note that this is only
                used to check/construct the variational distribution or within some
                quality metrics. Please make sure that this matches with the prior
                within the potential_fn. If `None` is given, we will try to infer it
                from potential_fn or q, if this fails we raise an Error.
            q: Variational distribution, either string, `Distribution`, or a
                `VIPosterior` object. This specifies a parametric class of distribution
                over which the best possible posterior approximation is searched. For
                string input, we support normalizing flows [maf, nsf, naf, unaf, nice,
                sospf, gf] via Zuko, and Gaussian families [gaussian, gaussian_diag].
                Note: For 1D problems, prefer "gf" (mixture of Gaussians) or "gaussian"
                as autoregressive flows may be unstable.
                You can also specify your own variational family by passing a
                `torch.distributions.Distribution`. Additionally, we allow a `Callable`
                with signature `(event_shape: torch.Size, link_transform:
                TorchTransform, device: str) -> Distribution` for custom flow
                configurations. The
                callable should return a distribution with `sample()` and `log_prob()`
                methods. If q is already a `VIPosterior`, then the arguments will be
                copied from it (relevant for multi-round training).
            theta_transform: Maps form prior support to unconstrained space. The
                inverse is used here to ensure that the posterior support is equal to
                that of the prior.
            vi_method: This specifies the variational methods which are used to fit q to
                the posterior. We currently support [rKL, fKL, IW, alpha]. Note that
                some of the divergences are `mode seeking` i.e. they underestimate
                variance and collapse on multimodal targets (`rKL`, `alpha` for alpha >
                1) and some are `mass covering` i.e. they overestimate variance but
                typically cover all modes (`fKL`, `IW`, `alpha` for alpha < 1).
            device: Training device, e.g., `cpu`, `cuda` or `cuda:0`. We will ensure
                that all other objects are also on this device.
            x_shape: Deprecated, should not be passed.
            parameters: List of parameters of the variational posterior. This is only
                required for user-defined q i.e. if q does not have a `parameters`
                attribute.
            modules: List of modules of the variational posterior. This is only
                required for user-defined q i.e. if q does not have a `modules`
                attribute.
            num_transforms: Number of transforms in the normalizing flow. Used for
                both single-x VI (when q is a string flow type) and amortized VI.
            hidden_features: Hidden layer size in flow networks. Used for both
                single-x VI and amortized VI.
            z_score_theta: Method for z-scoring θ (parameters). One of "none",
                "independent", "structured". Used for both single-x VI and amortized
                VI. Use "structured" for parameters with correlations.
            z_score_x: Method for z-scoring x (conditioning observation). One of
                "none", "independent", "structured". Only used for amortized VI
                (train_amortized). Use "structured" for structured data like images.
        """
        super().__init__(potential_fn, theta_transform, device, x_shape=x_shape)

        # Especially the prior may be on another device -> move it...
        self._device = device
        self.theta_transform = theta_transform
        self.x_shape = x_shape
        self.potential_fn.device = device
        self.potential_fn.to(device)

        # Get prior and previous builds
        if prior is not None:
            self._prior = prior
        elif hasattr(self.potential_fn, "prior") and isinstance(
            self.potential_fn.prior, Distribution
        ):
            self._prior = self.potential_fn.prior
        elif isinstance(q, VIPosterior) and isinstance(q._prior, Distribution):
            self._prior = q._prior
        else:
            raise ValueError(
                "We could not find a suitable prior distribution within `potential_fn` "
                "or `q` (if a VIPosterior is given). Please explicitly specify a prior."
            )

        self._prior = move_distribution_to_device(self._prior, device)
        self._optimizer = None

        # Mode tracking: None (not trained), "single_x", or "amortized"
        self._mode: Optional[Literal["single_x", "amortized"]] = None

        # Amortized mode: conditional flow q(θ|x)
        self._amortized_q: Optional[ConditionalDensityEstimator] = None

        self._num_transforms: int = num_transforms
        self._hidden_features: int = hidden_features
        self._z_score_theta: Literal["none", "independent", "structured"] = (
            z_score_theta
        )
        self._z_score_x: Literal["none", "independent", "structured"] = z_score_x

        # In contrast to MCMC we want to project into constrained space.
        if theta_transform is None:
            self.link_transform = mcmc_transform(self._prior, device=device).inv
        else:
            self.link_transform = theta_transform.inv

        if parameters is None:
            parameters = []
        if modules is None:
            modules = []
        # This will set the variational distribution and VI method
        self.set_q(
            q,
            parameters=parameters,
            modules=modules,
        )
        self.set_vi_method(vi_method)

        self._purpose = (
            "It provides Variational inference to .sample() from the posterior and "
            "can evaluate the _normalized_ posterior density with .log_prob()."
        )

    def to(self, device: Union[str, torch.device]) -> "VIPosterior":
        """Move all components to the given device.

        Args:
            device: The device to move the posterior to.

        Returns:
            self for method chaining.
        """
        self._device = device

        # Move potential (which moves prior, x_o, and estimator).
        self.potential_fn.to(device)  # type: ignore
        self._prior = move_distribution_to_device(self._prior, device)

        # Rebuild link_transform on new device (same logic as __init__).
        if self.theta_transform is None:
            self.link_transform = mcmc_transform(self._prior, device=device).inv
        else:
            self.link_transform = self.theta_transform.inv

        # Move cached tensors.
        if self._x is not None:
            self._x = self._x.to(device)
        if self._map is not None:
            self._map = self._map.to(device)
        if self._trained_on is not None:
            self._trained_on = self._trained_on.to(device)

        # Move variational distributions.
        if hasattr(self, "_q") and hasattr(self._q, "to"):
            self._q.to(device)  # type: ignore[union-attr]
        # Update link_transform reference on q if it caches one.
        if hasattr(self, "_q") and hasattr(self._q, "_link_transform"):
            self._q._link_transform = self.link_transform  # type: ignore[union-attr]
        if self._amortized_q is not None:
            self._amortized_q.to(device)

        return self

    def _build_unconditional_flow(
        self,
        flow_type: str,
        num_transforms: Optional[int] = None,
        hidden_features: Optional[int] = None,
        z_score_theta: Optional[Literal["none", "independent", "structured"]] = None,
    ) -> TransformedZukoFlow:
        """Build a Zuko unconditional flow for variational inference.

        The flow is wrapped with TransformedZukoFlow to handle the transformation
        between unconstrained (flow) space and constrained (prior) space. This ensures
        that samples from the flow match the prior's support and log_prob accounts
        for the Jacobian of the transformation.

        Args:
            flow_type: Type of flow, one of ["maf", "nsf", "naf", "unaf", "nice",
                "sospf", "gf"]. For "gaussian"/"gaussian_diag", use LearnableGaussian.
            num_transforms: Number of flow transforms. If None, uses instance default.
            hidden_features: Number of hidden features per layer. If None, uses
                instance default.
            z_score_theta: Method for z-scoring theta (parameters). One of
                "independent", "structured", or "none". If None, uses instance default.
                Use "structured" for parameters with correlations.

        Returns:
            TransformedZukoFlow: The constructed flow wrapped with link_transform.

        Raises:
            ValueError: If flow_type is not supported.
        """
        # Fall back to instance attributes
        if num_transforms is None:
            num_transforms = self._num_transforms
        if hidden_features is None:
            hidden_features = self._hidden_features
        if z_score_theta is None:
            z_score_theta = self._z_score_theta

        if flow_type not in _ZUKO_FLOW_TYPES:
            raise ValueError(
                f"Unknown flow type '{flow_type}'. "
                f"Supported types: {sorted(_ZUKO_FLOW_TYPES)} + "
                f"['gaussian', 'gaussian_diag']."
            )

        zuko_flow_type = flow_type.upper()

        # Get prior dimensionality
        prior_dim = self._prior.event_shape[0] if self._prior.event_shape else 1

        # Warn about 1D limitation for autoregressive flows (GF excluded: uses mixtures)
        if prior_dim == 1 and flow_type != "gf":
            warnings.warn(
                f"Using {flow_type.upper()} flow for 1D parameter space. "
                f"Autoregressive normalizing flows may be unstable for 1D VI "
                f"optimization. Consider using q='gaussian' or q='gf' for 1D.",
                UserWarning,
                stacklevel=3,
            )

        # Sample from prior to get batch for dimensionality inference and z-scoring
        # We apply link_transform.inv to map constrained prior samples to unconstrained
        # space (link_transform.forward maps unconstrained -> constrained)
        with torch.no_grad():
            prior_samples = self._prior.sample((1000,))
            batch_theta = self.link_transform.inv(prior_samples)
            assert isinstance(batch_theta, Tensor)  # Type narrowing for pyright

        flow = build_zuko_unconditional_flow(
            which_nf=zuko_flow_type,
            batch_x=batch_theta,
            z_score_x=z_score_theta,  # theta z-scoring passed to Zuko's x parameter
            hidden_features=hidden_features,
            num_transforms=num_transforms,
        )

        # Wrap flow with link_transform to ensure samples are in constrained space
        # The flow operates in unconstrained space, but we want samples/log_probs
        # in constrained space (matching the prior's support)
        transformed_flow = TransformedZukoFlow(
            flow=flow.to(self._device),
            link_transform=self.link_transform,
        )

        return transformed_flow.to(self._device)

    def _build_conditional_flow(
        self,
        theta: Tensor,
        x: Tensor,
        flow_type: Union[ZukoFlowType, str] = ZukoFlowType.NSF,
        num_transforms: int = 2,
        hidden_features: int = 32,
        z_score_theta: Literal["none", "independent", "structured"] = "independent",
        z_score_x: Literal["none", "independent", "structured"] = "independent",
    ) -> ConditionalDensityEstimator:
        """Build a conditional Zuko flow for amortized variational inference.

        Args:
            theta: Sample of θ values for z-scoring (batch_size, θ_dim).
            x: Sample of x values for z-scoring (batch_size, x_dim).
            flow_type: Type of flow. Can be a ZukoFlowType enum or string.
            num_transforms: Number of flow transforms.
            hidden_features: Number of hidden features per layer.
            z_score_theta: Method for z-scoring θ (the parameters being modeled).
                One of "none", "independent", "structured".
            z_score_x: Method for z-scoring x (the conditioning variable).
                One of "none", "independent", "structured". Use "structured" for
                structured data like images.

        Returns:
            ConditionalDensityEstimator: The constructed conditional flow q(θ|x).

        Raises:
            ValueError: If flow_type is not supported.
        """
        # Convert string to ZukoFlowType if needed
        if isinstance(flow_type, str):
            try:
                flow_type = ZukoFlowType[flow_type.upper()]
            except KeyError as e:
                raise ValueError(
                    f"Unknown flow type '{flow_type}'. "
                    f"Supported types: {[t.name for t in ZukoFlowType]}."
                ) from e

        return build_zuko_flow(
            flow_type.value.upper(),
            batch_x=theta,  # θ is what we model
            batch_y=x,  # x is the condition
            z_score_x=z_score_theta,  # z-score for θ (naming mismatch)
            z_score_y=z_score_x,  # z-score for x condition
            num_transforms=num_transforms,
            hidden_features=hidden_features,
        ).to(self._device)

    @property
    def q(
        self,
    ) -> Union[
        Distribution, ZukoUnconditionalFlow, TransformedZukoFlow, LearnableGaussian
    ]:
        """Returns the variational posterior."""
        return self._q

    @q.setter
    def q(self, q: QType) -> None:
        """Sets the variational distribution.

        If the distribution does not admit access through `parameters` and `modules`
        function, please use `set_q` to explicitly specify the parameters and modules.
        """
        self.set_q(q)

    def set_q(
        self,
        q: QType,
        parameters: Optional[Iterable] = None,
        modules: Optional[Iterable] = None,
    ) -> None:
        """Defines the variational family.

        You can specify over which parameters/modules we optimize. This is required for
        custom distributions which e.g. do not inherit nn.Modules or has the function
        `parameters` or `modules` to give direct access to trainable parameters.
        Further, you can pass a function, which constructs a variational distribution
        if called.

        Args:
            q: Variational distribution, either string, distribution, or a VIPosterior
                object. This specifies a parametric class of distribution over which
                the best possible posterior approximation is searched. For string input,
                we support normalizing flows [maf, nsf, naf, unaf, nice, sospf] via
                Zuko, and simple Gaussian families [gaussian, gaussian_diag] via pure
                PyTorch. You can also specify your own variational family by passing a
                `parameterized` distribution object i.e. a torch.distributions
                Distribution with methods `parameters` returning an iterable of all
                parameters (you can pass them within the parameters/modules attribute).
                Additionally, we allow a `Callable` with signature
                `(event_shape: torch.Size, link_transform: TorchTransform, device: str)
                -> Distribution`, which builds a custom distribution. If q is already
                a `VIPosterior`, then the arguments will be copied from it (relevant
                for multi-round training).

                Note: For 1D parameter spaces, autoregressive normalizing flows
                may be unstable. Consider using `q='gaussian'` or `q='gf'` for 1D.
            parameters: List of parameters associated with the distribution object.
            modules: List of modules associated with the distribution object.

        """
        if parameters is None:
            parameters = []
        if modules is None:
            modules = []
        self._q_arg = (q, parameters, modules)
        _flow_types = (ZukoUnconditionalFlow, TransformedZukoFlow, LearnableGaussian)
        if isinstance(q, _flow_types):
            # Flow/Gaussian passed directly (e.g., from _q_build_fn during retrain)
            make_object_deepcopy_compatible(q)
            self._trained_on = None
        elif isinstance(q, Distribution):
            q = adapt_variational_distribution(
                q,
                self._prior,
                self.link_transform,
                parameters=parameters,
                modules=modules,
            )
            make_object_deepcopy_compatible(q)
            self_custom_q_init_cache = deepcopy(q)
            self._q_build_fn = lambda *args, **kwargs: self_custom_q_init_cache
            self._trained_on = None
            self._zuko_flow_type = None
        elif isinstance(q, (str, Callable)):
            if isinstance(q, str):
                if q in _ZUKO_FLOW_TYPES:
                    q_flow = self._build_unconditional_flow(q)
                    self._zuko_flow_type = q
                    self._q_build_fn = lambda *args, ft=q, **kwargs: (
                        self._build_unconditional_flow(ft)
                    )
                    q = q_flow
                elif q in ("gaussian", "gaussian_diag"):
                    self._zuko_flow_type = None
                    full_cov = q == "gaussian"
                    dim = self._prior.event_shape[0]
                    q_dist = LearnableGaussian(
                        dim=dim,
                        full_covariance=full_cov,
                        link_transform=self.link_transform,
                        device=self._device,
                    )
                    self._q_build_fn = lambda *args, fc=full_cov, d=dim, **kwargs: (
                        LearnableGaussian(
                            dim=d,
                            full_covariance=fc,
                            link_transform=self.link_transform,
                            device=self._device,
                        )
                    )
                    q = q_dist
                else:
                    supported = sorted(_ZUKO_FLOW_TYPES) + ["gaussian", "gaussian_diag"]
                    raise ValueError(
                        f"Unknown variational family '{q}'. "
                        f"Supported options: {supported}"
                    )
            else:
                # Callable provided - use as-is
                self._zuko_flow_type = None
                self._q_build_fn = q
                q = self._q_build_fn(
                    self._prior.event_shape,
                    self.link_transform,
                    device=self._device,
                )
            make_object_deepcopy_compatible(q)
            self._trained_on = None
        elif isinstance(q, VIPosterior):
            self._q_build_fn = q._q_build_fn
            self._trained_on = q._trained_on
            self._mode = getattr(q, "_mode", None)  # Copy mode from source
            self._zuko_flow_type = getattr(q, "_zuko_flow_type", None)
            self.vi_method = q.vi_method  # type: ignore
            self._prior = q._prior
            self._x = q._x
            self._q_arg = q._q_arg
            make_object_deepcopy_compatible(q.q)
            q = deepcopy(q.q)
            # Move copied q to self's device (source may be on a different device).
            if hasattr(q, "to"):
                q.to(self._device)  # type: ignore[union-attr]
        # Validate the variational distribution
        if isinstance(q, _flow_types):
            pass  # These are validated during construction
        elif isinstance(q, Distribution):
            check_variational_distribution(q, self._prior)
        else:
            raise ValueError(
                f"Variational distribution must be a Distribution, got {type(q)}. "
                "Please create an issue on github https://github.com/mackelab/sbi/issues"
            )
        self._q = q

    @property
    def vi_method(self) -> str:
        """Variational inference method e.g. one of [rKL, fKL, IW, alpha]."""
        return self._vi_method

    @vi_method.setter
    def vi_method(self, method: str) -> None:
        """See `set_vi_method`."""
        self.set_vi_method(method)

    def set_vi_method(self, method: str) -> "VIPosterior":
        """Sets variational inference method.

        Args:
            method: One of [rKL, fKL, IW, alpha].

        Returns:
            `VIPosterior` for chainable calls.
        """
        self._vi_method = method
        self._optimizer_builder = get_VI_method(method)
        return self

    def sample(
        self,
        sample_shape: Shape = torch.Size(),
        x: Optional[Tensor] = None,
        show_progress_bars: bool = True,
    ) -> Tensor:
        r"""Draw samples from the variational posterior distribution $p(\theta|x)$.

        For single-x mode (trained via `train()`): samples from q(θ) trained on x_o.
        For amortized mode (trained via `train_amortized()`): samples from q(θ|x).

        Args:
            sample_shape: Desired shape of samples that are drawn from the posterior.
            x: Conditioning observation. In single-x mode, must match trained x_o
                (or be None to use default). In amortized mode, required and can be
                any observation. For batched observations, shape should be
                (batch_size, x_dim).
            show_progress_bars: Unused for `VIPosterior` since sampling from the
                variational distribution is fast. Included for API consistency.

        Returns:
            Samples from posterior with shape (*sample_shape, θ_dim) for single x,
            or (*sample_shape, batch_size, θ_dim) for batched observations in
            amortized mode.

        Raises:
            ValueError: If mode requirements are not met.
        """
        if self._mode == "amortized":
            # Amortized mode: sample from conditional flow q(θ|x)
            x = self._x_else_default_x(x)
            if x is None:
                raise ValueError(
                    "x is required for amortized mode. Provide an observation or "
                    "set a default x with set_default_x()."
                )
            x = atleast_2d_float32_tensor(x).to(self._device)
            assert self._amortized_q is not None
            # samples shape from flow: (*sample_shape, batch_size, θ_dim)
            samples = self._amortized_q.sample(torch.Size(sample_shape), condition=x)
            # Match base posterior behavior: drop singleton x batch dimension
            if x.shape[0] == 1:
                samples = samples.squeeze(-2)
            return samples
        else:
            # Single-x mode: sample from unconditional flow q(θ)
            x = self._x_else_default_x(x)
            if self._trained_on is None or (x != self._trained_on).any():
                raise ValueError(
                    f"The variational posterior was not fit on the specified "
                    f"observation {x}. Please train using posterior.train()."
                )
            samples = self.q.sample(torch.Size(sample_shape))
            return samples.reshape((*sample_shape, samples.shape[-1]))

    def sample_batched(
        self,
        sample_shape: Shape,
        x: Tensor,
        max_sampling_batch_size: int = 10000,
        show_progress_bars: bool = True,
    ) -> Tensor:
        """Sample from posterior for a batch of observations.

        In amortized mode, this is efficient as all x values are processed in
        parallel through the conditional flow.

        In single-x mode, this raises NotImplementedError since the unconditional
        flow is trained for a specific x_o.

        Args:
            sample_shape: Number of samples per observation.
            x: Batch of observations (num_obs, x_dim).
            max_sampling_batch_size: Unused for amortized mode (no batching needed).
            show_progress_bars: Unused for amortized mode.

        Returns:
            Samples of shape (*sample_shape, num_obs, θ_dim).

        Raises:
            NotImplementedError: If called in single-x mode.
        """
        if self._mode == "amortized":
            # In amortized mode, sample() handles batched x directly
            return self.sample(sample_shape, x=x, show_progress_bars=show_progress_bars)
        else:
            raise NotImplementedError(
                "Batched sampling is not implemented for single-x VI mode. "
                "Use train_amortized() to train an amortized posterior, or "
                "call sample() in a loop: [posterior.sample(shape, x_o) for x_o in x]."
            )

    def log_prob(
        self,
        theta: Tensor,
        x: Optional[Tensor] = None,
        track_gradients: bool = False,
    ) -> Tensor:
        r"""Returns the log-probability of theta under the variational posterior.

        For single-x mode: returns log q(θ).
        For amortized mode: returns log q(θ|x).

        Args:
            theta: Parameters to evaluate, shape (batch_theta, θ_dim).
            x: Observation. In single-x mode, must match trained x_o (or be None).
                In amortized mode, required and can be any observation.
                For single x, shape (1, x_dim) or (x_dim,).
                For batched x, shape (batch_x, x_dim).
            track_gradients: Whether the returned tensor supports tracking gradients.
                This can be helpful for e.g. sensitivity analysis but increases memory
                consumption.

        Returns:
            Log-probability of shape (batch,) where batch is:
            - batch_theta if x has batch size 1 (broadcast x)
            - batch_x if theta has batch size 1 (broadcast theta)
            - batch_theta if batch_theta == batch_x (paired evaluation)

        Raises:
            ValueError: If mode requirements are not met or batch sizes incompatible.
        """
        with torch.set_grad_enabled(track_gradients):
            theta = ensure_theta_batched(torch.as_tensor(theta)).to(self._device)

            if self._mode == "amortized":
                # Amortized mode: evaluate log q(θ|x)
                x = self._x_else_default_x(x)
                if x is None:
                    raise ValueError(
                        "x is required for amortized mode. Provide an observation or "
                        "set a default x with set_default_x()."
                    )
                x = atleast_2d_float32_tensor(x).to(self._device)
                assert self._amortized_q is not None

                # Handle broadcasting between theta and x
                batch_theta = theta.shape[0]
                batch_x = x.shape[0]

                if batch_theta != batch_x:
                    if batch_x == 1:
                        # Broadcast x to match theta
                        x = x.expand(batch_theta, -1)
                    elif batch_theta == 1:
                        # Broadcast theta to match x
                        theta = theta.expand(batch_x, -1)
                    else:
                        raise ValueError(
                            f"Batch sizes of theta ({batch_theta}) and x ({batch_x}) "
                            f"are incompatible. They must be equal, or one must be 1."
                        )

                # ZukoFlow expects input shape (sample_dim, batch_dim, *event_shape)
                # Add sample dimension, compute log_prob, then squeeze back
                theta_with_sample_dim = theta.unsqueeze(0)
                log_probs = self._amortized_q.log_prob(
                    theta_with_sample_dim, condition=x
                )
                return log_probs.squeeze(0)
            else:
                # Single-x mode: evaluate log q(θ)
                x = self._x_else_default_x(x)
                if self._trained_on is None or (x != self._trained_on).any():
                    raise ValueError(
                        f"The variational posterior was not fit on the specified "
                        f"observation {x}. Please train using posterior.train()."
                    )
                return self.q.log_prob(theta)

    def train(
        self,
        x: Optional[TorchTensor] = None,
        n_particles: int = 256,
        learning_rate: float = 1e-3,
        gamma: float = 0.999,
        max_num_iters: int = 2000,
        min_num_iters: int = 10,
        clip_value: float = 10.0,
        warm_up_rounds: int = 100,
        retrain_from_scratch: bool = False,
        reset_optimizer: bool = False,
        show_progress_bar: bool = True,
        check_for_convergence: bool = True,
        quality_control: bool = True,
        quality_control_metric: str = "psis",
        **kwargs,
    ) -> "VIPosterior":
        """This method trains the variational posterior for a single observation.

        Args:
            x: The observation, optional, defaults to self._x.
            n_particles: Number of samples to approximate expectations within the
                variational bounds. The larger the more accurate are gradient
                estimates, but the computational cost per iteration increases.
            learning_rate: Learning rate of the optimizer.
            gamma: Learning rate decay per iteration. We use an exponential decay
                scheduler.
            max_num_iters: Maximum number of iterations.
            min_num_iters: Minimum number of iterations.
            clip_value: Gradient clipping value, decreasing may help if you see invalid
                values.
            warm_up_rounds: Initialize the posterior as the prior.
            retrain_from_scratch: Retrain the variational distributions from scratch.
            reset_optimizer: Reset the divergence optimizer
            show_progress_bar: If any progress report should be displayed.
            quality_control: If False quality control is skipped.
            quality_control_metric: Which metric to use for evaluating the quality.
            kwargs: Hyperparameters check corresponding `DivergenceOptimizer` for detail
                eps: Determines sensitivity of convergence check.
                retain_graph: Boolean which decides whether to retain the computation
                    graph. This may be required for some `exotic` user-specified q's.
                optimizer: A PyTorch Optimizer class e.g. Adam or SGD. See
                    `DivergenceOptimizer` for details.
                scheduler: A PyTorch learning rate scheduler. See
                    `DivergenceOptimizer` for details.
                alpha: Only used if vi_method=`alpha`. Determines the alpha divergence.
                K: Only used if vi_method=`IW`. Determines the number of importance
                    weighted particles.
                stick_the_landing: If one should use the STL estimator (only for rKL,
                    IW, alpha).
                dreg: If one should use the DREG estimator (only for rKL, IW, alpha).
                weight_transform: Callable applied to importance weights (only for fKL)
        Returns:
            VIPosterior: `VIPosterior` (can be used to chain calls).

        Raises:
            ValueError: If hyperparameters are invalid.
        """
        # Validate hyperparameters
        if n_particles <= 0:
            raise ValueError(f"n_particles must be positive, got {n_particles}")
        if learning_rate <= 0:
            raise ValueError(f"learning_rate must be positive, got {learning_rate}")
        if not 0 < gamma <= 1:
            raise ValueError(f"gamma must be in (0, 1], got {gamma}")
        if max_num_iters <= 0:
            raise ValueError(f"max_num_iters must be positive, got {max_num_iters}")
        if min_num_iters < 0:
            raise ValueError(f"min_num_iters must be non-negative, got {min_num_iters}")
        if clip_value <= 0:
            raise ValueError(f"clip_value must be positive, got {clip_value}")

        # Update optimizer with current arguments.
        if self._optimizer is not None:
            self._optimizer.update({**locals(), **kwargs})

        # Init q and the optimizer if necessary
        if retrain_from_scratch:
            self.q = self._q_build_fn()  # type: ignore
            self._optimizer = self._optimizer_builder(
                self.potential_fn,
                self.q,
                lr=learning_rate,
                clip_value=clip_value,
                gamma=gamma,
                n_particles=n_particles,
                prior=self._prior,
                **kwargs,
            )

        if (
            reset_optimizer
            or self._optimizer is None
            or not isinstance(self._optimizer, self._optimizer_builder)
        ):
            self._optimizer = self._optimizer_builder(
                self.potential_fn,
                self.q,
                lr=learning_rate,
                clip_value=clip_value,
                gamma=gamma,
                n_particles=n_particles,
                prior=self._prior,
                **kwargs,
            )

        # Check context
        x = atleast_2d_float32_tensor(self._x_else_default_x(x)).to(  # type: ignore
            self._device
        )
        if not torch.isfinite(x).all():
            raise ValueError("x contains NaN or Inf values.")

        already_trained = self._trained_on is not None and (x == self._trained_on).all()

        # Optimize
        optimizer = self._optimizer
        optimizer.to(self._device)
        optimizer.reset_loss_stats()

        if show_progress_bar:
            iters = tqdm(range(max_num_iters))
        else:
            iters = range(max_num_iters)

        # Warmup before training
        if reset_optimizer or (not optimizer.warm_up_was_done and not already_trained):
            if show_progress_bar:
                iters.set_description(  # type: ignore
                    "Warmup phase, this may take a few seconds..."
                )
            optimizer.warm_up(warm_up_rounds)

        for i in iters:
            optimizer.step(x)
            mean_loss, std_loss = optimizer.get_loss_stats()
            # Update progress bar
            if show_progress_bar:
                assert isinstance(iters, tqdm)
                iters.set_description(  # type: ignore
                    f"Loss: {np.round(float(mean_loss), 2)}, "
                    f"Std: {np.round(float(std_loss), 2)}"
                )
            # Check for convergence
            if check_for_convergence and i > min_num_iters and optimizer.converged():
                if show_progress_bar:
                    print(f"\nConverged with loss: {np.round(float(mean_loss), 2)}")
                break
        # Training finished:
        self._trained_on = x
        if self._mode == "amortized":
            warnings.warn(
                "Switching from amortized to single-x mode. "
                "The previously trained amortized model will be discarded.",
                UserWarning,
                stacklevel=2,
            )
            self._amortized_q = None
        self._mode = "single_x"

        # Evaluate quality
        if quality_control:
            try:
                self.evaluate(quality_control_metric=quality_control_metric)
            except Exception as e:
                print(
                    f"Quality control showed a low quality of the variational "
                    f"posterior. We are automatically retraining the variational "
                    f"posterior from scratch with a smaller learning rate. "
                    f"Alternatively, if you want to skip quality control, please "
                    f"retrain with `VIPosterior.train(..., quality_control=False)`. "
                    f"\nThe error that occured is: {e}"
                )
                self.train(
                    learning_rate=learning_rate * 0.1,
                    retrain_from_scratch=True,
                    reset_optimizer=True,
                )

        return self

    def train_amortized(
        self,
        theta: Tensor,
        x: Tensor,
        n_particles: int = 128,
        learning_rate: float = 1e-3,
        gamma: float = 0.999,
        max_num_iters: int = 500,
        clip_value: float = 5.0,
        batch_size: int = 64,
        validation_fraction: float = 0.1,
        validation_batch_size: Optional[int] = None,
        validation_n_particles: Optional[int] = None,
        stop_after_iters: int = 20,
        show_progress_bar: bool = True,
        retrain_from_scratch: bool = False,
        flow_type: Optional[Union[ZukoFlowType, str]] = None,
        num_transforms: Optional[int] = None,
        hidden_features: Optional[int] = None,
        z_score_theta: Optional[Literal["none", "independent", "structured"]] = None,
        z_score_x: Optional[Literal["none", "independent", "structured"]] = None,
        params: Optional["VIPosteriorParameters"] = None,
    ) -> "VIPosterior":
        """Train a conditional flow q(θ|x) for amortized variational inference.

        This allows sampling from q(θ|x) for any observation x without retraining.
        Uses the ELBO (Evidence Lower Bound) objective with early stopping based on
        validation loss.

        Args:
            theta: Training θ values from simulations (num_sims, θ_dim).
            x: Training x values from simulations (num_sims, x_dim).
            n_particles: Number of samples to estimate ELBO per x.
            learning_rate: Learning rate for Adam optimizer.
            gamma: Learning rate decay per iteration.
            max_num_iters: Maximum training iterations.
            clip_value: Gradient clipping threshold.
            batch_size: Number of x values per training batch.
            validation_fraction: Fraction of data to use for validation.
            validation_batch_size: Batch size for validation loss. Defaults to
                `batch_size`.
            validation_n_particles: Number of particles for validation loss.
                Defaults to `n_particles`.
            stop_after_iters: Stop training after this many iterations without
                improvement in validation loss.
            show_progress_bar: Whether to show progress.
            retrain_from_scratch: If True, rebuild the flow from scratch.
            flow_type: Flow architecture for the variational distribution.
                Use ZukoFlowType.NSF, ZukoFlowType.MAF, etc., or a string.
                If None, uses value from params or instance default.
            num_transforms: Number of transforms in the flow. If None, uses value
                from params or instance default.
            hidden_features: Hidden layer size in the flow. If None, uses value
                from params or instance default.
            z_score_theta: Method for z-scoring θ (the parameters being modeled).
                One of "none", "independent", "structured". If None, uses value
                from params or instance default.
            z_score_x: Method for z-scoring x (the conditioning variable).
                One of "none", "independent", "structured". Use "structured" for
                structured data like images with spatial correlations. If None,
                uses value from params or instance default.
            params: Optional VIPosteriorParameters dataclass. Values are used as
                fallbacks when explicit arguments are None. Priority order:
                explicit args > params > instance attributes (from __init__).

        Returns:
            self for method chaining.
        """
        # Resolve parameters: explicit args > params dataclass > instance attrs
        if params is not None:
            # Amortized VI only supports string flow types (not VIPosterior or Callable)
            if not isinstance(params.q, str):
                raise ValueError(
                    "train_amortized() only supports string flow types "
                    f"(e.g., 'nsf', 'maf'), not {type(params.q).__name__}. "
                    "Use set_q() to pass custom distributions for single-x VI."
                )
            if flow_type is None:
                flow_type = params.q
            if num_transforms is None:
                num_transforms = params.num_transforms
            if hidden_features is None:
                hidden_features = params.hidden_features
            if z_score_theta is None:
                z_score_theta = params.z_score_theta
            if z_score_x is None:
                z_score_x = params.z_score_x

        # Fall back to instance attributes (set in __init__ from VIPosteriorParameters)
        if flow_type is None:
            flow_type = ZukoFlowType.NSF
        if num_transforms is None:
            num_transforms = self._num_transforms
        if hidden_features is None:
            hidden_features = self._hidden_features
        if z_score_theta is None:
            z_score_theta = self._z_score_theta
        if z_score_x is None:
            z_score_x = self._z_score_x

        theta = atleast_2d_float32_tensor(theta).to(self._device)
        x = atleast_2d_float32_tensor(x).to(self._device)

        # Validate inputs
        if theta.shape[0] != x.shape[0]:
            raise ValueError(
                f"Batch size mismatch: theta has {theta.shape[0]} samples, "
                f"x has {x.shape[0]} samples. They must match."
            )
        if len(theta) == 0:
            raise ValueError("Training data cannot be empty.")
        if not torch.isfinite(theta).all():
            raise ValueError("theta contains NaN or Inf values.")
        if not torch.isfinite(x).all():
            raise ValueError("x contains NaN or Inf values.")

        # Validate theta dimension matches prior
        prior_event_shape = self._prior.event_shape
        if len(prior_event_shape) > 0:
            expected_theta_dim = prior_event_shape[0]
            if theta.shape[1] != expected_theta_dim:
                raise ValueError(
                    f"theta dimension {theta.shape[1]} does not match prior "
                    f"event shape {expected_theta_dim}."
                )

        # Validate hyperparameters
        if not 0 < validation_fraction < 1:
            raise ValueError(
                f"validation_fraction must be in (0, 1), got {validation_fraction}"
            )
        if n_particles <= 0:
            raise ValueError(f"n_particles must be positive, got {n_particles}")
        if batch_size <= 0:
            raise ValueError(f"batch_size must be positive, got {batch_size}")

        # Validate flow_type early to fail fast
        if isinstance(flow_type, str):
            try:
                flow_type = ZukoFlowType[flow_type.upper()]
            except KeyError:
                raise ValueError(
                    f"Unknown flow type '{flow_type}'. "
                    f"Supported types: {[t.name for t in ZukoFlowType]}."
                ) from None

        if validation_batch_size is None:
            validation_batch_size = batch_size
        if validation_n_particles is None:
            validation_n_particles = n_particles

        if validation_batch_size <= 0:
            raise ValueError(
                f"validation_batch_size must be positive, got {validation_batch_size}"
            )
        if validation_n_particles <= 0:
            raise ValueError(
                f"validation_n_particles must be positive, got {validation_n_particles}"
            )

        # Split into training and validation sets
        num_examples = len(theta)
        num_val = int(validation_fraction * num_examples)
        num_train = num_examples - num_val

        if num_val == 0:
            raise ValueError(
                "Validation set is empty. Increase validation_fraction or provide more "
                "training data."
            )
        if num_train < batch_size:
            raise ValueError(
                f"Training set size ({num_train}) is smaller than batch_size "
                f"({batch_size}). Reduce validation_fraction or batch_size."
            )

        permuted_indices = torch.randperm(num_examples, device=self._device)
        train_indices = permuted_indices[:num_train]
        val_indices = permuted_indices[num_train:]

        theta_train, x_train = theta[train_indices], x[train_indices]
        x_val = x[val_indices]  # Only x needed for validation (θ sampled from q)

        use_val_subset = validation_batch_size < x_val.shape[0]

        # Build or rebuild the conditional flow (z-score on training data only)
        if self._amortized_q is None or retrain_from_scratch:
            self._amortized_q = self._build_conditional_flow(
                theta_train,
                x_train,
                flow_type=flow_type,
                num_transforms=num_transforms,
                hidden_features=hidden_features,
                z_score_theta=z_score_theta,
                z_score_x=z_score_x,
            )

        # Ensure potential_fn is on the correct device for amortized training
        self.potential_fn.to(self._device)

        # Setup optimizer
        optimizer = Adam(self._amortized_q.parameters(), lr=learning_rate)
        scheduler = ExponentialLR(optimizer, gamma=gamma)

        # Training loop with validation-based early stopping
        best_val_loss = float("inf")
        iters_since_improvement = 0
        best_state_dict = deepcopy(self._amortized_q.state_dict())

        if show_progress_bar:
            iters = tqdm(range(max_num_iters), desc="Amortized VI (ELBO)")
        else:
            iters = range(max_num_iters)

        for iteration in iters:
            # Training step
            self._amortized_q.train()
            optimizer.zero_grad()

            # Sample batch from training set
            idx = torch.randint(0, num_train, (batch_size,), device=self._device)
            x_batch = x_train[idx]

            train_loss = self._compute_amortized_elbo_loss(x_batch, n_particles)

            if not torch.isfinite(train_loss):
                raise RuntimeError(
                    f"Training loss became non-finite at iteration {iteration}: "
                    f"{train_loss.item()}. This indicates numerical instability. Try:\n"
                    f"  - Reducing learning_rate (currently {learning_rate})\n"
                    f"  - Reducing n_particles (currently {n_particles})\n"
                    f"  - Checking your potential_fn for numerical issues"
                )

            train_loss.backward()
            nn.utils.clip_grad_norm_(self._amortized_q.parameters(), clip_value)
            optimizer.step()
            scheduler.step()

            # Compute validation loss
            self._amortized_q.eval()
            with torch.no_grad():
                if use_val_subset:
                    val_idx = torch.randperm(x_val.shape[0], device=self._device)[
                        :validation_batch_size
                    ]
                    x_val_batch = x_val[val_idx]
                else:
                    x_val_batch = x_val
                val_loss = self._compute_amortized_elbo_loss(
                    x_val_batch, validation_n_particles
                ).item()

            # Check for improvement
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                iters_since_improvement = 0
                best_state_dict = deepcopy(self._amortized_q.state_dict())
            else:
                iters_since_improvement += 1

            if show_progress_bar:
                assert isinstance(iters, tqdm)
                iters.set_postfix({
                    "train": f"{train_loss.item():.3f}",
                    "val": f"{val_loss:.3f}",
                })

            # Early stopping
            if iters_since_improvement >= stop_after_iters:
                if show_progress_bar:
                    print(f"\nConverged at iteration {iteration}")
                break

        # Restore best model
        self._amortized_q.load_state_dict(best_state_dict)
        self._amortized_q.eval()
        if self._mode == "single_x":
            warnings.warn(
                "Switching from single-x to amortized mode. "
                "The previously trained single-x model will not be usable.",
                UserWarning,
                stacklevel=2,
            )
        self._mode = "amortized"

        return self

    def _compute_amortized_elbo_loss(self, x_batch: Tensor, n_particles: int) -> Tensor:
        """Compute negative ELBO loss for a batch of x values.

        Args:
            x_batch: Batch of observations (batch_size, x_dim).
            n_particles: Number of θ samples per x.

        Returns:
            Negative ELBO (scalar tensor).
        """
        assert self._amortized_q is not None, "q must be built before computing ELBO"
        batch_size = x_batch.shape[0]

        # Reparameterized samples from q(θ|x) with their log probabilities
        # theta_samples shape: (n_particles, batch_size, θ_dim)
        # log_q shape: (n_particles, batch_size)
        theta_samples, log_q = self._amortized_q.sample_and_log_prob(
            torch.Size((n_particles,)), condition=x_batch
        )

        # Vectorized evaluation of potential log p(θ|x) for all (θ, x) pairs
        # Flatten: (n_particles, batch_size, θ_dim) -> (n_particles * batch_size, θ_dim)
        theta_dim = theta_samples.shape[-1]
        theta_flat = theta_samples.reshape(n_particles * batch_size, theta_dim)

        # Tile x to match: (batch_size, x_dim) -> (n_particles * batch_size, x_dim)
        # Each block of batch_size rows corresponds to one particle.
        x_expanded = x_batch.repeat(n_particles, 1)

        # Set x_o for batched evaluation (x_is_iid=False: each θ paired with its x)
        self.potential_fn.set_x(x_expanded, x_is_iid=False)
        log_potential_flat = self.potential_fn(theta_flat)

        # Reshape: (n_particles * batch_size,) -> (n_particles, batch_size)
        log_potential = log_potential_flat.reshape(n_particles, batch_size)

        # ELBO = E_q[log p(θ|x) - log q(θ|x)]
        elbo = (log_potential - log_q).mean()
        return -elbo

    def evaluate(self, quality_control_metric: str = "psis", N: int = int(5e4)) -> None:
        """This function will evaluate the quality of the variational posterior
        distribution. We currently support two different metrics of type `psis`, which
        checks the quality based on the tails of importance weights (there should not be
        much with a large one), or `prop` which checks the proportionality between q
        and potential_fn.

        NOTE: In our experience `prop` is sensitive to distinguish ``good`` from ``ok``
        whereas `psis` is more sensitive in distinguishing `very bad` from `ok`.

        Args:
            quality_control_metric: The metric of choice, we currently support [psis,
                prop, prop_prior].
            N: Number of samples which is used to evaluate the metric.
        """
        quality_control_fn, quality_control_msg = get_quality_metric(
            quality_control_metric
        )
        metric = round(float(quality_control_fn(self, N=N)), 3)
        print(f"Quality Score: {metric} " + quality_control_msg)

    def map(
        self,
        x: Optional[TorchTensor] = None,
        num_iter: int = 1_000,
        num_to_optimize: int = 100,
        learning_rate: float = 0.01,
        init_method: Union[str, TorchTensor] = "proposal",
        num_init_samples: int = 10_000,
        save_best_every: int = 10,
        show_progress_bars: bool = False,
        force_update: bool = False,
    ) -> Tensor:
        r"""Returns the maximum-a-posteriori estimate (MAP).

        The method can be interrupted (Ctrl-C) when the user sees that the
        log-probability converges. The best estimate will be saved in `self._map` and
        can be accessed with `self.map()`. The MAP is obtained by running gradient
        ascent from a given number of starting positions (samples from the posterior
        with the highest log-probability). After the optimization is done, we select the
        parameter set that has the highest log-probability after the optimization.

        Warning: The default values used by this function are not well-tested. They
        might require hand-tuning for the problem at hand.

        For developers: if the prior is a `BoxUniform`, we carry out the optimization
        in unbounded space and transform the result back into bounded space.

        Args:
            x: Deprecated - use `.set_default_x()` prior to `.map()`.
            num_iter: Number of optimization steps that the algorithm takes
                to find the MAP.
            learning_rate: Learning rate of the optimizer.
            init_method: How to select the starting parameters for the optimization. If
                it is a string, it can be either [`posterior`, `prior`], which samples
                the respective distribution `num_init_samples` times. If it is a
                tensor, the tensor will be used as init locations.
            num_init_samples: Draw this number of samples from the posterior and
                evaluate the log-probability of all of them.
            num_to_optimize: From the drawn `num_init_samples`, use the
                `num_to_optimize` with highest log-probability as the initial points
                for the optimization.
            save_best_every: The best log-probability is computed, saved in the
                `map`-attribute, and printed every `save_best_every`-th iteration.
                Computing the best log-probability creates a significant overhead
                (thus, the default is `10`.)
            show_progress_bars: Whether to show a progressbar during sampling from
                the posterior.
            force_update: Whether to re-calculate the MAP when x is unchanged and
                have a cached value.
            log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
                {'norm_posterior': True} for SNPE.

        Returns:
            The MAP estimate.
        """
        self.proposal = self.q
        return super().map(
            x=x,
            num_iter=num_iter,
            num_to_optimize=num_to_optimize,
            learning_rate=learning_rate,
            init_method=init_method,
            num_init_samples=num_init_samples,
            save_best_every=save_best_every,
            show_progress_bars=show_progress_bars,
            force_update=force_update,
        )

    def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior":
        """This method is called when using `copy.deepcopy` on the object.

        It defines how the object is copied. We need to overwrite this method, since the
        default implementation does use __getstate__ and __setstate__ which we overwrite
        to enable pickling (and in particular the necessary modifications are
        incompatible deep copying).

        Args:
            memo (Optional[Dict], optional): Deep copy internal memo. Defaults to None.

        Returns:
            VIPosterior: Deep copy of the VIPosterior.
        """
        if memo is None:
            memo = {}

        # Create a new instance of the class
        cls = self.__class__
        result = cls.__new__(cls)
        # Add to memo
        memo[id(self)] = result
        # Copy attributes
        for k, v in self.__dict__.items():
            setattr(result, k, copy.deepcopy(v, memo))
        return result

    def __getstate__(self) -> Dict:
        """This method is called when pickling the object.

        It defines what is pickled. We need to overwrite this method, since some parts
        do not support pickle protocols (e.g. due to local functions).

        Returns:
            Dict: All attributes of the VIPosterior.
        """
        self._optimizer = None
        self.__deepcopy__ = None  # type: ignore
        self._q_build_fn = None
        self._q.__deepcopy__ = None  # type: ignore
        state = self.__dict__.copy()
        return state

    def __setstate__(self, state_dict: Dict):
        """This method is called when unpickling the object.

        Especially, we need to restore the removed attributes and ensure that the object
        e.g. remains deep copy compatible.

        Args:
            state_dict: Given state dictionary, we will restore the object from it.
        """
        self.__dict__ = state_dict
        q = deepcopy(self._q)
        # Restore removed attributes
        self.set_q(*self._q_arg)
        self._q = q
        make_object_deepcopy_compatible(self)
        make_object_deepcopy_compatible(self.q)
        # Handle amortized mode
        if self._mode == "amortized" and self._amortized_q is not None:
            make_object_deepcopy_compatible(self._amortized_q)

q property writable

Returns the variational posterior.

vi_method property writable

Variational inference method e.g. one of [rKL, fKL, IW, alpha].

__deepcopy__(memo=None)

This method is called when using copy.deepcopy on the object.

It defines how the object is copied. We need to overwrite this method, since the default implementation does use getstate and setstate which we overwrite to enable pickling (and in particular the necessary modifications are incompatible deep copying).

Parameters:

Name Type Description Default
memo Optional[Dict]

Deep copy internal memo. Defaults to None.

None

Returns:

Name Type Description
VIPosterior VIPosterior

Deep copy of the VIPosterior.

Source code in sbi/inference/posteriors/vi_posterior.py
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior":
    """This method is called when using `copy.deepcopy` on the object.

    It defines how the object is copied. We need to overwrite this method, since the
    default implementation does use __getstate__ and __setstate__ which we overwrite
    to enable pickling (and in particular the necessary modifications are
    incompatible deep copying).

    Args:
        memo (Optional[Dict], optional): Deep copy internal memo. Defaults to None.

    Returns:
        VIPosterior: Deep copy of the VIPosterior.
    """
    if memo is None:
        memo = {}

    # Create a new instance of the class
    cls = self.__class__
    result = cls.__new__(cls)
    # Add to memo
    memo[id(self)] = result
    # Copy attributes
    for k, v in self.__dict__.items():
        setattr(result, k, copy.deepcopy(v, memo))
    return result

__getstate__()

This method is called when pickling the object.

It defines what is pickled. We need to overwrite this method, since some parts do not support pickle protocols (e.g. due to local functions).

Returns:

Name Type Description
Dict Dict

All attributes of the VIPosterior.

Source code in sbi/inference/posteriors/vi_posterior.py
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
def __getstate__(self) -> Dict:
    """This method is called when pickling the object.

    It defines what is pickled. We need to overwrite this method, since some parts
    do not support pickle protocols (e.g. due to local functions).

    Returns:
        Dict: All attributes of the VIPosterior.
    """
    self._optimizer = None
    self.__deepcopy__ = None  # type: ignore
    self._q_build_fn = None
    self._q.__deepcopy__ = None  # type: ignore
    state = self.__dict__.copy()
    return state

__init__(potential_fn, prior=None, q='maf', theta_transform=None, vi_method='rKL', device='cpu', x_shape=None, parameters=None, modules=None, num_transforms=5, hidden_features=50, z_score_theta='independent', z_score_x='independent')

Parameters:

Name Type Description Default
potential_fn Union[BasePotential, CustomPotential]

The potential function from which to draw samples. Must be a BasePotential or a CustomPotential.

required
prior Optional[TorchDistribution]

This is the prior distribution. Note that this is only used to check/construct the variational distribution or within some quality metrics. Please make sure that this matches with the prior within the potential_fn. If None is given, we will try to infer it from potential_fn or q, if this fails we raise an Error.

None
q QType

Variational distribution, either string, Distribution, or a VIPosterior object. This specifies a parametric class of distribution over which the best possible posterior approximation is searched. For string input, we support normalizing flows [maf, nsf, naf, unaf, nice, sospf, gf] via Zuko, and Gaussian families [gaussian, gaussian_diag]. Note: For 1D problems, prefer “gf” (mixture of Gaussians) or “gaussian” as autoregressive flows may be unstable. You can also specify your own variational family by passing a torch.distributions.Distribution. Additionally, we allow a Callable with signature (event_shape: torch.Size, link_transform: TorchTransform, device: str) -> Distribution for custom flow configurations. The callable should return a distribution with sample() and log_prob() methods. If q is already a VIPosterior, then the arguments will be copied from it (relevant for multi-round training).

'maf'
theta_transform Optional[TorchTransform]

Maps form prior support to unconstrained space. The inverse is used here to ensure that the posterior support is equal to that of the prior.

None
vi_method Literal['rKL', 'fKL', 'IW', 'alpha']

This specifies the variational methods which are used to fit q to the posterior. We currently support [rKL, fKL, IW, alpha]. Note that some of the divergences are mode seeking i.e. they underestimate variance and collapse on multimodal targets (rKL, alpha for alpha > 1) and some are mass covering i.e. they overestimate variance but typically cover all modes (fKL, IW, alpha for alpha < 1).

'rKL'
device Union[str, device]

Training device, e.g., cpu, cuda or cuda:0. We will ensure that all other objects are also on this device.

'cpu'
x_shape Optional[Size]

Deprecated, should not be passed.

None
parameters Optional[Iterable]

List of parameters of the variational posterior. This is only required for user-defined q i.e. if q does not have a parameters attribute.

None
modules Optional[Iterable]

List of modules of the variational posterior. This is only required for user-defined q i.e. if q does not have a modules attribute.

None
num_transforms int

Number of transforms in the normalizing flow. Used for both single-x VI (when q is a string flow type) and amortized VI.

5
hidden_features int

Hidden layer size in flow networks. Used for both single-x VI and amortized VI.

50
z_score_theta Literal['none', 'independent', 'structured']

Method for z-scoring θ (parameters). One of “none”, “independent”, “structured”. Used for both single-x VI and amortized VI. Use “structured” for parameters with correlations.

'independent'
z_score_x Literal['none', 'independent', 'structured']

Method for z-scoring x (conditioning observation). One of “none”, “independent”, “structured”. Only used for amortized VI (train_amortized). Use “structured” for structured data like images.

'independent'
Source code in sbi/inference/posteriors/vi_posterior.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def __init__(
    self,
    potential_fn: Union[BasePotential, CustomPotential],
    prior: Optional[TorchDistribution] = None,  # type: ignore
    q: QType = "maf",
    theta_transform: Optional[TorchTransform] = None,
    vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL",
    device: Union[str, torch.device] = "cpu",
    x_shape: Optional[torch.Size] = None,
    parameters: Optional[Iterable] = None,
    modules: Optional[Iterable] = None,
    num_transforms: int = 5,
    hidden_features: int = 50,
    z_score_theta: Literal["none", "independent", "structured"] = "independent",
    z_score_x: Literal["none", "independent", "structured"] = "independent",
):
    """
    Args:
        potential_fn: The potential function from which to draw samples. Must be a
            `BasePotential` or a `CustomPotential`.
        prior: This is the prior distribution. Note that this is only
            used to check/construct the variational distribution or within some
            quality metrics. Please make sure that this matches with the prior
            within the potential_fn. If `None` is given, we will try to infer it
            from potential_fn or q, if this fails we raise an Error.
        q: Variational distribution, either string, `Distribution`, or a
            `VIPosterior` object. This specifies a parametric class of distribution
            over which the best possible posterior approximation is searched. For
            string input, we support normalizing flows [maf, nsf, naf, unaf, nice,
            sospf, gf] via Zuko, and Gaussian families [gaussian, gaussian_diag].
            Note: For 1D problems, prefer "gf" (mixture of Gaussians) or "gaussian"
            as autoregressive flows may be unstable.
            You can also specify your own variational family by passing a
            `torch.distributions.Distribution`. Additionally, we allow a `Callable`
            with signature `(event_shape: torch.Size, link_transform:
            TorchTransform, device: str) -> Distribution` for custom flow
            configurations. The
            callable should return a distribution with `sample()` and `log_prob()`
            methods. If q is already a `VIPosterior`, then the arguments will be
            copied from it (relevant for multi-round training).
        theta_transform: Maps form prior support to unconstrained space. The
            inverse is used here to ensure that the posterior support is equal to
            that of the prior.
        vi_method: This specifies the variational methods which are used to fit q to
            the posterior. We currently support [rKL, fKL, IW, alpha]. Note that
            some of the divergences are `mode seeking` i.e. they underestimate
            variance and collapse on multimodal targets (`rKL`, `alpha` for alpha >
            1) and some are `mass covering` i.e. they overestimate variance but
            typically cover all modes (`fKL`, `IW`, `alpha` for alpha < 1).
        device: Training device, e.g., `cpu`, `cuda` or `cuda:0`. We will ensure
            that all other objects are also on this device.
        x_shape: Deprecated, should not be passed.
        parameters: List of parameters of the variational posterior. This is only
            required for user-defined q i.e. if q does not have a `parameters`
            attribute.
        modules: List of modules of the variational posterior. This is only
            required for user-defined q i.e. if q does not have a `modules`
            attribute.
        num_transforms: Number of transforms in the normalizing flow. Used for
            both single-x VI (when q is a string flow type) and amortized VI.
        hidden_features: Hidden layer size in flow networks. Used for both
            single-x VI and amortized VI.
        z_score_theta: Method for z-scoring θ (parameters). One of "none",
            "independent", "structured". Used for both single-x VI and amortized
            VI. Use "structured" for parameters with correlations.
        z_score_x: Method for z-scoring x (conditioning observation). One of
            "none", "independent", "structured". Only used for amortized VI
            (train_amortized). Use "structured" for structured data like images.
    """
    super().__init__(potential_fn, theta_transform, device, x_shape=x_shape)

    # Especially the prior may be on another device -> move it...
    self._device = device
    self.theta_transform = theta_transform
    self.x_shape = x_shape
    self.potential_fn.device = device
    self.potential_fn.to(device)

    # Get prior and previous builds
    if prior is not None:
        self._prior = prior
    elif hasattr(self.potential_fn, "prior") and isinstance(
        self.potential_fn.prior, Distribution
    ):
        self._prior = self.potential_fn.prior
    elif isinstance(q, VIPosterior) and isinstance(q._prior, Distribution):
        self._prior = q._prior
    else:
        raise ValueError(
            "We could not find a suitable prior distribution within `potential_fn` "
            "or `q` (if a VIPosterior is given). Please explicitly specify a prior."
        )

    self._prior = move_distribution_to_device(self._prior, device)
    self._optimizer = None

    # Mode tracking: None (not trained), "single_x", or "amortized"
    self._mode: Optional[Literal["single_x", "amortized"]] = None

    # Amortized mode: conditional flow q(θ|x)
    self._amortized_q: Optional[ConditionalDensityEstimator] = None

    self._num_transforms: int = num_transforms
    self._hidden_features: int = hidden_features
    self._z_score_theta: Literal["none", "independent", "structured"] = (
        z_score_theta
    )
    self._z_score_x: Literal["none", "independent", "structured"] = z_score_x

    # In contrast to MCMC we want to project into constrained space.
    if theta_transform is None:
        self.link_transform = mcmc_transform(self._prior, device=device).inv
    else:
        self.link_transform = theta_transform.inv

    if parameters is None:
        parameters = []
    if modules is None:
        modules = []
    # This will set the variational distribution and VI method
    self.set_q(
        q,
        parameters=parameters,
        modules=modules,
    )
    self.set_vi_method(vi_method)

    self._purpose = (
        "It provides Variational inference to .sample() from the posterior and "
        "can evaluate the _normalized_ posterior density with .log_prob()."
    )

__setstate__(state_dict)

This method is called when unpickling the object.

Especially, we need to restore the removed attributes and ensure that the object e.g. remains deep copy compatible.

Parameters:

Name Type Description Default
state_dict Dict

Given state dictionary, we will restore the object from it.

required
Source code in sbi/inference/posteriors/vi_posterior.py
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
def __setstate__(self, state_dict: Dict):
    """This method is called when unpickling the object.

    Especially, we need to restore the removed attributes and ensure that the object
    e.g. remains deep copy compatible.

    Args:
        state_dict: Given state dictionary, we will restore the object from it.
    """
    self.__dict__ = state_dict
    q = deepcopy(self._q)
    # Restore removed attributes
    self.set_q(*self._q_arg)
    self._q = q
    make_object_deepcopy_compatible(self)
    make_object_deepcopy_compatible(self.q)
    # Handle amortized mode
    if self._mode == "amortized" and self._amortized_q is not None:
        make_object_deepcopy_compatible(self._amortized_q)

evaluate(quality_control_metric='psis', N=int(50000.0))

This function will evaluate the quality of the variational posterior distribution. We currently support two different metrics of type psis, which checks the quality based on the tails of importance weights (there should not be much with a large one), or prop which checks the proportionality between q and potential_fn.

NOTE: In our experience prop is sensitive to distinguish good from ok whereas psis is more sensitive in distinguishing very bad from ok.

Parameters:

Name Type Description Default
quality_control_metric str

The metric of choice, we currently support [psis, prop, prop_prior].

'psis'
N int

Number of samples which is used to evaluate the metric.

int(50000.0)
Source code in sbi/inference/posteriors/vi_posterior.py
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
def evaluate(self, quality_control_metric: str = "psis", N: int = int(5e4)) -> None:
    """This function will evaluate the quality of the variational posterior
    distribution. We currently support two different metrics of type `psis`, which
    checks the quality based on the tails of importance weights (there should not be
    much with a large one), or `prop` which checks the proportionality between q
    and potential_fn.

    NOTE: In our experience `prop` is sensitive to distinguish ``good`` from ``ok``
    whereas `psis` is more sensitive in distinguishing `very bad` from `ok`.

    Args:
        quality_control_metric: The metric of choice, we currently support [psis,
            prop, prop_prior].
        N: Number of samples which is used to evaluate the metric.
    """
    quality_control_fn, quality_control_msg = get_quality_metric(
        quality_control_metric
    )
    metric = round(float(quality_control_fn(self, N=N)), 3)
    print(f"Quality Score: {metric} " + quality_control_msg)

log_prob(theta, x=None, track_gradients=False)

Returns the log-probability of theta under the variational posterior.

For single-x mode: returns log q(θ). For amortized mode: returns log q(θ|x).

Parameters:

Name Type Description Default
theta Tensor

Parameters to evaluate, shape (batch_theta, θ_dim).

required
x Optional[Tensor]

Observation. In single-x mode, must match trained x_o (or be None). In amortized mode, required and can be any observation. For single x, shape (1, x_dim) or (x_dim,). For batched x, shape (batch_x, x_dim).

None
track_gradients bool

Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis but increases memory consumption.

False

Returns:

Type Description
Tensor

Log-probability of shape (batch,) where batch is:

Tensor
  • batch_theta if x has batch size 1 (broadcast x)
Tensor
  • batch_x if theta has batch size 1 (broadcast theta)
Tensor
  • batch_theta if batch_theta == batch_x (paired evaluation)

Raises:

Type Description
ValueError

If mode requirements are not met or batch sizes incompatible.

Source code in sbi/inference/posteriors/vi_posterior.py
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
def log_prob(
    self,
    theta: Tensor,
    x: Optional[Tensor] = None,
    track_gradients: bool = False,
) -> Tensor:
    r"""Returns the log-probability of theta under the variational posterior.

    For single-x mode: returns log q(θ).
    For amortized mode: returns log q(θ|x).

    Args:
        theta: Parameters to evaluate, shape (batch_theta, θ_dim).
        x: Observation. In single-x mode, must match trained x_o (or be None).
            In amortized mode, required and can be any observation.
            For single x, shape (1, x_dim) or (x_dim,).
            For batched x, shape (batch_x, x_dim).
        track_gradients: Whether the returned tensor supports tracking gradients.
            This can be helpful for e.g. sensitivity analysis but increases memory
            consumption.

    Returns:
        Log-probability of shape (batch,) where batch is:
        - batch_theta if x has batch size 1 (broadcast x)
        - batch_x if theta has batch size 1 (broadcast theta)
        - batch_theta if batch_theta == batch_x (paired evaluation)

    Raises:
        ValueError: If mode requirements are not met or batch sizes incompatible.
    """
    with torch.set_grad_enabled(track_gradients):
        theta = ensure_theta_batched(torch.as_tensor(theta)).to(self._device)

        if self._mode == "amortized":
            # Amortized mode: evaluate log q(θ|x)
            x = self._x_else_default_x(x)
            if x is None:
                raise ValueError(
                    "x is required for amortized mode. Provide an observation or "
                    "set a default x with set_default_x()."
                )
            x = atleast_2d_float32_tensor(x).to(self._device)
            assert self._amortized_q is not None

            # Handle broadcasting between theta and x
            batch_theta = theta.shape[0]
            batch_x = x.shape[0]

            if batch_theta != batch_x:
                if batch_x == 1:
                    # Broadcast x to match theta
                    x = x.expand(batch_theta, -1)
                elif batch_theta == 1:
                    # Broadcast theta to match x
                    theta = theta.expand(batch_x, -1)
                else:
                    raise ValueError(
                        f"Batch sizes of theta ({batch_theta}) and x ({batch_x}) "
                        f"are incompatible. They must be equal, or one must be 1."
                    )

            # ZukoFlow expects input shape (sample_dim, batch_dim, *event_shape)
            # Add sample dimension, compute log_prob, then squeeze back
            theta_with_sample_dim = theta.unsqueeze(0)
            log_probs = self._amortized_q.log_prob(
                theta_with_sample_dim, condition=x
            )
            return log_probs.squeeze(0)
        else:
            # Single-x mode: evaluate log q(θ)
            x = self._x_else_default_x(x)
            if self._trained_on is None or (x != self._trained_on).any():
                raise ValueError(
                    f"The variational posterior was not fit on the specified "
                    f"observation {x}. Please train using posterior.train()."
                )
            return self.q.log_prob(theta)

map(x=None, num_iter=1000, num_to_optimize=100, learning_rate=0.01, init_method='proposal', num_init_samples=10000, save_best_every=10, show_progress_bars=False, force_update=False)

Returns the maximum-a-posteriori estimate (MAP).

The method can be interrupted (Ctrl-C) when the user sees that the log-probability converges. The best estimate will be saved in self._map and can be accessed with self.map(). The MAP is obtained by running gradient ascent from a given number of starting positions (samples from the posterior with the highest log-probability). After the optimization is done, we select the parameter set that has the highest log-probability after the optimization.

Warning: The default values used by this function are not well-tested. They might require hand-tuning for the problem at hand.

For developers: if the prior is a BoxUniform, we carry out the optimization in unbounded space and transform the result back into bounded space.

Parameters:

Name Type Description Default
x Optional[TorchTensor]

Deprecated - use .set_default_x() prior to .map().

None
num_iter int

Number of optimization steps that the algorithm takes to find the MAP.

1000
learning_rate float

Learning rate of the optimizer.

0.01
init_method Union[str, TorchTensor]

How to select the starting parameters for the optimization. If it is a string, it can be either [posterior, prior], which samples the respective distribution num_init_samples times. If it is a tensor, the tensor will be used as init locations.

'proposal'
num_init_samples int

Draw this number of samples from the posterior and evaluate the log-probability of all of them.

10000
num_to_optimize int

From the drawn num_init_samples, use the num_to_optimize with highest log-probability as the initial points for the optimization.

100
save_best_every int

The best log-probability is computed, saved in the map-attribute, and printed every save_best_every-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is 10.)

10
show_progress_bars bool

Whether to show a progressbar during sampling from the posterior.

False
force_update bool

Whether to re-calculate the MAP when x is unchanged and have a cached value.

False
log_prob_kwargs

Will be empty for SNLE and SNRE. Will contain {‘norm_posterior’: True} for SNPE.

required

Returns:

Type Description
Tensor

The MAP estimate.

Source code in sbi/inference/posteriors/vi_posterior.py
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
def map(
    self,
    x: Optional[TorchTensor] = None,
    num_iter: int = 1_000,
    num_to_optimize: int = 100,
    learning_rate: float = 0.01,
    init_method: Union[str, TorchTensor] = "proposal",
    num_init_samples: int = 10_000,
    save_best_every: int = 10,
    show_progress_bars: bool = False,
    force_update: bool = False,
) -> Tensor:
    r"""Returns the maximum-a-posteriori estimate (MAP).

    The method can be interrupted (Ctrl-C) when the user sees that the
    log-probability converges. The best estimate will be saved in `self._map` and
    can be accessed with `self.map()`. The MAP is obtained by running gradient
    ascent from a given number of starting positions (samples from the posterior
    with the highest log-probability). After the optimization is done, we select the
    parameter set that has the highest log-probability after the optimization.

    Warning: The default values used by this function are not well-tested. They
    might require hand-tuning for the problem at hand.

    For developers: if the prior is a `BoxUniform`, we carry out the optimization
    in unbounded space and transform the result back into bounded space.

    Args:
        x: Deprecated - use `.set_default_x()` prior to `.map()`.
        num_iter: Number of optimization steps that the algorithm takes
            to find the MAP.
        learning_rate: Learning rate of the optimizer.
        init_method: How to select the starting parameters for the optimization. If
            it is a string, it can be either [`posterior`, `prior`], which samples
            the respective distribution `num_init_samples` times. If it is a
            tensor, the tensor will be used as init locations.
        num_init_samples: Draw this number of samples from the posterior and
            evaluate the log-probability of all of them.
        num_to_optimize: From the drawn `num_init_samples`, use the
            `num_to_optimize` with highest log-probability as the initial points
            for the optimization.
        save_best_every: The best log-probability is computed, saved in the
            `map`-attribute, and printed every `save_best_every`-th iteration.
            Computing the best log-probability creates a significant overhead
            (thus, the default is `10`.)
        show_progress_bars: Whether to show a progressbar during sampling from
            the posterior.
        force_update: Whether to re-calculate the MAP when x is unchanged and
            have a cached value.
        log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
            {'norm_posterior': True} for SNPE.

    Returns:
        The MAP estimate.
    """
    self.proposal = self.q
    return super().map(
        x=x,
        num_iter=num_iter,
        num_to_optimize=num_to_optimize,
        learning_rate=learning_rate,
        init_method=init_method,
        num_init_samples=num_init_samples,
        save_best_every=save_best_every,
        show_progress_bars=show_progress_bars,
        force_update=force_update,
    )

sample(sample_shape=torch.Size(), x=None, show_progress_bars=True)

Draw samples from the variational posterior distribution \(p(\theta|x)\).

For single-x mode (trained via train()): samples from q(θ) trained on x_o. For amortized mode (trained via train_amortized()): samples from q(θ|x).

Parameters:

Name Type Description Default
sample_shape Shape

Desired shape of samples that are drawn from the posterior.

Size()
x Optional[Tensor]

Conditioning observation. In single-x mode, must match trained x_o (or be None to use default). In amortized mode, required and can be any observation. For batched observations, shape should be (batch_size, x_dim).

None
show_progress_bars bool

Unused for VIPosterior since sampling from the variational distribution is fast. Included for API consistency.

True

Returns:

Type Description
Tensor

Samples from posterior with shape (*sample_shape, θ_dim) for single x,

Tensor

or (*sample_shape, batch_size, θ_dim) for batched observations in

Tensor

amortized mode.

Raises:

Type Description
ValueError

If mode requirements are not met.

Source code in sbi/inference/posteriors/vi_posterior.py
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
def sample(
    self,
    sample_shape: Shape = torch.Size(),
    x: Optional[Tensor] = None,
    show_progress_bars: bool = True,
) -> Tensor:
    r"""Draw samples from the variational posterior distribution $p(\theta|x)$.

    For single-x mode (trained via `train()`): samples from q(θ) trained on x_o.
    For amortized mode (trained via `train_amortized()`): samples from q(θ|x).

    Args:
        sample_shape: Desired shape of samples that are drawn from the posterior.
        x: Conditioning observation. In single-x mode, must match trained x_o
            (or be None to use default). In amortized mode, required and can be
            any observation. For batched observations, shape should be
            (batch_size, x_dim).
        show_progress_bars: Unused for `VIPosterior` since sampling from the
            variational distribution is fast. Included for API consistency.

    Returns:
        Samples from posterior with shape (*sample_shape, θ_dim) for single x,
        or (*sample_shape, batch_size, θ_dim) for batched observations in
        amortized mode.

    Raises:
        ValueError: If mode requirements are not met.
    """
    if self._mode == "amortized":
        # Amortized mode: sample from conditional flow q(θ|x)
        x = self._x_else_default_x(x)
        if x is None:
            raise ValueError(
                "x is required for amortized mode. Provide an observation or "
                "set a default x with set_default_x()."
            )
        x = atleast_2d_float32_tensor(x).to(self._device)
        assert self._amortized_q is not None
        # samples shape from flow: (*sample_shape, batch_size, θ_dim)
        samples = self._amortized_q.sample(torch.Size(sample_shape), condition=x)
        # Match base posterior behavior: drop singleton x batch dimension
        if x.shape[0] == 1:
            samples = samples.squeeze(-2)
        return samples
    else:
        # Single-x mode: sample from unconditional flow q(θ)
        x = self._x_else_default_x(x)
        if self._trained_on is None or (x != self._trained_on).any():
            raise ValueError(
                f"The variational posterior was not fit on the specified "
                f"observation {x}. Please train using posterior.train()."
            )
        samples = self.q.sample(torch.Size(sample_shape))
        return samples.reshape((*sample_shape, samples.shape[-1]))

sample_batched(sample_shape, x, max_sampling_batch_size=10000, show_progress_bars=True)

Sample from posterior for a batch of observations.

In amortized mode, this is efficient as all x values are processed in parallel through the conditional flow.

In single-x mode, this raises NotImplementedError since the unconditional flow is trained for a specific x_o.

Parameters:

Name Type Description Default
sample_shape Shape

Number of samples per observation.

required
x Tensor

Batch of observations (num_obs, x_dim).

required
max_sampling_batch_size int

Unused for amortized mode (no batching needed).

10000
show_progress_bars bool

Unused for amortized mode.

True

Returns:

Type Description
Tensor

Samples of shape (*sample_shape, num_obs, θ_dim).

Raises:

Type Description
NotImplementedError

If called in single-x mode.

Source code in sbi/inference/posteriors/vi_posterior.py
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
def sample_batched(
    self,
    sample_shape: Shape,
    x: Tensor,
    max_sampling_batch_size: int = 10000,
    show_progress_bars: bool = True,
) -> Tensor:
    """Sample from posterior for a batch of observations.

    In amortized mode, this is efficient as all x values are processed in
    parallel through the conditional flow.

    In single-x mode, this raises NotImplementedError since the unconditional
    flow is trained for a specific x_o.

    Args:
        sample_shape: Number of samples per observation.
        x: Batch of observations (num_obs, x_dim).
        max_sampling_batch_size: Unused for amortized mode (no batching needed).
        show_progress_bars: Unused for amortized mode.

    Returns:
        Samples of shape (*sample_shape, num_obs, θ_dim).

    Raises:
        NotImplementedError: If called in single-x mode.
    """
    if self._mode == "amortized":
        # In amortized mode, sample() handles batched x directly
        return self.sample(sample_shape, x=x, show_progress_bars=show_progress_bars)
    else:
        raise NotImplementedError(
            "Batched sampling is not implemented for single-x VI mode. "
            "Use train_amortized() to train an amortized posterior, or "
            "call sample() in a loop: [posterior.sample(shape, x_o) for x_o in x]."
        )

set_q(q, parameters=None, modules=None)

Defines the variational family.

You can specify over which parameters/modules we optimize. This is required for custom distributions which e.g. do not inherit nn.Modules or has the function parameters or modules to give direct access to trainable parameters. Further, you can pass a function, which constructs a variational distribution if called.

Parameters:

Name Type Description Default
q QType

Variational distribution, either string, distribution, or a VIPosterior object. This specifies a parametric class of distribution over which the best possible posterior approximation is searched. For string input, we support normalizing flows [maf, nsf, naf, unaf, nice, sospf] via Zuko, and simple Gaussian families [gaussian, gaussian_diag] via pure PyTorch. You can also specify your own variational family by passing a parameterized distribution object i.e. a torch.distributions Distribution with methods parameters returning an iterable of all parameters (you can pass them within the parameters/modules attribute). Additionally, we allow a Callable with signature (event_shape: torch.Size, link_transform: TorchTransform, device: str) -> Distribution, which builds a custom distribution. If q is already a VIPosterior, then the arguments will be copied from it (relevant for multi-round training).

Note: For 1D parameter spaces, autoregressive normalizing flows may be unstable. Consider using q='gaussian' or q='gf' for 1D.

required
parameters Optional[Iterable]

List of parameters associated with the distribution object.

None
modules Optional[Iterable]

List of modules associated with the distribution object.

None
Source code in sbi/inference/posteriors/vi_posterior.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def set_q(
    self,
    q: QType,
    parameters: Optional[Iterable] = None,
    modules: Optional[Iterable] = None,
) -> None:
    """Defines the variational family.

    You can specify over which parameters/modules we optimize. This is required for
    custom distributions which e.g. do not inherit nn.Modules or has the function
    `parameters` or `modules` to give direct access to trainable parameters.
    Further, you can pass a function, which constructs a variational distribution
    if called.

    Args:
        q: Variational distribution, either string, distribution, or a VIPosterior
            object. This specifies a parametric class of distribution over which
            the best possible posterior approximation is searched. For string input,
            we support normalizing flows [maf, nsf, naf, unaf, nice, sospf] via
            Zuko, and simple Gaussian families [gaussian, gaussian_diag] via pure
            PyTorch. You can also specify your own variational family by passing a
            `parameterized` distribution object i.e. a torch.distributions
            Distribution with methods `parameters` returning an iterable of all
            parameters (you can pass them within the parameters/modules attribute).
            Additionally, we allow a `Callable` with signature
            `(event_shape: torch.Size, link_transform: TorchTransform, device: str)
            -> Distribution`, which builds a custom distribution. If q is already
            a `VIPosterior`, then the arguments will be copied from it (relevant
            for multi-round training).

            Note: For 1D parameter spaces, autoregressive normalizing flows
            may be unstable. Consider using `q='gaussian'` or `q='gf'` for 1D.
        parameters: List of parameters associated with the distribution object.
        modules: List of modules associated with the distribution object.

    """
    if parameters is None:
        parameters = []
    if modules is None:
        modules = []
    self._q_arg = (q, parameters, modules)
    _flow_types = (ZukoUnconditionalFlow, TransformedZukoFlow, LearnableGaussian)
    if isinstance(q, _flow_types):
        # Flow/Gaussian passed directly (e.g., from _q_build_fn during retrain)
        make_object_deepcopy_compatible(q)
        self._trained_on = None
    elif isinstance(q, Distribution):
        q = adapt_variational_distribution(
            q,
            self._prior,
            self.link_transform,
            parameters=parameters,
            modules=modules,
        )
        make_object_deepcopy_compatible(q)
        self_custom_q_init_cache = deepcopy(q)
        self._q_build_fn = lambda *args, **kwargs: self_custom_q_init_cache
        self._trained_on = None
        self._zuko_flow_type = None
    elif isinstance(q, (str, Callable)):
        if isinstance(q, str):
            if q in _ZUKO_FLOW_TYPES:
                q_flow = self._build_unconditional_flow(q)
                self._zuko_flow_type = q
                self._q_build_fn = lambda *args, ft=q, **kwargs: (
                    self._build_unconditional_flow(ft)
                )
                q = q_flow
            elif q in ("gaussian", "gaussian_diag"):
                self._zuko_flow_type = None
                full_cov = q == "gaussian"
                dim = self._prior.event_shape[0]
                q_dist = LearnableGaussian(
                    dim=dim,
                    full_covariance=full_cov,
                    link_transform=self.link_transform,
                    device=self._device,
                )
                self._q_build_fn = lambda *args, fc=full_cov, d=dim, **kwargs: (
                    LearnableGaussian(
                        dim=d,
                        full_covariance=fc,
                        link_transform=self.link_transform,
                        device=self._device,
                    )
                )
                q = q_dist
            else:
                supported = sorted(_ZUKO_FLOW_TYPES) + ["gaussian", "gaussian_diag"]
                raise ValueError(
                    f"Unknown variational family '{q}'. "
                    f"Supported options: {supported}"
                )
        else:
            # Callable provided - use as-is
            self._zuko_flow_type = None
            self._q_build_fn = q
            q = self._q_build_fn(
                self._prior.event_shape,
                self.link_transform,
                device=self._device,
            )
        make_object_deepcopy_compatible(q)
        self._trained_on = None
    elif isinstance(q, VIPosterior):
        self._q_build_fn = q._q_build_fn
        self._trained_on = q._trained_on
        self._mode = getattr(q, "_mode", None)  # Copy mode from source
        self._zuko_flow_type = getattr(q, "_zuko_flow_type", None)
        self.vi_method = q.vi_method  # type: ignore
        self._prior = q._prior
        self._x = q._x
        self._q_arg = q._q_arg
        make_object_deepcopy_compatible(q.q)
        q = deepcopy(q.q)
        # Move copied q to self's device (source may be on a different device).
        if hasattr(q, "to"):
            q.to(self._device)  # type: ignore[union-attr]
    # Validate the variational distribution
    if isinstance(q, _flow_types):
        pass  # These are validated during construction
    elif isinstance(q, Distribution):
        check_variational_distribution(q, self._prior)
    else:
        raise ValueError(
            f"Variational distribution must be a Distribution, got {type(q)}. "
            "Please create an issue on github https://github.com/mackelab/sbi/issues"
        )
    self._q = q

set_vi_method(method)

Sets variational inference method.

Parameters:

Name Type Description Default
method str

One of [rKL, fKL, IW, alpha].

required

Returns:

Type Description
VIPosterior

VIPosterior for chainable calls.

Source code in sbi/inference/posteriors/vi_posterior.py
549
550
551
552
553
554
555
556
557
558
559
560
def set_vi_method(self, method: str) -> "VIPosterior":
    """Sets variational inference method.

    Args:
        method: One of [rKL, fKL, IW, alpha].

    Returns:
        `VIPosterior` for chainable calls.
    """
    self._vi_method = method
    self._optimizer_builder = get_VI_method(method)
    return self

to(device)

Move all components to the given device.

Parameters:

Name Type Description Default
device Union[str, device]

The device to move the posterior to.

required

Returns:

Type Description
VIPosterior

self for method chaining.

Source code in sbi/inference/posteriors/vi_posterior.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def to(self, device: Union[str, torch.device]) -> "VIPosterior":
    """Move all components to the given device.

    Args:
        device: The device to move the posterior to.

    Returns:
        self for method chaining.
    """
    self._device = device

    # Move potential (which moves prior, x_o, and estimator).
    self.potential_fn.to(device)  # type: ignore
    self._prior = move_distribution_to_device(self._prior, device)

    # Rebuild link_transform on new device (same logic as __init__).
    if self.theta_transform is None:
        self.link_transform = mcmc_transform(self._prior, device=device).inv
    else:
        self.link_transform = self.theta_transform.inv

    # Move cached tensors.
    if self._x is not None:
        self._x = self._x.to(device)
    if self._map is not None:
        self._map = self._map.to(device)
    if self._trained_on is not None:
        self._trained_on = self._trained_on.to(device)

    # Move variational distributions.
    if hasattr(self, "_q") and hasattr(self._q, "to"):
        self._q.to(device)  # type: ignore[union-attr]
    # Update link_transform reference on q if it caches one.
    if hasattr(self, "_q") and hasattr(self._q, "_link_transform"):
        self._q._link_transform = self.link_transform  # type: ignore[union-attr]
    if self._amortized_q is not None:
        self._amortized_q.to(device)

    return self

train(x=None, n_particles=256, learning_rate=0.001, gamma=0.999, max_num_iters=2000, min_num_iters=10, clip_value=10.0, warm_up_rounds=100, retrain_from_scratch=False, reset_optimizer=False, show_progress_bar=True, check_for_convergence=True, quality_control=True, quality_control_metric='psis', **kwargs)

This method trains the variational posterior for a single observation.

Parameters:

Name Type Description Default
x Optional[TorchTensor]

The observation, optional, defaults to self._x.

None
n_particles int

Number of samples to approximate expectations within the variational bounds. The larger the more accurate are gradient estimates, but the computational cost per iteration increases.

256
learning_rate float

Learning rate of the optimizer.

0.001
gamma float

Learning rate decay per iteration. We use an exponential decay scheduler.

0.999
max_num_iters int

Maximum number of iterations.

2000
min_num_iters int

Minimum number of iterations.

10
clip_value float

Gradient clipping value, decreasing may help if you see invalid values.

10.0
warm_up_rounds int

Initialize the posterior as the prior.

100
retrain_from_scratch bool

Retrain the variational distributions from scratch.

False
reset_optimizer bool

Reset the divergence optimizer

False
show_progress_bar bool

If any progress report should be displayed.

True
quality_control bool

If False quality control is skipped.

True
quality_control_metric str

Which metric to use for evaluating the quality.

'psis'
kwargs

Hyperparameters check corresponding DivergenceOptimizer for detail eps: Determines sensitivity of convergence check. retain_graph: Boolean which decides whether to retain the computation graph. This may be required for some exotic user-specified q’s. optimizer: A PyTorch Optimizer class e.g. Adam or SGD. See DivergenceOptimizer for details. scheduler: A PyTorch learning rate scheduler. See DivergenceOptimizer for details. alpha: Only used if vi_method=alpha. Determines the alpha divergence. K: Only used if vi_method=IW. Determines the number of importance weighted particles. stick_the_landing: If one should use the STL estimator (only for rKL, IW, alpha). dreg: If one should use the DREG estimator (only for rKL, IW, alpha). weight_transform: Callable applied to importance weights (only for fKL)

{}

Returns: VIPosterior: VIPosterior (can be used to chain calls).

Raises:

Type Description
ValueError

If hyperparameters are invalid.

Source code in sbi/inference/posteriors/vi_posterior.py
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
def train(
    self,
    x: Optional[TorchTensor] = None,
    n_particles: int = 256,
    learning_rate: float = 1e-3,
    gamma: float = 0.999,
    max_num_iters: int = 2000,
    min_num_iters: int = 10,
    clip_value: float = 10.0,
    warm_up_rounds: int = 100,
    retrain_from_scratch: bool = False,
    reset_optimizer: bool = False,
    show_progress_bar: bool = True,
    check_for_convergence: bool = True,
    quality_control: bool = True,
    quality_control_metric: str = "psis",
    **kwargs,
) -> "VIPosterior":
    """This method trains the variational posterior for a single observation.

    Args:
        x: The observation, optional, defaults to self._x.
        n_particles: Number of samples to approximate expectations within the
            variational bounds. The larger the more accurate are gradient
            estimates, but the computational cost per iteration increases.
        learning_rate: Learning rate of the optimizer.
        gamma: Learning rate decay per iteration. We use an exponential decay
            scheduler.
        max_num_iters: Maximum number of iterations.
        min_num_iters: Minimum number of iterations.
        clip_value: Gradient clipping value, decreasing may help if you see invalid
            values.
        warm_up_rounds: Initialize the posterior as the prior.
        retrain_from_scratch: Retrain the variational distributions from scratch.
        reset_optimizer: Reset the divergence optimizer
        show_progress_bar: If any progress report should be displayed.
        quality_control: If False quality control is skipped.
        quality_control_metric: Which metric to use for evaluating the quality.
        kwargs: Hyperparameters check corresponding `DivergenceOptimizer` for detail
            eps: Determines sensitivity of convergence check.
            retain_graph: Boolean which decides whether to retain the computation
                graph. This may be required for some `exotic` user-specified q's.
            optimizer: A PyTorch Optimizer class e.g. Adam or SGD. See
                `DivergenceOptimizer` for details.
            scheduler: A PyTorch learning rate scheduler. See
                `DivergenceOptimizer` for details.
            alpha: Only used if vi_method=`alpha`. Determines the alpha divergence.
            K: Only used if vi_method=`IW`. Determines the number of importance
                weighted particles.
            stick_the_landing: If one should use the STL estimator (only for rKL,
                IW, alpha).
            dreg: If one should use the DREG estimator (only for rKL, IW, alpha).
            weight_transform: Callable applied to importance weights (only for fKL)
    Returns:
        VIPosterior: `VIPosterior` (can be used to chain calls).

    Raises:
        ValueError: If hyperparameters are invalid.
    """
    # Validate hyperparameters
    if n_particles <= 0:
        raise ValueError(f"n_particles must be positive, got {n_particles}")
    if learning_rate <= 0:
        raise ValueError(f"learning_rate must be positive, got {learning_rate}")
    if not 0 < gamma <= 1:
        raise ValueError(f"gamma must be in (0, 1], got {gamma}")
    if max_num_iters <= 0:
        raise ValueError(f"max_num_iters must be positive, got {max_num_iters}")
    if min_num_iters < 0:
        raise ValueError(f"min_num_iters must be non-negative, got {min_num_iters}")
    if clip_value <= 0:
        raise ValueError(f"clip_value must be positive, got {clip_value}")

    # Update optimizer with current arguments.
    if self._optimizer is not None:
        self._optimizer.update({**locals(), **kwargs})

    # Init q and the optimizer if necessary
    if retrain_from_scratch:
        self.q = self._q_build_fn()  # type: ignore
        self._optimizer = self._optimizer_builder(
            self.potential_fn,
            self.q,
            lr=learning_rate,
            clip_value=clip_value,
            gamma=gamma,
            n_particles=n_particles,
            prior=self._prior,
            **kwargs,
        )

    if (
        reset_optimizer
        or self._optimizer is None
        or not isinstance(self._optimizer, self._optimizer_builder)
    ):
        self._optimizer = self._optimizer_builder(
            self.potential_fn,
            self.q,
            lr=learning_rate,
            clip_value=clip_value,
            gamma=gamma,
            n_particles=n_particles,
            prior=self._prior,
            **kwargs,
        )

    # Check context
    x = atleast_2d_float32_tensor(self._x_else_default_x(x)).to(  # type: ignore
        self._device
    )
    if not torch.isfinite(x).all():
        raise ValueError("x contains NaN or Inf values.")

    already_trained = self._trained_on is not None and (x == self._trained_on).all()

    # Optimize
    optimizer = self._optimizer
    optimizer.to(self._device)
    optimizer.reset_loss_stats()

    if show_progress_bar:
        iters = tqdm(range(max_num_iters))
    else:
        iters = range(max_num_iters)

    # Warmup before training
    if reset_optimizer or (not optimizer.warm_up_was_done and not already_trained):
        if show_progress_bar:
            iters.set_description(  # type: ignore
                "Warmup phase, this may take a few seconds..."
            )
        optimizer.warm_up(warm_up_rounds)

    for i in iters:
        optimizer.step(x)
        mean_loss, std_loss = optimizer.get_loss_stats()
        # Update progress bar
        if show_progress_bar:
            assert isinstance(iters, tqdm)
            iters.set_description(  # type: ignore
                f"Loss: {np.round(float(mean_loss), 2)}, "
                f"Std: {np.round(float(std_loss), 2)}"
            )
        # Check for convergence
        if check_for_convergence and i > min_num_iters and optimizer.converged():
            if show_progress_bar:
                print(f"\nConverged with loss: {np.round(float(mean_loss), 2)}")
            break
    # Training finished:
    self._trained_on = x
    if self._mode == "amortized":
        warnings.warn(
            "Switching from amortized to single-x mode. "
            "The previously trained amortized model will be discarded.",
            UserWarning,
            stacklevel=2,
        )
        self._amortized_q = None
    self._mode = "single_x"

    # Evaluate quality
    if quality_control:
        try:
            self.evaluate(quality_control_metric=quality_control_metric)
        except Exception as e:
            print(
                f"Quality control showed a low quality of the variational "
                f"posterior. We are automatically retraining the variational "
                f"posterior from scratch with a smaller learning rate. "
                f"Alternatively, if you want to skip quality control, please "
                f"retrain with `VIPosterior.train(..., quality_control=False)`. "
                f"\nThe error that occured is: {e}"
            )
            self.train(
                learning_rate=learning_rate * 0.1,
                retrain_from_scratch=True,
                reset_optimizer=True,
            )

    return self

train_amortized(theta, x, n_particles=128, learning_rate=0.001, gamma=0.999, max_num_iters=500, clip_value=5.0, batch_size=64, validation_fraction=0.1, validation_batch_size=None, validation_n_particles=None, stop_after_iters=20, show_progress_bar=True, retrain_from_scratch=False, flow_type=None, num_transforms=None, hidden_features=None, z_score_theta=None, z_score_x=None, params=None)

Train a conditional flow q(θ|x) for amortized variational inference.

This allows sampling from q(θ|x) for any observation x without retraining. Uses the ELBO (Evidence Lower Bound) objective with early stopping based on validation loss.

Parameters:

Name Type Description Default
theta Tensor

Training θ values from simulations (num_sims, θ_dim).

required
x Tensor

Training x values from simulations (num_sims, x_dim).

required
n_particles int

Number of samples to estimate ELBO per x.

128
learning_rate float

Learning rate for Adam optimizer.

0.001
gamma float

Learning rate decay per iteration.

0.999
max_num_iters int

Maximum training iterations.

500
clip_value float

Gradient clipping threshold.

5.0
batch_size int

Number of x values per training batch.

64
validation_fraction float

Fraction of data to use for validation.

0.1
validation_batch_size Optional[int]

Batch size for validation loss. Defaults to batch_size.

None
validation_n_particles Optional[int]

Number of particles for validation loss. Defaults to n_particles.

None
stop_after_iters int

Stop training after this many iterations without improvement in validation loss.

20
show_progress_bar bool

Whether to show progress.

True
retrain_from_scratch bool

If True, rebuild the flow from scratch.

False
flow_type Optional[Union[ZukoFlowType, str]]

Flow architecture for the variational distribution. Use ZukoFlowType.NSF, ZukoFlowType.MAF, etc., or a string. If None, uses value from params or instance default.

None
num_transforms Optional[int]

Number of transforms in the flow. If None, uses value from params or instance default.

None
hidden_features Optional[int]

Hidden layer size in the flow. If None, uses value from params or instance default.

None
z_score_theta Optional[Literal['none', 'independent', 'structured']]

Method for z-scoring θ (the parameters being modeled). One of “none”, “independent”, “structured”. If None, uses value from params or instance default.

None
z_score_x Optional[Literal['none', 'independent', 'structured']]

Method for z-scoring x (the conditioning variable). One of “none”, “independent”, “structured”. Use “structured” for structured data like images with spatial correlations. If None, uses value from params or instance default.

None
params Optional[VIPosteriorParameters]

Optional VIPosteriorParameters dataclass. Values are used as fallbacks when explicit arguments are None. Priority order: explicit args > params > instance attributes (from init).

None

Returns:

Type Description
VIPosterior

self for method chaining.

Source code in sbi/inference/posteriors/vi_posterior.py
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
def train_amortized(
    self,
    theta: Tensor,
    x: Tensor,
    n_particles: int = 128,
    learning_rate: float = 1e-3,
    gamma: float = 0.999,
    max_num_iters: int = 500,
    clip_value: float = 5.0,
    batch_size: int = 64,
    validation_fraction: float = 0.1,
    validation_batch_size: Optional[int] = None,
    validation_n_particles: Optional[int] = None,
    stop_after_iters: int = 20,
    show_progress_bar: bool = True,
    retrain_from_scratch: bool = False,
    flow_type: Optional[Union[ZukoFlowType, str]] = None,
    num_transforms: Optional[int] = None,
    hidden_features: Optional[int] = None,
    z_score_theta: Optional[Literal["none", "independent", "structured"]] = None,
    z_score_x: Optional[Literal["none", "independent", "structured"]] = None,
    params: Optional["VIPosteriorParameters"] = None,
) -> "VIPosterior":
    """Train a conditional flow q(θ|x) for amortized variational inference.

    This allows sampling from q(θ|x) for any observation x without retraining.
    Uses the ELBO (Evidence Lower Bound) objective with early stopping based on
    validation loss.

    Args:
        theta: Training θ values from simulations (num_sims, θ_dim).
        x: Training x values from simulations (num_sims, x_dim).
        n_particles: Number of samples to estimate ELBO per x.
        learning_rate: Learning rate for Adam optimizer.
        gamma: Learning rate decay per iteration.
        max_num_iters: Maximum training iterations.
        clip_value: Gradient clipping threshold.
        batch_size: Number of x values per training batch.
        validation_fraction: Fraction of data to use for validation.
        validation_batch_size: Batch size for validation loss. Defaults to
            `batch_size`.
        validation_n_particles: Number of particles for validation loss.
            Defaults to `n_particles`.
        stop_after_iters: Stop training after this many iterations without
            improvement in validation loss.
        show_progress_bar: Whether to show progress.
        retrain_from_scratch: If True, rebuild the flow from scratch.
        flow_type: Flow architecture for the variational distribution.
            Use ZukoFlowType.NSF, ZukoFlowType.MAF, etc., or a string.
            If None, uses value from params or instance default.
        num_transforms: Number of transforms in the flow. If None, uses value
            from params or instance default.
        hidden_features: Hidden layer size in the flow. If None, uses value
            from params or instance default.
        z_score_theta: Method for z-scoring θ (the parameters being modeled).
            One of "none", "independent", "structured". If None, uses value
            from params or instance default.
        z_score_x: Method for z-scoring x (the conditioning variable).
            One of "none", "independent", "structured". Use "structured" for
            structured data like images with spatial correlations. If None,
            uses value from params or instance default.
        params: Optional VIPosteriorParameters dataclass. Values are used as
            fallbacks when explicit arguments are None. Priority order:
            explicit args > params > instance attributes (from __init__).

    Returns:
        self for method chaining.
    """
    # Resolve parameters: explicit args > params dataclass > instance attrs
    if params is not None:
        # Amortized VI only supports string flow types (not VIPosterior or Callable)
        if not isinstance(params.q, str):
            raise ValueError(
                "train_amortized() only supports string flow types "
                f"(e.g., 'nsf', 'maf'), not {type(params.q).__name__}. "
                "Use set_q() to pass custom distributions for single-x VI."
            )
        if flow_type is None:
            flow_type = params.q
        if num_transforms is None:
            num_transforms = params.num_transforms
        if hidden_features is None:
            hidden_features = params.hidden_features
        if z_score_theta is None:
            z_score_theta = params.z_score_theta
        if z_score_x is None:
            z_score_x = params.z_score_x

    # Fall back to instance attributes (set in __init__ from VIPosteriorParameters)
    if flow_type is None:
        flow_type = ZukoFlowType.NSF
    if num_transforms is None:
        num_transforms = self._num_transforms
    if hidden_features is None:
        hidden_features = self._hidden_features
    if z_score_theta is None:
        z_score_theta = self._z_score_theta
    if z_score_x is None:
        z_score_x = self._z_score_x

    theta = atleast_2d_float32_tensor(theta).to(self._device)
    x = atleast_2d_float32_tensor(x).to(self._device)

    # Validate inputs
    if theta.shape[0] != x.shape[0]:
        raise ValueError(
            f"Batch size mismatch: theta has {theta.shape[0]} samples, "
            f"x has {x.shape[0]} samples. They must match."
        )
    if len(theta) == 0:
        raise ValueError("Training data cannot be empty.")
    if not torch.isfinite(theta).all():
        raise ValueError("theta contains NaN or Inf values.")
    if not torch.isfinite(x).all():
        raise ValueError("x contains NaN or Inf values.")

    # Validate theta dimension matches prior
    prior_event_shape = self._prior.event_shape
    if len(prior_event_shape) > 0:
        expected_theta_dim = prior_event_shape[0]
        if theta.shape[1] != expected_theta_dim:
            raise ValueError(
                f"theta dimension {theta.shape[1]} does not match prior "
                f"event shape {expected_theta_dim}."
            )

    # Validate hyperparameters
    if not 0 < validation_fraction < 1:
        raise ValueError(
            f"validation_fraction must be in (0, 1), got {validation_fraction}"
        )
    if n_particles <= 0:
        raise ValueError(f"n_particles must be positive, got {n_particles}")
    if batch_size <= 0:
        raise ValueError(f"batch_size must be positive, got {batch_size}")

    # Validate flow_type early to fail fast
    if isinstance(flow_type, str):
        try:
            flow_type = ZukoFlowType[flow_type.upper()]
        except KeyError:
            raise ValueError(
                f"Unknown flow type '{flow_type}'. "
                f"Supported types: {[t.name for t in ZukoFlowType]}."
            ) from None

    if validation_batch_size is None:
        validation_batch_size = batch_size
    if validation_n_particles is None:
        validation_n_particles = n_particles

    if validation_batch_size <= 0:
        raise ValueError(
            f"validation_batch_size must be positive, got {validation_batch_size}"
        )
    if validation_n_particles <= 0:
        raise ValueError(
            f"validation_n_particles must be positive, got {validation_n_particles}"
        )

    # Split into training and validation sets
    num_examples = len(theta)
    num_val = int(validation_fraction * num_examples)
    num_train = num_examples - num_val

    if num_val == 0:
        raise ValueError(
            "Validation set is empty. Increase validation_fraction or provide more "
            "training data."
        )
    if num_train < batch_size:
        raise ValueError(
            f"Training set size ({num_train}) is smaller than batch_size "
            f"({batch_size}). Reduce validation_fraction or batch_size."
        )

    permuted_indices = torch.randperm(num_examples, device=self._device)
    train_indices = permuted_indices[:num_train]
    val_indices = permuted_indices[num_train:]

    theta_train, x_train = theta[train_indices], x[train_indices]
    x_val = x[val_indices]  # Only x needed for validation (θ sampled from q)

    use_val_subset = validation_batch_size < x_val.shape[0]

    # Build or rebuild the conditional flow (z-score on training data only)
    if self._amortized_q is None or retrain_from_scratch:
        self._amortized_q = self._build_conditional_flow(
            theta_train,
            x_train,
            flow_type=flow_type,
            num_transforms=num_transforms,
            hidden_features=hidden_features,
            z_score_theta=z_score_theta,
            z_score_x=z_score_x,
        )

    # Ensure potential_fn is on the correct device for amortized training
    self.potential_fn.to(self._device)

    # Setup optimizer
    optimizer = Adam(self._amortized_q.parameters(), lr=learning_rate)
    scheduler = ExponentialLR(optimizer, gamma=gamma)

    # Training loop with validation-based early stopping
    best_val_loss = float("inf")
    iters_since_improvement = 0
    best_state_dict = deepcopy(self._amortized_q.state_dict())

    if show_progress_bar:
        iters = tqdm(range(max_num_iters), desc="Amortized VI (ELBO)")
    else:
        iters = range(max_num_iters)

    for iteration in iters:
        # Training step
        self._amortized_q.train()
        optimizer.zero_grad()

        # Sample batch from training set
        idx = torch.randint(0, num_train, (batch_size,), device=self._device)
        x_batch = x_train[idx]

        train_loss = self._compute_amortized_elbo_loss(x_batch, n_particles)

        if not torch.isfinite(train_loss):
            raise RuntimeError(
                f"Training loss became non-finite at iteration {iteration}: "
                f"{train_loss.item()}. This indicates numerical instability. Try:\n"
                f"  - Reducing learning_rate (currently {learning_rate})\n"
                f"  - Reducing n_particles (currently {n_particles})\n"
                f"  - Checking your potential_fn for numerical issues"
            )

        train_loss.backward()
        nn.utils.clip_grad_norm_(self._amortized_q.parameters(), clip_value)
        optimizer.step()
        scheduler.step()

        # Compute validation loss
        self._amortized_q.eval()
        with torch.no_grad():
            if use_val_subset:
                val_idx = torch.randperm(x_val.shape[0], device=self._device)[
                    :validation_batch_size
                ]
                x_val_batch = x_val[val_idx]
            else:
                x_val_batch = x_val
            val_loss = self._compute_amortized_elbo_loss(
                x_val_batch, validation_n_particles
            ).item()

        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            iters_since_improvement = 0
            best_state_dict = deepcopy(self._amortized_q.state_dict())
        else:
            iters_since_improvement += 1

        if show_progress_bar:
            assert isinstance(iters, tqdm)
            iters.set_postfix({
                "train": f"{train_loss.item():.3f}",
                "val": f"{val_loss:.3f}",
            })

        # Early stopping
        if iters_since_improvement >= stop_after_iters:
            if show_progress_bar:
                print(f"\nConverged at iteration {iteration}")
            break

    # Restore best model
    self._amortized_q.load_state_dict(best_state_dict)
    self._amortized_q.eval()
    if self._mode == "single_x":
        warnings.warn(
            "Switching from single-x to amortized mode. "
            "The previously trained single-x model will not be usable.",
            UserWarning,
            stacklevel=2,
        )
    self._mode = "amortized"

    return self