Shortcuts

ModelCheckpoint#

class ignite.handlers.checkpoint.ModelCheckpoint(dirname, filename_prefix, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, global_step_transform=None, include_self=False, **kwargs)[source]#

ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider Checkpoint.

This handler expects two arguments:

  • an Engine object

  • a dict mapping names (str) to objects that should be saved to disk.

See Examples for further details.

Warning

Behaviour of this class has been changed since v0.3.0.

There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise.

A single pt file is created instead of multiple files.

Parameters
  • dirname (str) – Directory path where objects will be saved.

  • filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of Checkpoint for more details.

  • score_function (Optional[Callable]) – if not None, it should be a function taking a single argument, an Engine object, and return a score (float). Objects with highest scores will be retained.

  • score_name (Optional[str]) – if score_function not None, it is possible to store its value using score_name. See Notes for more details.

  • n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • atomic (bool) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).

  • require_empty (bool) – If True, will raise exception if there are any files starting with filename_prefix in the directory dirname.

  • create_dir (bool) – If True, will create directory dirname if it does not exist.

  • global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in to_save with key checkpointer.

  • kwargs (Any) – Accepted keyword arguments for torch.save or xm.save in DiskSaver.

Changed in version 0.4.2: Accept kwargs for torch.save or xm.save

Examples

>>> import os
>>> from ignite.engine import Engine, Events
>>> from ignite.handlers import ModelCheckpoint
>>> from torch import nn
>>> trainer = Engine(lambda engine, batch: None)
>>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True)
>>> model = nn.Linear(3, 3)
>>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
>>> trainer.run([0, 1, 2, 3, 4], max_epochs=6)
>>> os.listdir('/tmp/models')
['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt']
>>> handler.last_checkpoint
['/tmp/models/myprefix_mymodel_30.pt']

Methods

get_default_score_fn

Helper method to get default score function based on the metric name.

load_objects

Helper method to apply load_state_dict on the objects from to_load using states from checkpoint.

load_state_dict

Method replace internal state of the class with provided state dict data.

reset

Method to reset saved checkpoint names.

setup_filename_pattern

Helper method to get the default filename pattern for a checkpoint.

state_dict

Method returns state dict with saved items: list of (priority, filename) pairs.

class Item(priority, filename)#

Create new instance of Item(priority, filename)

Parameters
  • priority (int) –

  • filename (str) –

count(value, /)#

Return number of occurrences of value.

property filename#

Alias for field number 1

index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

property priority#

Alias for field number 0

static get_default_score_fn(metric_name, score_sign=1.0)#

Helper method to get default score function based on the metric name.

Parameters
  • metric_name (str) – metric name to get the value from engine.state.metrics. Engine is the one to which Checkpoint handler is added.

  • score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.

Return type

Callable

Exemples:

from ignite.handlers import Checkpoint

best_acc_score = Checkpoint.get_default_score_fn("accuracy")

best_model_handler = Checkpoint(
    to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

Usage with error-like metric:

from ignite.handlers import Checkpoint

neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0)

best_model_handler = Checkpoint(
    to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

New in version 0.4.3.

static load_objects(to_load, checkpoint, **kwargs)#

Helper method to apply load_state_dict on the objects from to_load using states from checkpoint.

Exemples:

import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)

to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

Note

If to_load contains objects of type torch DistributedDataParallel or DataParallel, method load_state_dict will applied to their internal wrapped model (obj.module).

Parameters
  • to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}

  • checkpoint (Mapping) – a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.

  • kwargs (Any) – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)

Return type

None

load_state_dict(state_dict)#

Method replace internal state of the class with provided state dict data.

Parameters

state_dict (Mapping) – a dict with “saved” key and list of (priority, filename) pairs as values.

Return type

None

reset()#

Method to reset saved checkpoint names.

Use this method if the engine will independently run multiple times:

from ignite.handlers import Checkpoint

trainer = ...
checkpointer = Checkpoint(...)

trainer.add_event_handler(Events.COMPLETED, checkpointer)
trainer.add_event_handler(Events.STARTED, checkpointer.reset)

# fold 0
trainer.run(data0, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

# fold 1
trainer.run(data1, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

New in version 0.4.3.

Return type

None

static setup_filename_pattern(with_prefix=True, with_score=True, with_score_name=True, with_global_step=True)#

Helper method to get the default filename pattern for a checkpoint.

Parameters
  • with_prefix (bool) – If True, the filename_prefix is added to the filename pattern: {filename_prefix}_{name}.... Default, True.

  • with_score (bool) – If True, score is added to the filename pattern: ..._{score}.{ext}. Default, True. At least one of with_score and with_global_step should be True.

  • with_score_name (bool) – If True, score_name is added to the filename pattern: ..._{score_name}={score}.{ext}. If activated, argument with_score should be also True, otherwise an error is raised. Default, True.

  • with_global_step (bool) – If True, {global_step} is added to the filename pattern: ...{name}_{global_step}.... At least one of with_score and with_global_step should be True.

Return type

str

Example

from ignite.handlers import Checkpoint

filename_pattern = Checkpoint.setup_filename_pattern()

print(filename_pattern)
> "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"

New in version 0.4.3.

state_dict()#

Method returns state dict with saved items: list of (priority, filename) pairs. Can be used to save internal state of the class.

Return type

OrderedDict[str, List[Tuple[int, str]]]