如何使用数据迭代器
当训练或验证的数据提供者是一个迭代器(无限或有限,已知或未知大小)时,以下是一些如何设置训练器或评估器的基本示例。
用于训练的无限迭代器
让我们使用一个无限数据迭代器作为训练数据流
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def infinite_iterator(batch_size):
while True:
batch = torch.rand(batch_size, 3, 32, 32)
yield batch
def train_step(trainer, batch):
# ...
s = trainer.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}"
)
trainer = Engine(train_step)
# We need to specify epoch_length to define the epoch
trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=3)
1/3 : 1 - 63.862
1/3 : 2 - 64.042
1/3 : 3 - 63.936
1/3 : 4 - 64.141
1/3 : 5 - 64.767
2/3 : 6 - 63.791
2/3 : 7 - 64.565
2/3 : 8 - 63.602
2/3 : 9 - 63.995
2/3 : 10 - 63.943
3/3 : 11 - 63.831
3/3 : 12 - 64.276
3/3 : 13 - 64.148
3/3 : 14 - 63.920
3/3 : 15 - 64.226
State:
iteration: 15
epoch: 3
epoch_length: 5
max_epochs: 3
output: <class 'NoneType'>
batch: <class 'torch.Tensor'>
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
如果我们不指定epoch_length,我们可以通过调用terminate()来显式停止训练。在这种情况下,只会定义一个单一的epoch。
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def infinite_iterator(batch_size):
while True:
batch = torch.rand(batch_size, 3, 32, 32)
yield batch
def train_step(trainer, batch):
# ...
s = trainer.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}"
)
trainer = Engine(train_step)
@trainer.on(Events.ITERATION_COMPLETED(once=15))
def stop_training():
trainer.terminate()
trainer.run(infinite_iterator(4))
1/1 : 1 - 63.862
1/1 : 2 - 64.042
1/1 : 3 - 63.936
1/1 : 4 - 64.141
1/1 : 5 - 64.767
1/1 : 6 - 63.791
1/1 : 7 - 64.565
1/1 : 8 - 63.602
1/1 : 9 - 63.995
1/1 : 10 - 63.943
1/1 : 11 - 63.831
1/1 : 12 - 64.276
1/1 : 13 - 64.148
1/1 : 14 - 63.920
1/1 : 15 - 64.226
State:
iteration: 15
epoch: 1
epoch_length: <class 'NoneType'>
max_epochs: 1
output: <class 'NoneType'>
batch: <class 'torch.Tensor'>
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
相同的代码可以用于验证模型。
长度未知的有限迭代器
让我们使用一个有限的数据迭代器,但其长度对用户来说是未知的。在训练的情况下,我们希望多次遍历数据流,因此当数据迭代器耗尽时需要重新启动它。在代码中,我们没有指定epoch_length,它将自动确定。
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def finite_unk_size_data_iter():
for i in range(11):
yield i
def train_step(trainer, batch):
# ...
s = trainer.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}"
)
trainer = Engine(train_step)
@trainer.on(Events.DATALOADER_STOP_ITERATION)
def restart_iter():
trainer.state.dataloader = finite_unk_size_data_iter()
data_iter = finite_unk_size_data_iter()
trainer.run(data_iter, max_epochs=5)
1/5 : 1 - 0.000
1/5 : 2 - 1.000
1/5 : 3 - 2.000
1/5 : 4 - 3.000
1/5 : 5 - 4.000
1/5 : 6 - 5.000
1/5 : 7 - 6.000
1/5 : 8 - 7.000
1/5 : 9 - 8.000
1/5 : 10 - 9.000
1/5 : 11 - 10.000
2/5 : 12 - 0.000
2/5 : 13 - 1.000
2/5 : 14 - 2.000
2/5 : 15 - 3.000
2/5 : 16 - 4.000
2/5 : 17 - 5.000
2/5 : 18 - 6.000
2/5 : 19 - 7.000
2/5 : 20 - 8.000
2/5 : 21 - 9.000
2/5 : 22 - 10.000
3/5 : 23 - 0.000
3/5 : 24 - 1.000
3/5 : 25 - 2.000
3/5 : 26 - 3.000
3/5 : 27 - 4.000
3/5 : 28 - 5.000
3/5 : 29 - 6.000
3/5 : 30 - 7.000
3/5 : 31 - 8.000
3/5 : 32 - 9.000
3/5 : 33 - 10.000
4/5 : 34 - 0.000
4/5 : 35 - 1.000
4/5 : 36 - 2.000
4/5 : 37 - 3.000
4/5 : 38 - 4.000
4/5 : 39 - 5.000
4/5 : 40 - 6.000
4/5 : 41 - 7.000
4/5 : 42 - 8.000
4/5 : 43 - 9.000
4/5 : 44 - 10.000
5/5 : 45 - 0.000
5/5 : 46 - 1.000
5/5 : 47 - 2.000
5/5 : 48 - 3.000
5/5 : 49 - 4.000
5/5 : 50 - 5.000
5/5 : 51 - 6.000
5/5 : 52 - 7.000
5/5 : 53 - 8.000
5/5 : 54 - 9.000
5/5 : 55 - 10.000
State:
iteration: 55
epoch: 5
epoch_length: 11
max_epochs: 5
output: <class 'NoneType'>
batch: 10
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
在验证的情况下,代码很简单
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def finite_unk_size_data_iter():
for i in range(11):
yield i
def val_step(evaluator, batch):
# ...
s = evaluator.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}"
)
evaluator = Engine(val_step)
data_iter = finite_unk_size_data_iter()
evaluator.run(data_iter)
1/1 : 1 - 0.000
1/1 : 2 - 1.000
1/1 : 3 - 2.000
1/1 : 4 - 3.000
1/1 : 5 - 4.000
1/1 : 6 - 5.000
1/1 : 7 - 6.000
1/1 : 8 - 7.000
1/1 : 9 - 8.000
1/1 : 10 - 9.000
1/1 : 11 - 10.000
State:
iteration: 11
epoch: 1
epoch_length: 11
max_epochs: 1
output: <class 'NoneType'>
batch: <class 'NoneType'>
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
已知长度的有限迭代器
让我们使用一个已知大小的有限数据迭代器进行训练或验证。如果我们需要重新启动数据迭代器,我们可以像在未知大小的情况下一样,通过在@trainer.on(Events.DATALOADER_STOP_ITERATION)上附加重启处理程序来实现,但在这里我们将在迭代时显式地执行此操作:
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
size = 11
def finite_size_data_iter(size):
for i in range(size):
yield i
def train_step(trainer, batch):
# ...
s = trainer.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}"
)
trainer = Engine(train_step)
@trainer.on(Events.ITERATION_COMPLETED(every=size))
def restart_iter():
trainer.state.dataloader = finite_size_data_iter(size)
data_iter = finite_size_data_iter(size)
trainer.run(data_iter, max_epochs=5)
1/5 : 1 - 0.000
1/5 : 2 - 1.000
1/5 : 3 - 2.000
1/5 : 4 - 3.000
1/5 : 5 - 4.000
1/5 : 6 - 5.000
1/5 : 7 - 6.000
1/5 : 8 - 7.000
1/5 : 9 - 8.000
1/5 : 10 - 9.000
1/5 : 11 - 10.000
2/5 : 12 - 0.000
2/5 : 13 - 1.000
2/5 : 14 - 2.000
2/5 : 15 - 3.000
2/5 : 16 - 4.000
2/5 : 17 - 5.000
2/5 : 18 - 6.000
2/5 : 19 - 7.000
2/5 : 20 - 8.000
2/5 : 21 - 9.000
2/5 : 22 - 10.000
3/5 : 23 - 0.000
3/5 : 24 - 1.000
3/5 : 25 - 2.000
3/5 : 26 - 3.000
3/5 : 27 - 4.000
3/5 : 28 - 5.000
3/5 : 29 - 6.000
3/5 : 30 - 7.000
3/5 : 31 - 8.000
3/5 : 32 - 9.000
3/5 : 33 - 10.000
4/5 : 34 - 0.000
4/5 : 35 - 1.000
4/5 : 36 - 2.000
4/5 : 37 - 3.000
4/5 : 38 - 4.000
4/5 : 39 - 5.000
4/5 : 40 - 6.000
4/5 : 41 - 7.000
4/5 : 42 - 8.000
4/5 : 43 - 9.000
4/5 : 44 - 10.000
5/5 : 45 - 0.000
5/5 : 46 - 1.000
5/5 : 47 - 2.000
5/5 : 48 - 3.000
5/5 : 49 - 4.000
5/5 : 50 - 5.000
5/5 : 51 - 6.000
5/5 : 52 - 7.000
5/5 : 53 - 8.000
5/5 : 54 - 9.000
5/5 : 55 - 10.000
State:
iteration: 55
epoch: 5
epoch_length: 11
max_epochs: 5
output: <class 'NoneType'>
batch: 10
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
在验证的情况下,代码很简单
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
size = 11
def finite_size_data_iter(size):
for i in range(size):
yield i
def val_step(evaluator, batch):
# ...
s = evaluator.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}"
)
evaluator = Engine(val_step)
data_iter = finite_size_data_iter(size)
evaluator.run(data_iter)
1/1 : 1 - 0.000
1/1 : 2 - 1.000
1/1 : 3 - 2.000
1/1 : 4 - 3.000
1/1 : 5 - 4.000
1/1 : 6 - 5.000
1/1 : 7 - 6.000
1/1 : 8 - 7.000
1/1 : 9 - 8.000
1/1 : 10 - 9.000
1/1 : 11 - 10.000
State:
iteration: 11
epoch: 1
epoch_length: 11
max_epochs: 1
output: <class 'NoneType'>
batch: <class 'NoneType'>
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>