胶囊网络

作者: Jinjing Zhou, Jake Zhao, Zheng Zhang, Jinyang Li

在本教程中,您将学习如何用图来描述一个更经典的模型。这种方法提供了一个不同的视角。本教程描述了如何为胶囊网络实现一个胶囊模型。

警告

本教程旨在通过代码作为解释手段,深入理解论文内容。因此,实现并未针对运行效率进行优化。如需推荐的实现,请参考官方示例

Capsule的关键概念

Capsule模型提供了两个关键思想:更丰富的表示和动态路由。

更丰富的表示 – 在经典的卷积网络中,一个标量值表示给定特征的激活。相比之下,胶囊输出一个向量。向量的长度表示特征存在的概率。向量的方向表示特征的各种属性(如姿态、变形、纹理等)。

image0

动态路由 – 胶囊的输出会根据其预测与上层父胶囊的预测一致性程度,发送到上层的某些父胶囊。这种基于一致性的动态路由泛化了最大池化的静态路由。

在训练过程中,路由是迭代完成的。每次迭代根据观察到的协议调整胶囊之间的路由权重。这种方式类似于k-means算法或竞争学习

在本教程中,您将看到胶囊的动态路由算法如何自然地表达为图算法。该实现改编自Cedric Chee,仅替换了路由层。此版本实现了相似的速度和准确性。

模型实现

步骤1:设置和图形初始化

两层胶囊之间的连接形成了一个有向二分图,如下图所示。

image1

每个节点 \(j\) 都与特征 \(v_j\) 相关联,表示其胶囊的输出。每条边都与特征 \(b_{ij}\)\(\hat{u}_{j|i}\) 相关联。\(b_{ij}\) 决定了路由权重,而 \(\hat{u}_{j|i}\) 表示胶囊 \(i\)\(j\) 的预测。

这是我们如何设置图并初始化节点和边特征的。

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import matplotlib.pyplot as plt
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F


def init_graph(in_nodes, out_nodes, f_size):
    u = np.repeat(np.arange(in_nodes), out_nodes)
    v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes)
    g = dgl.DGLGraph((u, v))
    # init states
    g.ndata["v"] = th.zeros(in_nodes + out_nodes, f_size)
    g.edata["b"] = th.zeros(in_nodes * out_nodes, 1)
    return g

步骤2:定义消息传递函数

这是Capsule路由算法的伪代码。

image2 在类 DGLRoutingLayer 中实现伪代码的第4-7行,步骤如下:

  1. 计算耦合系数。

    • 系数是输入胶囊所有出边的softmax。 \(\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})\).

  2. 计算所有胶囊内的加权和。

    • 胶囊的输出等于其内部胶囊的加权和 \(s_j=\sum_i c_{ij}\hat{u}_{j|i}\)

  3. 压缩输出。

    • 将Capsule的输出向量长度压缩到范围(0,1),以便它可以表示(某些特征存在的)概率。

    • \(v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}\)

  4. 根据一致性的程度更新权重。

    • 标量积 \(\hat{u}_{j|i}\cdot v_j\) 可以被视为胶囊 \(i\)\(j\) 的匹配程度。它用于更新 \(b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j\)

import dgl.function as fn


class DGLRoutingLayer(nn.Module):
    def __init__(self, in_nodes, out_nodes, f_size):
        super(DGLRoutingLayer, self).__init__()
        self.g = init_graph(in_nodes, out_nodes, f_size)
        self.in_nodes = in_nodes
        self.out_nodes = out_nodes
        self.in_indx = list(range(in_nodes))
        self.out_indx = list(range(in_nodes, in_nodes + out_nodes))

    def forward(self, u_hat, routing_num=1):
        self.g.edata["u_hat"] = u_hat

        for r in range(routing_num):
            # step 1 (line 4): normalize over out edges
            edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes)
            self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1)
            self.g.edata["c u_hat"] = self.g.edata["c"] * self.g.edata["u_hat"]

            # Execute step 1 & 2
            self.g.update_all(fn.copy_e("c u_hat", "m"), fn.sum("m", "s"))

            # step 3 (line 6)
            self.g.nodes[self.out_indx].data["v"] = self.squash(
                self.g.nodes[self.out_indx].data["s"], dim=1
            )

            # step 4 (line 7)
            v = th.cat(
                [self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0
            )
            self.g.edata["b"] = self.g.edata["b"] + (
                self.g.edata["u_hat"] * v
            ).sum(dim=1, keepdim=True)

    @staticmethod
    def squash(s, dim=1):
        sq = th.sum(s**2, dim=dim, keepdim=True)
        s_norm = th.sqrt(sq)
        s = (sq / (1.0 + sq)) * (s / s_norm)
        return s

步骤3:测试

创建一个简单的20x10胶囊层。

in_nodes = 20
out_nodes = 10
f_size = 4
u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)
/home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/checkouts/latest/python/dgl/heterograph.py:92: DGLWarning: Recommend creating graphs by `dgl.graph(data)` instead of `dgl.DGLGraph(data)`.
  dgl_warning(

您可以通过监控耦合系数的熵来可视化胶囊网络的行为。它们应该开始时较高,然后随着权重逐渐集中在较少的边缘上而下降。

entropy_list = []
dist_list = []

for i in range(10):
    routing(u_hat)
    dist_matrix = routing.g.edata["c"].view(in_nodes, out_nodes)
    entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
    entropy_list.append(entropy.data.numpy())
    dist_list.append(dist_matrix.data.numpy())
stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o")
plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list)))
plt.close()

image3

或者,我们也可以观察直方图的演变。

import matplotlib.animation as animation
import seaborn as sns

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()


def dist_animate(i):
    ax.cla()
    sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
    ax.set_xlabel("Weight Distribution Histogram")
    ax.set_title("Routing: %d" % (i))


ani = animation.FuncAnimation(
    fig, dist_animate, frames=len(entropy_list), interval=500
)
plt.close()

image4

您可以监控较低级别的Capsules如何逐渐附加到较高级别的Capsules之一。

import networkx as nx
from networkx.algorithms import bipartite

g = routing.g.to_networkx()
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
pos = dict()

fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2.clf()
ax = fig2.subplots()
pos.update(
    (n, (i, 1)) for i, n in zip(height_in_y, X)
)  # put nodes from X at x=1
pos.update(
    (n, (i, 2)) for i, n in zip(height_out_y, Y)
)  # put nodes from Y at x=2


def weight_animate(i):
    ax.cla()
    ax.axis("off")
    ax.set_title("Routing: %d  " % i)
    dm = dist_list[i]
    nx.draw_networkx_nodes(
        g, pos, nodelist=range(in_nodes), node_color="r", node_size=100, ax=ax
    )
    nx.draw_networkx_nodes(
        g,
        pos,
        nodelist=range(in_nodes, in_nodes + out_nodes),
        node_color="b",
        node_size=100,
        ax=ax,
    )
    for edge in g.edges():
        nx.draw_networkx_edges(
            g,
            pos,
            edgelist=[edge],
            width=dm[edge[0], edge[1] - in_nodes] * 1.5,
            ax=ax,
        )


ani2 = animation.FuncAnimation(
    fig2, weight_animate, frames=len(dist_list), interval=500
)
plt.close()

image5

此可视化的完整代码可在 GitHub上找到。完整的 在MNIST上训练的代码也可以在GitHub上找到。

脚本的总运行时间: (0 分钟 0.607 秒)

Gallery generated by Sphinx-Gallery