dowhy.causal_prediction.datasets 包#

子模块#

dowhy.causal_prediction.datasets.base_dataset 模块#

MultipleDomainDataset class in this file is borrowed from DomainBed: facebookresearch/DomainBed
@inproceedings{gulrajani2021in,

标题={寻找丢失的领域泛化}, 作者={Ishaan Gulrajani 和 David Lopez-Paz}, 会议={国际学习表示会议}, 年份={2021},

}

class dowhy.causal_prediction.datasets.base_dataset.MultipleDomainDataset[来源]#

基础类:object

CHECKPOINT_FREQ = 100#
ENVIRONMENTS = None#
INPUT_SHAPE = None#
N_STEPS = 5001#
N_WORKERS = 8#

dowhy.causal_prediction.datasets.mnist 模块#

class dowhy.causal_prediction.datasets.mnist.MNISTCausalAttribute(root, download=True)[source]#

基础类:MultipleDomainDataset

MNISTCausalAttribute 数据集的类。

Parameters:
  • root – 可以找到数据的目录(如果不存在,则应下载到该目录)。

  • download – 二进制标志,指示是否应下载数据

Returns:

MultipleDomainDataset 类的一个实例

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['+90%', '+80%', '-90%', '-90%']#
INPUT_SHAPE = (2, 14, 14)#
N_STEPS = 5001#
color_dataset(images, labels, environment)[source]#

将MNIST数据集进行转换,以引入属性(颜色)和标签之间的相关性。 标签Y和颜色之间存在直接的因果关系。

Parameters:
  • images – 原始的MNIST图像

  • labels – 原始的MNIST标签

  • environment – 颜色与标签之间的相关性值

Returns:

包含转换后的图像、标签和属性(颜色)的TensorDataset

torch_bernoulli_(p, size)[来源]#
torch_xor_(a, b)[来源]#
class dowhy.causal_prediction.datasets.mnist.MNISTCausalIndAttribute(root, download=True)[source]#

基础类:MultipleDomainDataset

MNISTIndAttribute 数据集的类。

Parameters:
  • root – 可以找到数据的目录(如果不存在,则应下载到该目录)。

  • download – 二进制标志,指示是否应下载数据

Returns:

MultipleDomainDataset 类的一个实例

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['+90%, 15', '+80%, 16', '-90%, 90', '-90%, 90']#
INPUT_SHAPE = (2, 14, 14)#
N_STEPS = 5001#
color_dataset(images, labels, environment)[source]#

将MNIST数据集进行转换,以引入属性(颜色)和标签之间的相关性。 标签Y和颜色之间存在直接的因果关系。

Parameters:
  • images – 旋转的MNIST图像

  • labels – 原始的MNIST标签

  • environment – 颜色与标签之间的相关性值

Returns:

转换后的图像、标签和属性(颜色)

color_rot_dataset(images, labels, environment, env_id, angle)[source]#

通过以下方式转换MNIST数据集:(i) 对图像应用旋转,然后(ii) 在属性(颜色)和标签之间引入相关性。 属性(旋转角度)与标签Y无关;标签Y和颜色之间存在直接因果关系。

Parameters:
  • images – 原始的MNIST图像

  • labels – 原始的MNIST标签

  • environment – 颜色与标签之间的相关性值

  • angle – 用于图像变换的旋转角度值

Returns:

包含转换后的图像、标签和属性(颜色、角度)的TensorDataset

rotate_dataset(images, angle)[source]#

通过对图像应用旋转来转换MNIST数据集。 属性(旋转角度)与标签Y无关。

Parameters:
  • images – 原始的MNIST图像

  • angle – 用于图像变换的旋转角度值

Returns:

转换后的图像

torch_bernoulli_(p, size)[来源]#
torch_xor_(a, b)[source]#
class dowhy.causal_prediction.datasets.mnist.MNISTIndAttribute(root, download=True)[source]#

基础类:MultipleDomainDataset

MNISTIndAttribute 数据集的类。

Parameters:
  • root – 可以找到数据的目录(如果不存在,则应下载到该目录)。

  • download – 二进制标志,指示是否应下载数据

Returns:

MultipleDomainDataset 类的一个实例

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['15', '60', '90', '90']#
INPUT_SHAPE = (1, 14, 14)#
N_STEPS = 5001#
rotate_dataset(images, labels, env_id, angle)[source]#

通过对图像应用旋转来转换MNIST数据集。 属性(旋转角度)与标签Y无关。

Parameters:
  • images – 原始的MNIST图像

  • labels – 原始的MNIST标签

  • angle – 用于图像变换的旋转角度值

Returns:

包含转换后的图像、标签和属性(角度)的TensorDataset

torch_bernoulli_(p, size)[来源]#
torch_xor_(a, b)[来源]#

模块内容#