Source code for torch.optim.optimizer

from collections import defaultdict

import torch
from copy import deepcopy
from itertools import chain
from torch.autograd import Variable

required = object()


[docs]class Optimizer(object): """Base class for all optimizers. Arguments: params (iterable): an iterable of :class:`Variable` s or :class:`dict` s. Specifies what Variables should be optimized. defaults: (dict): a dict containing default values of optimization options (used when a parameter group doesn't specify them). """ def __init__(self, params, defaults): if isinstance(params, Variable) or torch.is_tensor(params): raise TypeError("params argument given to the optimizer should be " "an iterable of Variables or dicts, but got " + torch.typename(params)) self.state = defaultdict(dict) self.param_groups = list(params) if len(self.param_groups) == 0: raise ValueError("optimizer got an empty parameter list") if not isinstance(self.param_groups[0], dict): self.param_groups = [{'params': self.param_groups}] param_set = set() for group in self.param_groups: if isinstance(group['params'], torch.autograd.Variable): group['params'] = [group['params']] else: group['params'] = list(group['params']) group_set = set(group['params']) if not param_set.isdisjoint(group_set): raise ValueError("some parameters appear in more than one " "parameter group") param_set.update(group_set) for name, default in defaults.items(): for i, group in enumerate(self.param_groups): if default is required and name not in group: raise ValueError("parameter group " + str(i) + " didn't " "specify a value of required optimization parameter " + name) else: group.setdefault(name, default) for group in self.param_groups: for param in group['params']: if not isinstance(param, Variable): raise TypeError("optimizer can only optimize Variables, " "but one of the params is " + torch.typename(param)) if not param.requires_grad: raise ValueError("optimizing a parameter that doesn't " "require gradients") if not param.is_leaf: raise ValueError("can't optimize a non-leaf Variable") def __getstate__(self): return { 'state': self.state, 'param_groups': self.param_groups, } def __setstate__(self, state): self.__dict__.update(state)
[docs] def state_dict(self): """Returns the state of the optimizer as a :class:`dict`. It contains two entries: * state - a dict holding current optimization state. Its content differs between optimizer classes. * param_groups - a dict containig all parameter groups """ # Save ids instead of Variables def pack_group(group): packed = {k: v for k, v in group.items() if k != 'params'} packed['params'] = [id(p) for p in group['params']] return packed param_groups = [pack_group(g) for g in self.param_groups] # Remap state to use ids as keys packed_state = {(id(k) if isinstance(k, Variable) else k): v for k, v in self.state.items()} return { 'state': packed_state, 'param_groups': param_groups, }
[docs] def load_state_dict(self, state_dict): """Loads the optimizer state. Arguments: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = deepcopy(state_dict) # Validate the state_dict groups = self.param_groups saved_groups = state_dict['param_groups'] if len(groups) != len(saved_groups): raise ValueError("loaded state dict has a different number of " "parameter groups") param_lens = (len(g['params']) for g in groups) saved_lens = (len(g['params']) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): raise ValueError("loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group") # Update the state id_map = {old_id: p for old_id, p in zip(chain(*(g['params'] for g in saved_groups)), chain(*(g['params'] for g in groups)))} state = {id_map.get(k, k): v for k, v in state_dict['state'].items()} # Update parameter groups, setting their 'params' value def update_group(group, new_group): new_group['params'] = group['params'] return new_group param_groups = [ update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({'state': state, 'param_groups': param_groups})
[docs] def zero_grad(self): """Clears the gradients of all optimized :class:`Variable` s.""" for group in self.param_groups: for p in group['params']: if p.grad is not None: if p.grad.volatile: p.grad.data.zero_() else: data = p.grad.data p.grad = Variable(data.new().resize_as_(data).zero_())
[docs] def step(self, closure): """Performs a single optimization step (parameter update). Arguments: closure (callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. """ raise NotImplementedError