torch.bmm¶
- torch.bmm(input, mat2, *, out=None) 张量¶
对存储在
input和mat2中的矩阵执行批量矩阵-矩阵乘积。input和mat2必须是每个包含相同数量矩阵的3维张量。如果
input是一个 张量,mat2是一个 张量,out将是一个 张量。此操作符支持 TensorFloat32。
在某些 ROCm 设备上,当使用 float16 输入时,此模块将在反向传播中使用 不同的精度。
注意
此函数不执行广播。 要进行广播矩阵乘法,请参阅
torch.matmul()。示例:
>>> input = torch.randn(10, 3, 4) >>> mat2 = torch.randn(10, 4, 5) >>> res = torch.bmm(input, mat2) >>> res.size() torch.Size([10, 3, 5])