paddlespeech.t2s.datasets.sampler 模块

class paddlespeech.t2s.datasets.sampler.ErnieSATSampler(dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False)[来源]

基础: BatchSampler

限制数据加载到数据集的一个子集的采样器。 在这种情况下,每个进程可以传递一个 DistributedBatchSampler 实例 作为 DataLoader 的采样器,并加载对它独占的原始数据集的子集。 .. 注意:

Dataset is assumed to be of constant size.
Args:
dataset(paddle.io.Dataset): this could be a paddle.io.Dataset implement

或其他实现了 __len__ 的python对象,以便BatchSampler获取数据源的样本数量。

batch_size(int): mini-batch 中的样本索引数量。
num_replicas(int, optional): 分布式训练中的进程数量。

如果 num_replicas 为 None, num_replicas 将会从 paddle.distributed.ParallenEnv 获取。默认值为 None。

rank(int, optional): the rank of the current process among num_replicas

进程。如果 rank 是 None,则 rankpaddle.distributed.ParallenEnv 中获取。默认为 None。

shuffle(bool): whther to shuffle indices order before genrating

批处理索引。默认为假。

drop_last(bool): whether drop the last incomplete batch dataset size

不能被批量大小整除。默认值为假

Examples:

方法

set_epoch(epoch)

设置纪元号。当 shuffle=True 时,此数字用作随机数的种子。默认情况下,用户不能设置此值,所有副本(工作者)在每个纪元中使用不同的随机排序。如果在每个纪元中设置相同的数字,则此采样器将在所有纪元中产生相同的排序。参数: epoch (int): 纪元号。示例: .. code-block:: python.

set_epoch(epoch)[来源]

设置纪元编号。当 shuffle=True 时,此编号用作随机数的种子。默认情况下,用户可能无法设置此选项,所有副本(工作线程)在每个纪元使用不同的随机排序。如果在每个纪元设置相同的数字,则此采样器将在所有纪元中产生相同的排序。
参数:

epoch (int): 纪元号。

Examples:
import numpy as np

from paddle.io import Dataset, DistributedBatchSampler

# init with dataset
class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([784]).astype('float32')
        label = np.random.randint(0, 9, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)

for epoch in range(10):
    sampler.set_epoch(epoch)