因果效应VAE¶
该模块实现了因果效应变分自编码器 [1],展示了多项创新,包括:
一个用于因果效应推断的生成模型,包含隐藏的混杂因素;
一个模型和指南,带有双神经网络,以允许不平衡处理;以及
一个自定义的训练损失函数,包括ELBO项和额外的项,这些额外的项用于训练指南以能够回答反事实查询。
主要接口是CEVAE
类,但用户可以通过使用组件Model
、Guide
、TraceCausalEffect_ELBO
和实用工具来自定义。
参考文献
- [1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017).
- Causal Effect Inference with Deep Latent-Variable Models.
CEVAE 类¶
- class CEVAE(feature_dim, outcome_dist='bernoulli', latent_dim=20, hidden_dim=200, num_layers=3, num_samples=100)[source]¶
基础类:
torch.nn.modules.module.Module
实现因果效应VAE [1]的主类。这假设了一个图形模型
digraph { Z [pos="1,2!",style=filled]; X [pos="2,1!"]; y [pos="1,0!"]; t [pos="0,1!"]; Z -> X; Z -> t; Z -> y; t -> y; }其中 t 是一个二元处理变量,y 是一个结果,Z 是一个未观察到的混杂因素,X 是隐藏混杂因素 Z 的噪声函数。
示例:
cevae = CEVAE(feature_dim=5) cevae.fit(x_train, t_train, y_train) ite = cevae.ite(x_test) # individual treatment effect ate = ite.mean() # average treatment effect
- Variables
- Parameters
- fit(x, t, y, num_epochs=100, batch_size=100, learning_rate=0.001, learning_rate_decay=0.1, weight_decay=0.0001, log_every=100)[source]¶
使用
SVI
进行训练,并使用TraceCausalEffect_ELBO
损失函数。- Parameters
x (Tensor) –
t (Tensor) –
y (Tensor) –
num_epochs (int) – 训练轮数。默认为100。
batch_size (int) – 批量大小。默认为 100。
learning_rate (float) – 学习率。默认为1e-3。
learning_rate_decay (float) – 所有周期的学习率衰减; 每步的衰减率将取决于批量大小和周期数, 因此初始学习率将为
learning_rate
,最终学习率将为learning_rate * learning_rate_decay
。 默认值为 0.1。weight_decay (float) – 权重衰减。默认为1e-4。
log_every (int) – 每这么多步记录一次损失。如果为零,则不记录损失。默认为100。
- Returns
历元损失列表
- ite(x, num_samples=None, batch_size=None)[source]¶
计算一批数据
x
的个体治疗效果。\[ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr] - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]\]这具有复杂度
O(len(x) * num_samples ** 2)
。
- to_script_module()[source]¶
使用
torch.jit.trace_module()
编译此模块,假设self已经适应数据。- Returns
一个带有
ite()
方法的self的追踪版本。- Return type
CEVAE组件¶
- class Model(config)[source]¶
-
具有潜在混杂因素
z
和二元处理t
的因果模型的生成模型:z ~ p(z) # latent confounder x ~ p(x|z) # partial noisy observation of z t ~ p(t|z) # treatment, whose application is biased by z y ~ p(y|t,z) # outcome
这些分布中的每一个都由神经网络定义。
y
分布由一对不相交的神经网络定义,分别定义p(y|t=0,z)
和p(y|t=1,z)
;这允许高度不平衡的处理。- Parameters
config (dict) – 一个字典,指定
feature_dim
,latent_dim
,hidden_dim
,num_layers
, 和outcome_dist
。
- class Guide(config)[source]¶
-
用于因果效应估计的推理模型,包含潜在混杂因素
z
和二元处理t
:t ~ q(t|x) # treatment y ~ q(y|t,x) # outcome z ~ q(z|y,t,x) # latent confounder, an embedding
这些分布中的每一个都是由神经网络定义的。
y
和z
分布由定义p(-|t=0,...)
和p(-|t=1,...)
的不相交的神经网络对定义;这允许高度不平衡的处理。- Parameters
config (dict) – 一个字典,指定
feature_dim
,latent_dim
,hidden_dim
,num_layers
, 和outcome_dist
。
- class TraceCausalEffect_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基础类:
pyro.infer.trace_elbo.Trace_ELBO
用于训练
CEVAE
的损失函数。 根据[1],CEVAE的目标(最大化)是:-loss = ELBO + log q(t|x) + log q(y|t,x)
实用工具¶
- class FullyConnected(sizes, final_activation=None)[源代码]¶
基础类:
torch.nn.modules.container.Sequential
具有ELU激活的全连接多层网络。
- class DistributionNet(*args, **kwargs)[source]¶
基础类:
torch.nn.modules.module.Module
分布网络的基础类。
- class BernoulliNet(sizes)[source]¶
基础类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络输出一个单一的logits
值。这用于表示一个单一伯努利随机变量的条件概率分布,该变量以
sizes[0]
大小的实数值为条件,例如:net = BernoulliNet([3, 4]) z = torch.randn(3) logits, = net(z) t = net.make_dist(logits).sample()
- class ExponentialNet(sizes)[source]¶
基础类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络输出一个受约束的rate
。这用于表示单个正态随机变量的条件概率分布,该变量以
sizes[0]
大小的实数值为条件,例如:net = ExponentialNet([3, 4]) x = torch.randn(3) rate, = net(x) y = net.make_dist(rate).sample()
- class LaplaceNet(sizes)[source]¶
基础类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络输出一个受约束的loc,scale
对。这用于表示单个拉普拉斯随机变量在
sizes[0]
大小的实数值条件下的条件概率分布,例如:net = LaplaceNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
- class NormalNet(sizes)[source]¶
基础类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络输出一个受约束的loc,scale
对。这用于表示单个正态随机变量的条件概率分布,该变量以
sizes[0]
大小的实数值为条件,例如:net = NormalNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
- class StudentTNet(sizes)[source]¶
基础类:
pyro.contrib.cevae.DistributionNet
FullyConnected
网络输出一个受约束的df,loc,scale
三元组,其中共享的df > 1
。这用于表示单个学生t随机变量的条件概率分布,该变量以
sizes[0]
大小的实数值为条件,例如:net = StudentTNet([3, 4]) x = torch.randn(3) df, loc, scale = net(x) y = net.make_dist(df, loc, scale).sample()
- class DiagNormalNet(sizes)[源代码]¶
基础类:
torch.nn.modules.module.Module
FullyConnected
网络输出一个受约束的loc,scale
对。这用于表示一个
sizes[-1]
大小的对角正态随机变量的条件概率分布,该变量以sizes[0]
大小的实数值为条件,例如:net = DiagNormalNet([3, 4, 5]) z = torch.randn(3) loc, scale = net(z) x = dist.Normal(loc, scale).sample()
这是用于潜在
z
分布和预白化的x
特征,并保守地裁剪loc
和scale
值。