paddlespeech.t2s.models.transformer_tts.transformer_tts_updater 模块
- class paddlespeech.t2s.models.transformer_tts.transformer_tts_updater.TransformerTTSEvaluator(model: Layer, dataloader: DataLoader, init_state=None, use_masking: bool = False, use_weighted_masking: bool = False, output_dir: Optional[Path] = None, bce_pos_weight: float = 5.0, loss_type: str = 'L1', use_guided_attn_loss: bool = True, modules_applied_guided_attn: Sequence[str] = 'encoder-decoder', guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0)[来源]
-
- Attributes:
- name
方法
__call__([trainer])扩展的主要操作。
finalize(trainer)训练完成时执行的操作。
initialize(trainer)执行一次以获取正确的培训师状态的操作。
on_error(trainer, exc, tb)处理训练过程中引发的错误,然后再进行最终处理。
评估
evaluate_core
- class paddlespeech.t2s.models.transformer_tts.transformer_tts_updater.TransformerTTSUpdater(model: Layer, optimizer: Optimizer, dataloader: DataLoader, init_state=None, use_masking: bool = False, use_weighted_masking: bool = False, output_dir: Optional[Path] = None, bce_pos_weight: float = 5.0, loss_type: str = 'L1', use_guided_attn_loss: bool = True, modules_applied_guided_attn: Sequence[str] = 'encoder-decoder', guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0)[来源]
基础:
StandardUpdater- Attributes:
updates_per_epoch每个时代的更新者数量,由数据加载器的长度决定。
方法
new_epoch()开始一个新的周期。
read_batch()从数据加载器中读取一批,当数据耗尽时自动续订。
set_state_dict(state_dict)为更新器设置状态字典。
state_dict()一个Updater的状态字典,包括模型、优化器和更新器状态。
update_core(batch)训练步骤的简单案例。
加载
保存
更新