This is an implementation of the paper On the Convergence of Adam and Beyond.
We implement this as an extension to our Adam optimizer implementation. The implementation it self is really small since it’s very similar to Adam.
We also have an implementation of the synthetic example described in the paper where Adam fails to converge.
18from typing import Dict
19
20import torch
21from torch import nn
22
23from labml_nn.optimizers import WeightDecay
24from labml_nn.optimizers.adam import Adam
This class extends from Adam optimizer defined in adam.py
.
Adam optimizer is extending the class GenericAdaptiveOptimizer
defined in __init__.py
.
27class AMSGrad(Adam):
params
is the list of parameterslr
is the learning rate $\alpha$betas
is a tuple of ($\beta_1$, $\beta_2$)eps
is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
weight_decay
is an instance of class WeightDecay
defined in __init__.py
amsgrad
is a flag indicating whether to use AMSGrad or fallback to plain Adamdefaults
is a dictionary of default for group values.
This is useful when you want to extend the class Adam
.35 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
36 weight_decay: WeightDecay = WeightDecay(),
37 optimized_update: bool = True,
38 amsgrad=True, defaults=None):
53 defaults = {} if defaults is None else defaults
54 defaults.update(dict(amsgrad=amsgrad))
55
56 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupparam
is the parameter tensor $\theta_{t-1}$58 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
Call init_state
of Adam optimizer which we are extending
68 super().init_state(state, group, param)
If amsgrad
flag is True
for this parameter group, we maintain the maximum of
exponential moving average of squared gradient
72 if group['amsgrad']:
73 state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupgrad
is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$75 def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor):
Get $m_t$ and $v_t$ from Adam
85 m, v = super().get_mv(state, group, grad)
If this parameter group is using amsgrad
88 if group['amsgrad']:
Get $\max(v_1, v_2, …, v_{t-1})$.
🗒 The paper uses the notation $\hat{v}_t$ for this, which we don’t use that here because it confuses with the Adam’s usage of the same notation for bias corrected exponential moving average.
94 v_max = state['max_exp_avg_sq']
Calculate $\max(v_1, v_2, …, v_{t-1}, v_t)$.
🤔 I feel you should be taking / maintaining the max of the bias corrected second exponential average of squared gradient. But this is how it’s implemented in PyTorch also. I guess it doesn’t really matter since bias correction only increases the value and it only makes an actual difference during the early few steps of the training.
103 torch.maximum(v_max, v, out=v_max)
104
105 return m, v_max
106 else:
Fall back to Adam if the parameter group is not using amsgrad
108 return m, v
This is the synthetic experiment described in the paper, that shows a scenario where Adam fails.
The paper (and Adam) formulates the problem of optimizing as minimizing the expected value of a function, $\mathbb{E}[f(\theta)]$ with respect to the parameters $\theta$. In the stochastic training setting we do not get hold of the function $f$ it self; that is, when you are optimizing a NN $f$ would be the function on entire batch of data. What we actually evaluate is a mini-batch so the actual function is realization of the stochastic $f$. This is why we are talking about an expected value. So let the function realizations be $f_1, f_2, …, f_T$ for each time step of training.
We measure the performance of the optimizer as the regret, where $theta_t$ is the parameters at time step $t$, and $\theta^*$ is the optimal parameters that minimize $\mathbb{E}[f(\theta)]$.
Now lets define the synthetic problem, where $-1 \le x \le +1$. The optimal solution is $x = -1$.
This code will try running Adam and AMSGrad on this problem.
111def _synthetic_experiment(is_adam: bool):
Define $x$ parameter
151 x = nn.Parameter(torch.tensor([.0]))
Optimal, $x^* = -1$
153 x_star = nn.Parameter(torch.tensor([-1]), requires_grad=False)
155 def func(t: int, x_: nn.Parameter):
159 if t % 101 == 1:
160 return (1010 * x_).sum()
161 else:
162 return (-10 * x_).sum()
Initialize the relevant optimizer
165 if is_adam:
166 optimizer = Adam([x], lr=1e-2, betas=(0.9, 0.99))
167 else:
168 optimizer = AMSGrad([x], lr=1e-2, betas=(0.9, 0.99))
$R(T)$
170 total_regret = 0
171
172 from labml import monit, tracker, experiment
Create experiment to record results
175 with experiment.record(name='synthetic', comment='Adam' if is_adam else 'AMSGrad'):
Run for $10^7$ steps
177 for step in monit.loop(10_000_000):
$f_t(\theta_t) - f_t(\theta^*)$
179 regret = func(step, x) - func(step, x_star)
$R(T) = \sum_{t=1}^T \big[ f_t(\theta_t) - f_t(\theta^*) \big]$
181 total_regret += regret.item()
Track results every 1,000 steps
183 if (step + 1) % 1000 == 0:
184 tracker.save(loss=regret, x=x, regret=total_regret / (step + 1))
Calculate gradients
186 regret.backward()
Optimize
188 optimizer.step()
Clear gradients
190 optimizer.zero_grad()
Make sure $-1 \le x \le +1$
193 x.data.clamp_(-1., +1.)
194
195
196if __name__ == '__main__':
Run the synthetic experiment is Adam. Here are the results. You can see that Adam converges at $x = +1$
200 _synthetic_experiment(True)
Run the synthetic experiment is AMSGrad Here are the results. You can see that AMSGrad converges to true optimal $x = -1$
204 _synthetic_experiment(False)