paddlespeech.s2t.io.batchfy模块

paddlespeech.s2t.io.batchfy.make_batchset(data, batch_size=0, max_length_in=inf, max_length_out=inf, num_batches=0, min_batch_size=1, shortest_first=False, batch_sort_key='input', count='auto', batch_bins=0, batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, iaxis=0, oaxis=0)[来源]

从json字典中创建批量集

如果utts具有"category"值,

>>> data = [{'category': 'A', 'input': ..., 'utt':'utt1'},
...         {'category': 'B', 'input': ..., 'utt':'utt2'},
...         {'category': 'B', 'input': ..., 'utt':'utt3'},
...         {'category': 'A', 'input': ..., 'utt':'utt4'}]
>>> make_batchset(data, batchsize=2, ...)
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]

请注意,如果任何 utts 没有 "category",执行方式与 batchfy_by_{count} 相同

Parameters:
  • data (List[Dict[str, Any]]) -- 从 data.json 加载的字典

  • batch_size (int) -- 一次小批量中序列的最大数量。

  • batch_bins (int) -- mini批次中最大箱子的数量(帧 x 维度)。

  • batch_frames_in (int) -- 每个小批量中的最大输入帧数。

  • batch_frames_out (int) -- 每个小批量中最大输出帧数。

  • batch_frames_out -- 小批量中输入+输出帧的最大数量。

  • count (str) -- 计算批处理最大大小的策略。
    有关选项,请参见 io.batchfy.BATCH_COUNT_CHOICES

  • max_length_in (int) -- 决定自适应批量大小的输入最大长度

  • max_length_out (int) -- 决定自适应批量大小的输出最大长度

  • num_batches (int) -- # 使用的批次数(用于调试)

  • min_batch_size (int) -- 最小批量大小(用于多GPU)

  • shortest_first (bool) -- 如果为真,按样本从最短到最长排序,否则反向排序

  • batch_sort_key (str) -- 在创建小批量数据之前如何排序数据 ["input", "output", "shuffle"]

  • swap_io (bool) -- 如果为 True,使用 "input" 作为输出,"output" 作为输入在 data 字典中

  • mt (bool) -- 如果为 True,则使用 "output" 的 0 轴作为输出,"output" 的 1 轴作为输入,在 data 字典中

  • iaxis (int) -- 用于访问输入的维度 (对于ASR和TTS,iaxis=0,对于MT,iaxis="1".)

  • oaxis (int) -- 访问输出的维度(对于ASR,TTS,MT,oaxis=0,保留用于未来研究,-1 意味着所有轴。)

Returns:

List[List[Tuple[str, dict]]] 批次列表