2. 快速入门¶
本页面将引导您完成使用Norse提高效率的初始步骤。 我们将介绍如何
处理神经元状态
使用Norse无需考虑时间因素
使用Norse处理循环任务
随时间使用Norse
如果您对脉冲神经网络(SNNs)完全不了解,我们建议您浏览我们介绍该主题的页面:脉冲信号入门。
2.1. 运行现成代码¶
如果您想快速入门,我们推荐使用Jupyter Notebook教程合集。这些教程可以直接在Google Colab上在线运行。
此外,我们提供了一系列任务,您可以在安装Norse后立即运行。 其中最常见的实验之一是MNIST分类任务。 Norse实现了与现代非脉冲网络相当的性能:
python -m norse.task.mnist
请参阅运行任务获取更多任务及如何运行它们的详细信息。
2.2. 构建带状态的神经网络¶
如果你想使用Norse构建自己的模型,需要了解神经元包含状态这一特性。 实际上这意味着Norse中的神经元会输出两个内容:脉冲张量和神经元状态。 Norse会在初始阶段为你初始化所有必要的状态,但你需要 持续传递这个状态。 如果不这样做,状态将始终为零,神经元永远不会发放脉冲,你的神经元将 永远处于死亡状态!
import torch
import norse
cell = norse.torch.LIFCell()
data = torch.ones(1)
spikes, state = cell(data)
下次调用该单元格时,需要传入该状态。否则将得到完全相同的输出
spikes, state = cell(data, state)
注意: 如果你在寻找灵感,这与PyTorch的RNN模块类似。
2.3. 使用带时间参数的Norse神经元¶
与PyTorch的Sequential类似, Norse的神经元模型可以在网络中串联。 遗憾的是,由于状态的原因,这与PyTorch自身的RNNs一样无法直接实现。 为此,Norse提供了SequentialState模块, 用于将有状态的模块连接起来:
import torch
import norse
model = SequentialState(
torch.nn.Linear(10, 5),
norse.torch.LIFCell(),
torch.nn.Linear(5, 1)
)
data = torch.ones(8, 10) # (batch, input)
out, state = model(data) # (8, 1) output shape
2.4. 在Norse中使用循环¶
所有神经元模块都包含Cell和RecurrentCell类。
上文中应用的Cell类仅作为神经元的前馈激活功能,
而RecurrentCell还包含线性和循环权重(我们同时对输入和循环脉冲进行加权)。
因此,我们需要告知模块它需要采用的形状,
因为我们必须初始化权重以匹配所需的输入/输出形状。
RecurrentCell 类可以直接使用,并能轻松集成到上述代码中。请注意,我们保持了相同的输入/输出形状,但这可以很容易地调整:
import torch
import norse
model = SequentialState(
torch.nn.Linear(10, 5),
norse.torch.LIFRecurrentCell(5, 5),
torch.nn.Linear(5, 1)
)
data = torch.ones(8, 10) # (batch, input)
out, state = model(data) # (8, 1) output shape
2.5. 时间序列中使用Norse¶
上述``XCell``遵循PyTorch的抽象概念,其中单元是"简单"的激活函数,仅应用一次。 然而,神经元存在于时间中,需要至少几个时间步长的输入才能发生有趣的事情(比如脉冲)。
上述网络(不含时间的那一个)在加入时间维度后依然能完美运行,您可以通过简单的for循环将其封装。不过,也可以选择让每个模块随时间独立运行。
在Norse中,我们通过从模型中移除Cell后缀来建模这个时间维度。因此,时序中的LIFCell将简称为LIF。同理,时序中的LIFRecurrentCell将简称为LIFRecurrent。
常规的Torch模块也需要在时间维度上运行。为此,我们添加了一个模块来提升PyTorch模块至时间域(即简单地在每个时间步运行一次)。
综上所述,我们得到以下内容:
import torch
import norse
model = SequentialState(
norse.Lift(torch.nn.Linear(10, 5)),
norse.LSNNRecurrent(5, 5),
norse.Lift(torch.nn.Linear(5, 1))
)
data = torch.ones(100, 8, 10) # (time, batch, input)
out, state = model(data)