灵活的权重检查点

该模块包含用于决定何时写入和清除检查点的方法。

警告

虽然这个模块提供了一种灵活且模块化的方式来描述所需的检查点行为,但它目前只存储模型的权重(更准确地说,是它的torch.nn.Module.state_dict())。因此,它尚未取代常规检查点中描述的完整训练循环检查点机制。

它由两个主要组件组成:检查点计划决定是否在给定的时期写入检查点。 如果我们有多个检查点,我们可以使用多个保留策略来决定保留哪些检查点以及 哪些要丢弃。对于这两者,我们提供了一组基本规则,以及通过联合组合它们的方式。 这些应该足以轻松建模大多数所需的检查点行为。

示例

下面你可以找到一些如何在训练管道中使用它们的示例。 如果你想在实际训练前检查(静态)检查点计划的行为, 你可以查看 pykeen.checkpoints.final_checkpoints()pykeen.checkpoints.simulate_checkpoints()

为了减少必要的导入次数,示例都使用字典/字符串来指定组件,而不是传递类或实际实例。 您可以在使用解析器中找到关于解析的更多信息。 调度组件的解析器是pykeen.checkpoints.schedule.schedule_resolver, 而保管组件的解析器是pykeen.checkpoints.keeper_resolver

示例 1

"""Write a checkpoint every 10 steps and keep them all."""

from pykeen.pipeline import pipeline

result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=100,
        callbacks="checkpoint",
        # create one checkpoint every 10 epochs
        callbacks_kwargs=dict(
            schedule="every",
            schedule_kwargs=dict(
                frequency=10,
            ),
        ),
    ),
)

示例 2

"""Write a checkpoint at epoch 1, 7, and 10 and keep them all."""

from pykeen.pipeline import pipeline

result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=10,
        callbacks="checkpoint",
        # create checkpoints at epoch 1, 7, and 10
        callbacks_kwargs=dict(
            schedule="explicit",
            schedule_kwargs=dict(steps=(1, 7, 10)),
        ),
    ),
)

示例 3

"""Write a checkpoint avery 5 epochs, but also at epoch 7."""

from pykeen.pipeline import pipeline

result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=10,
        callbacks="checkpoint",
        callbacks_kwargs=dict(
            schedule="union",
            # create checkpoints every 5 epochs, and at epoch 7
            schedule_kwargs=dict(bases=["every", "explicit"], bases_kwargs=[dict(frequency=5), dict(steps=[7])]),
        ),
    ),
)

示例 4

"""Write a checkpoint whenever a metric improves (here, just the training loss)."""

from pykeen.checkpoints import MetricSelection
from pykeen.pipeline import pipeline
from pykeen.trackers import tracker_resolver

# create a default result tracker (or use a proper one)
result_tracker = tracker_resolver.make(None)
result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=10,
        callbacks="checkpoint",
        callbacks_kwargs=dict(
            schedule="best",
            schedule_kwargs=dict(
                result_tracker=result_tracker,
                # in this example, we just use the training loss
                metric_selection=MetricSelection(
                    metric="loss",
                    maximize=False,
                ),
            ),
        ),
    ),
    # Important: use the same result tracker instance as in the checkpoint callback
    result_tracker=result_tracker,
)

示例 5

"""Write a checkpoint every 10 steps, but keep only the last one and one every 50 steps."""

from pykeen.pipeline import pipeline

result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=100,
        callbacks="checkpoint",
        # create one checkpoint every 10 epochs
        callbacks_kwargs=dict(
            schedule="every",
            schedule_kwargs=dict(
                frequency=10,
            ),
            keeper="union",
            keeper_kwargs=dict(
                bases=["modulo", "last"],
                bases_kwargs=[dict(divisor=50), None],
            ),
        ),
    ),
)

函数

save_model(model, file)

将模型保存到文件中。

simulate_checkpoints([num_epochs, schedule, ...])

模拟一个检查点计划并打印有关检查点的信息。

final_checkpoints([num_epochs, schedule, ...])

模拟一个检查点计划并返回保留检查点的时期集合。

CheckpointSchedule()

检查点计划的接口。

EveryCheckpointSchedule([frequency])

\(n\)步创建一个检查点。

ExplicitCheckpointSchedule(steps)

为明确选择的步骤创建一个检查点。

BestCheckpointSchedule(result_tracker, ...)

每当指标改善时创建一个检查点。

UnionCheckpointSchedule(bases[, bases_kwargs])

每当基础计划之一需要时,创建一个检查点。

CheckpointKeeper()

一个检查点清理接口。

LastCheckpointKeeper([keep])

保留最后\(n\)个检查点。

ModuloCheckpointKeeper([divisor])

如果步骤可以被一个数字整除,则保留检查点。

ExplicitCheckpointKeeper(keep)

在明确的步骤中保留检查点。

BestCheckpointKeeper(result_tracker, ...)

保留达到指标最佳值的步骤的检查点。

UnionCheckpointKeeper(bases[, bases_kwargs])

在满足其中一个条件时保留一个检查点。

MetricSelection(metric[, prefix, maximize])

选择要监控的指标。

类继承图

Inheritance diagram of pykeen.checkpoints.schedule.CheckpointSchedule, pykeen.checkpoints.schedule.EveryCheckpointSchedule, pykeen.checkpoints.schedule.ExplicitCheckpointSchedule, pykeen.checkpoints.schedule.BestCheckpointSchedule, pykeen.checkpoints.schedule.UnionCheckpointSchedule, pykeen.checkpoints.keeper.CheckpointKeeper, pykeen.checkpoints.keeper.LastCheckpointKeeper, pykeen.checkpoints.keeper.ModuloCheckpointKeeper, pykeen.checkpoints.keeper.ExplicitCheckpointKeeper, pykeen.checkpoints.keeper.BestCheckpointKeeper, pykeen.checkpoints.keeper.UnionCheckpointKeeper, pykeen.checkpoints.utils.MetricSelection