register_fusion_module¶
- torchtune.modules.model_fusion.register_fusion_module(module: Module)[source]¶
将方法fusion_params添加到nn.Module中,将所有模块参数标记为融合参数。这可以用于一个层或整个模型,以结合两个或多个预训练模型。
例如,您可能希望在编码器上添加一个投影头,以学习从预训练编码到解码器嵌入空间的投影。这在深度融合和早期融合模型中都很典型。
示例
>>> projection_head = FeedForward(...) >>> register_fusion_module(projection_head)) >>> encoder = nn.Sequential(clip_vit_224(), projection_head)
- Parameters:
模块 (nn.Module) – 要添加fusion_params方法的模块