Pyro中的Zuko

该文件包含在Pyro管道中使用基于Zuko的归一化流的帮助程序。

可以在 tutorial/svi_flow_guide.ipynbtutorial/vae_flow_prior.ipynb 找到相关的教程。

class ZukoToPyro(dist: torch.distributions.distribution.Distribution)[source]

将Zuko分布包装为Pyro分布。

如果 dist 有一个 rsample_and_log_prob 方法,比如 Zuko 的流,它将在采样时使用,而不是 rsample。返回的对数密度将被缓存以供后续评分使用。

Parameters

dist (torch.distributions.Distribution) – 一个分布实例。

flow = zuko.flows.MAF(features=5)

# flow() is a torch.distributions.Distribution

dist = flow()
x = dist.sample((2, 3))
log_p = dist.log_prob(x)

# ZukoToPyro(flow()) is a pyro.distributions.Distribution

dist = ZukoToPyro(flow())
x = dist((2, 3))
log_p = dist.log_prob(x)

with pyro.plate("data", 42):
    z = pyro.sample("z", dist)