ModelCheckpoint#
-
class
ignite.handlers.checkpoint.
ModelCheckpoint
(dirname, filename_prefix, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, global_step_transform=None, include_self=False, **kwargs)[source]# ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider
Checkpoint
.This handler expects two arguments:
an
Engine
objecta dict mapping names (str) to objects that should be saved to disk.
See Examples for further details.
Warning
Behaviour of this class has been changed since v0.3.0.
There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise.
A single pt file is created instead of multiple files.
- Parameters
dirname (str) – Directory path where objects will be saved.
filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of
Checkpoint
for more details.score_function (Optional[Callable]) – if not None, it should be a function taking a single argument, an
Engine
object, and return a score (float). Objects with highest scores will be retained.score_name (Optional[str]) – if
score_function
not None, it is possible to store its value using score_name. See Notes for more details.n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.
atomic (bool) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).
require_empty (bool) – If True, will raise exception if there are any files starting with
filename_prefix
in the directorydirname
.create_dir (bool) – If True, will create directory
dirname
if it does not exist.global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use
global_step_from_engine()
.include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in
to_save
with keycheckpointer
.kwargs (Any) – Accepted keyword arguments for torch.save or xm.save in DiskSaver.
Changed in version 0.4.2: Accept
kwargs
for torch.save or xm.saveExamples
>>> import os >>> from ignite.engine import Engine, Events >>> from ignite.handlers import ModelCheckpoint >>> from torch import nn >>> trainer = Engine(lambda engine, batch: None) >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True) >>> model = nn.Linear(3, 3) >>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) >>> trainer.run([0, 1, 2, 3, 4], max_epochs=6) >>> os.listdir('/tmp/models') ['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt'] >>> handler.last_checkpoint ['/tmp/models/myprefix_mymodel_30.pt']
Methods
Helper method to get default score function based on the metric name.
Helper method to apply
load_state_dict
on the objects fromto_load
using states fromcheckpoint
.Method replace internal state of the class with provided state dict data.
Method to reset saved checkpoint names.
Helper method to get the default filename pattern for a checkpoint.
Method returns state dict with saved items: list of
(priority, filename)
pairs.-
class
Item
(priority, filename)# Create new instance of Item(priority, filename)
-
count
(value, /)# Return number of occurrences of value.
-
property
filename
# Alias for field number 1
-
index
(value, start=0, stop=9223372036854775807, /)# Return first index of value.
Raises ValueError if the value is not present.
-
property
priority
# Alias for field number 0
-
-
static
get_default_score_fn
(metric_name, score_sign=1.0)# Helper method to get default score function based on the metric name.
- Parameters
metric_name (str) – metric name to get the value from
engine.state.metrics
. Engine is the one to whichCheckpoint
handler is added.score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.
- Return type
Callable
Exemples:
from ignite.handlers import Checkpoint best_acc_score = Checkpoint.get_default_score_fn("accuracy") best_model_handler = Checkpoint( to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
Usage with error-like metric:
from ignite.handlers import Checkpoint neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0) best_model_handler = Checkpoint( to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
New in version 0.4.3.
-
static
load_objects
(to_load, checkpoint, **kwargs)# Helper method to apply
load_state_dict
on the objects fromto_load
using states fromcheckpoint
.Exemples:
import torch from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint, Checkpoint trainer = Engine(lambda engine, batch: None) handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True) model = torch.nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) to_save = {"weights": model, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) trainer.run(torch.randn(10, 1), 5) to_load = to_save checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" checkpoint = torch.load(checkpoint_fp) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Note
If
to_load
contains objects of type torch DistributedDataParallel or DataParallel, methodload_state_dict
will applied to their internal wrapped model (obj.module
).- Parameters
to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}
checkpoint (Mapping) – a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.
kwargs (Any) – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)
- Return type
-
load_state_dict
(state_dict)# Method replace internal state of the class with provided state dict data.
- Parameters
state_dict (Mapping) – a dict with “saved” key and list of
(priority, filename)
pairs as values.- Return type
-
reset
()# Method to reset saved checkpoint names.
Use this method if the engine will independently run multiple times:
from ignite.handlers import Checkpoint trainer = ... checkpointer = Checkpoint(...) trainer.add_event_handler(Events.COMPLETED, checkpointer) trainer.add_event_handler(Events.STARTED, checkpointer.reset) # fold 0 trainer.run(data0, max_epochs=max_epochs) print("Last checkpoint:", checkpointer.last_checkpoint) # fold 1 trainer.run(data1, max_epochs=max_epochs) print("Last checkpoint:", checkpointer.last_checkpoint)
New in version 0.4.3.
- Return type
-
static
setup_filename_pattern
(with_prefix=True, with_score=True, with_score_name=True, with_global_step=True)# Helper method to get the default filename pattern for a checkpoint.
- Parameters
with_prefix (bool) – If True, the
filename_prefix
is added to the filename pattern:{filename_prefix}_{name}...
. Default, True.with_score (bool) – If True,
score
is added to the filename pattern:..._{score}.{ext}
. Default, True. At least one ofwith_score
andwith_global_step
should be True.with_score_name (bool) – If True,
score_name
is added to the filename pattern:..._{score_name}={score}.{ext}
. If activated, argumentwith_score
should be also True, otherwise an error is raised. Default, True.with_global_step (bool) – If True,
{global_step}
is added to the filename pattern:...{name}_{global_step}...
. At least one ofwith_score
andwith_global_step
should be True.
- Return type
Example
from ignite.handlers import Checkpoint filename_pattern = Checkpoint.setup_filename_pattern() print(filename_pattern) > "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"
New in version 0.4.3.