EMAHandler#
-
class
ignite.handlers.ema_handler.EMAHandler(model, momentum=0.0002, momentum_warmup=None, warmup_iters=None)[source]# Exponential moving average (EMA) handler can be used to compute a smoothed version of model. The EMA model is updated as follows:
where and are the EMA weights and online model weights at -th iteration, respectively; is the update momentum. The handler allows for linearly warming up the momentum in the beginning when training process is not stable. Current momentum can be retrieved from
Engine.state.ema_momentum.- Parameters
model (torch.nn.modules.module.Module) – the online model for which an EMA model will be computed. If
modelisDataParallelorDistributedDataParallel, the EMA smoothing will be applied tomodel.module.momentum (float) – the update momentum after warmup phase, should be float in range .
momentum_warmup (Optional[float]) – the initial update momentum during warmup phase, the value should be smaller than
momentum. Momentum will linearly increase from this value tomomentuminwarmup_itersiterations. IfNone, no warmup will be performed.warmup_iters (Optional[int]) – iterations of warmup. If
None, no warmup will be performed.
- Return type
-
ema_model# the exponential moving averaged model.
-
model# the online model that is tracked by EMAHandler. It is
model.moduleifmodelin the initialization method is an instance ofDistributedDataParallel.
-
momentum# the update momentum after warmup phase.
-
momentum_warmup# the initial update momentum.
-
warmup_iters# number of warmup iterations.
Note
The EMA model is already in
evalmode. If model in the arguments is annn.ModuleorDistributedDataParallel, the EMA model is annn.Moduleand it is on the same device as the online model. If the model is annn.DataParallel, then the EMA model is annn.DataParallel.Note
It is recommended to initialize and use an EMA handler in following steps:
Initialize
model(nn.ModuleorDistributedDataParallel) andema_handler(EMAHandler).Build
trainer(ignite.engine.Engine).Resume from checkpoint for
modelandema_handler.ema_model.Attach
ema_handlertotrainer.
Examples
device = torch.device("cuda:0") model = nn.Linear(2, 1).to(device) # update the ema every 5 iterations ema_handler = EMAHandler( model, momentum=0.0002, momentum_warmup=0.0001, warmup_iters=10000) # get the ema model, which is an instance of nn.Module ema_model = ema_handler.ema_model trainer = Engine(train_step_fn) to_load = {"model": model, "ema_model", ema_model, "trainer", trainer} if resume_from is not None: Checkpoint.load_objects(to_load, checkpoint=resume_from) # update the EMA model every 5 iterations ema_handler.attach(trainer, name="ema_momentum", event=Events.ITERATION_COMPLETED(every=5)) # add other handlers to_save = to_load ckpt_handler = Checkpoint(to_save, DiskSaver(...), ...) trainer.add_event_handler(Events.EPOCH_COMPLETED, ckpt_handler) # current momentum can be retrieved from engine.state, # the attribute name is the `name` parameter used in the attach function @trainer.on(Events.ITERATION_COMPLETED): def print_ema_momentum(engine): print(f"current momentum: {engine.state.ema_momentum}" # use ema model for validation val_step_fn = get_val_step_fn(ema_model) evaluator = Engine(val_step_fn) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): engine.run(val_data_loader) trainer.run(...)
The following example shows how to attach two handlers to the same trainer:
generator = build_generator(...) discriminator = build_discriminator(...) gen_handler = EMAHandler(generator) disc_handler = EMAHandler(discriminator) step_fn = get_step_fn(...) engine = Engine(step_fn) # update EMA model of generator every 1 iteration gen_handler.attach(engine, "gen_ema_momentum", event=Events.ITERATION_COMPLETED) # update EMA model of discriminator every 2 iteration disc_handler.attach(engine, "dis_ema_momentum", event=Events.ITERATION_COMPLETED(every=2)) @engine.on(Events.ITERATION_COMPLETED) def print_ema_momentum(engine): print(f"current momentum for generator: {engine.state.gen_ema_momentum}") print(f"current momentum for discriminator: {engine.state.disc_ema_momentum}") engine.run(...)
New in version 0.5.0.
Methods
Attach the handler to engine.
-
attach(engine, name='ema_momentum', event=<Events.ITERATION_COMPLETED: 'iteration_completed'>)[source]# Attach the handler to engine. After the handler is attached, the
Engine.statewill add an new attribute with namename. Then, current momentum can be retrieved by fromEngine.statewhen the engine runs.- Parameters
engine (ignite.engine.engine.Engine) – trainer to which the handler will be attached.
name (str) – attribute name for retrieving EMA momentum from
Engine.state. It should be a unique name since a trainer can have multiple EMA handlers.event (Union[str, ignite.engine.events.Events, ignite.engine.events.CallableEventWithFilter, ignite.engine.events.EventsList]) – event when the EMA momentum and EMA model are updated.
- Return type