torch.autograd.profiler_util 的源代码
import bisect
import itertools
import math
from collections import defaultdict, namedtuple
from operator import attrgetter
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.autograd import DeviceType
__all__ = [
"EventList",
"FormattedTimesMixin",
"Interval",
"Kernel",
"FunctionEvent",
"FunctionEventAvg",
"StringTable",
"MemRecordsAcc",
]
class EventList(list):
"""一个包含事件的列表(用于漂亮打印)。"""
def __init__(self, *args, **kwargs):
use_cuda = kwargs.pop("use_cuda", True)
use_device = kwargs.pop("use_device", None)
profile_memory = kwargs.pop("profile_memory", False)
with_flops = kwargs.pop("with_flops", False)
super().__init__(*args, **kwargs)
self._use_cuda = use_cuda
self._use_device = use_device
self._profile_memory = profile_memory
self._tree_built = False
self._with_flops = with_flops
def _build_tree(self):
self._populate_cpu_children()
self._remove_dup_nodes()
self._set_backward_stacktraces()
self._tree_built = True
def __str__(self):
return self.table()
def _remove_dup_nodes(self):
while True:
to_delete = set()
for idx in range(len(self)):
if (
self[idx].cpu_parent is not None
and self[idx].cpu_parent.name == self[idx].name
and len(self[idx].cpu_parent.cpu_children) == 1
):
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
self[idx].cpu_parent.kernels = self[idx].kernels # 将内核提升
for ch in self[idx].cpu_children:
ch.cpu_parent = self[idx].cpu_parent
to_delete.add(idx)
if len(to_delete) == 0:
break
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
self.clear()
self.extend(new_evts)
def _populate_cpu_children(self):
"""将子事件填充到每个底层FunctionEvent对象中。
如果[s1, e1)在[s2, e2)内,则一个事件是另一个事件的子事件。其中
s1和e1是子事件区间的开始和结束。而
s2和e2是父事件区间的开始和结束
示例:在事件列表[[0, 10], [1, 3], [3, 4]]中,[0, 10]
将是另外两个区间的父事件。
如果由于任何原因两个区间仅部分相交,此函数
将不会记录它们之间的父子关系。
"""
# 一些事件可以是异步的(即开始和结束在不同的线程上),
# 由于通常未定义如何将子区间归因于
# 异步区间,我们在计算嵌套区间和统计信息时不使用它们
sync_events = [
evt
for evt in self
if not evt.is_async and evt.device_type == DeviceType.CPU
]
events = sorted(
sync_events,
key=attrgetter("thread"),
)
# 按线程和node_id分组,以便事件
# 具有相同的thread_id但来自不同节点的事件不会错误地
# 分组在一起。
threads = itertools.groupby(
events<