lampe.inference.cnre

Contrastive neural ratio estimation (CNRE) components.

The principle of contrastive neural ratio estimation (CNRE) is to predict whether a set \(\Theta = \{\theta^1, \dots, \theta^K\}\) contains or not the parameters that originated an observation \(x\). The elements of \(\Theta\) are drawn independently from the prior \(p(\theta)\) and the element \(\theta^k\) that originates the observation \(x \sim p(x | \theta^k)\) is chosen uniformly within \(\Theta\), such that

\[\begin{split}p(\Theta, x) & = p(\Theta) \, p(x | \Theta) \\ & = p(\Theta) \frac{1}{K} \sum_{k = 1}^K p(x | \theta^k) \\ & = p(\Theta) \, p(x) \frac{1}{K} \sum_{k = 1}^K r(\theta^k, x)\end{split}\]

where \(r(\theta, x)\) is the likelihood-to-evidence (LTE) ratio. The task is to discriminate between pairs \((\Theta, x)\) for which \(\Theta\) either does or does not contain the nominal parameters of \(x\), similar to the original NRE optimization problem. For this task, the decision function modeling the Bayes optimal classifier is

\[d(\Theta, x) = \frac{p(\Theta, x)}{p(\Theta, x) + \frac{1}{\gamma} p(\Theta) p(x)} = \frac{\sum_{k = 1}^K r(\theta^k, x)}{\frac{K}{\gamma} + \sum_{k = 1}^K r(\theta^k, x)} \, ,\]

where \(\gamma \in \mathbb{R}^+\) are the odds of \(\Theta\) containing to not containing the nominal parameters. Consequently, a classifier \(d_\phi(\Theta, x)\) can be equivalently replaced and trained as a composition of ratios \(r_\phi(\theta^k, x)\).

Note

The quantity \(d_\phi(\Theta, x)\) corresponds to \(q_\phi(y \neq 0 | \Theta, x)\) or \(1 - q_\phi(y = 0 | \Theta, x)\) in the notations of Miller et al. (2022).

References

Contrastive Neural Ratio Estimation (Miller et al., 2022)

Classes

CNRELoss

Creates a module that calculates the contrastive cross-entropy loss for a NRE network.

BCNRELoss

Creates a module that calculates the balanced contrastive cross-entropy loss for a NRE network.

Descriptions

class lampe.inference.cnre.CNRELoss(estimator, cardinality=2, gamma=1.0)

Creates a module that calculates the contrastive cross-entropy loss for a NRE network.

Given a batch of \(N \geq 2K\) pairs \((\theta_i, x_i)\), the module returns

\[l = \frac{1}{N} \sum_{i = 1}^N \frac{\gamma}{\gamma + 1} \ell(d_\phi(\Theta_i, x_i)) + \frac{1}{\gamma + 1} \ell(1 - d_\phi(\Theta_{i+K}, x_i))\]

where \(\ell(p) = -\log p\) is the negative log-likelihood and \(\Theta_i = \{\theta_i, \dots, \theta_{i+K-1}\}\).

References

Contrastive Neural Ratio Estimation (Miller et al., 2022)
Parameters:
  • estimator (Module) – A log-ratio network \(\log r_\phi(\theta, x)\).

  • cardinality (int) – The cardinality \(K\) of \(\Theta\).

  • gamma (float) – The odds ratio \(\gamma\).

forward(theta, x)
Parameters:
  • theta (Tensor) – The parameters \(\theta\), with shape \((N, D)\).

  • x (Tensor) – The observation \(x\), with shape \((N, L)\).

Returns:

The scalar loss \(l\).

Return type:

Tensor

class lampe.inference.cnre.BCNRELoss(estimator, cardinality=2, gamma=1.0, lmbda=100.0)

Creates a module that calculates the balanced contrastive cross-entropy loss for a NRE network.

Given a batch of \(N \geq 2K\) pairs \((\theta_i, x_i)\), the module returns

\[\begin{split}l & = \frac{1}{N} \sum_{i = 1}^N \frac{\gamma}{\gamma + 1} \ell(d_\phi(\Theta_i, x_i)) + \frac{1}{\gamma + 1} \ell(1 - d_\phi(\Theta_{i+K}, x_i)) \\ & + \lambda \left(1 - \frac{1}{N} \sum_{i = 1}^N d_\phi(\Theta_i, x_i) + d_\phi(\Theta_{i+K}, x_i) \right)^2\end{split}\]

where \(\ell(p) = -\log p\) is the negative log-likelihood and \(\Theta_i = \{\theta_i, \dots, \theta_{i+K-1}\}\).

References

Balancing Simulation-based Inference for Conservative Posteriors (Delaunoy et al., 2023)
Parameters:
  • estimator (Module) – A log-ratio network \(\log r_\phi(\theta, x)\).

  • cardinality (int) – The cardinality \(K\) of \(\Theta\).

  • gamma (float) – The odds ratio \(\gamma\).

  • lmbda (float) – The weight \(\lambda\) controlling the strength of the balancing condition.

forward(theta, x)
Parameters:
  • theta (Tensor) – The parameters \(\theta\), with shape \((N, D)\).

  • x (Tensor) – The observation \(x\), with shape \((N, L)\).

Returns:

The scalar loss \(l\).

Return type:

Tensor