如何在训练期间切换数据提供者
在这个例子中,我们将看到如何在训练过程中使用set_data()轻松切换数据提供者。
基本设置
必需的依赖项
!pip install pytorch-ignite
导入
from ignite.engine import Engine, Events
数据提供者
data1 = [1, 2, 3]
data2 = [11, 12, 13]
创建虚拟的 trainer
让我们创建一个虚拟的 train_step,它将打印当前的迭代和数据批次。
def train_step(engine, batch):
print(f"Iter[{engine.state.iteration}] Current datapoint = ", batch)
trainer = Engine(train_step)
附加处理程序以切换数据
现在我们必须决定何时切换数据提供者。可以是在一个epoch、迭代或自定义事件之后。下面,我们将在特定的迭代之后切换数据。然后我们附加一个处理器到trainer,它将在switch_iteration之后执行一次,并使用set_data(),以便当:
- 迭代 <=
switch_iteration,批次来自data1 - 迭代 >
switch_iteration, 批次来自data2
switch_iteration = 5
@trainer.on(Events.ITERATION_COMPLETED(once=switch_iteration))
def switch_dataloader():
print("<------- Switch Data ------->")
trainer.set_data(data2)
最后我们运行trainer进行一些周期。
trainer.run(data1, max_epochs=5)
Iter[1] Current datapoint = 1
Iter[2] Current datapoint = 2
Iter[3] Current datapoint = 3
Iter[4] Current datapoint = 1
Iter[5] Current datapoint = 2
<------- Switch Data ------->
Iter[6] Current datapoint = 11
Iter[7] Current datapoint = 12
Iter[8] Current datapoint = 13
Iter[9] Current datapoint = 11
Iter[10] Current datapoint = 12
Iter[11] Current datapoint = 13
Iter[12] Current datapoint = 11
Iter[13] Current datapoint = 12
Iter[14] Current datapoint = 13
Iter[15] Current datapoint = 11
State:
iteration: 15
epoch: 5
epoch_length: 3
max_epochs: 5
output: <class 'NoneType'>
batch: 11
metrics: <class 'dict'>
dataloader: <class 'list'>
seed: <class 'NoneType'>
times: <class 'dict'>