# Tensor 类通过 python_tensor.cpp 添加到此模块中
from typing import Optional, Tuple, List, Union, Any
import torch
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
from torch import Tensor
# 半结构化稀疏支持
from .semi_structured import (
SparseSemiStructuredTensor,
SparseSemiStructuredTensorCUSPARSELT,
SparseSemiStructuredTensorCUTLASS,
to_sparse_semi_structured
)
# 支持 TorchScript 和 MyPy 的变通方法:
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch.types import _dtype as DType
DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
else:
# JIT 不理解 Union 和 torch.dtype
DType = int
DimOrDims = Optional[Tuple[int]]
__all__ = [
'addmm',
'check_sparse_tensor_invariants',
'mm',
'sum',
'softmax',
'log_softmax',
'SparseSemiStructuredTensor',
'SparseSemiStructuredTensorCUTLASS',
'SparseSemiStructuredTensorCUSPARSELT',
'to_sparse_semi_structured',
'as_sparse_gradcheck',
]
addmm = _add_docstr(_sparse._sparse_addmm, r"""
sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
此函数在正向过程中与 :func:`torch.addmm` 完全相同,
但它支持稀疏 COO 矩阵 :attr:`mat1` 的反向传播。
当 :attr:`mat1` 是 COO 张量时,它必须具有 `sparse_dim = 2`。
当输入是 COO 张量时,此函数还支持两个输入的反向传播。
支持 CSR 和 COO 存储格式。
.. 注意::
此函数不支持对 CSR 矩阵计算导数。
参数:
mat (Tensor): 要添加的密集矩阵
mat1 (Tensor): 要相乘的稀疏矩阵
mat2 (Tensor): 要相乘的密集矩阵
beta (Number, 可选): :attr:`mat` 的乘数 (:math:`\beta`)
alpha (Number, 可选): :math:`mat1 @ mat2` 的乘数 (:math:`\alpha`)
""")
mm = _add_docstr(_sparse._sparse_mm, r"""
对稀疏矩阵 :attr:`mat1` 和 (稀疏或密集) 矩阵 :attr:`mat2` 进行矩阵乘法。
类似于 :func:`torch.mm`,如果 :attr:`mat1` 是
:math:`(n \times m)` 张量,:attr:`mat2` 是 :math:`(m \times p)` 张量,输出将是
:math:`(n \times p)` 张量。
当 :attr:`mat1` 是 COO 张量时,它必须具有 `sparse_dim = 2`。
当输入是 COO 张量时,此函数还支持两个输入的反向传播。
支持 CSR 和 COO 存储格式。
.. 注意::
此函数不支持对 CSR 矩阵计算导数。
此函数还额外接受一个可选的 :attr:`reduce` 参数,允许
指定可选的归约操作,数学上执行以下操作:
.. math::
z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
其中 :math:`\bigoplus` 定义归约运算符。:attr:`reduce` 仅在
CPU 设备上的 CSR 存储格式中实现。
参数:
mat1 (Tensor): 要相乘的第一个稀疏矩阵
mat2 (Tensor): 要相乘的第二个矩阵,可以是稀疏或密集
reduce (str, 可选): 应用于非唯一索引的归约操作
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`)。默认 :obj:`"sum"`。
形状:
此函数的输出张量格式如下:
- 稀疏 x 稀疏 -> 稀疏
- 稀疏 x 密集 -> 密集
示例::
>>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
>>> a
tensor(indices=tensor([[0, 0, 1],
[0, 2, 1]]),
values=tensor([1., 2., 3.]),
size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
>>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
>>> b
tensor([[0., 1.],
[2., 0.],
[0., 0.]], requires_grad=True)
>>> y = torch.sparse.mm(a, b)
>>> y
tensor([[0., 1.],
[6., 0.]], grad_fn=)
>>> y.sum().backward()
>>> a.grad
tensor(indices=tensor([[0, 0, 1],
[0, 2, 1]]),
values=tensor([1., 0., 2.]),
size=(2, 3), nnz=3, layout=torch.sparse_coo)
>>> c = a.detach().to_sparse_csr()
>>> c
tensor(crow_indices=tensor([0, 2, 3]),
col_indices=tensor([0, 2, 1]),
values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
layout=torch.sparse_csr)
>>> y1 = torch.sparse.mm(c, b, 'sum')
>>> y1
tensor([[0., 1.],
[6., 0.]], grad_fn=)
>>> y2 = torch.sparse.mm(c, b, 'max')
>>> y2
tensor([[0., 1.],
[6., 0.]], grad_fn=)
""")
sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r"""
sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
根据 :attr:`input` 的稀疏模式,对密集矩阵 :attr:`mat1` 和 :attr:`mat2` 进行矩阵乘法。
矩阵 :attr:`input` 被添加到最终结果中。
数学上执行以下操作:
.. math::
\text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
其中 :math:`\text{spy}(\text{input})` 是 :attr:`input` 的稀疏模式矩阵,:attr:`alpha`
和 :attr:`beta` 是缩放因子。
:math:`\text{spy}(\text{input})` 在 :attr:`input` 具有非零值的位置为