符号
用于描述在常见torch模块中找到的符号的工具。
类
一个符号参数( |
|
一个用于保存模型符号表示的类。 |
|
一个简单的类,用于保存有关给定模块符号性质的相关信息。 |
- class SymInfo
基础类:
object一个简单的类,用于保存有关给定模块符号性质的相关信息。
- __init__(is_shape_preserving=False, **kwargs)
使用给定的符号信息初始化实例。
- Parameters:
is_shape_preserving (bool) –
kwargs (Symbol) –
- property is_shape_preserving: bool
返回指示模块是否保持形状。
- class SymMap
基础类:
object一个用于保存模型符号表示的类。
- __init__(model)
使用所需的模块进行初始化。
- Return type:
无
- add_sym_info(key, sym_info)
手动添加模型的模块的sym_info。
- Parameters:
key (模块) –
sym_info (SymInfo) –
- Return type:
无
- is_shape_preserving(key)
返回符号模块是否保持形状。
- Parameters:
key (模块) –
- Return type:
bool
- named_modules()
生成名称(来自 self._mod_to_name)和相关的模块。
- Return type:
生成器[元组[字符串, 模块], 无, 无]
- named_symbols(key=None, free=None, dynamic=None, searchable=None, constant=None)
生成所有符号模块或特定模块中符号的名称和符号。
- Parameters:
key (Module | None) – 从中获取符号的模块。如果未提供,则递归遍历所有模块。
free (bool | None) – 是否包含自由符号。
dynamic (bool | None) – 是否包含动态符号。
searchable (bool | None) – 是否包含可搜索的符号。
constant (bool | None) – 是否包含常量符号。
- Yields:
(name, Symbol) – 包含名称和符号的元组。
- Return type:
生成器[元组[字符串, 符号], 无, 无]
默认行为是迭代自由、动态、可搜索或常量符号。相应地设置参数以仅迭代某些符号。当
free、dynamic、searchable或constant设置为True时,仅迭代该类型的符号。如果free、dynamic、searchable或constant设置为False,则跳过该类型的符号。
- prune()
通过移除仅包含常量符号的模块来修剪地图。
- Return type:
无
- classmethod register(nn_cls, is_explicit_leaf=True)
使用此功能注册一个函数,该函数定义了给定nn模块的符号。
- Parameters:
nn_cls (Type[Module] | List[Type[Module]]) – 注册函数的nn模块类。
is_explicit_leaf (bool) – 模块是否为显式叶子节点,即在跟踪期间是否应将其视为叶子节点。
- Returns:
一个装饰器,用于为给定的nn模块类注册给定的函数。
- Return type:
下面展示了一个注册模块符号信息的示例:
@SymMap.register(nn.Linear) def get_linear_sym_info(mod: nn.Linear) -> SymInfo: in_features = Symbol(cl_type=Symbol.CLType.INCOMING, elastic_dims={-1}) out_features = Symbol( is_searchable=True, cl_type=Symbol.CLType.OUTGOING, elastic_dims={-1} ) return SymInfo(in_features=in_features, out_features=out_features)
- set_symbol(mod, name, symbol)
从给定的模块中设置具有给定名称的符号。
- Parameters:
mod (模块) –
name (str) –
symbol (Symbol) –
- Return type:
无
- classmethod unregister(nn_cls)
取消注册先前已注册的模块。
如果模块未注册,则抛出 KeyError。
- Parameters:
nn_cls (类型[模块]) –
- Return type:
无
- class Symbol
基础类:
object一个符号参数(
Symbol)属于SymModule。Symbol的一个例子可能是卷积的kernel_size。
请注意,一个符号可以有以下状态(互斥的): - 自由:符号未绑定到任何值 - 可搜索:符号是自由的并且可以被搜索 - 常量:符号的值不能被更改 - 动态:符号的值由其父符号决定
此外,一个符号可以展示与其跨层重要性相关的属性: - incoming: 该符号依赖于模块的输入张量 - outgoing: 模块的输出张量依赖于该符号 - none: 该符号不具有跨层重要性(仅影响模块的内部)
基于这些基本属性,我们定义了一些有用的复合属性: - is_cross_layer: 符号是传入还是传出 - is_dangling: 符号是自由的且跨层
- __init__(is_searchable=False, is_sortable=True, cl_type=CLType.NONE, elastic_dims=None)
使用与追踪相关的信息初始化Symbol。
- Parameters:
is_searchable (bool) –
is_sortable (bool) –
cl_type (CLType) –
elastic_dims (Set[int] | None) –
- disable(_memo=None)
通过DFS禁用符号并将其与其整个依赖树一起标记为常量。
在此调用之后,
is_constant == True。- Parameters:
_memo (Set[Symbol] | None) –
- Return type:
无
- property elastic_dims: Set[int]
返回引用此符号的张量维度集合。
请注意,这里指的是从层传入或传出的张量的维度, 不是模块的参数。例如,对于Conv2d层,这指的是传入/传出张量的“NCHW”中的“C”。
这必须为传入/传出的符号定义,对于所有其他情况必须为空。
还要注意,这是一组维度,尽管每个Symbol实际上只有一个维度可以是弹性的。使用一组而不是单个张量的额外灵活性是为了能够以不同的索引符号描述相同的维度(例如,对于Conv2d,{1,-3}中的1和-3都指的是“NCHW”中的“C”维度)。
- property is_constant: bool
返回指示符,判断符号是否为常量。
- property is_cross_layer: bool
返回指示符,判断符号是否为跨层。
- property is_dangling: bool
返回指示符,判断符号是否为悬挂(跨层且自由)。
- property is_dynamic: bool
返回指示符,判断符号是否为动态。
- property is_free: bool
返回指示符,判断符号是否空闲。
- property is_incoming: bool
返回指示符,判断符号是否为跨层传入。
- property is_outgoing: bool
返回指示符,判断符号是否为跨层传出。
- property is_searchable: bool
返回指示符号是否可搜索。
- property is_sortable: bool
返回指示符,表示依赖树中的符号是否可排序。