可视化

PyTorch中,如果张量的属性requires_gradTrue,当我们使用该张量进行任何操作时,计算图将被创建。 计算图的实现类似于链表——Tensors是节点,它们通过属性gran_fn连接。 PyTorchViz是一个Python包,它使用Graphviz作为后端来绘制计算图。 TorchOpt使用PyTorchViz作为蓝图,并在支持其所有功能的前提下,提供了更易于使用的可视化功能。


用法

让我们从一个简单的乘法计算图开始。 我们声明了变量 x,并设置了标志 requires_grad=True,然后计算 y = 2 * x。接着我们可视化了 y 的计算图。

我们提供了函数 make_dot(),它接受一个张量作为输入。 可视化代码如下所示:

from IPython.display import display
import torch
import torchopt


x = torch.tensor(1.0, requires_grad=True)
y = 2 * x
display(torchopt.visual.make_dot(y))
../_images/visualization-fig1.svg

图中显示y通过乘法边连接。 y的梯度将通过乘法反向函数流动,然后累积在x上。 请注意,我们传递了一个字典用于添加节点标签。

为了向计算图添加辅助注释,我们可以传递一个字典作为参数 paramsmake_dot()。 键是将在计算图中显示的注释,值是需要被注释的张量。 因此,上面的代码可以修改如下:

from IPython.display import display
import torch
import torchopt


x = torch.tensor(1.0, requires_grad=True)
y = 2 * x
display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))

然后让我们绘制一个神经网络。 请注意,我们可以传递由方法named_parameters返回的生成器来添加节点标签。

from IPython.display import display
import torch
from torch import nn
import torchopt


class Net(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1, bias=True)

    def forward(self, x):
        return self.fc(x)


dim = 5
batch_size = 2
net = Net(dim)
xs = torch.ones((batch_size, dim))
ys = torch.ones((batch_size, 1))
pred = net(xs)
loss = F.mse_loss(pred, ys)

display(torchopt.visual.make_dot(loss, params=(net.named_parameters(), {'loss': loss})))
../_images/visualization-fig2.svg

元学习算法的计算图将会更加复杂。 我们的可视化工具允许用户将提取的网络状态作为输入,以便更好地进行可视化。

from IPython.display import display
import torch
from torch import nn
import torchopt

class MetaNet(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1, bias=True)

    def forward(self, x, meta_param):
        return self.fc(x) + meta_param


dim = 5
batch_size = 2
net = MetaNet(dim)

xs = torch.ones((batch_size, dim))
ys = torch.ones((batch_size, 1))

optimizer = torchopt.MetaSGD(net, lr=1e-3)
meta_param = torch.tensor(1.0, requires_grad=True)

# Set enable_visual
net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')

pred = net(xs, meta_param)
loss = F.mse_loss(pred, ys)
optimizer.step(loss)

# Set enable_visual
net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')

pred = net(xs, meta_param)
loss = F.mse_loss(pred, torch.ones_like(pred))

# Draw computation graph
display(
    torchopt.visual.make_dot(
        loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]
    )
)
../_images/visualization-fig3.svg

笔记本教程

查看笔记本教程在Visualization