dowhy.causal_prediction.models 包#

子模块#

dowhy.causal_prediction.models.networks 模块#

The MNIST_MLP architecture is borrowed from OoD-Bench:
@inproceedings{ye2022ood,

标题={OoD-Bench: 量化与理解分布外泛化的两个维度}, 作者={叶南扬, 李凯灿, 白浩月, 于润鹏, 洪兰青, 周峰伟, 李振国, 朱军}, 会议={CVPR}, 年份={2022}

}

dowhy.causal_prediction.models.networks.Classifier(in_features, out_features, is_nonlinear=False)[source]#
class dowhy.causal_prediction.models.networks.ContextNet(input_shape)[source]#

基础类:Module

初始化内部模块状态,由nn.Module和ScriptModule共享。

forward(x)[来源]#

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传递的配方需要在此函数内定义,但之后应该调用Module实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。

training: bool#
class dowhy.causal_prediction.models.networks.Identity[source]#

基础类:Module

一个身份层

初始化内部模块状态,由nn.Module和ScriptModule共享。

forward(x)[来源]#

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传递的配方需要在此函数内定义,但之后应该调用Module实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。

training: bool#
class dowhy.causal_prediction.models.networks.MLP(n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout)[源代码]#

基础类:Module

只是一个多层感知器

初始化内部模块状态,由nn.Module和ScriptModule共享。

forward(x)[来源]#

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传递的配方需要在此函数内定义,但之后应该调用Module实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。

training: bool#
class dowhy.causal_prediction.models.networks.MNIST_CNN(input_shape)[源代码]#

基础类:Module

为MNIST手动调整的架构。 到目前为止我注意到的这个架构的奇怪之处: - 在特征中的均值池化后添加一个线性层会损害性能

RotatedMNIST-100 的泛化能力严重不足。

初始化内部模块状态,由nn.Module和ScriptModule共享。

forward(x)[来源]#

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传递的配方需要在此函数内定义,但之后应该调用Module实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。

n_outputs = 128#
training: bool#
class dowhy.causal_prediction.models.networks.MNIST_MLP(input_shape)[来源]#

基础类:Module

初始化内部模块状态,由nn.Module和ScriptModule共享。

forward(x)[来源]#

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传递的配方需要在此函数内定义,但之后应该调用Module实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。

training: bool#
class dowhy.causal_prediction.models.networks.ResNet(input_shape, resnet18=True, resnet_dropout=0.0)[源代码]#

基础类:Module

ResNet 去掉了 softmax 并冻结了 batchnorm

初始化内部模块状态,由nn.Module和ScriptModule共享。

forward(x)[来源]#

将x编码为大小为n_outputs的特征向量。

freeze_bn()[source]#
train(mode=True)[source]#

覆盖默认的train()以冻结BN参数

training: bool#

模块内容#