6. Embedding and GPU acceleration#

This tutorial demonstrates how to train an embedding network to extract informative features from observations and use these features for neural posterior estimation. Training and inference are performed on GPU.

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import zuko

from tqdm import tqdm

from lampe.data import JointLoader, H5Dataset
from lampe.diagnostics import expected_coverage_mc
from lampe.inference import NPE, NPELoss
from lampe.plots import nice_rc, corner, mark_point, coverage_plot
from lampe.utils import GDStep

6.1. Simulator#

For this tutorial we use a simulator that procedurally generates distorted Worley noise images with respect to three parameters.

LABELS = [r'$\lambda$', r'$\sigma$', r'$p$']
LOWER = torch.tensor([10.0, 0.1, 1.0])
UPPER = torch.tensor([50.0, 0.5, 2.0])

prior = zuko.distributions.BoxUniform(LOWER, UPPER)

class WorleyNoise:
    def __init__(self):
        self.domain = torch.cartesian_prod(
            torch.linspace(0, 1, 64),
            torch.linspace(0, 1, 64),
        )
    
    def __call__(self, theta: torch.Tensor) -> torch.Tensor:
        lmbda, sigma, p = theta

        mu = torch.rand(2)
        cores = torch.rand(np.random.poisson(lmbda), 2)
        
        # Distortion
        x = self.domain
        x = x + torch.exp(-(x - mu).square().sum(dim=-1, keepdim=True) / sigma**2) * (x - mu)
        
        # Cells
        x = torch.cdist(x, cores, p=p)
        x = torch.min(x, dim=-1).values
        x = x / x.max()

        return x.reshape(64, 64)

simulator = WorleyNoise()

theta = prior.sample()
x = simulator(theta)

print('theta:', theta.tolist())
print('x:')

plt.imshow(x)
plt.show()
theta: [41.51055145263672, 0.21925213932991028, 1.9552865028381348]
x:
../_images/8a0469939e4e29f728c8d8bacccf08de581d2ee2c8f520751993463a4fa67e5a.png

6.2. Datasets#

We store (on disk) \(2^{16}\) pairs \((\theta, x) \sim p(\theta, x)\) for training, \(2^{12}\) pairs for validation and \(2^{10}\) pairs for testing.

loader = JointLoader(prior, simulator, batch_size=2**4, vectorized=False)

H5Dataset.store(loader, 'train.h5', size=2**16, overwrite=True)
H5Dataset.store(loader, 'valid.h5', size=2**12, overwrite=True)
H5Dataset.store(loader, 'test.h5', size=2**10, overwrite=True)

trainset = H5Dataset('train.h5', batch_size=256, shuffle=True)
validset = H5Dataset('valid.h5', batch_size=256)
testset = H5Dataset('test.h5')
100%|██████████| 65536/65536 [05:57<00:00, 183.51pair/s]
100%|██████████| 4096/4096 [00:24<00:00, 167.55pair/s]
100%|██████████| 1024/1024 [00:05<00:00, 192.07pair/s]

It is usually recommended to preprocess the inputs of neural networks so that their mean and variance are close to zero and one, respectively. In our case, it is necessary to process the parameters \(\theta\) but not the observation \(x\).

def preprocess(theta: torch.Tensor) -> torch.Tensor:
    return 2 * (theta - LOWER) / (UPPER - LOWER) - 1

def postprocess(theta: torch.Tensor) -> torch.Tensor:
    return (theta + 1) / 2 * (UPPER - LOWER) + LOWER

6.3. Architecture#

Because the observations are images, they cannot be fed directly to a NPE module, which only accepts vector-shaped observations. Instead, we rely on a convolutional neural network to extract a vector of informative features from the observation and feed this vector to the density estimator. In our case, the latter is a neural spline flow (NSF) with rational-quadratic splines borrowed from the zuko package. Additionally, the parameters are standardized to the \([-1, 1]\) interval, which improves convergence.

class ResBlock(nn.Sequential):
    def __init__(self, features: int):
        super().__init__(
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + super().forward(x)

class NPEWithEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.npe = NPE(3, 8, build=zuko.flows.NSF, hidden_features=[128] * 3, activation=nn.ELU)
        self.embedding = nn.Sequential(
            nn.Unflatten(-2, (1, 64)),
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            ResBlock(16),
            ResBlock(16),
            nn.AvgPool2d(2, 2),
            ResBlock(16),
            ResBlock(16),
            nn.AvgPool2d(2, 2),
            ResBlock(16),
            ResBlock(16),
            nn.AvgPool2d(2, 2),
            ResBlock(16),
            ResBlock(16),
            nn.AvgPool2d(2, 2),
            nn.Flatten(-3),
            nn.Linear(16 * 4 * 4, 8),
        )

    def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        return self.npe(theta, self.embedding(x))

    def flow(self, x: torch.Tensor):  # -> Distribution
        return self.npe.flow(self.embedding(x))

6.4. Training#

Using an embedding with NPE does not impact the training routine. Here, we choose to monitor the loss of the estimator on a validation set distinct from the training set, which could help detect overfitting. The training is performed on GPU, requiring to transfer batches to GPU before each step.

estimator = NPEWithEmbedding().cuda()
loss = NPELoss(estimator)
optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)
step = GDStep(optimizer, clip=1.0)

with tqdm(range(32), unit='epoch') as tq:
    for epoch in tq:
        estimator.train()

        losses = torch.stack([
            step(loss(preprocess(theta).cuda(), x.cuda()))
            for theta, x in trainset
        ])

        estimator.eval()

        with torch.no_grad():
            val_losses = torch.stack([
                loss(preprocess(theta).cuda(), x.cuda())
                for theta, x in validset
            ])

        tq.set_postfix(loss=losses.mean().item(), val_loss=val_losses.mean().item())
100%|██████████| 32/32 [11:56<00:00, 22.38s/epoch, loss=-2.95, val_loss=-3.08]

6.5. Inference#

theta_star, x_star = testset[0]

with torch.no_grad():
    samples = estimator.flow(x_star.cuda()).sample((2**16,)).cpu()
    samples = postprocess(samples)
plt.rcParams.update(nice_rc(latex=True))  # nicer plot settings

fig = corner(
    samples,
    smooth=2,
    domain=(LOWER, UPPER),
    labels=LABELS,
    legend=r'$p_\phi(\theta | x^*)$',
    figsize=(4.8, 4.8),
)

mark_point(fig, theta_star)
../_images/755bfb1ed87d49df9142870b03d089d3574e34b3a24bb46511fca7320fe80456.png

6.6. Expected coverage#

levels, coverages = expected_coverage_mc(
    posterior=estimator.flow,
    pairs=((preprocess(theta), x) for theta, x in testset),
    device='cuda',
)

fig = coverage_plot(levels, coverages, legend='NPE')
1024pair [00:33, 30.72pair/s]
../_images/e9a6a150c09cb3bd67767d0019d2119a8b93a0121e58d1572b7c27094b23cbf8.png

We notice that the estimator is slightly overconfident, meaning that the nominal parameters don’t fall in the predicted credible regions often enough. This problem could be mitigated with more training data.