concat
支持追踪连接操作的模块。
请注意,此实现假设concat操作/符号是依赖符号中唯一可搜索的符号。这使我们能够简化部分依赖关系,否则这些依赖关系将因多个concat链接在一起而产生。
这里有一个小的例外:每当一个concat依赖于另一个concat时,我们可以禁用独立的concat,并通过这种方式简化表示。
然而,真正的链接连接无法处理,例如,
torch.cat([x1,x2], dim=1) + torch.cat([y1, y2], dim=1),
因为没有办法禁用其中一个而不禁用另一个。
类
存储用于连接的符号输入的有序列表的符号。 |
|
用于处理concat特定跟踪逻辑的节点。 |
- class ConcatNodeProcessor
基础:
NodeProcessor用于处理concat特定跟踪逻辑的节点。
- __init__(*args, **kwargs)
初始化。
- Return type:
无
- is_special_node(node, target)
返回节点是否为连接节点。
- Parameters:
节点 (Node) –
target (Module | Callable) –
- Return type:
bool
- post_process()
恢复到原始符号。
- Return type:
无
- process(node, id, input_nodes)
使所有用于连接的输入都可搜索。
- Parameters:
节点 (Node) –
id (int) –
input_nodes (List[Node]) –
- Return type:
无
- reset()
重置状态。
- Return type:
无
- class ConcatSymbol
基础:
Symbol存储用于连接的符号输入的有序列表的符号。
- class Input
基础:
Symbol特殊符号,用于表示ConcatSymbol的输入。
这个符号是常规Symbol的增强版本,用于处理与concat操作的交互,并用于猴子补丁原始符号。
- __init__(*args, **kwargs)
构造函数。
- property concat_sym: ConcatSymbol
返回连接符号。
- __init__(symbols, cl_type=CLType.NONE, elastic_dims=None)
从输入符号初始化Symbol。
- disable(_memo=None)
禁用所有符号,包括输入符号。
我们通过虚假地将输入符号添加到依赖列表中来处理它们。请注意,依赖列表最终会被清除——所以这样做是可以的。
- Parameters:
_memo (Set[Symbol] | None) –
- Return type:
无
- property input_syms: List[输入 | ConcatSymbol]
返回符号。
- property is_constant: bool
返回指示符,判断符号是否为常量。
与常规符号不同,常规符号的此属性是根据手动设置的标志确定的,而concat的is_constant属性是根据所有输入符号是否都是常量来设置的。
- property is_searchable: bool
返回指示符号是否可搜索。
与常规符号不同,常规符号的此属性是根据手动设置的标志确定的,而concat的is_searchable属性是根据是否有任何输入符号是可搜索的来设置的。