Adam Optimizer

This is an implementation of popular optimizer Adam from paper Adam: A Method for Stochastic Optimization.

Adam update is,

where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters. $m_t$ and $v_t$ are first and second order moments. $\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments. $\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.

Effective step taken assuming $\epsilon = 0$ is, This is bounded by, when $1-\beta_1 \gt \sqrt{1-\beta_2}$ and otherwise. And in most common scenarios,

40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay

Adam Optimizer

We extend the class GenericAdaptiveOptimizer defined in __init__.py to implement the Adam optimizer.

50class Adam(GenericAdaptiveOptimizer):

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate $\alpha$
  • betas is a tuple of ($\beta_1$, $\beta_2$)
  • eps is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • ‘optimized_update’ is a flag whether to optimize the bias correction of the second moment by doing it after adding $\epsilon$
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class Adam.
58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):
76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_update

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor $\theta_{t-1}$
83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

This is the number of optimizer steps taken on the parameter, $t$

93        state['step'] = 0

Exponential moving average of gradients, $m_t$

95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Exponential moving average of squared gradient values, $v_t$

97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

Calculate $m_t$ and and $v_t$

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

Get $\beta_1$ and $\beta_2$

109        beta1, beta2 = group['betas']

Get $m_{t-1}$ and $v_{t-1}$

112        m, v = state['exp_avg'], state['exp_avg_sq']

In-place calculation of $m_t$

116        m.mul_(beta1).add_(grad, alpha=1 - beta1)

In-place calculation of $v_t$

119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121        return m, v

Get learning-rate

This returns the modified learning rate based on the state. For Adam this is just the specified learning rate for the parameter group, $\alpha$.

123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131        return group['lr']

Do the Adam parameter update

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor $\theta_{t-1}$
  • m and v are the uncorrected first and second moments $m_t$ and $v_t$.

This computes the following

Since $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalars and others are tensors we modify this calculation to optimize the computation.

where is what we should specify as the hyper-parameter.

133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134                    m: torch.Tensor, v: torch.Tensor):

Get $\beta_1$ and $\beta_2$

166        beta1, beta2 = group['betas']

Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$

168        bias_correction1 = 1 - beta1 ** state['step']

Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$

170        bias_correction2 = 1 - beta2 ** state['step']

Get learning rate

173        lr = self.get_lr(state, group)

Whether to optimize the computation

176        if self.optimized_update:

$\sqrt{v_t} + \hat{\epsilon}$

178            denominator = v.sqrt().add_(group['eps'])

$\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$

180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1

$\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$

183            param.data.addcdiv_(m, denominator, value=-step_size)

Computation without optimization

185        else:

$\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$

187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

$\frac{\alpha}{1-\beta_1^t}$

189            step_size = lr / bias_correction1

$\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$

192            param.data.addcdiv_(m, denominator, value=-step_size)

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
  • param is the parameter tensor $\theta_{t-1}$
194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

Calculate weight decay

205        grad = self.weight_decay(param, grad, group)

Get $m_t$ and $v_t$

208        m, v = self.get_mv(state, group, grad)

Increment $t$ the number of optimizer steps

211        state['step'] += 1

Perform Adam update

214        self.adam_update(state, group, param, m, v)