torch.utils.tensorboard.writer 的源代码
"""提供一个API,用于将协议缓冲区写入事件文件,供TensorBoard用于可视化。"""
import os
import time
from typing import List, Optional, Union, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from matplotlib.figure import Figure
from tensorboard.compat import tf
from tensorboard.compat.proto import event_pb2
from tensorboard.compat.proto.event_pb2 import Event, SessionLog
from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
from tensorboard.summary.writer.event_file_writer import EventFileWriter
from ._convert_np import make_np
from ._embedding import get_embedding_info, make_mat, make_sprite, make_tsv, write_pbtxt
from ._onnx_graph import load_onnx_graph
from ._pytorch_graph import graph
from ._utils import figure_to_image
from .summary import (
audio,
custom_scalars,
histogram,
histogram_raw,
hparams,
image,
image_boxes,
mesh,
pr_curve,
pr_curve_raw,
scalar,
tensor_proto,
text,
video,
)
__all__ = ["FileWriter", "SummaryWriter"]
class FileWriter:
"""将协议缓冲区写入事件文件,供TensorBoard使用。
`FileWriter`类提供了一种机制,用于在给定目录中创建事件文件,并向其中添加摘要和事件。该类异步更新文件内容。这允许训练程序直接从训练循环调用方法,而不会减慢训练速度。
"""
def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=""):
"""创建一个`FileWriter`和一个事件文件。
在构造时,writer会在`log_dir`中创建一个新的事件文件。构造函数的其他参数控制异步写入事件文件的方式。
参数:
log_dir: 字符串。事件文件将写入的目录。
max_queue: 整数。在调用'add'方法强制刷新到磁盘之前,待处理事件和摘要的队列大小。默认是10个项目。
flush_secs: 数字。将待处理事件和摘要刷新到磁盘的频率,以秒为单位。默认是每两分钟。
filename_suffix: 字符串。在log_dir目录中添加到所有事件文件名的后缀。有关文件名构造的更多详细信息,请参见tensorboard.summary.writer.event_file_writer.EventFileWriter。
"""
# 有时会传入PosixPath,我们需要在所有情况下将其强制转换为字符串
# TODO: 如果我们实际上是自己传入PosixPath,未来可能会删除这个
log_dir = str(log_dir)
self.event_writer = EventFileWriter(
log_dir, max_queue, flush_secs, filename_suffix
)
def get_logdir(self):
"""返回事件文件将写入的目录。"""
return self.event_writer.get_logdir()
def add_event(self, event, step=None, walltime=None):
"""将事件添加到事件文件。
参数:
event: 一个`Event`协议缓冲区。
step: 数字。可选的训练过程的全局步数值,与事件一起记录。
walltime: 浮点数。可选的默认墙时间(当前时间)覆盖(从time.time()),以秒为单位,表示事件发生的时间。
"""
event.wall_time = time.time() if walltime is None else walltime
if step is not None:
# 确保step从numpy或其他格式转换
# 因为protobuf版本可能不支持某些转换
event.step = int(step)
self.event_writer.add_event(event)
def add_summary(self, summary, global_step=None, walltime=None):
"""将`Summary`协议缓冲区添加到事件文件。
此方法将提供的摘要包装在`Event`协议缓冲区中,并将其添加到事件文件。
参数:
summary: 一个`Summary`协议缓冲区。
global_step: 数字。可选的训练过程的全局步数值,与摘要一起记录。
walltime: 浮点数。可选的默认墙时间(当前时间)覆盖(从time.time()),以秒为单位,表示事件发生的时间。
"""
event = event_pb2.Event(summary=summary)
self.add_event(event, global_step, walltime)
def add_graph(self, graph_profile, walltime=None):
"""将`Graph`和步统计协议缓冲区添加到事件文件。
参数:
graph_profile: 一个`Graph`和步统计协议缓冲区。
walltime: 浮点数。可选的默认墙时间(当前时间)覆盖(从time.time()),以秒为单位,表示事件发生的时间。
"""
graph = graph_profile[0]
stepstats = graph_profile[1]
event = event_pb2.Event(graph_def=graph.SerializeToString())
self.add_event(event, None, walltime)
trm = event_pb2.TaggedRunMetadata(
tag="step1", run_metadata