时间同步批采样器#
- class pytorch_forecasting.data.samplers.TimeSynchronizedBatchSampler(sampler: Sampler, batch_size: int = 64, shuffle: bool = False, drop_last: bool = False)[来源]#
基础:
GroupedSampler
随机地以时间同步的方式对小批量样本进行抽样。
时间同步意味着第一个解码器样本的时间索引在整个批次中对齐。 该采样器不支持数据集中缺失值。
初始化。
- Parameters:
sampler (采样器 或 可迭代对象) – 基础采样器。可以是任何可迭代对象
drop_last (bool) – 如果最后一个小批量的样本数小于 batch_size,是否从一组中删除最后一个小批量。默认为 False。
shuffle (bool) – 是否打乱数据集。默认为 False。
batch_size (int, 可选) – 一个小批量中的样本数量。这实际上是样本的最大数量。由于小批量是按预测时间分组的,因此批量大小可能会小于最大值。默认为64。
- Inherited-members:
方法
construct_batch_groups
(groups)构建可以被抽样的批次索引
get_groups
(sampler)创建可以进行采样的组。