speechbrain.utils.epoch_loop 模块

实现一个可检查点的epoch计数器(循环),可选择集成早停功能。

Authors
  • 阿库·柔赫 2020

  • 达维德·博拉 2021

摘要

类:

EpochCounter

一个可以保存和恢复其状态的时期计数器。

EpochCounterWithStopper

一个可以保存和恢复其状态的epoch计数器,通过跟踪目标指标集成早期停止器。

参考

class speechbrain.utils.epoch_loop.EpochCounter(limit)[source]

基础类:object

一个可以保存和恢复其状态的时期计数器。

将此用作迭代器进行epochs。 请注意,此迭代器为您提供从[1 … limit]的数字,而不是 [0 … limit-1] 如 range(limit) 会提供的。

Parameters:

limit (int) – 最大轮次数

Example

>>> from speechbrain.utils.checkpoints import Checkpointer
>>> tmpdir = getfixture('tmpdir')
>>> epoch_counter = EpochCounter(10)
>>> recoverer = Checkpointer(tmpdir, {"epoch": epoch_counter})
>>> recoverer.recover_if_possible()
>>> # Now after recovery,
>>> # the epoch starts from where it left off!
>>> for epoch in epoch_counter:
...     # Run training...
...     ckpt = recoverer.save_checkpoint()
class speechbrain.utils.epoch_loop.EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)[source]

基础类: EpochCounter

一个可以保存和恢复其状态的时期计数器,通过跟踪目标指标集成早期停止器。

Parameters:
  • limit (int) – 最大轮次数

  • limit_to_stop (int) – 性能没有改善的最大连续周期数

  • limit_warmup (int) – 在开始检查早停之前需要等待的epoch数

  • direction ("max""min") – 优化目标指标的方向

Example

>>> limit = 10
>>> limit_to_stop = 5
>>> limit_warmup = 2
>>> direction = "min"
>>> epoch_counter = EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)
>>> for epoch in epoch_counter:
...     # Run training...
...     # Track a validation metric, (insert calculation here)
...     current_valid_metric = 0
...     # Update epoch counter so that we stop at the appropriate time
...     epoch_counter.update_metric(current_valid_metric)
...     print(epoch)
1
2
3
4
5
6
7
8
__next__()[source]

如果我们已经达到条件,则停止迭代。

update_metric(current_metric)[source]

更新状态以反映相关指标的最新值。

注意:每个验证循环应仅调用一次。

Parameters:

current_metric (float) – 用于做出停止决策的指标。