Engine#
- 
class ignite.engine.engine.Engine(process_function)[source]#
- Runs a given - process_functionover each batch of a dataset, emitting events as it goes.- Parameters
- process_function (Callable) – A function receiving a handle to the engine and the current batch in each iteration, and returns data to be stored in the engine’s state. 
 - 
state#
- object that is used to pass internal and user-defined state between event handlers. It is created with the engine and its attributes (e.g. - state.iteration,- state.epochetc) are reset on every- run().
 - 
last_event_name#
- last event name triggered by the engine. 
 - Examples - Create a basic trainer - def update_model(engine, batch): inputs, targets = batch optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() return loss.item() trainer = Engine(update_model) @trainer.on(Events.ITERATION_COMPLETED(every=100)) def log_training(engine): batch_loss = engine.state.output lr = optimizer.param_groups[0]['lr'] e = engine.state.epoch n = engine.state.max_epochs i = engine.state.iteration print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss}, lr: {lr}") trainer.run(data_loader, max_epochs=5) > Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01 > ... > Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01 - Create a basic evaluator to compute metrics - from ignite.metrics import Accuracy def predict_on_batch(engine, batch) model.eval() with torch.no_grad(): x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) return y_pred, y evaluator = Engine(predict_on_batch) Accuracy().attach(evaluator, "val_acc") evaluator.run(val_dataloader) - Compute image mean/std on training dataset - from ignite.metrics import Average def compute_mean_std(engine, batch): b, c, *_ = batch['image'].shape data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64) mean = torch.mean(data, dim=-1).sum(dim=0) mean2 = torch.mean(data ** 2, dim=-1).sum(dim=0) return {"mean": mean, "mean^2": mean2} compute_engine = Engine(compute_mean_std) img_mean = Average(output_transform=lambda output: output['mean']) img_mean.attach(compute_engine, 'mean') img_mean2 = Average(output_transform=lambda output: output['mean^2']) img_mean2.attach(compute_engine, 'mean2') state = compute_engine.run(train_loader) state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean'] ** 2) mean = state.metrics['mean'].tolist() std = state.metrics['std'].tolist() - Resume engine’s run from a state. User can load a state_dict and run engine starting from loaded state : - # Restore from an epoch state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)} # or an iteration # state_dict = {"iteration": 500, "max_epochs": 100, "epoch_length": len(data_loader)} trainer = Engine(...) trainer.load_state_dict(state_dict) trainer.run(data) - Methods - Add an event handler to be executed when the specified event is fired. - Execute all the handlers associated with given event. - Check if the specified event has the specified handler. - Setups engine from state_dict. - Decorator shortcut for add_event_handler. - Add events that can be fired. - Remove event handler handler from registered handlers of the engine - Runs the process_function over the passed data. - Method to set data. - Returns a dictionary containing engine’s state: “epoch_length”, “max_epochs” and “iteration” and other state values defined by engine.state_dict_user_keys - Sends terminate signal to the engine, so that it terminates completely the run after the current iteration. - Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration. - 
add_event_handler(event_name, handler, *args, **kwargs)[source]#
- Add an event handler to be executed when the specified event is fired. - Parameters
- event_name (Any) – An event or a list of events to attach the handler. Valid events are from - Eventsor any- event_nameadded by- register_events().
- handler (Callable) – the callable event handler that should be invoked. No restrictions on its signature. The first argument can be optionally engine, the - Engineobject, handler is bound to.
- args (Any) – optional args to be passed to - handler.
- kwargs (Any) – optional keyword args to be passed to - handler.
 
- Returns
- RemovableEventHandle, which can be used to remove the handler.
- Return type
 - Note - Note that other arguments can be passed to the handler in addition to the *args and **kwargs passed here, for example during - EXCEPTION_RAISED.- Example usage: - engine = Engine(process_function) def print_epoch(engine): print(f"Epoch: {engine.state.epoch}") engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch) events_list = Events.EPOCH_COMPLETED | Events.COMPLETED def execute_something(): # do some thing not related to engine pass engine.add_event_handler(events_list, execute_something) - Note - Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine. See - Eventsfor more details.
 - 
fire_event(event_name)[source]#
- Execute all the handlers associated with given event. - This method executes all handlers associated with the event event_name. This is the method used in - run()to call the core events found in- Events.- Custom events can be fired if they have been registered before with - register_events(). The engine state attribute should be used to exchange “dynamic” data among process_function and handlers.- This method is called automatically for core events. If no custom events are used in the engine, there is no need for the user to call the method. - Parameters
- event_name (Any) – event for which the handlers should be executed. Valid events are from - Eventsor any event_name added by- register_events().
- Return type
 
 - 
has_event_handler(handler, event_name=None)[source]#
- Check if the specified event has the specified handler. - Parameters
- handler (Callable) – the callable event handler. 
- event_name (Optional[Any]) – The event the handler attached to. Set this to - Noneto search all events.
 
- Return type
 
 - 
load_state_dict(state_dict)[source]#
- Setups engine from state_dict. - State dictionary should contain keys: iteration or epoch, max_epochs and epoch_length. If engine.state_dict_user_keys contains keys, they should be also present in the state dictionary. Iteration and epoch values are 0-based: the first iteration or epoch is zero. - This method does not remove any custom attributes added by user. - Parameters
- state_dict (collections.abc.Mapping) – a dict with parameters 
- Return type
 - # Restore from the 4rd epoch state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)} # or 500th iteration # state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)} trainer = Engine(...) trainer.load_state_dict(state_dict) trainer.run(data) 
 - 
on(event_name, *args, **kwargs)[source]#
- Decorator shortcut for add_event_handler. - Parameters
- event_name (Any) – An event to attach the handler to. Valid events are from - Eventsor any- event_nameadded by- register_events().
- args (Any) – optional args to be passed to handler. 
- kwargs (Any) – optional keyword args to be passed to handler. 
 
- Return type
- Callable 
 - Example usage: - engine = Engine(process_function) @engine.on(Events.EPOCH_COMPLETED) def print_epoch(): print(f"Epoch: {engine.state.epoch}") @engine.on(Events.EPOCH_COMPLETED | Events.COMPLETED) def execute_something(): # do some thing not related to engine pass 
 - 
register_events(*event_names, event_to_attr=None)[source]#
- Add events that can be fired. - Registering an event will let the user trigger these events at any point. This opens the door to make the - run()loop even more configurable.- By default, the events from - Eventsare registered.- Parameters
- event_names (Union[List[str], List[ignite.engine.events.EventEnum]]) – Defines the name of the event being supported. New events can be a str or an object derived from - EventEnum. See example below.
- event_to_attr (Optional[dict]) – A dictionary to map an event to a state attribute. 
 
- Return type
 - Example usage: - from ignite.engine import Engine, Events, EventEnum class CustomEvents(EventEnum): FOO_EVENT = "foo_event" BAR_EVENT = "bar_event" def process_function(e, batch): # ... trainer.fire_event("bwd_event") loss.backward() # ... trainer.fire_event("opt_event") optimizer.step() trainer = Engine(process_function) trainer.register_events(*CustomEvents) trainer.register_events("bwd_event", "opt_event") @trainer.on(Events.EPOCH_COMPLETED) def trigger_custom_event(): if required(...): trainer.fire_event(CustomEvents.FOO_EVENT) else: trainer.fire_event(CustomEvents.BAR_EVENT) @trainer.on(CustomEvents.FOO_EVENT) def do_foo_op(): # ... @trainer.on(CustomEvents.BAR_EVENT) def do_bar_op(): # ... - Example with State Attribute: - from enum import Enum from ignite.engine import Engine, EventEnum class TBPTT_Events(EventEnum): TIME_ITERATION_STARTED = "time_iteration_started" TIME_ITERATION_COMPLETED = "time_iteration_completed" TBPTT_event_to_attr = { TBPTT_Events.TIME_ITERATION_STARTED: 'time_iteration', TBPTT_Events.TIME_ITERATION_COMPLETED: 'time_iteration' } engine = Engine(process_function) engine.register_events(*TBPTT_Events, event_to_attr=TBPTT_event_to_attr) engine.run(data) # engine.state contains an attribute time_iteration, which can be accessed using engine.state.time_iteration 
 - 
remove_event_handler(handler, event_name)[source]#
- Remove event handler handler from registered handlers of the engine - Parameters
- handler (Callable) – the callable event handler that should be removed 
- event_name (Any) – The event the handler attached to. 
 
- Return type
 
 - 
run(data, max_epochs=None, max_iters=None, epoch_length=None)[source]#
- Runs the process_function over the passed data. - Engine has a state and the following logic is applied in this function: - At the first call, new state is defined by max_epochs, max_iters, epoch_length, if provided. A timer for total and per-epoch time is initialized when Events.STARTED is handled. 
- If state is already defined such that there are iterations to run until max_epochs and no input arguments provided, state is kept and used in the function. 
- If state is defined and engine is “done” (no iterations to run until max_epochs), a new state is defined. 
- If state is defined, engine is NOT “done”, then input arguments if provided override defined state. 
 - Parameters
- data (Iterable) – Collection of batches allowing repeated iteration (e.g., list or DataLoader). 
- max_epochs (Optional[int]) – Max epochs to run for (default: None). If a new state should be created (first run or run again from ended engine), it’s default value is 1. If run is resuming from a state, provided max_epochs will be taken into account and should be larger than engine.state.max_epochs. 
- epoch_length (Optional[int]) – Number of iterations to count as one epoch. By default, it can be set as len(data). If data is an iterator and epoch_length is not set, then it will be automatically determined as the iteration on which data iterator raises StopIteration. This argument should not change if run is resuming from a state. 
- max_iters (Optional[int]) – Number of iterations to run for. max_iters and max_epochs are mutually exclusive; only one of the two arguments should be provided. 
 
- Returns
- output state. 
- Return type
 - Note - User can dynamically preprocess input batch at - ITERATION_STARTEDand store output batch in engine.state.batch. Latter is passed as usually to process_function as argument:- trainer = ... @trainer.on(Events.ITERATION_STARTED) def switch_batch(engine): engine.state.batch = preprocess_batch(engine.state.batch) - Restart the training from the beginning. User can reset max_epochs = None: - # ... trainer.run(train_loader, max_epochs=5) # Reset model weights etc. and restart the training trainer.state.max_epochs = None trainer.run(train_loader, max_epochs=2) 
 - 
set_data(data)[source]#
- Method to set data. After calling the method the next batch passed to processing_function is from newly provided data. Please, note that epoch length is not modified. - Parameters
- data (Union[Iterable, torch.utils.data.dataloader.DataLoader]) – Collection of batches allowing repeated iteration (e.g., list or DataLoader). 
- Return type
 - Example usage:
- User can switch data provider during the training: - data1 = ... data2 = ... switch_iteration = 5000 def train_step(e, batch): # when iteration <= switch_iteration # batch is from data1 # when iteration > switch_iteration # batch is from data2 ... trainer = Engine(train_step) @trainer.on(Events.ITERATION_COMPLETED(once=switch_iteration)) def switch_dataloader(): trainer.set_data(data2) trainer.run(data1, max_epochs=100) 
 
 - 
state_dict()[source]#
- Returns a dictionary containing engine’s state: “epoch_length”, “max_epochs” and “iteration” and other state values defined by engine.state_dict_user_keys - engine = Engine(...) engine.state_dict_user_keys.append("alpha") engine.state_dict_user_keys.append("beta") ... @engine.on(Events.STARTED) def init_user_value(_): engine.state.alpha = 0.1 engine.state.beta = 1.0 @engine.on(Events.COMPLETED) def save_engine(_): state_dict = engine.state_dict() assert "alpha" in state_dict and "beta" in state_dict torch.save(state_dict, "/tmp/engine.pt") - Returns
- a dictionary containing engine’s state 
- Return type
- OrderedDict