双线性¶
- class torch.nn.Bilinear(in1_features, in2_features, out_features, bias=True, device=None, dtype=None)[源代码]¶
对传入的数据应用双线性变换:.
- Parameters
- Shape:
输入1: 其中 和 表示包括无在内的任意数量的附加维度。输入的所有维度(除了最后一个维度)应该相同。
输入2: 其中 。
输出: 其中 并且除了最后一个维度之外,其他维度与输入的形状相同。
- Variables
权重 (torch.Tensor) – 模块的可学习权重,形状为 . 这些值从 初始化,其中
偏置 – 模块的可学习偏置,形状为 。 如果
bias
为True
,则这些值从 中初始化,其中
示例:
>>> m = nn.Bilinear(20, 30, 40) >>> input1 = torch.randn(128, 20) >>> input2 = torch.randn(128, 30) >>> output = m(input1, input2) >>> print(output.size()) torch.Size([128, 40])