lampe.inference.amnre

Arbitrary marginal neural ratio estimation (AMNRE) components.

The principle of AMNRE is to introduce, as input to the classifier, a binary mask \(b \in \{0, 1\}^D\) indicating a subset of parameters \(\theta_b = (\theta_i: b_i = 1)\) of interest. Intuitively, this allows the classifier to distinguish subspaces and to learn a different ratio for each of them. Formally, the classifier network takes the form \(d_\phi(\theta_b, x, b)\) and the optimization problem becomes

\[\arg\min_\phi \frac{1}{2} \mathbb{E}_{p(\theta, x) P(b)} \big[ \ell(d_\phi(\theta_b, x, b)) \big] + \frac{1}{2} \mathbb{E}_{p(\theta)p(x) P(b)} \big[ \ell(1 - d_\phi(\theta_b, x, b)) \big],\]

where \(P(b)\) is a binary mask distribution. In this context, the Bayes optimal classifier is

\[d(\theta_b, x, b) = \frac{p(\theta_b, x)}{p(\theta_b, x) + p(\theta_b) p(x)} = \frac{r(\theta_b, x)}{1 + r(\theta_b, x)} .\]

Therefore, a classifier network trained for AMNRE gives access to an estimator \(\log r_\phi(\theta_b, x, b)\) of all marginal LTE log-ratios \(\log r(\theta_b, x)\).

References

Arbitrary Marginal Neural Ratio Estimation for Simulation-based Inference (Rozet et al., 2021)

Classes

AMNRE

Creates an arbitrary marginal neural ratio estimation (AMNRE) network.

AMNRELoss

Creates a module that calculates the cross-entropy loss for an AMNRE network.

Descriptions

class lampe.inference.amnre.AMNRE(theta_dim, x_dim, *args, **kwargs)

Creates an arbitrary marginal neural ratio estimation (AMNRE) network.

Parameters:
  • theta_dim (int) – The dimensionality \(D\) of the parameter space.

  • x_dim (int) – The dimensionality \(L\) of the observation space.

  • args – Positional arguments passed to lampe.inference.nre.NRE.

  • kwargs – Keyword arguments passed to lampe.inference.nre.NRE.

forward(theta, x, b)
Parameters:
  • theta (Tensor) – The parameters \(\theta\), with shape \((*, D)\), or a subset \(\theta_b\).

  • x (Tensor) – The observation \(x\), with shape \((*, L)\).

  • b (BoolTensor) – A binary mask \(b\), with shape \((*, D)\).

Returns:

The log-ratio \(\log r_\phi(\theta_b, x, b)\), with shape \((*,)\).

Return type:

Tensor

class lampe.inference.amnre.AMNRELoss(estimator, mask_dist)

Creates a module that calculates the cross-entropy loss for an AMNRE network.

Given a batch of \(N \geq 2\) pairs \((\theta_i, x_i)\), the module returns

\[l = \frac{1}{2N} \sum_{i = 1}^N \ell(d_\phi(\theta_i \odot b_i, x_i, b_i)) + \ell(1 - d_\phi(\theta_{i+1} \odot b_i, x_i, b_i))\]

where the binary masks \(b_i\) are sampled from a distribution \(P(b)\).

Parameters:
  • estimator (Module) – A log-ratio network \(\log r_\phi(\theta, x, b)\).

  • mask_dist (Distribution) – A binary mask distribution \(P(b)\).

forward(theta, x)
Parameters:
  • theta (Tensor) – The parameters \(\theta\), with shape \((N, D)\).

  • x (Tensor) – The observation \(x\), with shape \((N, L)\).

Returns:

The scalar loss \(l\).

Return type:

Tensor