Shortcuts

torch.utils.tensorboard.writer 的源代码

"""提供一个API,用于将协议缓冲区写入事件文件,供TensorBoard用于可视化。"""

import os
import time
from typing import List, Optional, Union, TYPE_CHECKING

import torch

if TYPE_CHECKING:
    from matplotlib.figure import Figure
from tensorboard.compat import tf
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto.event_pb2 import Event, SessionLog
from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
from tensorboard.summary.writer.event_file_writer import EventFileWriter

from ._convert_np import make_np
from ._embedding import get_embedding_info, make_mat, make_sprite, make_tsv, write_pbtxt
from ._onnx_graph import load_onnx_graph
from ._pytorch_graph import graph
from ._utils import figure_to_image
from .summary import (
    audio,
    custom_scalars,
    histogram,
    histogram_raw,
    hparams,
    image,
    image_boxes,
    mesh,
    pr_curve,
    pr_curve_raw,
    scalar,
    tensor_proto,
    text,
    video,
)

__all__ = ["FileWriter", "SummaryWriter"]


class FileWriter:
    """将协议缓冲区写入事件文件,供TensorBoard使用。

    `FileWriter`类提供了一种机制,用于在给定目录中创建事件文件,并向其中添加摘要和事件。该类异步更新文件内容。这允许训练程序直接从训练循环调用方法,而不会减慢训练速度。
    """

    def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=""):
        """创建一个`FileWriter`和一个事件文件。

        在构造时,writer会在`log_dir`中创建一个新的事件文件。构造函数的其他参数控制异步写入事件文件的方式。

        参数:
          log_dir: 字符串。事件文件将写入的目录。
          max_queue: 整数。在调用'add'方法强制刷新到磁盘之前,待处理事件和摘要的队列大小。默认是10个项目。
          flush_secs: 数字。将待处理事件和摘要刷新到磁盘的频率,以秒为单位。默认是每两分钟。
          filename_suffix: 字符串。在log_dir目录中添加到所有事件文件名的后缀。有关文件名构造的更多详细信息,请参见tensorboard.summary.writer.event_file_writer.EventFileWriter。
        """
        # 有时会传入PosixPath,我们需要在所有情况下将其强制转换为字符串
        # TODO: 如果我们实际上是自己传入PosixPath,未来可能会删除这个
        log_dir = str(log_dir)
        self.event_writer = EventFileWriter(
            log_dir, max_queue, flush_secs, filename_suffix
        )

    def get_logdir(self):
        """返回事件文件将写入的目录。"""
        return self.event_writer.get_logdir()

    def add_event(self, event, step=None, walltime=None):
        """将事件添加到事件文件。

        参数:
          event: 一个`Event`协议缓冲区。
          step: 数字。可选的训练过程的全局步数值,与事件一起记录。
          walltime: 浮点数。可选的默认墙时间(当前时间)覆盖(从time.time()),以秒为单位,表示事件发生的时间。
        """
        event.wall_time = time.time() if walltime is None else walltime
        if step is not None:
            # 确保step从numpy或其他格式转换
            # 因为protobuf版本可能不支持某些转换
            event.step = int(step)
        self.event_writer.add_event(event)

    def add_summary(self, summary, global_step=None, walltime=None):
        """将`Summary`协议缓冲区添加到事件文件。

        此方法将提供的摘要包装在`Event`协议缓冲区中,并将其添加到事件文件。

        参数:
          summary: 一个`Summary`协议缓冲区。
          global_step: 数字。可选的训练过程的全局步数值,与摘要一起记录。
          walltime: 浮点数。可选的默认墙时间(当前时间)覆盖(从time.time()),以秒为单位,表示事件发生的时间。
        """
        event = event_pb2.Event(summary=summary)
        self.add_event(event, global_step, walltime)

    def add_graph(self, graph_profile, walltime=None):
        """将`Graph`和步统计协议缓冲区添加到事件文件。

        参数:
          graph_profile: 一个`Graph`和步统计协议缓冲区。
          walltime: 浮点数。可选的默认墙时间(当前时间)覆盖(从time.time()),以秒为单位,表示事件发生的时间。
        """
        graph = graph_profile[0]
        stepstats = graph_profile[1]
        event = event_pb2.Event(graph_def=graph.SerializeToString())
        self.add_event(event, None, walltime)

        trm = event_pb2.TaggedRunMetadata(
            tag="step1", run_metadata
优云智算