Shortcuts

update_state_dict_for_classifier

torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[source]

验证分类器模型加载检查点的状态字典。 在调用model.load_state_dict(state_dict)之前使用。 如果output.weight的形状不匹配,此函数将覆盖状态字典中的output.weight,以加载模型中的output.weight。您可能还希望覆盖此行为,例如,如果您的检查点和模型的num_classes相同。

具体来说,当从基础语言模型的检查点微调分类器模型时,该模型具有形状为[vocab_dim, embed_dim]output.weight,我们将状态字典中的output.weight覆盖为模型中随机初始化的[num_classes, embed_dim]权重。这是就地完成的。

Parameters:
  • state_dict (Dict[str, torch.Tensor]) – 要加载到分类器模型中的状态字典。

  • model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – 来自 model.named_parameters() 的模型命名参数。

  • force_override (bool) – 是否用模型的output.weight替换state_dict中的output.weight,即使形状匹配。

注释

  • output.bias 如果存在于 state_dict 中,将被忽略

  • This function will always replace the output.weight in state_dict,

    如果 output.weight != model.output.weight

Raises:
  • AssertionError – 如果 state_dict 不包含 output.weight

  • AssertionError – 如果 model_named_parameters 不包含 output.weight