PyTorch-Ignite PyTorch-Ignite

如何基于前向或后向传递创建自定义事件

本指南演示了如何创建依赖于计算的损失和反向传播的自定义事件

在这个例子中,我们将在MNIST数据集上使用ResNet18模型。基础代码与入门指南中使用的相同。

基本设置

import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

from ignite.engine import Engine, EventEnum, Events, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Timer
from ignite.contrib.handlers import BasicTimeProfiler, HandlersTimeProfiler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.model = resnet18(num_classes=10)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        return self.model(x)


model = Net().to(device)

data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
    MNIST(download=True, root=".", transform=data_transform, train=True),
    batch_size=128,
    shuffle=True,
)

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()

创建自定义事件

首先,让我们基于反向传播创建一些自定义事件。所有用户定义的自定义事件都应继承自基类 EventEnum

class BackpropEvents(EventEnum):
    BACKWARD_STARTED = 'backward_started'
    BACKWARD_COMPLETED = 'backward_completed'
    OPTIM_STEP_COMPLETED = 'optim_step_completed'

创建 trainer

然后我们定义train_step函数以应用于所有批次。在此过程中,我们使用 fire_event来执行与该特定事件相关的所有处理程序。

def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, y = batch[0].to(device), batch[1].to(device)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    
    engine.fire_event(BackpropEvents.BACKWARD_STARTED)
    loss.backward()
    engine.fire_event(BackpropEvents.BACKWARD_COMPLETED)

    optimizer.step()
    engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED)

    return loss.item()


trainer = Engine(train_step)

trainer中注册自定义事件

最后,为了确保我们的事件可以被触发,我们在trainer中使用register_events来注册它们。

trainer.register_events(*BackpropEvents)

将处理程序附加到自定义事件

现在我们可以轻松地附加处理程序,以便在触发特定事件(如BACKWARD_COMPLETED)时执行。

@trainer.on(BackpropEvents.BACKWARD_COMPLETED)
def function_before_backprop(engine):
    print(f"Iter[{engine.state.iteration}] Function fired after backward pass")

最后,你可以运行trainer进行一些周期。

trainer.run(train_loader, max_epochs=3)

你也可以查看TBPTT Trainer的源代码以获取详细解释。