注意
Go to the end to download the full example code
Transformer 作为图神经网络
作者: 叶梓豪, 周金晶, 郭启鹏, 甘泉, 张政
警告
本教程旨在通过代码作为解释手段,深入理解论文内容。因此,实现并未针对运行效率进行优化。如需推荐的实现,请参考官方示例。
在本教程中,您将了解Transformer模型的简化实现。 您可以看到最重要的设计要点的亮点。例如,这里 只有单头注意力。完整的代码可以在 这里找到。
整体结构与研究论文Annotated Transformer中的结构相似。
Transformer模型,作为CNN/RNN架构的替代品用于序列建模,是在研究论文中引入的:Attention is All You Need。它改进了机器翻译和自然语言推理任务的最新技术(GPT)。最近关于使用大规模语料库预训练Transformer的工作(BERT)支持它能够学习高质量的语义表示。
Transformer的有趣之处在于其广泛使用了注意力机制。注意力的经典应用来自于机器翻译模型,其中输出标记会关注所有输入标记。
Transformer 还在解码器和编码器中应用了自注意力。这个过程迫使单词相互关联并组合在一起,无论它们在序列中的位置如何。这与基于RNN的模型不同,在RNN模型中,单词(在源句子中)是沿着链组合的,这被认为过于受限。
Transformer的注意力层
在Transformer的注意力层中,对于每个节点,模块学习为其传入边分配权重。对于节点对\((i, j)\)(从\(i\)到\(j\)),其中节点\(x_i, x_j \in \mathbb{R}^n\),它们的连接分数定义如下:
其中 \(W_q, W_k, W_v \in \mathbb{R}^{n\times d_k}\) 将表示 \(x\) 分别映射到“查询”、“键”和“值”空间。
还有其他实现评分函数的可能性。点积衡量了给定查询 \(q_j\) 和键 \(k_i\) 的相似性:如果 \(j\) 需要存储在 \(i\) 中的信息,位置 \(j\) 的查询向量 \(q_j\) 应该接近位置 \(i\) 的键向量 \(k_i\)。
然后使用该分数来计算传入值的总和,这些值在边的权重上进行了归一化,并存储在\(\textrm{wv}\)中。然后对\(\textrm{wv}\)应用一个仿射层以获得输出\(o\):
多头注意力层
在Transformer中,注意力是多头的。一个头非常类似于卷积网络中的一个通道。多头注意力由多个注意力头组成,其中每个头指的是一个单一的注意力模块。所有头的\(\textrm{wv}^{(i)}\)被连接并通过一个仿射层映射到输出\(o\):
下面的代码封装了多头注意力机制的必要组件,并提供了两个接口。
get
将状态 'x' 映射到查询、键和值,这是后续步骤(propagate_attention
)所必需的。get_o
将注意力后的更新值映射到输出 \(o\) 以进行后处理。
class MultiHeadAttention(nn.Module):
"Multi-Head Attention"
def __init__(self, h, dim_model):
"h: number of heads; dim_model: hidden dimension"
super(MultiHeadAttention, self).__init__()
self.d_k = dim_model // h
self.h = h
# W_q, W_k, W_v, W_o
self.linears = clones(nn.Linear(dim_model, dim_model), 4)
def get(self, x, fields='qkv'):
"Return a dict of queries / keys / values."
batch_size = x.shape[0]
ret = {}
if 'q' in fields:
ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)
if 'k' in fields:
ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)
if 'v' in fields:
ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)
return ret
def get_o(self, x):
"get output of the multi-head attention"
batch_size = x.shape[0]
return self.linears[3](x.view(batch_size, -1))
DGL如何用图神经网络实现Transformer
通过将注意力视为图中的边,并在边上采用消息传递来引导适当的处理,你可以获得Transformer的不同视角。
图结构
通过将源句子和目标句子的标记映射到节点来构建图。完整的Transformer图由三个子图组成:
源语言图。这是一个完整的图,每个标记\(s_i\)可以关注任何其他标记\(s_j\)(包括自环)。
目标语言图。这个图是半完整的,因为\(t_i\)只关注\(t_j\)如果\(i > j\)(输出标记不能依赖于未来的单词)。
跨语言图。这是一个二分图,其中每个源标记\(s_i\)到每个目标标记\(t_j\)都有一条边,意味着每个目标标记都可以关注源标记。
完整的图片看起来像这样:
在数据集准备阶段预先构建图形。
消息传递
一旦你定义了图结构,接下来就定义消息传递的计算。
假设你已经计算了所有的查询 \(q_i\),键 \(k_i\) 和值 \(v_i\)。对于每个节点 \(i\)(无论 它是源标记还是目标标记),你可以将 注意力计算分解为两个步骤:
消息计算: 计算注意力分数 \(\mathrm{score}_{ij}\) 在 \(i\) 和所有节点 \(j\) 之间,通过取 \(q_i\) 和 \(k_j\) 之间的缩放点积。从 \(j\) 发送到 \(i\) 的消息将包括分数 \(\mathrm{score}_{ij}\) 和 值 \(v_j\)。
消息聚合: 根据分数 \(\mathrm{score}_{ij}\) 从所有 \(j\) 聚合值 \(v_j\)。
简单实现
消息计算
计算 score
并将源节点的 v
发送到目的地的邮箱
def message_func(edges):
return {'score': ((edges.src['k'] * edges.dst['q'])
.sum(-1, keepdim=True)),
'v': edges.src['v']}
消息聚合
对所有入边进行归一化并加权求和以获取输出
import torch as th
import torch.nn.functional as F
def reduce_func(nodes, d_k=64):
v = nodes.mailbox['v']
att = F.softmax(nodes.mailbox['score'] / th.sqrt(d_k), 1)
return {'dx': (att * v).sum(1)}
在特定边缘执行
import functools.partial as partial
def naive_propagate_attention(self, g, eids):
g.send_and_recv(eids, message_func, partial(reduce_func, d_k=self.d_k))
使用内置函数加速
为了加速消息传递过程,使用DGL的内置函数,包括:
fn.src_mul_egdes(src_field, edges_field, out_field)
将源节点的属性和边的属性相乘,并将结果发送到目标节点的邮箱,由out_field
键控。fn.copy_e(edges_field, out_field)
将边的属性复制到目标节点的邮箱中。fn.sum(edges_field, out_field)
对边的属性进行求和,并将聚合结果发送到目标节点的邮箱。
在这里,你将那些内置函数组装成propagate_attention
,这也是最终实现中的主要图操作函数。为了加速它,将softmax
操作分解为以下步骤。回想一下,每个头有两个阶段。
通过将源节点的
k
和目标节点的q
相乘来计算注意力分数g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
在所有目标节点的入边上进行缩放Softmax
步骤1:使用比例归一化常数对分数进行指数化
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
\[\textrm{score}_{ij}\leftarrow\exp{\left(\frac{\textrm{score}_{ij}}{ \sqrt{d_k}}\right)}\]
步骤2:获取每个节点入边上由“scores”加权的关联节点的“values”;获取每个节点入边上的“scores”之和以进行归一化。请注意,这里的\(\textrm{wv}\)未归一化。
msg: fn.u_mul_e('v', 'score', 'v'), reduce: fn.sum('v', 'wv')
\[\textrm{wv}_j=\sum_{i=1}^{N} \textrm{score}_{ij} \cdot v_i\]msg: fn.copy_e('score', 'score'), reduce: fn.sum('score', 'z')
\[\textrm{z}_j=\sum_{i=1}^{N} \textrm{score}_{ij}\]
\(\textrm{wv}\) 的归一化留给后处理。
def src_dot_dst(src_field, dst_field, out_field):
def func(edges):
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}
return func
def scaled_exp(field, scale_constant):
def func(edges):
# clamp for softmax numerical stability
return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}
return func
def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
# Update node state
g.send_and_recv(eids,
[fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])
预处理和后处理
在Transformer中,数据需要在propagate_attention
函数之前和之后进行预处理和后处理。
预处理 预处理函数 pre_func
首先
对节点表示进行归一化,然后将它们映射到一组
查询、键和值,以自注意力为例:
后处理 后处理函数 post_funcs
完成
对应于变压器一层的整个计算:1.
归一化 \(\textrm{wv}\) 并获取多头注意力层的输出
\(o\)。
添加残差连接:
在\(x\)上应用两层位置前馈层,然后添加残差连接:
\[x \leftarrow x + \textrm{LayerNorm}(\textrm{FFN}(x))\]其中 \(\textrm{FFN}\) 指的是前馈函数。
class Encoder(nn.Module):
def __init__(self, layer, N):
super(Encoder, self).__init__()
self.N = N
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def pre_func(self, i, fields='qkv'):
layer = self.layers[i]
def func(nodes):
x = nodes.data['x']
norm_x = layer.sublayer[0].norm(x)
return layer.self_attn.get(norm_x, fields=fields)
return func
def post_func(self, i):
layer = self.layers[i]
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[0].dropout(o)
x = layer.sublayer[1](x, layer.feed_forward)
return {'x': x if i < self.N - 1 else self.norm(x)}
return func
class Decoder(nn.Module):
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.N = N
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def pre_func(self, i, fields='qkv', l=0):
layer = self.layers[i]
def func(nodes):
x = nodes.data['x']
if fields == 'kv':
norm_x = x # In enc-dec attention, x has already been normalized.
else:
norm_x = layer.sublayer[l].norm(x)
return layer.self_attn.get(norm_x, fields)
return func
def post_func(self, i, l=0):
layer = self.layers[i]
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[l].dropout(o)
if l == 1:
x = layer.sublayer[2](x, layer.feed_forward)
return {'x': x if i < self.N - 1 else self.norm(x)}
return func
这完成了Transformer中一层编码器和解码器的所有过程。
注意
子层连接部分与原始论文有些不同。然而,这个实现与The Annotated Transformer 和 OpenNMT相同。
Transformer图的主类
Transformer的处理流程可以看作是在完整图中进行的两阶段消息传递(适当地添加预处理和后处理):1)编码器中的自注意力,2)解码器中的自注意力,然后是编码器和解码器之间的交叉注意力,如下图所示。
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
super(Transformer, self).__init__()
self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed
self.pos_enc = pos_enc
self.generator = generator
self.h, self.d_k = h, d_k
def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
# Send weighted values to target nodes
g.send_and_recv(eids,
[fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])
def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph."
# Pre-compute queries and key-value pairs.
for pre_func, nids in pre_pairs:
g.apply_nodes(pre_func, nids)
self.propagate_attention(g, eids)
# Further calculation after attention mechanism
for post_func, nids in post_pairs:
g.apply_nodes(post_func, nids)
def forward(self, graph):
g = graph.g
nids, eids = graph.nids, graph.eids
# Word Embedding and Position Embedding
src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])
tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])
g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)
for i in range(self.encoder.N):
# Step 1: Encoder Self-attention
pre_func = self.encoder.pre_func(i, 'qkv')
post_func = self.encoder.post_func(i)
nodes, edges = nids['enc'], eids['ee']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
for i in range(self.decoder.N):
# Step 2: Dncoder Self-attention
pre_func = self.decoder.pre_func(i, 'qkv')
post_func = self.decoder.post_func(i)
nodes, edges = nids['dec'], eids['dd']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
# Step 3: Encoder-Decoder attention
pre_q = self.decoder.pre_func(i, 'q', 1)
pre_kv = self.decoder.pre_func(i, 'kv', 1)
post_func = self.decoder.post_func(i, 1)
nodes_e, nodes_d, edges = nids['enc'], nids['dec'], eids['ed']
self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])
return self.generator(g.ndata['x'][nids['dec']])
注意
通过调用update_graph
函数,您可以使用几乎相同的代码在任何子图上创建自己的Transformer。这种灵活性使我们能够发现新的稀疏结构(参见这里提到的局部注意力)。请注意,在此实现中,您不使用掩码或填充,这使得逻辑更加清晰并节省内存。代价是实现速度较慢。
训练
本教程不涵盖原始论文中提到的其他几种技术,如标签平滑和Noam优化。有关这些模块的详细描述,请阅读哈佛NLP团队撰写的The Annotated Transformer。
任务和数据集
Transformer 是一个适用于各种 NLP 任务的通用框架。本教程重点介绍序列到序列学习:这是一个典型的案例,用于说明其工作原理。
至于数据集,有两个示例任务:复制和排序,以及两个真实世界的翻译任务:multi30k en-de任务和wmt14 en-de任务。
复制数据集: 将输入序列复制到输出。(训练/验证/测试: 9000, 1000, 1000)
排序数据集: 将输入序列排序后输出。(训练/验证/测试: 9000, 1000, 1000)
Multi30k 英德翻译,将句子从英语翻译成德语。 (训练/验证/测试:29000, 1000, 1000)
WMT14 en-de,将句子从英文翻译成德文。 (训练/验证/测试:4500966/3000/3003)
注意
使用wmt14进行训练需要多GPU支持,目前不可用。欢迎贡献!
图构建
批处理 这与处理 Tree-LSTM 的方式类似。预先构建一个图池,包括所有可能的输入长度和输出长度的组合。然后对于批次中的每个样本,调用 dgl.batch
将它们大小的图批量处理成一个单一的大图。
你可以将创建图池和构建BatchedGraph的过程封装在dataset.GraphPool
和dataset.TranslationDataset
中。
graph_pool = GraphPool()
data_iter = dataset(graph_pool, mode='train', batch_size=1, devices=devices)
for graph in data_iter:
print(graph.nids['enc']) # encoder node ids
print(graph.nids['dec']) # decoder node ids
print(graph.eids['ee']) # encoder-encoder edge ids
print(graph.eids['ed']) # encoder-decoder edge ids
print(graph.eids['dd']) # decoder-decoder edge ids
print(graph.src[0]) # Input word index list
print(graph.src[1]) # Input positions
print(graph.tgt[0]) # Output word index list
print(graph.tgt[1]) # Ouptut positions
break
输出:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], device='cuda:0')
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74, 75, 76, 77, 78, 79, 80], device='cuda:0')
tensor([ 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94,
95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
165, 166, 167, 168, 169, 170], device='cuda:0')
tensor([171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184,
185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,
199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212,
213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225],
device='cuda:0')
tensor([28, 25, 7, 26, 6, 4, 5, 9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 0, 28, 25, 7, 26, 6, 4, 5, 9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
将所有内容整合在一起
在复制任务上训练一个单头、单层、128维的transformer。将其他参数设置为默认值。
本教程不包含推理模块。它需要波束搜索。有关完整实现,请参见GitHub仓库。
from tqdm.auto import tqdm
import torch as th
import numpy as np
from loss import LabelSmoothing, SimpleLossCompute
from modules import make_model
from optims import NoamOpt
from dgl.contrib.transformer import get_dataset, GraphPool
def run_epoch(data_iter, model, loss_compute, is_train=True):
for i, g in tqdm(enumerate(data_iter)):
with th.set_grad_enabled(is_train):
output = model(g)
loss = loss_compute(output, g.tgt_y, g.n_tokens)
print('average loss: {}'.format(loss_compute.avg_loss))
print('accuracy: {}'.format(loss_compute.accuracy))
N = 1
batch_size = 128
devices = ['cuda' if th.cuda.is_available() else 'cpu']
dataset = get_dataset("copy")
V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 128
# Create model
model = make_model(V, V, N=N, dim_model=128, dim_ff=128, h=1)
# Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight
model, criterion = model.to(devices[0]), criterion.to(devices[0])
model_opt = NoamOpt(dim_model, 1, 400,
th.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
loss_compute = SimpleLossCompute
att_maps = []
for epoch in range(4):
train_iter = dataset(graph_pool, mode='train', batch_size=batch_size, devices=devices)
valid_iter = dataset(graph_pool, mode='valid', batch_size=batch_size, devices=devices)
print('Epoch: {} Training...'.format(epoch))
model.train(True)
run_epoch(train_iter, model,
loss_compute(criterion, model_opt), is_train=True)
print('Epoch: {} Evaluating...'.format(epoch))
model.att_weight_map = None
model.eval()
run_epoch(valid_iter, model,
loss_compute(criterion, None), is_train=False)
att_maps.append(model.att_weight_map)
可视化
训练后,您可以可视化Transformer在复制任务上生成的注意力。
src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
# visualize head 0 of encoder-decoder attention
att_animation(att_maps, 'e2d', src_seq, tgt_seq, 0)
从图中你可以看到解码器节点逐渐学会关注输入序列中的相应节点,这是预期的行为。
Multi-head attention
除了在玩具任务上训练的单头注意力的注意力。我们还可视化了在多-30k数据集上训练的单层Transformer网络的编码器自注意力、解码器自注意力和编码器-解码器注意力的注意力分数。
从可视化中你可以看到不同头的多样性,这是你所期望的。不同的头学习单词对之间的不同关系。
编码器自注意力
编码器-解码器注意力 目标序列中的大多数词会关注源序列中的相关词,例如:当生成“See”(在德语中)时,几个注意力头会关注“lake”;当生成“Eisfischerhütte”时,几个注意力头会关注“ice”。
解码器自注意力 大多数词会关注它们前面的几个词。
自适应通用变压器
谷歌最近的一篇研究论文,Universal Transformer,是一个展示update_graph
如何适应更复杂更新规则的例子。
通用Transformer被提出来解决普通Transformer在计算上不具备通用性的问题,通过在Transformer中引入递归:
通用Transformer的基本思想是通过在表示上应用Transformer层,在每一步循环中反复修订序列中所有符号的表示。
与普通的Transformer相比,Universal Transformer在其层之间共享权重,并且不固定递归时间(这意味着Transformer中的层数)。
进一步的优化采用了一种自适应计算时间(ACT)机制,允许模型动态调整序列中每个位置的表示被修订的次数(以下称为步骤)。该模型也被称为自适应通用变换器(AUT)。
在AUT中,您维护一个活动节点列表。在每一步\(t\)中,我们通过以下方式计算列表中所有节点的停止概率:\(h (0
然后动态决定哪些节点仍然活跃。一个节点在时间 \(T\) 停止,当且仅当 \(\sum_{t=1}^{T-1} h_t < 1 - \varepsilon \leq \sum_{t=1}^{T}h_t\)。停止的节点将从列表中移除。该过程继续进行,直到列表为空或达到预定义的最大步骤。从 DGL 的角度来看,这意味着“活跃”图随着时间的推移变得越来越稀疏。
节点的最终状态 \(s_i\) 是 \(x_i^t\) 通过 \(h_i^t\) 的加权平均值:
在DGL中,通过在仍然活跃的节点和与这些节点相关的边上调用update_graph
来实现一个算法。以下代码展示了DGL中的Universal Transformer类:
class UTransformer(nn.Module):
"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
MAX_DEPTH = 8
thres = 0.99
act_loss_weight = 0.01
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, d_k):
super(UTransformer, self).__init__()
self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed
self.pos_enc, self.time_enc = pos_enc, time_enc
self.halt_enc = HaltingUnit(h * d_k)
self.halt_dec = HaltingUnit(h * d_k)
self.generator = generator
self.h, self.d_k = h, d_k
def step_forward(self, nodes):
# add positional encoding and time encoding, increment step by one
x = nodes.data['x']
step = nodes.data['step']
pos = nodes.data['pos']
return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))),
'step': step + 1}
def halt_and_accum(self, name, end=False):
"field: 'enc' or 'dec'"
halt = self.halt_enc if name == 'enc' else self.halt_dec
thres = self.thres
def func(nodes):
p = halt(nodes.data['x'])
sum_p = nodes.data['sum_p'] + p
active = (sum_p < thres) & (1 - end)
_continue = active.float()
r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue
s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x']
return {'p': p, 'sum_p': sum_p, 'r': r, 's': s, 'active': active}
return func
def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes
g.send_and_recv(eids,
[fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])
def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph."
# Pre-compute queries and key-value pairs.
for pre_func, nids in pre_pairs:
g.apply_nodes(pre_func, nids)
self.propagate_attention(g, eids)
# Further calculation after attention mechanism
for post_func, nids in post_pairs:
g.apply_nodes(post_func, nids)
def forward(self, graph):
g = graph.g
N, E = graph.n_nodes, graph.n_edges
nids, eids = graph.nids, graph.eids
# embed & pos
g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0])
g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0])
g.nodes[nids['enc']].data['pos'] = graph.src[1]
g.nodes[nids['dec']].data['pos'] = graph.tgt[1]
# init step
device = next(self.parameters()).device
g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device) # accumulated state
g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device) # halting prob
g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device) # remainder
g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device) # sum of pondering values
g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device) # step
g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device) # active
for step in range(self.MAX_DEPTH):
pre_func = self.encoder.pre_func('qkv')
post_func = self.encoder.post_func()
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc'])
if len(nodes) == 0: break
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee'])
end = step == self.MAX_DEPTH - 1
self.update_graph(g, edges,
[(self.step_forward, nodes), (pre_func, nodes)],
[(post_func, nodes), (self.halt_and_accum('enc', end), nodes)])
g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s'])
for step in range(self.MAX_DEPTH):
pre_func = self.decoder.pre_func('qkv')
post_func = self.decoder.post_func()
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec'])
if len(nodes) == 0: break
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd'])
self.update_graph(g, edges,
[(self.step_forward, nodes), (pre_func, nodes)],
[(post_func, nodes)])
pre_q = self.decoder.pre_func('q', 1)
pre_kv = self.decoder.pre_func('kv', 1)
post_func = self.decoder.post_func(1)
nodes_e = nids['enc']
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed'])
end = step == self.MAX_DEPTH - 1
self.update_graph(g, edges,
[(pre_q, nodes), (pre_kv, nodes_e)],
[(post_func, nodes), (self.halt_and_accum('dec', end), nodes)])
g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s'])
act_loss = th.mean(g.ndata['r']) # ACT loss
return self.generator(g.ndata['x'][nids['dec']]), act_loss * self.act_loss_weight
调用 filter_nodes
和 filter_edge
来查找仍然活跃的节点/边:
注意
filter_nodes()
接受一个谓词和一个节点ID列表/张量作为输入,然后返回满足给定谓词的节点ID张量。filter_edges()
接受一个谓词 和一个边ID列表/张量作为输入,然后返回满足给定谓词的边ID张量。
有关完整实现,请参见GitHub repo。
下图显示了自适应计算时间的效果。句子的不同位置被修改了不同的次数。
你也可以在AUT在排序任务(达到99.7%的准确率)的训练过程中,可视化节点上步骤分布的动态变化,这展示了AUT如何在训练过程中学习减少递归步骤。
注意
由于存在许多依赖项,笔记本本身无法执行。
下载 7_transformer.py,
并将 Python 脚本复制到目录 examples/pytorch/transformer
然后运行 python 7_transformer.py
以查看其工作原理。
Total running time of the script: (0 minutes 0.000 seconds)