How to switch data provider during training
In this example, we will see how one can easily switch the data provider during the training using
set_data()
.
Basic Setup
Required Dependencies
!pip install pytorch-ignite
Import
from ignite.engine import Engine, Events
Data Providers
data1 = [1, 2, 3]
data2 = [11, 12, 13]
Create dummy trainer
Let’s create a dummy train_step
which will print the current iteration and batch of data.
def train_step(engine, batch):
print(f"Iter[{engine.state.iteration}] Current datapoint = ", batch)
trainer = Engine(train_step)
Attach handler to switch data
Now we have to decide when to switch the data provider. It can be after an epoch, iteration or something custom. Below, we are going to switch data after some specific iteration. And then we attach a handler to trainer
that will be executed once after switch_iteration
and use set_data()
so that when:
- iteration <=
switch_iteration
, batch is fromdata1
- iteration >
switch_iteration
, batch is fromdata2
switch_iteration = 5
@trainer.on(Events.ITERATION_COMPLETED(once=switch_iteration))
def switch_dataloader():
print("<------- Switch Data ------->")
trainer.set_data(data2)
And finally we run the trainer
for some epochs.
trainer.run(data1, max_epochs=5)
Iter[1] Current datapoint = 1
Iter[2] Current datapoint = 2
Iter[3] Current datapoint = 3
Iter[4] Current datapoint = 1
Iter[5] Current datapoint = 2
<------- Switch Data ------->
Iter[6] Current datapoint = 11
Iter[7] Current datapoint = 12
Iter[8] Current datapoint = 13
Iter[9] Current datapoint = 11
Iter[10] Current datapoint = 12
Iter[11] Current datapoint = 13
Iter[12] Current datapoint = 11
Iter[13] Current datapoint = 12
Iter[14] Current datapoint = 13
Iter[15] Current datapoint = 11
State:
iteration: 15
epoch: 5
epoch_length: 3
max_epochs: 5
output: <class 'NoneType'>
batch: 11
metrics: <class 'dict'>
dataloader: <class 'list'>
seed: <class 'NoneType'>
times: <class 'dict'>