Shortcuts

worker.learner

learner_hook

请参考 ding/worker/learner/learner_hook.py 以获取使用说明

钩子

class ding.worker.learner.learner_hook.Hook(name: str, priority: float, **kwargs)[源代码]
Overview:

钩子的抽象类。

Interfaces:

__init__, __call__

Property:

名称, 优先级

__init__(name: str, priority: float, **kwargs) None[source]
Overview:

钩子的初始化方法。设置名称和优先级。

Arguments:
  • 名称 (str): 钩子的名称

  • 优先级 (float): 在call_hook调用序列中使用的优先级。值越低,优先级越高。

LearnerHook

class ding.worker.learner.learner_hook.LearnerHook(*args, position: str, **kwargs)[source]
Overview:

用于Learner中的钩子的抽象类。

Interfaces:

__init__

Property:

名称, 优先级, 位置

注意

子类应实现 self.__call__

__init__(*args, position: str, **kwargs) None[source]
Overview:

初始化 LearnerHook。

Arguments:
  • 位置 (str): 在学习者中调用钩子的位置。必须在 [‘before_run’, ‘after_run’, ‘before_iter’, ‘after_iter’] 中。

LoadCkptHook

class ding.worker.learner.learner_hook.LoadCkptHook(*args, ext_args: EasyDict = {}, **kwargs)[source]
Overview:

加载检查点的钩子

Interfaces:

__init__, __call__

Property:

名称, 优先级, 位置

__call__(engine: BaseLearner) None[来源]
Overview:

加载检查点到学习器。检查点信息包括策略状态字典和迭代次数。

Arguments:
  • engine (BaseLearner): 要加载检查点的BaseLearner。

__init__(*args, ext_args: EasyDict = {}, **kwargs) None[source]
Overview:

初始化 LoadCkptHook。

Arguments:
  • ext_args (EasyDict): 扩展参数。使用 ext_args.freq 来设置 load_ckpt_freq

SaveCkptHook

class ding.worker.learner.learner_hook.SaveCkptHook(*args, ext_args: EasyDict = {}, **kwargs)[source]
Overview:

保存检查点的钩子

Interfaces:

__init__, __call__

Property:

名称, 优先级, 位置

__call__(engine: BaseLearner) None[source]
Overview:

在相应的路径中保存检查点。 检查点信息包括策略状态字典和迭代次数。

Arguments:
  • 引擎 (BaseLearner): 需要保存检查点的BaseLearner

__init__(*args, ext_args: EasyDict = {}, **kwargs) None[source]
Overview:

初始化 SaveCkptHook

Arguments:
  • ext_args (EasyDict): 扩展参数,使用 ext_args.freq 来设置 save_ckpt_freq

日志显示钩子

class ding.worker.learner.learner_hook.LogShowHook(*args, ext_args: EasyDict = {}, **kwargs)[source]
Overview:

显示日志的钩子

Interfaces:

__init__, __call__

Property:

名称, 优先级, 位置

__call__(engine: BaseLearner) None[source]
Overview:

显示日志,如果等级为0并且在间隔迭代时更新记录和tb_logger, 清除所有学习者的日志缓冲区,无论等级如何

Arguments:
  • 引擎 (BaseLearner): 基础学习器

__init__(*args, ext_args: EasyDict = {}, **kwargs) None[source]
Overview:

初始化 LogShowHook

Arguments:
  • ext_args (EasyDict): 扩展参数,使用 ext_args.freq 来设置频率

LogReduceHook

class ding.worker.learner.learner_hook.LogReduceHook(*args, ext_args: EasyDict = {}, **kwargs)[source]
Overview:

用于减少分布式(多GPU)日志的钩子

Interfaces:

__init__, __call__

Property:

名称, 优先级, 位置

__call__(engine: BaseLearner) None[source]
Overview:

减少来自分布式(多GPU)学习器的日志

Arguments:
  • 引擎 (BaseLearner): 基础学习器

__init__(*args, ext_args: EasyDict = {}, **kwargs) None[source]
Overview:

初始化 LogReduceHook

Arguments:
  • ext_args (EasyDict): 扩展参数,使用 ext_args.freq 来设置 log_reduce_freq

注册学习者钩子

Overview:

向hook_mapping添加一个新的LearnerHook类,这样你就可以使用build_learner_hook_by_cfg构建一个实例。

Arguments:
  • 名称 (str): 注册钩子的名称

  • hook_type (type): 你实现的注册hook_type,用于实现LearnerHook

Examples:
>>> class HookToRegister(LearnerHook):
>>>     def __init__(*args, **kargs):
>>>         ...
>>>         ...
>>>     def __call__(*args, **kargs):
>>>         ...
>>>         ...
>>> ...
>>> register_learner_hook('name_of_hook', HookToRegister)
>>> ...
>>> hooks = build_learner_hook_by_cfg(cfg)

通过配置构建学习者钩子

Overview:

通过配置在hook_mapping中构建学习者钩子。 此函数通常用于根据cfg初始化hooks, 而add_learner_hook()通常用于将现有的LearnerHook添加到hooks中。

Arguments:
  • cfg (EasyDict): 配置字典。应该像 {'hook': xxx} 这样。

Returns:
  • 钩子 (Dict[str, List[Hook]): 键应该在 ['before_run', 'after_run', 'before_iter', 'after_iter'] 中,每个值应该是一个包含该位置所有钩子的列表。

Note:

值越低意味着优先级越高。

merge_hooks

Overview:

合并两个具有相同键的钩子字典,每个值都通过稳定的方法按钩子优先级排序。

Arguments:
  • hooks1 (Dict[str, List[Hook]): 要合并的hooks1。

  • hooks2 (Dict[str, List[Hook]): 要合并的hooks2。

Returns:
  • new_hooks (Dict[str, List[Hook]): 新的合并钩子字典。

Note:

此合并函数使用稳定排序方法,不会干扰相同优先级的钩子。

基础学习器

请参考 ding/worker/learner/base_learner.py 以获取使用说明

基础学习器

class ding.worker.learner.base_learner.BaseLearner(cfg: EasyDict, policy: namedtuple = None, tb_logger: SummaryWriter | None = None, dist_info: Tuple[int, int] = None, exp_name: str | None = 'default_experiment', instance_name: str | None = 'learner')[source]
Overview:

策略学习的基础类。

Interface:

训练, 调用钩子, 注册钩子, 保存检查点, 开始, 设置数据加载器, 关闭

Property:

学习信息, 优先级信息, 上次迭代, 训练迭代, 排名, 世界大小, 策略 监控器, 日志缓冲区, 日志记录器, tb日志记录器, 检查点名称, 实验名称, 实例名称

__init__(cfg: EasyDict, policy: namedtuple = None, tb_logger: SummaryWriter | None = None, dist_info: Tuple[int, int] = None, exp_name: str | None = 'default_experiment', instance_name: str | None = 'learner') None[来源]
Overview:

初始化方法,根据配置构建常见的学习者组件,例如钩子、包装器等。

Arguments:
  • cfg (EasyDict): 学习器配置,你可以参考 cls.config 获取详细信息。

  • 策略 (namedtuple): 学习模式下的策略函数集合。策略也可以在运行时初始化。

  • tb_logger (SummaryWriter): Tensorboard 摘要写入器。

  • dist_info (Tuple[int, int]): 多GPU分布式训练信息。

  • exp_name (str): 实验名称,用于指示输出目录。

  • instance_name (str): 实例名称,应在不同的学习者中唯一。

Notes:

如果你想在同步CUDA模式下进行调试,请在__init__的开头添加以下代码。

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  # for debug async CUDA
_setup_hook() None[source]
Overview:

为base_learner设置钩子。钩子是在base_learner中特定时间点实现某些功能的方式。你可以参考learner_hook.py

_setup_wrapper() None[source]
Overview:

使用 _time_wrapper 来获取 train_time

Note:

data_time 被包裹在 setup_dataloader 中。

call_hook(name: str) None[source]
Overview:

根据位置名称调用相应的钩子插件。

Arguments:
  • name (str): 钩子在哪个位置调用,应在 [‘before_run’, ‘after_run’, ‘before_iter’, ‘after_iter’] 中。

close() None[来源]
Overview:

[仅在并行模式下使用] 关闭相关资源,例如数据加载器、TensorBoard记录器等。

register_hook(hook: LearnerHook) None[source]
Overview:

添加一个新的学习器钩子。

Arguments:
  • hook (LearnerHook): 要添加的钩子。

save_checkpoint(ckpt_name: str | None = None) None[source]
Overview:

直接调用 save_ckpt_after_run 钩子来保存检查点。

Note:

必须确保“save_ckpt_after_run”在“after_run”钩子中注册。 此方法在以下情况下调用:

  • auto_checkpoint (torch_utils/checkpoint_helper.py),该功能设计用于在异常发生时保存检查点。

  • serial_pipeline (entry/serial_entry.py). 用于在达到新的最高回合回报时保存检查点。

setup_dataloader() None[source]
Overview:

[仅在并行模式下使用] 设置学习者的数据加载器。

注意

只有在并行模式下,我们才会使用属性 get_data_dataloader 从文件系统中获取数据; 而在串行版本中,我们可以直接从内存中获取数据。

在并行模式下,get_dataLearnerCommHelper 设置,并且应该是可调用的。 如果不需要,用户不需要了解相关细节。

train(data: dict, envstep: int = -1, policy_kwargs: dict | None = None) None[source]
Overview:

给定训练数据,实现一次迭代的网络更新并更新相关变量。 学习者的API用于串行输入。 也在每次迭代训练时在start中调用。

Arguments:
  • 数据 (dict): 从回放缓冲区检索的训练数据。

注意

_policy 必须在调用此方法之前设置。

_policy.forward 方法包含:前向传播、反向传播、梯度同步(如果在多GPU模式下)和参数更新。

before_iterafter_iter 钩子在开始和结束时被调用。

create_learner

Overview:

给定键(learner_name),如果在learner_mapping的值中,则创建一个新的学习者实例,否则抛出KeyError。换句话说,派生学习者必须先注册,然后才能调用create_learner来获取实例。

Arguments:
  • cfg (EasyDict): 学习器配置。必要的键:[learner.import_module, learner.learner_type]。

Returns:
  • 学习者 (BaseLearner): 创建的新学习者,应该是learner_mapping值之一的实例。