Shortcuts

torch.overrides

此模块提供了各种辅助函数,用于 __torch_function__ 协议。有关 __torch_function__ 协议的更多详细信息,请参阅 扩展 torch Python API

函数

torch.overrides.get_ignored_functions()[源代码]

返回不能被 __torch_function__ 覆盖的公共函数。

Returns

在 torch API 中公开可用但不能通过 __torch_function__ 重写的一组函数。这主要是因为这些函数的参数都不是张量或类似张量的对象。

Return type

设置[可调用]

示例

>>> torch.Tensor.as_subclass  torch.overrides.get_ignored_functions()
True
>>> torch.add  torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()[源代码]

列出可通过 __torch_function__ 重写的功能

Returns

一个映射包含可覆盖函数的命名空间到该命名空间中可被覆盖的函数的字典。

Return type

Dict[任意, List[可调用]]

torch.overrides.resolve_name(f)[源代码]

获取传递给 __torch_function__ 的函数的可读字符串名称

Parameters

f (可调用对象) – 要解析名称的函数。

Returns

函数的名称;如果进行求值,它应该返回输入的函数。

Return type

str

torch.overrides.get_testing_overrides()[源代码]

返回一个包含所有可覆盖函数的虚拟覆盖的字典

Returns

一个字典,将PyTorch API中可覆盖的函数映射到具有与实际函数相同签名的lambda函数,并且无条件返回-1。这些lambda函数对于测试定义了__torch_function__的类型的API覆盖率非常有用。

Return type

字典[可调用, 可调用]

示例

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<签名 (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[源代码]

实现一个函数,用于检查 __torch_function__ 的重写。

请参阅 C++ 实现中与此函数等效的 torch::autograd::handle_torch_function。

Parameters
  • public_api (函数) – 由公共 torch API 暴露的函数,最初调用方式为 public_api(*args, **kwargs),现在正在检查其参数。

  • relevant_args (可迭代对象) – 用于检查 __torch_function__ 方法的参数的可迭代对象。

  • args (tuple) – 最初传递给 public_api 的任意位置参数。

  • kwargs (元组) – 最初传递给 public_api 的任意关键字参数。

Returns

调用 implementation__torch_function__ 方法的结果,视情况而定。

Return type

object

:引发 TypeError : 如果没有找到实现。:

示例

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides.has_torch_function()

检查可迭代对象的元素中是否存在__torch_function__实现,或者是否启用了__torch_function__模式。将精确的TensorParameter视为不可分派的。使用此方法来保护对handle_torch_function()的调用;不要使用它来测试某物是否类似于Tensor,请改用is_tensor_like()。 :param relevant_args: 要检查__torch_function__方法的可迭代对象或参数。 :type relevant_args: 可迭代对象

Returns

如果 relevant_args 中的任何元素具有 __torch_function__ 实现,则为 True,否则为 False。

Return type

bool

另请参阅

torch.is_tensor_like

检查某物是否为类似张量的对象,包括精确的 Tensor

torch.overrides.is_tensor_like(inp)[源代码]

如果传入的输入是类似张量的,则返回True

目前,当输入类型的对象上存在__torch_function__属性时,就会发生这种情况。

示例

张量的子类通常是一个类似张量的对象。

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

内置或用户定义的类型通常不是类似张量的。

>>> is_tensor_like(6)

>>> is_tensor_like(None)

>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())

但是,可以通过实现 __torch_function__ 使其具有类似张量的行为。

>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)[源代码]

如果传入的函数是torch.Tensor所属的方法或属性的处理器,则返回True,如传入__torch_function__

注意

对于属性,必须传入它们的 __get__ 方法。

这可能出于以下原因而需要:

  1. 方法/属性有时不包含__module__槽。

  2. 他们要求第一个传入的参数是 torch.Tensor 的一个实例。

示例

>>> is_tensor_method_or_property(torch.Tensor.add)

>>> is_tensor_method_or_property(torch.add)

Return type

bool

torch.overrides.wrap_torch_function(dispatcher)[源代码]

使用与 __torch_function__ 相关的功能包装给定的函数。

Parameters

调度器 (可调用对象) – 一个返回传入函数的类似张量可迭代对象的可调用对象。

注意

这个装饰器可能会降低代码的性能。通常情况下,将代码表达为一系列支持 __torch_function__ 的函数就足够了。如果你发现自己处于这种罕见的情况,例如当你包装一个低级库并且还需要它适用于类似张量的对象时,那么这个函数是可用的。

示例

>>> def dispatcher(a): # 必须与func具有相同的签名
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a): # 这将使func可由__torch_function__调度
...     return a + 0
优云智算