1. Simulators and datasets#

In lampe, a simulator can be any Python callable that takes (a vector of) parameters \(\theta\) as input, in the form of a NumPy array or a PyTorch tensor, and returns a stochastic observation \(x\) as output. As such, a simulator implicitly defines a likelihood distribution \(p(x | \theta)\). The prior distribution \(p(\theta)\) is restricted to be a PyTorch distribution implementing the sample and log_prob methods. Together, the prior and the simulator form a joint distribution \(p(\theta, x) = p(\theta) p(x | \theta)\) from which parameters-observation pairs \((\theta, x)\) can be drawn.

This notebook walks you through setting up a prior and a simulator, sampling pairs from the joint, saving pairs on disk and, finally, loading pairs from disk.

%matplotlib inline

import lampe
import torch
import zuko

from itertools import chain, islice
from tqdm import tqdm

1.1. Prior#

In Bayesian inference problems, it is quite common to have a uniform prior distribution \(p(\theta)\). We recommend to use the BoxUniform class from the zuko package to create a multivariate uniform distribution between a lower and an upper bound.

For instance, let’s say \(\theta\) is uniform over the domain \([-1, 1]^3\).

LOWER = -torch.ones(3)
UPPER = torch.ones(3)

prior = zuko.distributions.BoxUniform(LOWER, UPPER)
prior
BoxUniform(low: torch.Size([3]), high: torch.Size([3]))

Using the sample method, parameters \(\theta\) can be sampled from the newly created prior distribution. The log-density can also be evaluated for given parameters using the log_prob method.

theta = prior.sample()
theta
tensor([-0.7064,  0.0166, -0.3484])
prior.log_prob(theta)
tensor(-2.0794)

1.2. Simulator#

As mentioned previously, the only constraints for the simulator are that the vector of parameters \(\theta\) and the observation \(x\) must either be NumPy arrays or PyTorch tensors. The callable is otherwise unconstrained, which allows for a large variety of simulators, including already existing ones and ones implemented in other programming languages than Python.

Let’s consider the following generative process as our simulator:

\[\begin{split} x = \begin{pmatrix} \theta_1 + \theta_2 \times \theta_3 \\ \theta_1 \times \theta_2 + \theta_3 \end{pmatrix} + 0.05 \times \varepsilon \end{split}\]

where \(\varepsilon\) is a 2-d standard Gaussian random variable.

def simulator(theta: torch.Tensor) -> torch.Tensor:
    x = torch.stack([
        theta[..., 0] + theta[..., 1] * theta[..., 2],
        theta[..., 0] * theta[..., 1] + theta[..., 2],
    ], dim=-1)

    return x + 0.05 * torch.randn_like(x)

theta = prior.sample()
x = simulator(theta)

print(theta, x, sep='\n')
tensor([ 0.9121, -0.5785,  0.7630])
tensor([0.4579, 0.3166])

This implementation is vectorized, which means that it can process a batch of parameters at once.

theta = prior.sample((4,))
x = simulator(theta)

print(theta, x, sep='\n')
tensor([[ 0.9290, -0.1757, -0.7820],
        [ 0.5494, -0.7134, -0.3816],
        [-0.2231, -0.6708,  0.0089],
        [-0.3692,  0.8460,  0.4389]])
tensor([[ 1.1476, -1.0494],
        [ 0.9043, -0.7716],
        [-0.3210,  0.2096],
        [-0.0251,  0.1417]])

1.3. Sampling#

Given a prior and a simulator, the JointLoader class provided by the lampe.data module creates an iterable DataLoader of batched parameters-observation pairs \((\theta, x) \sim p(\theta, x)\).

loader = lampe.data.JointLoader(prior, simulator, batch_size=256)

for theta, x in tqdm(islice(loader, 256), total=256):
    pass
100%|██████████| 256/256 [00:03<00:00, 84.28it/s]

If the simulator takes a NumPy array as input, it should be indicated with numpy=True. Additionally, if the simulator is vectorized, it is recommended to indicate it with vectorized=True as it can improve sampling performances significantly.

loader = lampe.data.JointLoader(prior, simulator, batch_size=256, vectorized=True)

for theta, x in tqdm(islice(loader, 256), total=256):
    pass
100%|██████████| 256/256 [00:00<00:00, 10390.68it/s]

1.4. Loading in memory#

If the simulator is fast or inexpensive, it is reasonable to generate pairs \((\theta, x)\) on demand. Otherwise, the pairs have to be generated ahead of time. The lampe.data module provides the JointDataset class to interact with in-memory pairs \((\theta, x)\). This is ideal when your data fits in RAM.

theta = prior.sample((1024,))
x = simulator(theta)

dataset = lampe.data.JointDataset(theta, x)

print(*dataset[42], sep='\n')
tensor([ 0.6757, -0.3787,  0.9581])
tensor([0.3410, 0.7707])
for theta, x in tqdm(dataset):
    pass
100%|██████████| 1024/1024 [00:00<00:00, 282991.85it/s]

JointDataset can be wrapped in a DataLoader to enable batching and shuffling.

1.5. Saving on disk#

If your data does not fit in RAM or you need to reuse it later, you may want to store it on disk. The HDF5 file format is commonly used for this purpose, as it was specifically designed to hold large amounts of numerical data. The lampe.data module provides the H5Dataset class to help load and store pairs \((\theta, x)\) in HDF5 files. The H5Dataset.store function takes an iterable of batched pairs \((\theta, x)\) as input and stores them into a new HDF5 file. The iterable can be a precomputed list, a custom generator or even a JointLoader instance.

pairs = []

for _ in range(256):
    theta = prior.sample((256,))
    x = simulator(theta)

    pairs.append((theta, x))

lampe.data.H5Dataset.store(pairs, 'data_0.h5', size=2**16)
100%|██████████| 65536/65536 [00:00<00:00, 518158.51sample/s]
def generate():
    while True:
        theta = prior.sample((256,))
        x = simulator(theta)

        yield theta, x

lampe.data.H5Dataset.store(generate(), 'data_1.h5', size=2**16)
100%|██████████| 65536/65536 [00:00<00:00, 375452.15sample/s]
lampe.data.H5Dataset.store(loader, 'data_2.h5', size=2**16)
100%|██████████| 65536/65536 [00:00<00:00, 369042.89sample/s]

1.6. Loading from disk#

Now that the pairs are stored, we need to load them. The H5Dataset class creates an IterableDataset of pairs \((\theta, x)\) from an HDF5 file. The pairs are dynamically loaded, meaning they are read from disk on demand instead of being cached in memory. This allows for very large datasets that do not even fit in memory.

dataset = lampe.data.H5Dataset('data_0.h5')

for i in tqdm(range(len(dataset))):
    theta, x = dataset[i]
100%|██████████| 65536/65536 [00:53<00:00, 1227.49it/s]

However, as it can be slow to read pairs from disk one by one, H5Dataset implements a custom __iter__ method which loads pairs by chunks to reduce the number of disk reads.

for theta, x in tqdm(dataset):
    pass
100%|██████████| 65536/65536 [00:00<00:00, 188557.83it/s]

Alternatively, if your data fits in memory, you can load it at once with the to_memory method, which returns a JointDataset.

dataset = lampe.data.H5Dataset('data_0.h5').to_memory()

1.6.1. Batching and shuffling#

As DataLoader is not able to shuffle the elements of an IterableDataset instance, H5Dataset provides a shuffle argument to enable (disabled by default) shuffling the pairs when iterating. A batch_size argument is also available.

dataset = lampe.data.H5Dataset('data_0.h5', batch_size=4, shuffle=True)
for theta, x in dataset:
    print(theta, x, sep='\n')
    break
tensor([[-0.5102, -0.1597,  0.4225],
        [ 0.6263, -0.5490, -0.5424],
        [ 0.3914,  0.8860,  0.7201],
        [-0.6238,  0.9327, -0.6588]])
tensor([[-0.5276,  0.5434],
        [ 0.8239, -0.9316],
        [ 0.9018,  1.0367],
        [-1.3076, -1.2636]])
for theta, x in dataset:
    print(theta, x, sep='\n')
    break
tensor([[ 0.3951, -0.0103, -0.9181],
        [-0.5488, -0.9719, -0.5787],
        [-0.2422, -0.0027,  0.8907],
        [ 0.6688,  0.7993, -0.6854]])
tensor([[ 0.2745, -0.8634],
        [-0.0400, -0.0773],
        [-0.2196,  0.9108],
        [ 0.1359, -0.1790]])

1.7. Merging datasets#

H5Dataset can only load data from a single HDF5 file, but it is easy to aggregate data from multiple sources, such as data generated on several machines or at different time.

datasets = [
    lampe.data.H5Dataset('data_0.h5', batch_size=512),
    lampe.data.H5Dataset('data_1.h5', batch_size=256),
    lampe.data.H5Dataset('data_2.h5', batch_size=128),
]

lampe.data.H5Dataset.store(
    pairs=chain(*datasets),
    file='data_all.h5',
    size=sum(map(len, datasets)),
)
100%|██████████| 196608/196608 [00:00<00:00, 206486.69it/s]
dataset = lampe.data.H5Dataset('data_all.h5')

for theta, x in tqdm(dataset):
    pass
100%|██████████| 196608/196608 [00:00<00:00, 242486.34it/s]