TypedLinear
- class dgl.nn.pytorch.TypedLinear(in_size, out_size, num_types, regularizer=None, num_bases=None)[source]
Bases:
Module
根据类型进行线性变换。
对于输入批次中的每个样本 \(x \in X\),应用线性变换 \(xW_t\),其中 \(t\) 是 \(x\) 的类型。
该模块支持两种正则化方法(基础分解和块对角分解),这些方法由“使用图卷积网络建模关系数据”提出。
基础正则化通过以下方式分解 \(W_t\):
\[W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}\]其中 \(B\) 是基数的数量,\(V_b^{(l)}\) 与系数 \(a_{tb}^{(l)}\) 线性组合。
块对角分解正则化将 \(W_t\) 分解为 \(B\) 个块对角矩阵。我们将 \(B\) 称为基的数量:
\[W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}\]其中 \(B\) 是基数的数量,\(Q_{tb}^{(l)}\) 是具有形状 \(R^{(d^{(l+1)}/B)\times(d^{l}/B)}\) 的块基数。
- Parameters:
示例
无正则化。
>>> from dgl.nn import TypedLinear >>> import torch >>> >>> x = torch.randn(100, 32) >>> x_type = torch.randint(0, 5, (100,)) >>> m = TypedLinear(32, 64, 5) >>> y = m(x, x_type) >>> print(y.shape) torch.Size([100, 64])
基于正则化
>>> x = torch.randn(100, 32) >>> x_type = torch.randint(0, 5, (100,)) >>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4) >>> y = m(x, x_type) >>> print(y.shape) torch.Size([100, 64])