分组采样器#

class pytorch_forecasting.data.samplers.GroupedSampler(sampler: Sampler, batch_size: int = 64, shuffle: bool = False, drop_last: bool = False)[来源]#

基础: Sampler

以分组的方式随机抽取小批量样本。

这意味着来自不同组的项目总是一起抽样。 这是一个抽象类。 实现get_groups()方法,该方法创建要进行抽样的组。

初始化。

Parameters:
  • sampler (采样器可迭代对象) – 基础采样器。可以是任何可迭代对象

  • drop_last (bool) – 如果最后一个小批量的样本数小于 batch_size,是否从一组中删除最后一个小批量。默认为 False。

  • shuffle (bool) – 是否打乱数据集。默认为 False。

  • batch_size (int, 可选) – 一个小批量中的样本数量。这实际上是样本的最大数量。由于小批量是按预测时间分组的,因此批量大小可能会小于最大值。默认为64。

Inherited-members:

方法

construct_batch_groups(groups)

构建可以被抽样的批次索引

get_groups(sampler)

创建可以进行采样的组。

construct_batch_groups(groups)[来源]#

构建可以被抽样的批次索引

get_groups(sampler: Sampler)[来源]#

创建可以进行采样的组。

Parameters:

采样器 (Sampler) – 将具有属性 data_source,其类型为 TimeSeriesDataSet。

Returns:

类似字典的对象,数据源索引作为值,组名称作为键

Return type:

字典类型