实用工具
神经网络组件的实用工具。
- exception ShapeError(shape: Sequence[int], reference: Sequence[int])[source]
形状不匹配的错误。
初始化错误。
- adjacency_tensor_to_stacked_matrix(num_relations: int, num_entities: int, source: Tensor, target: Tensor, edge_type: Tensor, edge_weights: Tensor | None = None, horizontal: bool = True) Tensor[source]
堆叠邻接矩阵,如[thanapalasingam2021]中所述。
此方法将形状为(num_entities, num_relations, num_entities)的(稀疏)邻接张量重新排列为形状为(num_entities, num_relations * num_entities)(水平堆叠)或(num_entities * num_relations, num_entities)(垂直堆叠)的稀疏邻接矩阵。因此,我们可以通过一次稀疏矩阵乘法(以及一些额外的预处理和/或后处理)来执行R-GCN的关系特定消息传递。
- Parameters:
- Returns:
形状:(num_entities * num_relations, num_entities) 或 (num_entities, num_entities * num_relations) 堆叠的邻接矩阵
- Return type:
- apply_optional_bn(x: Tensor, batch_norm: BatchNorm1d | None = None) Tensor[source]
应用可选的批量归一化。
支持多个批次维度。
- Parameters:
x (Tensor) – 形状:
(..., d)`输入张量。batch_norm (BatchNorm1d | None) – 一个可选的批量归一化层。
- Returns:
形状:
(..., d)`归一化后的张量。- Return type:
- safe_diagonal(matrix: Tensor) Tensor[source]
从潜在的稀疏矩阵中提取对角线。
注意
这是一个临时的解决方案,因为
torch.diagonal()不适用于稀疏张量
- use_horizontal_stacking(input_dim: int, output_dim: int) bool[源代码]
根据输入和输出维度确定堆叠方向。
垂直堆叠方法适用于低维输入和高维输出,因为首先进行的是向低维的投影。而水平堆叠方法则适用于高维输入和低维输出,因为最后进行的是向高维的投影。
四元数的实用工具。
- multiplication_table() Tensor[source]
创建四元数基乘法表。
- Returns:
形状: (4, 4, 4) 基元素乘积表。
- Return type:
..参见:: https://en.wikipedia.org/wiki/Quaternion#Multiplication_of_basis_elements