update_state_dict_for_classifier¶
- torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[source]¶
验证分类器模型加载检查点的状态字典。 在调用
model.load_state_dict(state_dict)之前使用。 如果output.weight的形状不匹配,此函数将覆盖状态字典中的output.weight,以加载模型中的output.weight。您可能还希望覆盖此行为,例如,如果您的检查点和模型的num_classes相同。具体来说,当从基础语言模型的检查点微调分类器模型时,该模型具有形状为
[vocab_dim, embed_dim]的output.weight,我们将状态字典中的output.weight覆盖为模型中随机初始化的[num_classes, embed_dim]权重。这是就地完成的。- Parameters:
state_dict (Dict[str, torch.Tensor]) – 要加载到分类器模型中的状态字典。
model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – 来自
model.named_parameters()的模型命名参数。force_override (bool) – 是否用模型的
output.weight替换state_dict中的output.weight,即使形状匹配。
注释
output.bias如果存在于state_dict中,将被忽略- This function will always replace the
output.weightinstate_dict, 如果
output.weight != model.output.weight。
- This function will always replace the
- Raises:
AssertionError – 如果
state_dict不包含output.weight。AssertionError – 如果
model_named_parameters不包含output.weight。