How to load checkpoint and resume training
In this example, we will be using a ResNet18 model on the MNIST dataset. The base code is the same as used in the Getting Started Guide.
Required Dependencies
!pip install pytorch-ignite -q
Basic Setup
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Checkpoint, global_step_from_engine
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
def forward(self, x):
return self.model(x)
model = Net().to(device)
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(
MNIST(download=True, root=".", transform=data_transform, train=True),
batch_size=128,
shuffle=True,
)
val_loader = DataLoader(
MNIST(download=True, root=".", transform=data_transform, train=False),
batch_size=256,
shuffle=False,
)
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
evaluator = create_supervised_evaluator(
model, metrics={"accuracy": Accuracy(), "loss": Loss(criterion)}, device=device
)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print(
f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}"
)
Checkpoint
We can use
Checkpoint()
as shown below to save the latest model after each epoch is completed. to_save
here also saves the state of the optimizer and trainer
in case we want to load this checkpoint and resume training.
to_save = {'model': model, 'optimizer': optimizer, 'trainer': trainer}
checkpoint_dir = "checkpoints/"
checkpoint = Checkpoint(
to_save,
checkpoint_dir,
n_saved=1,
global_step_transform=global_step_from_engine(trainer),
)
evaluator.add_event_handler(Events.COMPLETED, checkpoint)
<ignite.engine.events.RemovableEventHandle at 0x7f1a8490c090>
Start Training
Finally, we start the engine on the training dataset and run it for 2 epochs:
trainer.run(train_loader, max_epochs=2)
Validation Results - Epoch[1] Avg accuracy: 0.96 Avg loss: 0.16
Validation Results - Epoch[2] Avg accuracy: 0.98 Avg loss: 0.07
State:
iteration: 938
epoch: 2
epoch_length: 469
max_epochs: 2
output: 0.026344267651438713
batch: <class 'list'>
metrics: <class 'dict'>
dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
seed: <class 'NoneType'>
times: <class 'dict'>
Load Checkpoint
Now let’s assume, we have reset our model, optimizer and trainer. After instantiating these objects again, we need to resume training from the checkpoint that we have saved.
!ls checkpoints
checkpoint_2.pt
We can use
load_objects()
to apply the state of our checkpoint to the objects stored in to_save
.
checkpoint_fp = checkpoint_dir + "checkpoint_2.pt"
checkpoint = torch.load(checkpoint_fp, map_location=device)
Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)
Resume Training
trainer.run(train_loader, max_epochs=4)
Validation Results - Epoch[3] Avg accuracy: 0.99 Avg loss: 0.04
Validation Results - Epoch[4] Avg accuracy: 0.98 Avg loss: 0.06
State:
iteration: 1876
epoch: 4
epoch_length: 469
max_epochs: 4
output: 0.0412273071706295
batch: <class 'list'>
metrics: <class 'dict'>
dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
seed: <class 'NoneType'>
times: <class 'dict'>