paddlespeech.t2s.datasets.sampler 模块
- class paddlespeech.t2s.datasets.sampler.ErnieSATSampler(dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False)[来源]
基础:
BatchSampler限制数据加载到数据集的一个子集的采样器。 在这种情况下,每个进程可以传递一个 DistributedBatchSampler 实例 作为 DataLoader 的采样器,并加载对它独占的原始数据集的子集。 .. 注意:
Dataset is assumed to be of constant size.
- Args:
- dataset(paddle.io.Dataset): this could be a paddle.io.Dataset implement
或其他实现了 __len__ 的python对象,以便BatchSampler获取数据源的样本数量。
batch_size(int): mini-batch 中的样本索引数量。
num_replicas(int, optional): 分布式训练中的进程数量。如果
num_replicas为 None,num_replicas将会从paddle.distributed.ParallenEnv获取。默认值为 None。- rank(int, optional): the rank of the current process among
num_replicas 进程。如果
rank是 None,则rank从paddle.distributed.ParallenEnv中获取。默认为 None。- shuffle(bool): whther to shuffle indices order before genrating
批处理索引。默认为假。
- drop_last(bool): whether drop the last incomplete batch dataset size
不能被批量大小整除。默认值为假
- Examples:
方法
set_epoch(epoch)设置纪元号。当
shuffle=True时,此数字用作随机数的种子。默认情况下,用户不能设置此值,所有副本(工作者)在每个纪元中使用不同的随机排序。如果在每个纪元中设置相同的数字,则此采样器将在所有纪元中产生相同的排序。参数: epoch (int): 纪元号。示例: .. code-block:: python.- set_epoch(epoch)[来源]
设置纪元编号。当
shuffle=True时,此编号用作随机数的种子。默认情况下,用户可能无法设置此选项,所有副本(工作线程)在每个纪元使用不同的随机排序。如果在每个纪元设置相同的数字,则此采样器将在所有纪元中产生相同的排序。
参数:epoch (int): 纪元号。
- Examples:
import numpy as np from paddle.io import Dataset, DistributedBatchSampler # init with dataset class RandomDataset(Dataset): def __init__(self, num_samples): self.num_samples = num_samples def __getitem__(self, idx): image = np.random.random([784]).astype('float32') label = np.random.randint(0, 9, (1, )).astype('int64') return image, label def __len__(self): return self.num_samples dataset = RandomDataset(100) sampler = DistributedBatchSampler(dataset, batch_size=64) for epoch in range(10): sampler.set_epoch(epoch)