可视化
在PyTorch中,如果张量的属性requires_grad为True,当我们使用该张量进行任何操作时,计算图将被创建。
计算图的实现类似于链表——Tensors是节点,它们通过属性gran_fn连接。
PyTorchViz是一个Python包,它使用Graphviz作为后端来绘制计算图。
TorchOpt使用PyTorchViz作为蓝图,并在支持其所有功能的前提下,提供了更易于使用的可视化功能。
用法
让我们从一个简单的乘法计算图开始。
我们声明了变量 x,并设置了标志 requires_grad=True,然后计算 y = 2 * x。接着我们可视化了 y 的计算图。
我们提供了函数 make_dot(),它接受一个张量作为输入。
可视化代码如下所示:
from IPython.display import display
import torch
import torchopt
x = torch.tensor(1.0, requires_grad=True)
y = 2 * x
display(torchopt.visual.make_dot(y))
图中显示y通过乘法边连接。
y的梯度将通过乘法反向函数流动,然后累积在x上。
请注意,我们传递了一个字典用于添加节点标签。
为了向计算图添加辅助注释,我们可以传递一个字典作为参数 params 给 make_dot()。
键是将在计算图中显示的注释,值是需要被注释的张量。
因此,上面的代码可以修改如下:
from IPython.display import display
import torch
import torchopt
x = torch.tensor(1.0, requires_grad=True)
y = 2 * x
display(torchopt.visual.make_dot(y, params={'x': x, 'y': y}))
然后让我们绘制一个神经网络。
请注意,我们可以传递由方法named_parameters返回的生成器来添加节点标签。
from IPython.display import display
import torch
from torch import nn
import torchopt
class Net(nn.Module):
def __init__(self, dim):
super().__init__()
self.fc = nn.Linear(dim, 1, bias=True)
def forward(self, x):
return self.fc(x)
dim = 5
batch_size = 2
net = Net(dim)
xs = torch.ones((batch_size, dim))
ys = torch.ones((batch_size, 1))
pred = net(xs)
loss = F.mse_loss(pred, ys)
display(torchopt.visual.make_dot(loss, params=(net.named_parameters(), {'loss': loss})))
元学习算法的计算图将会更加复杂。 我们的可视化工具允许用户将提取的网络状态作为输入,以便更好地进行可视化。
from IPython.display import display
import torch
from torch import nn
import torchopt
class MetaNet(nn.Module):
def __init__(self, dim):
super().__init__()
self.fc = nn.Linear(dim, 1, bias=True)
def forward(self, x, meta_param):
return self.fc(x) + meta_param
dim = 5
batch_size = 2
net = MetaNet(dim)
xs = torch.ones((batch_size, dim))
ys = torch.ones((batch_size, 1))
optimizer = torchopt.MetaSGD(net, lr=1e-3)
meta_param = torch.tensor(1.0, requires_grad=True)
# Set enable_visual
net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')
pred = net(xs, meta_param)
loss = F.mse_loss(pred, ys)
optimizer.step(loss)
# Set enable_visual
net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')
pred = net(xs, meta_param)
loss = F.mse_loss(pred, torch.ones_like(pred))
# Draw computation graph
display(
torchopt.visual.make_dot(
loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]
)
)
笔记本教程
查看笔记本教程在Visualization。