worker.learner¶
learner_hook¶
请参考 ding/worker/learner/learner_hook.py 以获取使用说明
钩子¶
LearnerHook¶
LoadCkptHook¶
- class ding.worker.learner.learner_hook.LoadCkptHook(*args, ext_args: EasyDict = {}, **kwargs)[source]¶
- Overview:
加载检查点的钩子
- Interfaces:
__init__, __call__
- Property:
名称, 优先级, 位置
- __call__(engine: BaseLearner) None[来源]¶
- Overview:
加载检查点到学习器。检查点信息包括策略状态字典和迭代次数。
- Arguments:
engine (
BaseLearner): 要加载检查点的BaseLearner。
SaveCkptHook¶
- class ding.worker.learner.learner_hook.SaveCkptHook(*args, ext_args: EasyDict = {}, **kwargs)[source]¶
- Overview:
保存检查点的钩子
- Interfaces:
__init__, __call__
- Property:
名称, 优先级, 位置
- __call__(engine: BaseLearner) None[source]¶
- Overview:
在相应的路径中保存检查点。 检查点信息包括策略状态字典和迭代次数。
- Arguments:
引擎 (
BaseLearner): 需要保存检查点的BaseLearner
日志显示钩子¶
- class ding.worker.learner.learner_hook.LogShowHook(*args, ext_args: EasyDict = {}, **kwargs)[source]¶
- Overview:
显示日志的钩子
- Interfaces:
__init__, __call__
- Property:
名称, 优先级, 位置
- __call__(engine: BaseLearner) None[source]¶
- Overview:
显示日志,如果等级为0并且在间隔迭代时更新记录和tb_logger, 清除所有学习者的日志缓冲区,无论等级如何
- Arguments:
引擎 (
BaseLearner): 基础学习器
LogReduceHook¶
- class ding.worker.learner.learner_hook.LogReduceHook(*args, ext_args: EasyDict = {}, **kwargs)[source]¶
- Overview:
用于减少分布式(多GPU)日志的钩子
- Interfaces:
__init__, __call__
- Property:
名称, 优先级, 位置
- __call__(engine: BaseLearner) None[source]¶
- Overview:
减少来自分布式(多GPU)学习器的日志
- Arguments:
引擎 (
BaseLearner): 基础学习器
注册学习者钩子¶
- Overview:
向hook_mapping添加一个新的LearnerHook类,这样你就可以使用build_learner_hook_by_cfg构建一个实例。
- Arguments:
名称 (
str): 注册钩子的名称hook_type (
type): 你实现的注册hook_type,用于实现LearnerHook
- Examples:
>>> class HookToRegister(LearnerHook): >>> def __init__(*args, **kargs): >>> ... >>> ... >>> def __call__(*args, **kargs): >>> ... >>> ... >>> ... >>> register_learner_hook('name_of_hook', HookToRegister) >>> ... >>> hooks = build_learner_hook_by_cfg(cfg)
通过配置构建学习者钩子¶
- Overview:
通过配置在hook_mapping中构建学习者钩子。 此函数通常用于根据cfg初始化
hooks, 而add_learner_hook()通常用于将现有的LearnerHook添加到hooks中。- Arguments:
cfg (
EasyDict): 配置字典。应该像 {'hook': xxx} 这样。
- Returns:
钩子 (
Dict[str, List[Hook]): 键应该在 ['before_run', 'after_run', 'before_iter', 'after_iter'] 中,每个值应该是一个包含该位置所有钩子的列表。
- Note:
值越低意味着优先级越高。
merge_hooks¶
- Overview:
合并两个具有相同键的钩子字典,每个值都通过稳定的方法按钩子优先级排序。
- Arguments:
hooks1 (
Dict[str, List[Hook]): 要合并的hooks1。hooks2 (
Dict[str, List[Hook]): 要合并的hooks2。
- Returns:
new_hooks (
Dict[str, List[Hook]): 新的合并钩子字典。
- Note:
此合并函数使用稳定排序方法,不会干扰相同优先级的钩子。
基础学习器¶
请参考 ding/worker/learner/base_learner.py 以获取使用说明
基础学习器¶
- class ding.worker.learner.base_learner.BaseLearner(cfg: EasyDict, policy: namedtuple = None, tb_logger: SummaryWriter | None = None, dist_info: Tuple[int, int] = None, exp_name: str | None = 'default_experiment', instance_name: str | None = 'learner')[source]¶
- Overview:
策略学习的基础类。
- Interface:
训练, 调用钩子, 注册钩子, 保存检查点, 开始, 设置数据加载器, 关闭
- Property:
学习信息, 优先级信息, 上次迭代, 训练迭代, 排名, 世界大小, 策略 监控器, 日志缓冲区, 日志记录器, tb日志记录器, 检查点名称, 实验名称, 实例名称
- __init__(cfg: EasyDict, policy: namedtuple = None, tb_logger: SummaryWriter | None = None, dist_info: Tuple[int, int] = None, exp_name: str | None = 'default_experiment', instance_name: str | None = 'learner') None[来源]¶
- Overview:
初始化方法,根据配置构建常见的学习者组件,例如钩子、包装器等。
- Arguments:
cfg (
EasyDict): 学习器配置,你可以参考 cls.config 获取详细信息。策略 (
namedtuple): 学习模式下的策略函数集合。策略也可以在运行时初始化。tb_logger (
SummaryWriter): Tensorboard 摘要写入器。dist_info (
Tuple[int, int]): 多GPU分布式训练信息。exp_name (
str): 实验名称,用于指示输出目录。instance_name (
str): 实例名称,应在不同的学习者中唯一。
- Notes:
如果你想在同步CUDA模式下进行调试,请在
__init__的开头添加以下代码。os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debug async CUDA
- _setup_hook() None[source]¶
- Overview:
为base_learner设置钩子。钩子是在base_learner中特定时间点实现某些功能的方式。你可以参考
learner_hook.py。
- _setup_wrapper() None[source]¶
- Overview:
使用
_time_wrapper来获取train_time。- Note:
data_time被包裹在setup_dataloader中。
- call_hook(name: str) None[source]¶
- Overview:
根据位置名称调用相应的钩子插件。
- Arguments:
name (
str): 钩子在哪个位置调用,应在 [‘before_run’, ‘after_run’, ‘before_iter’, ‘after_iter’] 中。
- register_hook(hook: LearnerHook) None[source]¶
- Overview:
添加一个新的学习器钩子。
- Arguments:
hook (
LearnerHook): 要添加的钩子。
- save_checkpoint(ckpt_name: str | None = None) None[source]¶
- Overview:
直接调用
save_ckpt_after_run钩子来保存检查点。- Note:
必须确保“save_ckpt_after_run”在“after_run”钩子中注册。 此方法在以下情况下调用:
auto_checkpoint(torch_utils/checkpoint_helper.py),该功能设计用于在异常发生时保存检查点。serial_pipeline(entry/serial_entry.py). 用于在达到新的最高回合回报时保存检查点。
- setup_dataloader() None[source]¶
- Overview:
[仅在并行模式下使用] 设置学习者的数据加载器。
注意
只有在并行模式下,我们才会使用属性
get_data和_dataloader从文件系统中获取数据; 而在串行版本中,我们可以直接从内存中获取数据。在并行模式下,
get_data由LearnerCommHelper设置,并且应该是可调用的。 如果不需要,用户不需要了解相关细节。
- train(data: dict, envstep: int = -1, policy_kwargs: dict | None = None) None[source]¶
- Overview:
给定训练数据,实现一次迭代的网络更新并更新相关变量。 学习者的API用于串行输入。 也在每次迭代训练时在
start中调用。- Arguments:
数据 (
dict): 从回放缓冲区检索的训练数据。
注意
_policy必须在调用此方法之前设置。_policy.forward方法包含:前向传播、反向传播、梯度同步(如果在多GPU模式下)和参数更新。before_iter和after_iter钩子在开始和结束时被调用。
create_learner¶
- Overview:
给定键(learner_name),如果在learner_mapping的值中,则创建一个新的学习者实例,否则抛出KeyError。换句话说,派生学习者必须先注册,然后才能调用
create_learner来获取实例。- Arguments:
cfg (
EasyDict): 学习器配置。必要的键:[learner.import_module, learner.learner_type]。
- Returns:
学习者 (
BaseLearner): 创建的新学习者,应该是learner_mapping值之一的实例。