Shortcuts

torchtune.training

检查点

torchtune 提供了检查点工具,以便在训练过程中无缝切换检查点格式,并实现与生态系统其他部分的互操作性。有关检查点的全面概述,请参阅 检查点深入解析

FullModelHFCheckpointer

读取和写入HF格式检查点的检查点器。

FullModelMetaCheckpointer

读取和写入Meta格式检查点的检查点器。

FullModelTorchTuneCheckpointer

检查点读取器,以与torchtune兼容的格式读取和写入检查点。

ModelType

ModelType 被检查点用于区分不同的模型架构。

FormattedCheckpointFiles

这个类提供了一种更简洁的方式来表示格式为file_{i}_of_{n_files}.pth的文件名列表。

update_state_dict_for_classifier

验证分类器模型加载检查点时的状态字典。

降低精度

用于在降低精度设置下工作的实用程序。

get_dtype

获取与给定精度字符串对应的 torch.dtype。

set_default_dtype

上下文管理器用于设置torch的默认数据类型。

validate_expected_param_dtype

验证所有输入参数是否具有预期的数据类型。

get_quantizer_mode

给定一个量化器对象,返回一个指定量化类型的字符串。

分布式

用于启用和进行分布式训练的实用工具。

init_distributed

初始化torch.distributed所需的进程组。

is_distributed

检查是否设置了初始化torch.distributed所需的所有环境变量,并且分布式已正确安装。

get_world_size_and_rank

获取当前世界大小(即默认进程组中的总排名数)和当前进程的排名号的函数。

gather_cpu_state_dict

将分片状态字典转换为CPU上的完整状态字典 仅在rank0上返回非空结果以避免CPU内存峰值

内存管理

用于减少训练期间内存消耗的工具。

apply_selective_activation_checkpointing

用于设置激活检查点并包装模型以进行检查点的实用程序。

set_activation_checkpointing

用于对传入模型应用激活检查点的实用工具。

OptimizerInBackwardWrapper

一个用于在反向运行时优化器的检查点保存和加载的简单类。

create_optim_in_bwd_wrapper

为在反向传播中运行的优化器步骤创建一个包装器。

register_optim_in_bwd_hooks

为在反向传播中运行的优化器步骤注册钩子。

调度器

用于控制训练过程中学习率的实用工具。

get_cosine_schedule_with_warmup

创建一个学习率调度,使学习率从0.0线性增加到num_warmup_steps,然后在剩余的num_training_steps-num_warmup_steps上按余弦调度减少到0.0(假设num_cycles = 0.5)。

get_lr

Full_finetune_distributed 和 full_finetune_single_device 假设所有优化器具有相同的学习率(LR),这里用于验证所有学习率是否相同,如果相同则返回 True。

指标记录

各种日志记录工具。

metric_logging.CometLogger

用于与Comet一起使用的日志记录器 (https://www.comet.com/site/).

metric_logging.WandBLogger

用于Weights and Biases应用程序的Logger(https://wandb.ai/)。

metric_logging.TensorBoardLogger

用于与PyTorch的TensorBoard实现一起使用的日志记录器(https://pytorch.org/docs/stable/tensorboard.html)。

metric_logging.StdoutLogger

记录器输出到标准输出。

metric_logging.DiskLogger

记录器到磁盘。

性能与性能分析

torchtune 提供了用于分析和调试微调任务内存和性能的工具。

get_memory_stats

计算传入设备的内存摘要。

log_memory_stats

将包含内存统计信息的字典记录到日志记录器中。

setup_torch_profiler

设置profile并返回带有设置后更新的分析器配置。

杂项

get_unmasked_sequence_lengths

返回每个批次元素的序列长度,不包括被屏蔽的标记。

set_seed

设置常用库中伪随机数生成器种子的函数。