paddlespeech.s2t.training.extensions.plot 模块

class paddlespeech.s2t.training.extensions.plot.PlotAttentionReport(att_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)[来源]

基础: Extension

绘制注意力报告。

Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):

注意力可视化的功能。

data (list[tuple(str, dict[str, list[Any]])]): 列表 json utt 键项。 outdir (str): 保存图形的目录。 converter (espnet.asr.*_backend.asr.CustomConverter):

转换数据的函数。

device (int | torch.device): 设备。
reverse (bool): 如果为 True,输入和输出长度将被反转。
ikey (str): 用于访问输入的密钥

(用于ASR/ST ikey="input",用于MT ikey="output"。)

iaxis (int): Dimension to access input

(对于ASR/ST iaxis=0,对于MT iaxis=1.)

okey (str): Key to access output

(对于 ASR/ST okey="input",MT okay="output".)

oaxis (int): Dimension to access output

(对于ASR/ST oaxis=0,对于MT oaxis=0.)

subsampling_factor (int): 编码器中的下采样因子

Attributes:
default_name

扩展的默认名称,默认的类名。

name

方法

__call__(trainer)

绘制并保存att_ws矩阵的图像文件。

draw_attention_plot(att_w)

绘制 att_w 矩阵。

draw_han_plot(att_w)

绘制层次注意力的 att_w 矩阵。

finalize(trainer)

训练完成时执行的操作。

get_attention_weights()

返回注意力权重。

initialize(trainer)

执行一次以获取正确的培训师状态的操作。

log_attentions(logger, step)

将att_ws矩阵的图像文件添加到tensorboard。

on_error(trainer, exc, tb)

处理训练过程中引发的错误,然后再进行最终处理。

trim_attention_weight(uttid, att_w)

根据self.reverse变换注意力矩阵。

draw_attention_plot(att_w)[来源]

绘制 att_w 矩阵。

Returns:

matplotlib.pyplot: 带有注意力矩阵图像的pyplot对象。

draw_han_plot(att_w)[来源]

绘制层次注意力的 att_w 矩阵。

Returns:

matplotlib.pyplot: 带有注意力矩阵图像的pyplot对象。

get_attention_weights()[来源]

返回注意力权重。

Returns:
numpy.ndarray: attention weights. float. Its shape would be

与后端不同。
* pytorch-> 1) 多头情况 => (B, H, Lmax, Tmax), 2)

其他情况 => (B, Lmax, Tmax).

  • chainer-> (B, Lmax, Tmax)

log_attentions(logger, step)[来源]

将att_ws矩阵的图像文件添加到tensorboard。

trim_attention_weight(uttid, att_w)[来源]

转换注意力矩阵以考虑自我反向。

class paddlespeech.s2t.training.extensions.plot.PlotCTCReport(ctc_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)[来源]

基础: Extension

绘制CTC报告。

Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):

CTC可视化的功能。

data (list[tuple(str, dict[str, list[Any]])]): 列表 json utt 键项。 outdir (str): 保存图形的目录。 converter (espnet.asr.*_backend.asr.CustomConverter):

转换数据的函数。

device (int | torch.device): 设备。reverse (bool): 如果为 True,输入和输出长度将被反转。ikey (str): 用于访问输入的密钥

(用于ASR/ST ikey="input",用于MT ikey="output"。)

iaxis (int): Dimension to access input

(对于ASR/ST iaxis=0,对于MT iaxis=1.)

okey (str): Key to access output

(对于 ASR/ST okey="input",MT okay="output".)

oaxis (int): Dimension to access output

(对于ASR/ST oaxis=0,对于MT oaxis=0.)

subsampling_factor (int):编码器中的下采样因子

Attributes:
default_name

扩展的默认名称,默认的类名。

name

方法

__call__(trainer)

绘制并保存ctc概率的图像文件。

draw_ctc_plot(ctc_prob)

绘制ctc_prob矩阵。

finalize(trainer)

训练完成时执行的操作。

get_ctc_probs()

返回 CTC 概率。

initialize(trainer)

执行一次以获取正确的培训师状态的操作。

log_ctc_probs(logger, step)

将ctc概率的图像文件添加到tensorboard中。

on_error(trainer, exc, tb)

处理训练过程中引发的错误,然后再进行最终处理。

trim_ctc_prob(uttid, prob)

根据输入长度修剪CTC后验。

draw_ctc_plot(ctc_prob)[来源]

绘制 ctc_prob 矩阵。

Returns:

matplotlib.pyplot: 带有CTC概率矩阵图像的pyplot对象。

get_ctc_probs()[来源]

返回 CTC 概率。

Returns:
numpy.ndarray: CTC probs. float. Its shape would be

与后端不同。 (B, Tmax, vocab).

log_ctc_probs(logger, step)[来源]

将ctc概率的图像文件添加到tensorboard。

trim_ctc_prob(uttid, prob)[来源]

根据输入长度修剪CTC后验。