余弦相似度¶
- class torch.nn.CosineSimilarity(dim=1, eps=1e-08)[源代码]¶
返回和之间的余弦相似度,沿dim计算。
- Shape:
输入1: 其中 D 位于位置 dim
- Input2: , same number of dimensions as x1, matching x1 size at dimension dim,
并且在其他维度上与 x1 可广播。
输出:
- Examples::
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) >>> output = cos(input1, input2)