Pyro中的Zuko¶
该文件包含在Pyro管道中使用基于Zuko的归一化流的帮助程序。
可以在 tutorial/svi_flow_guide.ipynb 和 tutorial/vae_flow_prior.ipynb 找到相关的教程。
- class ZukoToPyro(dist: torch.distributions.distribution.Distribution)[source]¶
将Zuko分布包装为Pyro分布。
如果
dist
有一个rsample_and_log_prob
方法,比如 Zuko 的流,它将在采样时使用,而不是rsample
。返回的对数密度将被缓存以供后续评分使用。- Parameters
dist (torch.distributions.Distribution) – 一个分布实例。
flow = zuko.flows.MAF(features=5) # flow() is a torch.distributions.Distribution dist = flow() x = dist.sample((2, 3)) log_p = dist.log_prob(x) # ZukoToPyro(flow()) is a pyro.distributions.Distribution dist = ZukoToPyro(flow()) x = dist((2, 3)) log_p = dist.log_prob(x) with pyro.plate("data", 42): z = pyro.sample("z", dist)