lampe.data#
Datasets and data loaders.
Classes#
Creates an infinite data loader of batched pairs \((\theta, x)\) generated by a prior distribution \(p(\theta)\) and a simulator. |
|
Creates an in-memory dataset of pairs \((\theta, x)\). |
|
Creates an iterable dataset of pairs \((\theta, x)\) from an HDF5 file. |
Descriptions#
- class lampe.data.JointLoader(prior, simulator, batch_size=256, vectorized=False, numpy=False, **kwargs)#
Creates an infinite data loader of batched pairs \((\theta, x)\) generated by a prior distribution \(p(\theta)\) and a simulator.
The simulator is a stochastic function taking (a vector of) parameters \(\theta\), in the form of a NumPy array or a PyTorch tensor, as input and returning an observation \(x\) as output, which implicitly defines a likelihood distribution \(p(x | \theta)\). Together with the prior, they form a joint distribution \(p(\theta, x) = p(\theta) p(x | \theta)\) from which pairs \((\theta, x)\) are independently drawn.
- Parameters
prior (Distribution) – A prior distribution \(p(\theta)\).
simulator (Callable) – A callable simulator.
batch_size (Optional[int]) – The batch size of the generated pairs.
vectorized (bool) – Whether the simulator accepts batched inputs or not.
numpy (bool) – Whether the simulator requires NumPy or PyTorch inputs.
kwargs – Keyword arguments passed to
torch.utils.data.DataLoader
.
Example
>>> loader = JointLoader(prior, simulator, numpy=True, num_workers=4) >>> for theta, x in loader: ... theta, x = theta.cuda(), x.cuda() ... something(theta, x)
- class lampe.data.JointDataset(theta, x, batch_size=None, shuffle=False)#
Creates an in-memory dataset of pairs \((\theta, x)\).
JointDataset
supports indexing and slicing, but also implements a custom__iter__
method which supports batching and shuffling.- Parameters
Example
>>> dataset = JointDataset(theta, x, batch_size=256, shuffle=True) >>> theta, x = dataset[42:69] >>> theta.shape torch.Size([27, 5]) >>> for theta, x in dataset: ... theta, x = theta.cuda(), x.cuda() ... something(theta, x)
- class lampe.data.H5Dataset(file, batch_size=None, chunk_size=256, chunk_step=256, shuffle=False)#
Creates an iterable dataset of pairs \((\theta, x)\) from an HDF5 file.
As it can be slow to load pairs from disk one by one,
H5Dataset
implements a custom__iter__
method that loads several contiguous chunks of pairs at once and shuffles their concatenation before yielding the pairs.H5Dataset
also implements the__len__
and__getitem__
methods for convenience.Important
When using
H5Dataset
withtorch.utils.data.DataLoader
, it is recommended to disable the loader’s automatic batching and provide the batch size to the dataset, as it significantly improves loading performances. In fact, unless using several workers or memory pinning, it is not even necessary to wrap the dataset in atorch.utils.data.DataLoader
.- Parameters
Example
>>> dataset = H5Dataset('data.h5', batch_size=256, shuffle=True) >>> theta, x = dataset[0] >>> theta tensor([-0.1215, -1.3641, 0.7233, -1.2150, -1.9263]) >>> for theta, x in dataset: ... theta, x = theta.cuda(), x.cuda() ... something(theta, x)
- to_memory()#
Loads all pairs in memory and returns them as a
JointDataset
.Example
>>> dataset = H5Dataset('data.h5').to_memory()
- static store(pairs, file, size, overwrite=False, dtype=<class 'numpy.float32'>, **meta)#
Creates an HDF5 file containing pairs \((\theta, x)\).
The sets of parameters \(\theta\) are stored in a collection named
'theta'
and the observations in a collection named'x'
.- Parameters
pairs (Iterable[Tuple[ndarray, ndarray]]) – An iterable over batched pairs \((\theta, x)\).
file (Union[str, Path]) – An HDF5 filename to store pairs in.
size (int) – The number of pairs to store.
overwrite (bool) – Whether to overwrite existing files or not. If
False
and the file already exists, the function raises an error.dtype (dtype) – The data type to store pairs in.
meta – Metadata to store in the file.
Example
>>> loader = JointLoader(prior, simulator, batch_size=16) >>> H5Dataset.store(loader, 'data.h5', 4096) 100%|██████████| 4096/4096 [01:35<00:00, 42.69pair/s]