This is an implementation of popular optimizer Adam from paper Adam: A Method for Stochastic Optimization.
Adam update is,
where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters. $m_t$ and $v_t$ are first and second order moments. $\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments. $\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
Effective step taken assuming $\epsilon = 0$ is, This is bounded by, when $1-\beta_1 \gt \sqrt{1-\beta_2}$ and otherwise. And in most common scenarios,
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
We extend the class GenericAdaptiveOptimizer
defined in __init__.py
to implement the Adam optimizer.
50class Adam(GenericAdaptiveOptimizer):
params
is the list of parameterslr
is the learning rate $\alpha$betas
is a tuple of ($\beta_1$, $\beta_2$)eps
is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
weight_decay
is an instance of class WeightDecay
defined in __init__.py
defaults
is a dictionary of default for group values.
This is useful when you want to extend the class Adam
.58 def __init__(self, params,
59 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60 weight_decay: WeightDecay = WeightDecay(),
61 optimized_update: bool = True,
62 defaults: Optional[Dict[str, Any]] = None):
76 defaults = {} if defaults is None else defaults
77 defaults.update(weight_decay.defaults())
78 super().__init__(params, defaults, lr, betas, eps)
79
80 self.weight_decay = weight_decay
81 self.optimized_update = optimized_update
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupparam
is the parameter tensor $\theta_{t-1}$83 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
This is the number of optimizer steps taken on the parameter, $t$
93 state['step'] = 0
Exponential moving average of gradients, $m_t$
95 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
Exponential moving average of squared gradient values, $v_t$
97 state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupgrad
is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$99 def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
Get $\beta_1$ and $\beta_2$
109 beta1, beta2 = group['betas']
Get $m_{t-1}$ and $v_{t-1}$
112 m, v = state['exp_avg'], state['exp_avg_sq']
In-place calculation of $m_t$
116 m.mul_(beta1).add_(grad, alpha=1 - beta1)
In-place calculation of $v_t$
119 v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121 return m, v
This returns the modified learning rate based on the state. For Adam this is just the specified learning rate for the parameter group, $\alpha$.
123 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131 return group['lr']
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupparam
is the parameter tensor $\theta_{t-1}$m
and v
are the uncorrected first and second moments $m_t$ and $v_t$.This computes the following
Since $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalars and others are tensors we modify this calculation to optimize the computation.
where is what we should specify as the hyper-parameter.
133 def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134 m: torch.Tensor, v: torch.Tensor):
Get $\beta_1$ and $\beta_2$
166 beta1, beta2 = group['betas']
Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
168 bias_correction1 = 1 - beta1 ** state['step']
Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
170 bias_correction2 = 1 - beta2 ** state['step']
Get learning rate
173 lr = self.get_lr(state, group)
Whether to optimize the computation
176 if self.optimized_update:
$\sqrt{v_t} + \hat{\epsilon}$
178 denominator = v.sqrt().add_(group['eps'])
$\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
180 step_size = lr * math.sqrt(bias_correction2) / bias_correction1
$\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
183 param.data.addcdiv_(m, denominator, value=-step_size)
Computation without optimization
185 else:
$\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
187 denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
$\frac{\alpha}{1-\beta_1^t}$
189 step_size = lr / bias_correction1
$\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
192 param.data.addcdiv_(m, denominator, value=-step_size)
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupgrad
is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$param
is the parameter tensor $\theta_{t-1}$194 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
Calculate weight decay
205 grad = self.weight_decay(param, grad, group)
Get $m_t$ and $v_t$
208 m, v = self.get_mv(state, group, grad)
Increment $t$ the number of optimizer steps
211 state['step'] += 1
Perform Adam update
214 self.adam_update(state, group, param, m, v)