DGL中的Tree-LSTM

作者: Zihao Ye, Qipeng Guo, Minjie Wang, Jake Zhao, Zheng Zhang

警告

The tutorial aims at gaining insights into the paper, with code as a mean of explanation. The implementation thus is NOT optimized for running efficiency. For recommended implementation, please refer to the official examples.

import os

在本教程中,您将学习如何使用Tree-LSTM网络进行情感分析。 Tree-LSTM是长短期记忆(LSTM)网络在树形网络拓扑结构中的一种推广。

Tree-LSTM结构最初由Kai等人在2015年ACL的一篇论文中提出:Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks。核心思想是通过将链式结构的LSTM扩展为树形结构的LSTM,为语言任务引入句法信息。依赖树和成分树技术被用来获取一个“潜在树”。

训练Tree-LSTMs的挑战在于批处理——这是机器学习中加速优化的标准技术。然而,由于树通常具有不同的形状,并行化并不简单。DGL提供了一种替代方案。将所有树汇集到一个图中,然后根据每棵树的结构引导消息传递。

任务和数据集

这里的步骤使用了 Stanford Sentiment Treebankdgl.data 中。该数据集提供了细粒度的、树级的情感 注释。有五个类别:非常负面、负面、中性、正面和 非常正面,这些类别表示当前子树中的情感。在成分树中,非叶子 节点不包含单词,因此使用一个特殊的 PAD_WORD 标记来表示它们。在训练和推理过程中, 它们的嵌入将被掩码为全零。

该图展示了SST数据集的一个样本,这是一个带有情感标签的选区解析树。为了加快速度,构建一个包含五个句子的小集合并查看第一个。

from collections import namedtuple

os.environ["DGLBACKEND"] = "pytorch"
import dgl
from dgl.data.tree import SSTDataset


SSTBatch = namedtuple("SSTBatch", ["graph", "mask", "wordid", "label"])

# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SSTDataset(mode="tiny")  # the "tiny" set has only five trees
tiny_sst = [tr for tr in trainset]
num_vocabs = trainset.vocab_size
num_classes = trainset.num_classes

vocab = trainset.vocab  # vocabulary dict: key -> id
inv_vocab = {
    v: k for k, v in vocab.items()
}  # inverted vocabulary dict: id -> word

a_tree = tiny_sst[0]
for token in a_tree.ndata["x"].tolist():
    if token != trainset.PAD_WORD:
        print(inv_vocab[token], end=" ")
import matplotlib.pyplot as plt
the rock is destined to be the 21st century 's new `` conan '' and that he 's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .

步骤1: 批处理

将所有树添加到一个图中,使用 batch() API。

import networkx as nx

graph = dgl.batch(tiny_sst)


def plot_tree(g):
    # this plot requires pygraphviz package
    pos = nx.nx_agraph.graphviz_layout(g, prog="dot")
    nx.draw(
        g,
        pos,
        with_labels=False,
        node_size=10,
        node_color=[[0.5, 0.5, 0.5]],
        arrowsize=4,
    )
    plt.show()


plot_tree(graph.to_networkx())
3 tree lstm

你可以阅读更多关于batch()的定义,或者直接跳到下一步: .. 注意:

**Definition**: :func:`~dgl.batch` unions a list of :math:`B`
  :class:`~dgl.DGLGraph`\ s and returns a :class:`~dgl.DGLGraph` of batch
  size :math:`B`.

- The union includes all the nodes,
  edges, and their features. The order of nodes, edges, and features are
  preserved.

    - Given that you have :math:`V_i` nodes for graph
      :math:`\mathcal{G}_i`, the node ID :math:`j` in graph
      :math:`\mathcal{G}_i` correspond to node ID
      :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.

    - Therefore, performing feature transformation and message passing on
      the batched graph is equivalent to doing those
      on all ``DGLGraph`` constituents in parallel.

- Duplicate references to the same graph are
  treated as deep copies; the nodes, edges, and features are duplicated,
  and mutation on one reference does not affect the other.
- The batched graph keeps track of the meta
  information of the constituents so it can be
  :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.

步骤2:使用消息传递API的Tree-LSTM单元

研究人员提出了两种类型的Tree-LSTM:Child-Sum Tree-LSTM和\(N\)-ary Tree-LSTM。在本教程中,您将专注于将Binary Tree-LSTM应用于二值化的成分树。这种应用也被称为Constituency Tree-LSTM。使用PyTorch作为后端框架来设置网络。

N元树-LSTM中,节点\(j\)处的每个单元维护一个隐藏表示\(h_j\)和一个记忆单元\(c_j\)。单元\(j\)接收输入向量\(x_j\)和子单元的隐藏表示:\(h_{jl}, 1\leq l\leq N\)作为输入,然后通过以下方式更新其新的隐藏表示\(h_j\)和记忆单元\(c_j\)

\[\begin{split}i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\ f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\ o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\ u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\ c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\ h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\\end{split}\]

它可以分解为三个阶段:message_funcreduce_funcapply_node_func

注意

apply_node_func 是一个之前未介绍过的新节点用户定义函数。在 apply_node_func 中,用户指定如何处理节点特征, 而不考虑边特征和消息。在 Tree-LSTM 的情况下, apply_node_func 是必须的,因为存在(叶子)节点具有 \(0\) 条入边,这些节点不会通过 reduce_func 进行更新。

import torch as th
import torch.nn as nn


class TreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__()
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
        self.U_f = nn.Linear(2 * h_size, 2 * h_size)

    def message_func(self, edges):
        return {"h": edges.src["h"], "c": edges.src["c"]}

    def reduce_func(self, nodes):
        # concatenate h_jl for equation (1), (2), (3), (4)
        h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
        # equation (2)
        f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
        # second term of equation (5)
        c = th.sum(f * nodes.mailbox["c"], 1)
        return {"iou": self.U_iou(h_cat), "c": c}

    def apply_node_func(self, nodes):
        # equation (1), (3), (4)
        iou = nodes.data["iou"] + self.b_iou
        i, o, u = th.chunk(iou, 3, 1)
        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
        # equation (5)
        c = i * u + nodes.data["c"]
        # equation (6)
        h = o * th.tanh(c)
        return {"h": h, "c": c}

步骤3:定义遍历

在定义了消息传递函数之后,诱导正确的顺序来触发它们。这与GCN等模型有显著不同,在GCN中,所有节点同时从上游节点拉取消息。

在Tree-LSTM的情况下,消息从树的叶子开始,向上传播/处理,直到它们到达根部。可视化如下:

DGL 定义了一个生成器来执行拓扑排序,每个项目都是一个张量,记录从底层到根节点的节点。通过检查以下内容的差异,可以欣赏并行化的程度:

# to heterogenous graph
trv_a_tree = dgl.graph(a_tree.edges())
print("Traversing one tree:")
print(dgl.topological_nodes_generator(trv_a_tree))

# to heterogenous graph
trv_graph = dgl.graph(graph.edges())
print("Traversing many trees at the same time:")
print(dgl.topological_nodes_generator(trv_graph))
Traversing one tree:
(tensor([ 2,  3,  6,  8, 13, 15, 17, 19, 22, 23, 25, 27, 28, 29, 30, 32, 34, 36,
        38, 40, 43, 46, 47, 49, 50, 52, 58, 59, 60, 62, 64, 65, 66, 68, 69, 70]), tensor([ 1, 21, 26, 45, 48, 57, 63, 67]), tensor([24, 44, 56, 61]), tensor([20, 42, 55]), tensor([18, 54]), tensor([16, 53]), tensor([14, 51]), tensor([12, 41]), tensor([11, 39]), tensor([10, 37]), tensor([35]), tensor([33]), tensor([31]), tensor([9]), tensor([7]), tensor([5]), tensor([4]), tensor([0]))
Traversing many trees at the same time:
(tensor([  2,   3,   6,   8,  13,  15,  17,  19,  22,  23,  25,  27,  28,  29,
         30,  32,  34,  36,  38,  40,  43,  46,  47,  49,  50,  52,  58,  59,
         60,  62,  64,  65,  66,  68,  69,  70,  74,  76,  78,  79,  82,  83,
         85,  88,  90,  92,  93,  95,  96, 100, 102, 103, 105, 109, 110, 112,
        113, 117, 118, 119, 121, 125, 127, 129, 130, 132, 133, 135, 138, 140,
        141, 142, 143, 150, 152, 153, 155, 158, 159, 161, 162, 164, 168, 170,
        171, 174, 175, 178, 179, 182, 184, 185, 187, 189, 190, 191, 192, 195,
        197, 198, 200, 202, 205, 208, 210, 212, 213, 214, 216, 218, 219, 220,
        223, 225, 227, 229, 230, 232, 235, 237, 240, 242, 244, 246, 248, 249,
        251, 253, 255, 256, 257, 259, 262, 263, 267, 269, 270, 271, 272]), tensor([  1,  21,  26,  45,  48,  57,  63,  67,  77,  81,  91,  94, 101, 108,
        111, 116, 128, 131, 139, 151, 157, 160, 169, 173, 177, 183, 188, 196,
        211, 217, 228, 247, 254, 261, 268]), tensor([ 24,  44,  56,  61,  75,  89,  99, 107, 115, 126, 137, 149, 156, 167,
        181, 186, 194, 209, 215, 226, 245, 252, 266]), tensor([ 20,  42,  55,  73,  87, 124, 136, 154, 180, 207, 224, 243, 250, 265]), tensor([ 18,  54,  86, 123, 134, 148, 176, 206, 222, 241, 264]), tensor([ 16,  53,  84, 122, 172, 204, 239, 260]), tensor([ 14,  51,  80, 120, 166, 203, 238, 258]), tensor([ 12,  41,  72, 114, 165, 201, 236]), tensor([ 11,  39, 106, 163, 199, 234]), tensor([ 10,  37, 104, 147, 193, 233]), tensor([ 35,  98, 146, 231]), tensor([ 33,  97, 145, 221]), tensor([ 31,  71, 144]), tensor([9]), tensor([7]), tensor([5]), tensor([4]), tensor([0]))

调用 prop_nodes() 来触发消息传递:

import dgl.function as fn
import torch as th

trv_graph.ndata["a"] = th.ones(graph.num_nodes(), 1)
traversal_order = dgl.topological_nodes_generator(trv_graph)
trv_graph.prop_nodes(
    traversal_order,
    message_func=fn.copy_u("a", "a"),
    reduce_func=fn.sum("a", "a"),
)

# the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph)

注意

在调用prop_nodes()之前,请预先指定message_funcreduce_func。在示例中,您可以看到内置的从源复制和求和函数作为消息函数,以及一个用于演示的归约函数。

整合在一起

这是指定Tree-LSTM类的完整代码。

class TreeLSTM(nn.Module):
    def __init__(
        self,
        num_vocabs,
        x_size,
        h_size,
        num_classes,
        dropout,
        pretrained_emb=None,
    ):
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = nn.Embedding(num_vocabs, x_size)
        if pretrained_emb is not None:
            print("Using glove")
            self.embedding.weight.data.copy_(pretrained_emb)
            self.embedding.weight.requires_grad = True
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(h_size, num_classes)
        self.cell = TreeLSTMCell(x_size, h_size)

    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
        # to heterogenous graph
        g = dgl.graph(g.edges())
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
        g.ndata["iou"] = self.cell.W_iou(
            self.dropout(embeds)
        ) * batch.mask.float().unsqueeze(-1)
        g.ndata["h"] = h
        g.ndata["c"] = c
        # propagate
        dgl.prop_nodes_topo(
            g,
            message_func=self.cell.message_func,
            reduce_func=self.cell.reduce_func,
            apply_node_func=self.cell.apply_node_func,
        )
        # compute logits
        h = self.dropout(g.ndata.pop("h"))
        logits = self.linear(h)
        return logits


import torch.nn.functional as F

主循环

最后,你可以在PyTorch中编写一个训练范式。

from torch.utils.data import DataLoader

device = th.device("cpu")
# hyper parameters
x_size = 256
h_size = 256
dropout = 0.5
lr = 0.05
weight_decay = 1e-4
epochs = 10

# create the model
model = TreeLSTM(
    trainset.vocab_size, x_size, h_size, trainset.num_classes, dropout
)
print(model)

# create the optimizer
optimizer = th.optim.Adagrad(
    model.parameters(), lr=lr, weight_decay=weight_decay
)


def batcher(dev):
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
        return SSTBatch(
            graph=batch_trees,
            mask=batch_trees.ndata["mask"].to(device),
            wordid=batch_trees.ndata["x"].to(device),
            label=batch_trees.ndata["y"].to(device),
        )

    return batcher_dev


train_loader = DataLoader(
    dataset=tiny_sst,
    batch_size=5,
    collate_fn=batcher(device),
    shuffle=False,
    num_workers=0,
)

# training loop
for epoch in range(epochs):
    for step, batch in enumerate(train_loader):
        g = batch.graph
        n = g.num_nodes()
        h = th.zeros((n, h_size))
        c = th.zeros((n, h_size))
        logits = model(batch, h, c)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, batch.label, reduction="sum")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = th.argmax(logits, 1)
        acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
        print(
            "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
                epoch, step, loss.item(), acc
            )
        )
TreeLSTM(
  (embedding): Embedding(19536, 256)
  (dropout): Dropout(p=0.5, inplace=False)
  (linear): Linear(in_features=256, out_features=5, bias=True)
  (cell): TreeLSTMCell(
    (W_iou): Linear(in_features=256, out_features=768, bias=False)
    (U_iou): Linear(in_features=512, out_features=768, bias=False)
    (U_f): Linear(in_features=512, out_features=512, bias=True)
  )
)
/home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/checkouts/latest/python/dgl/core.py:82: DGLWarning: The input graph for the user-defined edge function does not contain valid edges
  dgl_warning(
Epoch 00000 | Step 00000 | Loss 439.9816 | Acc 0.2015 |
Epoch 00001 | Step 00000 | Loss 226.3700 | Acc 0.7179 |
Epoch 00002 | Step 00000 | Loss 542.1960 | Acc 0.6190 |
Epoch 00003 | Step 00000 | Loss 474.2100 | Acc 0.7766 |
Epoch 00004 | Step 00000 | Loss 326.0594 | Acc 0.6264 |
Epoch 00005 | Step 00000 | Loss 176.7779 | Acc 0.8352 |
Epoch 00006 | Step 00000 | Loss 116.4029 | Acc 0.8571 |
Epoch 00007 | Step 00000 | Loss 114.9060 | Acc 0.8828 |
Epoch 00008 | Step 00000 | Loss 74.3830 | Acc 0.9267 |
Epoch 00009 | Step 00000 | Loss 65.8115 | Acc 0.9231 |

要在具有不同设置(如CPU或GPU)的完整数据集上训练模型,请参考PyTorch示例。还有一个Child-Sum Tree-LSTM的实现。

脚本的总运行时间: (0 分钟 2.759 秒)

Gallery generated by Sphinx-Gallery