Shortcuts

torch.sparse 的源代码

# Tensor 类通过 python_tensor.cpp 添加到此模块中
from typing import Optional, Tuple, List, Union, Any

import torch
from torch._C import _add_docstr, _sparse  # type: ignore[attr-defined]
from torch import Tensor

# 半结构化稀疏支持
from .semi_structured import (
    SparseSemiStructuredTensor,
    SparseSemiStructuredTensorCUSPARSELT,
    SparseSemiStructuredTensorCUTLASS,
    to_sparse_semi_structured
)

# 支持 TorchScript 和 MyPy 的变通方法:
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from torch.types import _dtype as DType
    DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
else:
    # JIT 不理解 Union 和 torch.dtype
    DType = int
    DimOrDims = Optional[Tuple[int]]


__all__ = [
    'addmm',
    'check_sparse_tensor_invariants',
    'mm',
    'sum',
    'softmax',
    'log_softmax',
    'SparseSemiStructuredTensor',
    'SparseSemiStructuredTensorCUTLASS',
    'SparseSemiStructuredTensorCUSPARSELT',
    'to_sparse_semi_structured',
    'as_sparse_gradcheck',
]

addmm = _add_docstr(_sparse._sparse_addmm, r"""
sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor

此函数在正向过程中与 :func:`torch.addmm` 完全相同,
但它支持稀疏 COO 矩阵 :attr:`mat1` 的反向传播。
当 :attr:`mat1` 是 COO 张量时,它必须具有 `sparse_dim = 2`。
当输入是 COO 张量时,此函数还支持两个输入的反向传播。

支持 CSR 和 COO 存储格式。

.. 注意::
    此函数不支持对 CSR 矩阵计算导数。

参数:
    mat (Tensor): 要添加的密集矩阵
    mat1 (Tensor): 要相乘的稀疏矩阵
    mat2 (Tensor): 要相乘的密集矩阵
    beta (Number, 可选): :attr:`mat` 的乘数 (:math:`\beta`)
    alpha (Number, 可选): :math:`mat1 @ mat2` 的乘数 (:math:`\alpha`)
""")


mm = _add_docstr(_sparse._sparse_mm, r"""
    对稀疏矩阵 :attr:`mat1` 和 (稀疏或密集) 矩阵 :attr:`mat2` 进行矩阵乘法。
    类似于 :func:`torch.mm`,如果 :attr:`mat1` 是
    :math:`(n \times m)` 张量,:attr:`mat2` 是 :math:`(m \times p)` 张量,输出将是
    :math:`(n \times p)` 张量。
    当 :attr:`mat1` 是 COO 张量时,它必须具有 `sparse_dim = 2`。
    当输入是 COO 张量时,此函数还支持两个输入的反向传播。

    支持 CSR 和 COO 存储格式。

.. 注意::
    此函数不支持对 CSR 矩阵计算导数。

    此函数还额外接受一个可选的 :attr:`reduce` 参数,允许
    指定可选的归约操作,数学上执行以下操作:

.. math::

    z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}

其中 :math:`\bigoplus` 定义归约运算符。:attr:`reduce` 仅在
CPU 设备上的 CSR 存储格式中实现。

参数:
    mat1 (Tensor): 要相乘的第一个稀疏矩阵
    mat2 (Tensor): 要相乘的第二个矩阵,可以是稀疏或密集
    reduce (str, 可选): 应用于非唯一索引的归约操作
        (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`)。默认 :obj:`"sum"`。

形状:
    此函数的输出张量格式如下:
    - 稀疏 x 稀疏 -> 稀疏
    - 稀疏 x 密集 -> 密集

示例::

    >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
    >>> a
    tensor(indices=tensor([[0, 0, 1],
                           [0, 2, 1]]),
           values=tensor([1., 2., 3.]),
           size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
    >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
    >>> b
    tensor([[0., 1.],
            [2., 0.],
            [0., 0.]], requires_grad=True)
    >>> y = torch.sparse.mm(a, b)
    >>> y
    tensor([[0., 1.],
            [6., 0.]], grad_fn=)
    >>> y.sum().backward()
    >>> a.grad
    tensor(indices=tensor([[0, 0, 1],
                           [0, 2, 1]]),
           values=tensor([1., 0., 2.]),
           size=(2, 3), nnz=3, layout=torch.sparse_coo)
    >>> c = a.detach().to_sparse_csr()
    >>> c
    tensor(crow_indices=tensor([0, 2, 3]),
           col_indices=tensor([0, 2, 1]),
           values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
           layout=torch.sparse_csr)
    >>> y1 = torch.sparse.mm(c, b, 'sum')
    >>> y1
    tensor([[0., 1.],
            [6., 0.]], grad_fn=)
    >>> y2 = torch.sparse.mm(c, b, 'max')
    >>> y2
    tensor([[0., 1.],
            [6., 0.]], grad_fn=)
""")


sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r"""
sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor

根据 :attr:`input` 的稀疏模式,对密集矩阵 :attr:`mat1` 和 :attr:`mat2` 进行矩阵乘法。
矩阵 :attr:`input` 被添加到最终结果中。

数学上执行以下操作:

.. math::

    \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}

其中 :math:`\text{spy}(\text{input})` 是 :attr:`input` 的稀疏模式矩阵,:attr:`alpha`
和 :attr:`beta` 是缩放因子。
:math:`\text{spy}(\text{input})` 在 :attr:`input` 具有非零值的位置为