线性¶
- class torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)[源代码]¶
对传入的数据应用线性变换:。
此模块支持 TensorFloat32。
在某些 ROCm 设备上,当使用 float16 输入时,此模块将在反向传播中使用 不同的精度。
- Parameters
- Shape:
输入: 其中 表示包括无在内的任意数量的维度,并且 。
输出: 其中除了最后一个维度外,其他维度与输入的形状相同,并且 。
- Variables
权重 (torch.Tensor) – 模块的可学习权重,形状为 。这些值从 初始化,其中
偏差 – 模块的可学习偏差,形状为 。 如果
bias
为True
,则这些值从 中初始化,其中
示例:
>>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30])