符号

用于描述在常见torch模块中找到的符号的工具。

Symbol

一个符号参数(Symbol)属于SymModule

SymMap

一个用于保存模型符号表示的类。

SymInfo

一个简单的类,用于保存有关给定模块符号性质的相关信息。

class SymInfo

基础类:object

一个简单的类,用于保存有关给定模块符号性质的相关信息。

SymDict

Dict[str, Symbol] 的别名

__init__(is_shape_preserving=False, **kwargs)

使用给定的符号信息初始化实例。

Parameters:
  • is_shape_preserving (bool) –

  • kwargs (Symbol) –

property is_shape_preserving: bool

返回指示模块是否保持形状。

class SymMap

基础类:object

一个用于保存模型符号表示的类。

SymRegisterFunc

Callable 的别名 [[Module], SymInfo]

__init__(model)

使用所需的模块进行初始化。

Return type:

add_sym_info(key, sym_info)

手动添加模型的模块的sym_info。

Parameters:
  • key (模块) –

  • sym_info (SymInfo) –

Return type:

get_symbol(mod, name)

从给定的模块中获取具有给定名称的符号。

Parameters:
  • mod (模块) –

  • name (str) –

Return type:

符号

is_shape_preserving(key)

返回符号模块是否保持形状。

Parameters:

key (模块) –

Return type:

bool

items()

返回字典上的迭代器。

Return type:

生成器[元组[模块, 字典[字符串, 符号]], , ]

named_modules()

生成名称(来自 self._mod_to_name)和相关的模块。

Return type:

生成器[元组[字符串, 模块], , ]

named_sym_dicts()

生成名称(来自 self._mod_to_name)和相关的符号模块。

Return type:

生成器[元组[字符串, 字典[字符串, 符号]], , ]

named_symbols(key=None, free=None, dynamic=None, searchable=None, constant=None)

生成所有符号模块或特定模块中符号的名称和符号。

Parameters:
  • key (Module | None) – 从中获取符号的模块。如果未提供,则递归遍历所有模块。

  • free (bool | None) – 是否包含自由符号。

  • dynamic (bool | None) – 是否包含动态符号。

  • searchable (bool | None) – 是否包含可搜索的符号。

  • constant (bool | None) – 是否包含常量符号。

Yields:

(name, Symbol) – 包含名称和符号的元组。

Return type:

生成器[元组[字符串, 符号], , ]

默认行为是迭代自由、动态、可搜索或常量符号。相应地设置参数以仅迭代某些符号。当freedynamicsearchableconstant设置为True时,仅迭代该类型的符号。如果freedynamicsearchableconstant设置为False,则跳过该类型的符号。

pop(key)

从字典中移除给定的模块并返回其符号表示。

Parameters:

key (模块) –

Return type:

字典[字符串, 符号]

prune()

通过移除仅包含常量符号的模块来修剪地图。

Return type:

classmethod register(nn_cls, is_explicit_leaf=True)

使用此功能注册一个函数,该函数定义了给定nn模块的符号。

Parameters:
  • nn_cls (Type[Module] | List[Type[Module]]) – 注册函数的nn模块类。

  • is_explicit_leaf (bool) – 模块是否为显式叶子节点,即在跟踪期间是否应将其视为叶子节点。

Returns:

一个装饰器,用于为给定的nn模块类注册给定的函数。

Return type:

可调用[[可调用[[模块], SymInfo]], 可调用[[模块], SymInfo]]

下面展示了一个注册模块符号信息的示例:

@SymMap.register(nn.Linear)
def get_linear_sym_info(mod: nn.Linear) -> SymInfo:
    in_features = Symbol(cl_type=Symbol.CLType.INCOMING, elastic_dims={-1})
    out_features = Symbol(
        is_searchable=True, cl_type=Symbol.CLType.OUTGOING, elastic_dims={-1}
    )
    return SymInfo(in_features=in_features, out_features=out_features)
set_symbol(mod, name, symbol)

从给定的模块中设置具有给定名称的符号。

Parameters:
  • mod (模块) –

  • name (str) –

  • symbol (Symbol) –

Return type:

classmethod unregister(nn_cls)

取消注册先前已注册的模块。

如果模块未注册,则抛出 KeyError。

Parameters:

nn_cls (类型[模块]) –

Return type:

class Symbol

基础类:object

一个符号参数(Symbol)属于SymModule

Symbol的一个例子可能是卷积的kernel_size。

请注意,一个符号可以有以下状态(互斥的): - 自由:符号未绑定到任何值 - 可搜索:符号是自由的并且可以被搜索 - 常量:符号的值不能被更改 - 动态:符号的值由其父符号决定

此外,一个符号可以展示与其跨层重要性相关的属性: - incoming: 该符号依赖于模块的输入张量 - outgoing: 模块的输出张量依赖于该符号 - none: 该符号不具有跨层重要性(仅影响模块的内部)

基于这些基本属性,我们定义了一些有用的复合属性: - is_cross_layer: 符号是传入还是传出 - is_dangling: 符号是自由的且跨层

class CLType

基础:Enum

符号的跨层类型。

INCOMING = 2
NONE = 1
OUTGOING = 3
__init__(is_searchable=False, is_sortable=True, cl_type=CLType.NONE, elastic_dims=None)

使用与追踪相关的信息初始化Symbol。

Parameters:
  • is_searchable (bool) –

  • is_sortable (bool) –

  • cl_type (CLType) –

  • elastic_dims (Set[int] | None) –

property cl_type: CLType

返回符号的跨层类型。

disable(_memo=None)

通过DFS禁用符号并将其与其整个依赖树一起标记为常量。

在此调用之后,is_constant == True

Parameters:

_memo (Set[Symbol] | None) –

Return type:

property elastic_dims: Set[int]

返回引用此符号的张量维度集合。

请注意,这里指的是从层传入或传出的张量的维度, 不是模块的参数。例如,对于Conv2d层,这指的是传入/传出张量的“NCHW”中的“C”。

这必须为传入/传出的符号定义,对于所有其他情况必须为空。

还要注意,这是一组维度,尽管每个Symbol实际上只有一个维度可以是弹性的。使用一组而不是单个张量的额外灵活性是为了能够以不同的索引符号描述相同的维度(例如,对于Conv2d,{1,-3}中的1和-3都指的是“NCHW”中的“C”维度)。

property is_constant: bool

返回指示符,判断符号是否为常量。

property is_cross_layer: bool

返回指示符,判断符号是否为跨层。

property is_dangling: bool

返回指示符,判断符号是否为悬挂(跨层且自由)。

property is_dynamic: bool

返回指示符,判断符号是否为动态。

property is_free: bool

返回指示符,判断符号是否空闲。

property is_incoming: bool

返回指示符,判断符号是否为跨层传入。

property is_outgoing: bool

返回指示符,判断符号是否为跨层传出。

property is_searchable: bool

返回指示符号是否可搜索。

property is_sortable: bool

返回指示符,表示依赖树中的符号是否可排序。

注册一个父符号,即让这个符号依赖于父符号。

Parameters:

sp_parent (Symbol) –

Return type:

property parent: Symbol | None

返回父符号。