torch._tensor_str 的源代码
import contextlib
import dataclasses
import math
import textwrap
from typing import Any, Dict, Optional
import torch
from torch import inf
@dataclasses.dataclass
class __PrinterOptions:
precision: int = 4
threshold: float = 1000
edgeitems: int = 3
linewidth: int = 80
sci_mode: Optional[bool] = None
PRINT_OPTS = __PrinterOptions()
# 我们可以使用 **kwargs,但这将提供更好的文档
[docs]def set_printoptions(
precision=None,
threshold=None,
edgeitems=None,
linewidth=None,
profile=None,
sci_mode=None,
):
r"""设置打印选项。项目无耻地从 NumPy 中获取
参数:
precision: 浮点输出的精度位数 (默认 = 4)。
threshold: 触发摘要而不是完整 `repr` 的数组元素总数 (默认 = 1000)。
edgeitems: 每个维度开头和结尾的摘要中的数组项数 (默认 = 3)。
linewidth: 插入换行符的每行字符数 (默认 = 80)。阈值矩阵将忽略此参数。
profile: 合理的默认值用于漂亮打印。可以使用上述任何选项进行覆盖。(可以是 `default`, `short`, `full` 中的任何一个)
sci_mode: 启用 (True) 或禁用 (False) 科学记数法。如果指定为 None (默认),则值由 `torch._tensor_str._Formatter` 定义。该值由框架自动选择。
示例::
>>> # 限制元素的精度
>>> torch.set_printoptions(precision=2)
>>> torch.tensor([1.12345])
tensor([1.12])
>>> # 限制显示的元素数量
>>> torch.set_printoptions(threshold=5)
>>> torch.arange(10)
tensor([0, 1, 2, ..., 7, 8, 9])
>>> # 恢复默认值
>>> torch.set_printoptions(profile='default')
>>> torch.tensor([1.12345])
tensor([1.1235])
>>> torch.arange(10)
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
"""
if profile is not None:
if profile == "default":
PRINT_OPTS.precision = 4
PRINT_OPTS.threshold = 1000
PRINT_OPTS.edgeitems = 3
PRINT_OPTS.linewidth = 80
elif profile == "short":
PRINT_OPTS.precision = 2
PRINT_OPTS.threshold = 1000
PRINT_OPTS.edgeitems = 2
PRINT_OPTS.linewidth = 80
elif profile == "full":
PRINT_OPTS.precision = 4
PRINT_OPTS.threshold = inf
PRINT_OPTS.edgeitems = 3
PRINT_OPTS.linewidth = 80
if precision is not None:
PRINT_OPTS.precision = precision
if threshold is not None:
PRINT_OPTS.threshold = threshold
if edgeitems is not None:
PRINT_OPTS.edgeitems = edgeitems
if linewidth is not None:
PRINT_OPTS.linewidth = linewidth
PRINT_OPTS.sci_mode = sci_mode
def get_printoptions() -> Dict[str, Any]:
r"""获取当前的打印选项,作为可以传递给 set_printoptions() 的字典 ``**kwargs``。
"""
return dataclasses.asdict(PRINT_OPTS)
@contextlib.contextmanager
def printoptions(**kwargs):
r"""临时更改打印选项的上下文管理器。接受的参数与 :func:`set_printoptions` 相同。"""
old_kwargs = get_printoptions()
set_printoptions(**kwargs)
try:
yield
finally:
set_printoptions(**old_kwargs)
def tensor_totype(t):
dtype = torch.float if t.is_mps else torch.double
return t.to(dtype=dtype)
class _Formatter:
def __init__(self, tensor):
self.floating_dtype = tensor.dtype.is_floating_point
self.int_mode = True
self.sci_mode = False
self.max_width = 1
with torch.no_grad<span class