时间同步批采样器#

class pytorch_forecasting.data.samplers.TimeSynchronizedBatchSampler(sampler: Sampler, batch_size: int = 64, shuffle: bool = False, drop_last: bool = False)[来源]#

基础: GroupedSampler

随机地以时间同步的方式对小批量样本进行抽样。

时间同步意味着第一个解码器样本的时间索引在整个批次中对齐。 该采样器不支持数据集中缺失值。

初始化。

Parameters:
  • sampler (采样器可迭代对象) – 基础采样器。可以是任何可迭代对象

  • drop_last (bool) – 如果最后一个小批量的样本数小于 batch_size,是否从一组中删除最后一个小批量。默认为 False。

  • shuffle (bool) – 是否打乱数据集。默认为 False。

  • batch_size (int, 可选) – 一个小批量中的样本数量。这实际上是样本的最大数量。由于小批量是按预测时间分组的,因此批量大小可能会小于最大值。默认为64。

Inherited-members:

方法

construct_batch_groups(groups)

构建可以被抽样的批次索引

get_groups(sampler)

创建可以进行采样的组。

get_groups(sampler: Sampler)[来源]#

创建可以进行采样的组。

Parameters:

采样器 (Sampler) – 将具有属性 data_source,其类型为 TimeSeriesDataSet。

Returns:

类似字典的对象,数据源索引作为值,组名称作为键

Return type:

字典类型