Shortcuts

torch._tensor_str 的源代码

import contextlib
import dataclasses
import math
import textwrap
from typing import Any, Dict, Optional

import torch
from torch import inf


@dataclasses.dataclass
class __PrinterOptions:
    precision: int = 4
    threshold: float = 1000
    edgeitems: int = 3
    linewidth: int = 80
    sci_mode: Optional[bool] = None


PRINT_OPTS = __PrinterOptions()


# 我们可以使用 **kwargs,但这将提供更好的文档
[docs]def set_printoptions( precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None, ): r"""设置打印选项。项目无耻地从 NumPy 中获取 参数: precision: 浮点输出的精度位数 (默认 = 4)。 threshold: 触发摘要而不是完整 `repr` 的数组元素总数 (默认 = 1000)。 edgeitems: 每个维度开头和结尾的摘要中的数组项数 (默认 = 3)。 linewidth: 插入换行符的每行字符数 (默认 = 80)。阈值矩阵将忽略此参数。 profile: 合理的默认值用于漂亮打印。可以使用上述任何选项进行覆盖。(可以是 `default`, `short`, `full` 中的任何一个) sci_mode: 启用 (True) 或禁用 (False) 科学记数法。如果指定为 None (默认),则值由 `torch._tensor_str._Formatter` 定义。该值由框架自动选择。 示例:: >>> # 限制元素的精度 >>> torch.set_printoptions(precision=2) >>> torch.tensor([1.12345]) tensor([1.12]) >>> # 限制显示的元素数量 >>> torch.set_printoptions(threshold=5) >>> torch.arange(10) tensor([0, 1, 2, ..., 7, 8, 9]) >>> # 恢复默认值 >>> torch.set_printoptions(profile='default') >>> torch.tensor([1.12345]) tensor([1.1235]) >>> torch.arange(10) tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) """ if profile is not None: if profile == "default": PRINT_OPTS.precision = 4 PRINT_OPTS.threshold = 1000 PRINT_OPTS.edgeitems = 3 PRINT_OPTS.linewidth = 80 elif profile == "short": PRINT_OPTS.precision = 2 PRINT_OPTS.threshold = 1000 PRINT_OPTS.edgeitems = 2 PRINT_OPTS.linewidth = 80 elif profile == "full": PRINT_OPTS.precision = 4 PRINT_OPTS.threshold = inf PRINT_OPTS.edgeitems = 3 PRINT_OPTS.linewidth = 80 if precision is not None: PRINT_OPTS.precision = precision if threshold is not None: PRINT_OPTS.threshold = threshold if edgeitems is not None: PRINT_OPTS.edgeitems = edgeitems if linewidth is not None: PRINT_OPTS.linewidth = linewidth PRINT_OPTS.sci_mode = sci_mode
def get_printoptions() -> Dict[str, Any]: r"""获取当前的打印选项,作为可以传递给 set_printoptions() 的字典 ``**kwargs``。 """ return dataclasses.asdict(PRINT_OPTS) @contextlib.contextmanager def printoptions(**kwargs): r"""临时更改打印选项的上下文管理器。接受的参数与 :func:`set_printoptions` 相同。""" old_kwargs = get_printoptions() set_printoptions(**kwargs) try: yield finally: set_printoptions(**old_kwargs) def tensor_totype(t): dtype = torch.float if t.is_mps else torch.double return t.to(dtype=dtype) class _Formatter: def __init__(self, tensor): self.floating_dtype = tensor.dtype.is_floating_point self.int_mode = True self.sci_mode = False self.max_width = 1 with torch.no_grad<span class