分组采样器#
- class pytorch_forecasting.data.samplers.GroupedSampler(sampler: Sampler, batch_size: int = 64, shuffle: bool = False, drop_last: bool = False)[来源]#
基础:
Sampler
以分组的方式随机抽取小批量样本。
这意味着来自不同组的项目总是一起抽样。 这是一个抽象类。 实现
get_groups()
方法,该方法创建要进行抽样的组。初始化。
- Parameters:
sampler (采样器 或 可迭代对象) – 基础采样器。可以是任何可迭代对象
drop_last (bool) – 如果最后一个小批量的样本数小于 batch_size,是否从一组中删除最后一个小批量。默认为 False。
shuffle (bool) – 是否打乱数据集。默认为 False。
batch_size (int, 可选) – 一个小批量中的样本数量。这实际上是样本的最大数量。由于小批量是按预测时间分组的,因此批量大小可能会小于最大值。默认为64。
- Inherited-members:
方法
construct_batch_groups
(groups)构建可以被抽样的批次索引
get_groups
(sampler)创建可以进行采样的组。