dowhy.causal_prediction.datasets 包#
子模块#
dowhy.causal_prediction.datasets.base_dataset 模块#
- MultipleDomainDataset class in this file is borrowed from DomainBed: facebookresearch/DomainBed
- @inproceedings{gulrajani2021in,
标题={寻找丢失的领域泛化}, 作者={Ishaan Gulrajani 和 David Lopez-Paz}, 会议={国际学习表示会议}, 年份={2021},
}
dowhy.causal_prediction.datasets.mnist 模块#
- class dowhy.causal_prediction.datasets.mnist.MNISTCausalAttribute(root, download=True)[source]#
-
MNISTCausalAttribute 数据集的类。
- Parameters:
root – 可以找到数据的目录(如果不存在,则应下载到该目录)。
download – 二进制标志,指示是否应下载数据
- Returns:
MultipleDomainDataset 类的一个实例
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['+90%', '+80%', '-90%', '-90%']#
- INPUT_SHAPE = (2, 14, 14)#
- N_STEPS = 5001#
- class dowhy.causal_prediction.datasets.mnist.MNISTCausalIndAttribute(root, download=True)[source]#
-
MNISTIndAttribute 数据集的类。
- Parameters:
root – 可以找到数据的目录(如果不存在,则应下载到该目录)。
download – 二进制标志,指示是否应下载数据
- Returns:
MultipleDomainDataset 类的一个实例
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['+90%, 15', '+80%, 16', '-90%, 90', '-90%, 90']#
- INPUT_SHAPE = (2, 14, 14)#
- N_STEPS = 5001#
- color_dataset(images, labels, environment)[source]#
将MNIST数据集进行转换,以引入属性(颜色)和标签之间的相关性。 标签Y和颜色之间存在直接的因果关系。
- Parameters:
images – 旋转的MNIST图像
labels – 原始的MNIST标签
environment – 颜色与标签之间的相关性值
- Returns:
转换后的图像、标签和属性(颜色)
- color_rot_dataset(images, labels, environment, env_id, angle)[source]#
通过以下方式转换MNIST数据集:(i) 对图像应用旋转,然后(ii) 在属性(颜色)和标签之间引入相关性。 属性(旋转角度)与标签Y无关;标签Y和颜色之间存在直接因果关系。
- Parameters:
images – 原始的MNIST图像
labels – 原始的MNIST标签
environment – 颜色与标签之间的相关性值
angle – 用于图像变换的旋转角度值
- Returns:
包含转换后的图像、标签和属性(颜色、角度)的TensorDataset
- class dowhy.causal_prediction.datasets.mnist.MNISTIndAttribute(root, download=True)[source]#
-
MNISTIndAttribute 数据集的类。
- Parameters:
root – 可以找到数据的目录(如果不存在,则应下载到该目录)。
download – 二进制标志,指示是否应下载数据
- Returns:
MultipleDomainDataset 类的一个实例
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['15', '60', '90', '90']#
- INPUT_SHAPE = (1, 14, 14)#
- N_STEPS = 5001#