lampe.inference.amnre¶
Arbitrary marginal neural ratio estimation (AMNRE) components.
The principle of AMNRE is to introduce, as input to the classifier, a binary mask \(b \in \{0, 1\}^D\) indicating a subset of parameters \(\theta_b = (\theta_i: b_i = 1)\) of interest. Intuitively, this allows the classifier to distinguish subspaces and to learn a different ratio for each of them. Formally, the classifier network takes the form \(d_\phi(\theta_b, x, b)\) and the optimization problem becomes
where \(P(b)\) is a binary mask distribution. In this context, the Bayes optimal classifier is
Therefore, a classifier network trained for AMNRE gives access to an estimator \(\log r_\phi(\theta_b, x, b)\) of all marginal LTE log-ratios \(\log r(\theta_b, x)\).
References
Classes¶
Descriptions¶
- class lampe.inference.amnre.AMNRE(theta_dim, x_dim, *args, **kwargs)¶
Creates an arbitrary marginal neural ratio estimation (AMNRE) network.
- Parameters:
theta_dim (int) – The dimensionality \(D\) of the parameter space.
x_dim (int) – The dimensionality \(L\) of the observation space.
args – Positional arguments passed to
lampe.inference.nre.NRE.kwargs – Keyword arguments passed to
lampe.inference.nre.NRE.
- forward(theta, x, b)¶
- Parameters:
- Returns:
The log-ratio \(\log r_\phi(\theta_b, x, b)\), with shape \((*,)\).
- Return type:
- class lampe.inference.amnre.AMNRELoss(estimator, mask_dist)¶
Creates a module that calculates the cross-entropy loss for an AMNRE network.
Given a batch of \(N \geq 2\) pairs \((\theta_i, x_i)\), the module returns
\[l = \frac{1}{2N} \sum_{i = 1}^N \ell(d_\phi(\theta_i \odot b_i, x_i, b_i)) + \ell(1 - d_\phi(\theta_{i+1} \odot b_i, x_i, b_i))\]where the binary masks \(b_i\) are sampled from a distribution \(P(b)\).
- Parameters:
estimator (Module) – A log-ratio network \(\log r_\phi(\theta, x, b)\).
mask_dist (Distribution) – A binary mask distribution \(P(b)\).