CosineAnnealingScheduler#
-
class
ignite.handlers.param_scheduler.
CosineAnnealingScheduler
(optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False, param_group_index=None)[source]# Anneals ‘start_value’ to ‘end_value’ over each cycle.
The annealing takes the form of the first half of a cosine wave (as suggested in [Smith17]).
- Parameters
optimizer (torch.optim.optimizer.Optimizer) – torch optimizer or any object with attribute
param_groups
as a sequence.param_name (str) – name of optimizer’s parameter to update.
start_value (float) – value at start of cycle.
end_value (float) – value at the end of the cycle.
cycle_size (int) – length of cycle.
cycle_mult (float) – ratio by which to change the cycle_size at the end of each cycle (default=1).
start_value_mult (float) – ratio by which to change the start value at the end of each cycle (default=1.0).
end_value_mult (float) – ratio by which to change the end value at the end of each cycle (default=1.0).
save_history (bool) – whether to log the parameter values to engine.state.param_history, (default=False).
param_group_index (Optional[int]) – optimizer’s parameters group to use.
Note
If the scheduler is bound to an ‘ITERATION_*’ event, ‘cycle_size’ should usually be the number of batches in an epoch.
Examples:
from ignite.handlers.param_scheduler import CosineAnnealingScheduler scheduler = CosineAnnealingScheduler(optimizer, 'lr', 1e-1, 1e-3, len(train_loader)) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # # Anneals the learning rate from 1e-1 to 1e-3 over the course of 1 epoch. #
from ignite.handlers.param_scheduler import CosineAnnealingScheduler from ignite.handlers.param_scheduler import LinearCyclicalScheduler optimizer = SGD( [ {"params": model.base.parameters(), 'lr': 0.001}, {"params": model.fc.parameters(), 'lr': 0.01}, ] ) scheduler1 = LinearCyclicalScheduler(optimizer, 'lr', 1e-7, 1e-5, len(train_loader), param_group_index=0) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1, "lr (base)") scheduler2 = CosineAnnealingScheduler(optimizer, 'lr', 1e-5, 1e-3, len(train_loader), param_group_index=1) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler2, "lr (fc)")
- Smith17
Smith, Leslie N. “Cyclical learning rates for training neural networks.” Applications of Computer Vision (WACV), 2017 IEEE Winter Conference on. IEEE, 2017
New in version 0.5.1.
Methods
Method to get current optimizer’s parameter value
Copies parameters from
state_dict
into this ParamScheduler.Method to plot simulated scheduled values during num_events events.
Method to simulate scheduled values during num_events events.
Returns a dictionary containing a whole state of ParamScheduler.
-
load_state_dict
(state_dict)# Copies parameters from
state_dict
into this ParamScheduler.- Parameters
state_dict (Mapping) – a dict containing parameters.
- Return type
-
classmethod
plot_values
(num_events, **scheduler_kwargs)# Method to plot simulated scheduled values during num_events events.
This class requires matplotlib package to be installed:
pip install matplotlib
- Parameters
num_events (int) – number of events during the simulation.
scheduler_kwargs (Mapping) – parameter scheduler configuration kwargs.
- Returns
matplotlib.lines.Line2D
- Return type
Any
Examples
import matplotlib.pylab as plt plt.figure(figsize=(10, 7)) LinearCyclicalScheduler.plot_values(num_events=50, param_name='lr', start_value=1e-1, end_value=1e-3, cycle_size=10))
-
classmethod
simulate_values
(num_events, **scheduler_kwargs)# Method to simulate scheduled values during num_events events.
- Parameters
num_events (int) – number of events during the simulation.
scheduler_kwargs (Any) – parameter scheduler configuration kwargs.
- Returns
event_index, value
- Return type
List[List[int]]
Examples:
lr_values = np.array(LinearCyclicalScheduler.simulate_values(num_events=50, param_name='lr', start_value=1e-1, end_value=1e-3, cycle_size=10)) plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate") plt.xlabel("events") plt.ylabel("values") plt.legend()