PyTorch-Ignite PyTorch-Ignite

如何在训练期间切换数据提供者

在这个例子中,我们将看到如何在训练过程中使用set_data()轻松切换数据提供者。

基本设置

必需的依赖项

!pip install pytorch-ignite

导入

from ignite.engine import Engine, Events

数据提供者

data1 = [1, 2, 3]
data2 = [11, 12, 13]

创建虚拟的 trainer

让我们创建一个虚拟的 train_step,它将打印当前的迭代和数据批次。

def train_step(engine, batch):
    print(f"Iter[{engine.state.iteration}] Current datapoint = ", batch)

trainer = Engine(train_step)

附加处理程序以切换数据

现在我们必须决定何时切换数据提供者。可以是在一个epoch、迭代或自定义事件之后。下面,我们将在特定的迭代之后切换数据。然后我们附加一个处理器到trainer,它将在switch_iteration之后执行一次,并使用set_data(),以便当:

  • 迭代 <= switch_iteration,批次来自 data1
  • 迭代 > switch_iteration, 批次来自 data2
switch_iteration = 5


@trainer.on(Events.ITERATION_COMPLETED(once=switch_iteration))
def switch_dataloader():
    print("<------- Switch Data ------->")
    trainer.set_data(data2)

最后我们运行trainer进行一些周期。

trainer.run(data1, max_epochs=5)
Iter[1] Current datapoint =  1
Iter[2] Current datapoint =  2
Iter[3] Current datapoint =  3
Iter[4] Current datapoint =  1
Iter[5] Current datapoint =  2
<------- Switch Data ------->
Iter[6] Current datapoint =  11
Iter[7] Current datapoint =  12
Iter[8] Current datapoint =  13
Iter[9] Current datapoint =  11
Iter[10] Current datapoint =  12
Iter[11] Current datapoint =  13
Iter[12] Current datapoint =  11
Iter[13] Current datapoint =  12
Iter[14] Current datapoint =  13
Iter[15] Current datapoint =  11





State:
	iteration: 15
	epoch: 5
	epoch_length: 3
	max_epochs: 5
	output: <class 'NoneType'>
	batch: 11
	metrics: <class 'dict'>
	dataloader: <class 'list'>
	seed: <class 'NoneType'>
	times: <class 'dict'>