paddlespeech.s2t.training.extensions.plot 模块
- class paddlespeech.s2t.training.extensions.plot.PlotAttentionReport(att_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)[来源]
基础:
Extension绘制注意力报告。
- Args:
- att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
注意力可视化的功能。
data (list[tuple(str, dict[str, list[Any]])]): 列表 json utt 键项。 outdir (str): 保存图形的目录。 converter (espnet.asr.*_backend.asr.CustomConverter):
转换数据的函数。
device (int | torch.device): 设备。
reverse (bool): 如果为 True,输入和输出长度将被反转。
ikey (str): 用于访问输入的密钥(用于ASR/ST ikey="input",用于MT ikey="output"。)
- iaxis (int): Dimension to access input
(对于ASR/ST iaxis=0,对于MT iaxis=1.)
- okey (str): Key to access output
(对于 ASR/ST okey="input",MT okay="output".)
- oaxis (int): Dimension to access output
(对于ASR/ST oaxis=0,对于MT oaxis=0.)
subsampling_factor (int): 编码器中的下采样因子
- Attributes:
default_name扩展的默认名称,默认的类名。
- name
方法
__call__(trainer)绘制并保存att_ws矩阵的图像文件。
draw_attention_plot(att_w)绘制 att_w 矩阵。
draw_han_plot(att_w)绘制层次注意力的 att_w 矩阵。
finalize(trainer)训练完成时执行的操作。
返回注意力权重。
initialize(trainer)执行一次以获取正确的培训师状态的操作。
log_attentions(logger, step)将att_ws矩阵的图像文件添加到tensorboard。
on_error(trainer, exc, tb)处理训练过程中引发的错误,然后再进行最终处理。
trim_attention_weight(uttid, att_w)根据self.reverse变换注意力矩阵。
- class paddlespeech.s2t.training.extensions.plot.PlotCTCReport(ctc_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)[来源]
基础:
Extension绘制CTC报告。
- Args:
- ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
CTC可视化的功能。
data (list[tuple(str, dict[str, list[Any]])]): 列表 json utt 键项。 outdir (str): 保存图形的目录。 converter (espnet.asr.*_backend.asr.CustomConverter):
转换数据的函数。
device (int | torch.device): 设备。reverse (bool): 如果为 True,输入和输出长度将被反转。ikey (str): 用于访问输入的密钥
(用于ASR/ST ikey="input",用于MT ikey="output"。)
- iaxis (int): Dimension to access input
(对于ASR/ST iaxis=0,对于MT iaxis=1.)
- okey (str): Key to access output
(对于 ASR/ST okey="input",MT okay="output".)
- oaxis (int): Dimension to access output
(对于ASR/ST oaxis=0,对于MT oaxis=0.)
subsampling_factor (int):编码器中的下采样因子
- Attributes:
default_name扩展的默认名称,默认的类名。
- name
方法
__call__(trainer)绘制并保存ctc概率的图像文件。
draw_ctc_plot(ctc_prob)绘制ctc_prob矩阵。
finalize(trainer)训练完成时执行的操作。
返回 CTC 概率。
initialize(trainer)执行一次以获取正确的培训师状态的操作。
log_ctc_probs(logger, step)将ctc概率的图像文件添加到tensorboard中。
on_error(trainer, exc, tb)处理训练过程中引发的错误,然后再进行最终处理。
trim_ctc_prob(uttid, prob)根据输入长度修剪CTC后验。