dowhy.causal_prediction.dataloaders 包#

子模块#

dowhy.causal_prediction.dataloaders.fast_data_loader 模块#

class dowhy.causal_prediction.dataloaders.fast_data_loader.FastDataLoader(dataset, batch_size, num_workers)[source]#

基础类:object

DataLoader 包装器,通过不在每个周期重新生成工作进程,略微提高了速度。

class dowhy.causal_prediction.dataloaders.fast_data_loader.InfiniteDataLoader(dataset, weights, batch_size, num_workers)[来源]#

基础类:object

dowhy.causal_prediction.dataloaders.get_data_loader 模块#

dowhy.causal_prediction.dataloaders.get_data_loader.get_eval_loader(dataset, envs, batch_size, class_balanced=False)[source]#

返回评估数据加载器(测试/验证)。

Parameters:
  • dataset – 包含环境列表的数据集类

  • envs – 包含数据集中验证/测试域索引的列表

  • batch_size – 用于数据加载器的批量大小值

  • class_balanced – 二进制标志,指示是否在类别之间进行平衡采样

Returns:

数据加载器列表

dowhy.causal_prediction.dataloaders.get_data_loader.get_loaders(dataset, train_envs, batch_size, val_envs=None, test_envs=None, class_balanced=False, holdout_fraction=0.2, trial_seed=0)[来源]#

返回训练、验证和测试数据加载器。

Parameters:
  • dataset – 包含环境列表的数据集类

  • train_envs – 包含数据集中训练域索引的列表

  • batch_size – 用于数据加载器的批量大小值

  • val_envs – 包含数据集中验证域索引的列表。如果为None,则使用训练数据的一部分(holdout_fraction)来创建验证集。

  • test_envs – 包含数据集中测试域索引的列表

  • class_balanced – 二进制标志,指示是否在类别之间进行平衡采样

  • holdout_fraction – 用于创建验证域的训练数据的比例。当val_envs为None时使用。

  • trial_seed – 用于从训练数据生成验证分割的种子。当val_envs为None时使用。

Returns:

数据加载器列表的字典格式 {‘train_loaders’: [train_dataloader_1, train_dataloader_2, ….],

’val_loaders’: [val_dataloader_1, val_dataloader_2, ….], ‘test_loaders’: [test_dataloader_1, test_dataloader_2, ….]

}

dowhy.causal_prediction.dataloaders.get_data_loader.get_train_eval_loader(dataset, envs, batch_size, class_balanced, holdout_fraction, trial_seed)[source]#

返回训练和验证数据加载器。

Parameters:
  • dataset – 包含环境列表的数据集类

  • envs – 包含数据集中训练域索引的列表

  • batch_size – 用于数据加载器的批量大小值

  • class_balanced – 二进制标志,指示是否在类别之间进行平衡采样

  • holdout_fraction – 用于创建验证域的训练数据的比例

  • trial_seed – 用于从训练数据生成验证分割的种子

Returns:

两个数据加载器列表,分别用于训练(train_loaders)和验证(val_loaders)

dowhy.causal_prediction.dataloaders.get_data_loader.get_train_loader(dataset, envs, batch_size, class_balanced=False)[source]#

返回训练数据加载器。

Parameters:
  • dataset – 包含环境列表的数据集类

  • envs – 包含数据集中训练域索引的列表

  • batch_size – 用于数据加载器的批量大小值

  • class_balanced – 二进制标志,指示是否在类别之间进行平衡采样

Returns:

数据加载器列表

dowhy.causal_prediction.dataloaders.misc 模块#

杂项辅助函数

dowhy.causal_prediction.dataloaders.misc.make_weights_for_balanced_classes(dataset)[来源]#
dowhy.causal_prediction.dataloaders.misc.seed_hash(*args)[source]#

从所有参数中派生一个整数哈希,用作随机种子。

dowhy.causal_prediction.dataloaders.misc.split_dataset(dataset, n, seed=0)[来源]#

返回一对数据集,对应于给定数据集的随机分割,第一个数据集包含n个数据点,其余的在最后一个数据集中,使用给定的随机种子

模块内容#