未验证 提交 23941e43 编写于 作者: Y Yibing Liu 提交者: GitHub

Update lamb optimizer (#18333)

* Update lamb optimizer

test=develop, test=document_preview

* Regenerate api spec

test=develop, test=document_preview
上级 135a59ed
......@@ -861,7 +861,7 @@ paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'los
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.load (ArgSpec(args=['self', 'stat_dict'], varargs=None, keywords=None, defaults=None), ('document', '649a92cf7f1ea28666fd00c4ea01acde'))
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b'))
paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'exclude_from_weight_decay_fn', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LambOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
paddle.fluid.optimizer.LambOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae'))
paddle.fluid.optimizer.LambOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
......
......@@ -60,23 +60,13 @@ correction. For more information, please refer to https://arxiv.org/abs/1904.009
The updating of parameters follows:
$$
m_t^l &= \beta_1 m_{t - 1}^l + (1 - \beta_1)g_t^l \\
m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t \\
v_t^l &= \beta_2 v_{t - 1}^l + (1 - \beta_2)g_t^l \odot g_t^l \\
v_t &= \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2 \\
\widehat{m}_t^l &= m_t^l/(1 - \beta_1^t) \\
r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon} \\
\widehat{v}_t^l &= v_t^l/(1 - \beta_2^t) \\
r_1 &= \left \| w_{t-1}^l \right \|_2 \\
r_2 &= \left \| \frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l \right \|_2 \\
r &= r_1 / r_2 \\
\eta^l &= r \times \eta \\
w_t^l &= w_{t-1}^l -\eta ^l \times (\frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l)
w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
$$
where $m$ is the 1st moment, and $v$ the 2nd moment, $\eta$ the
......
......@@ -66,19 +66,14 @@ struct LambMomentUpdateFunctor {
T g = grad_[i];
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
T mom1_h = mom1 / (1 - beta1_pow);
T mom2_h = mom2 / (1 - beta2_pow);
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
trust_ratio_div_[i] = mom1_h / sqrt(mom2_h + epsilon_) + weight_decay_ * p;
trust_ratio_div_[i] = mom1 / (sqrt(mom2) + epsilon_) + weight_decay_ * p;
}
};
......@@ -130,19 +125,14 @@ struct SparseLambMomentUpdateFunctor {
// The following code is same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
T mom1_h = mom1 / (1 - beta1_pow);
T mom2_h = mom2 / (1 - beta2_pow);
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
trust_ratio_div_[i] = mom1_h / sqrt(mom2_h + epsilon_) + weight_decay_ * p;
trust_ratio_div_[i] = mom1 / (sqrt(mom2) + epsilon_) + weight_decay_ * p;
}
inline HOSTDEVICE void operator()(size_t i) const {
......
......@@ -2091,30 +2091,20 @@ class LambOptimizer(AdamOptimizer):
LAMB Optimizer is designed to scale up the batch size of training without losing
accuracy, which supports adaptive element-wise updating and accurate layer-wise
correction. For more information, please refer to `Reducing BERT Pre-Training
Time from 3 Days to 76 Minutes <https://arxiv.org/abs/1904.00962>`_ .
correction. For more information, please refer to `Large Batch Optimization for
Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .
The updating of parameters follows:
.. math::
m_t^l & = \\beta_1 m_{t - 1}^l + (1 - \\beta_1)g_t^l
m_t &= \\beta_1 m_{t - 1}+ (1 - \\beta_1)g_t \\
v_t^l & = \\beta_2 v_{t - 1}^l + (1 - \\beta_2)g_t^l \odot g_t^l
v_t &= \\beta_2 v_{t - 1} + (1 - \\beta_2)g_t^2 \\
\\widehat{m}_t^l & = m_t^l/(1 - \\beta_1^t)
r_t &= \\frac{m_t}{\\sqrt{v_t}+\\epsilon} \\
\\widehat{v}_t^l & = v_t^l/(1 - \\beta_2^t)
r_1 & = \\left \| w_{t-1}^l \\right \|_2
r_2 & = \\left \| \\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l \\right \|_2
r & = r_1 / r_2
\\eta^l & = r \\times \\eta
w_t^l & = w_{t-1}^l -\\eta ^l \\times (\\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l)
w_t &= w_{t-1} -\\eta_t \\frac{\\left \| w_{t-1}\\right \|}{\\left \| r_t + \\lambda w_{t-1}\\right \|} (r_t + \\lambda w_{t-1})
where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
......@@ -2128,8 +2118,10 @@ class LambOptimizer(AdamOptimizer):
beta1 (float): The exponential decay rate for the 1st moment estimates.
beta2 (float): The exponential decay rate for the 2nd moment estimates.
epsilon (float): A small float value for numerical stability.
regularization: A Regularizer, such as
regularization (Regularizer): A Regularizer, such as
fluid.regularizer.L1DecayRegularizer.
exclude_from_weight_decay_fn (function): Exclude a parameter from weight
decay when **exclude_from_weight_decay_fn(parameter)** returns true.
name (str|None): An optional name prefix.
Examples:
......@@ -2141,11 +2133,16 @@ class LambOptimizer(AdamOptimizer):
hidden = fluid.layers.fc(input=data, size=10)
cost = fluid.layers.mean(hidden)
optimizer = fluid.optimizer.Lamb(learning_rate=0.002)
def exclude_fn(param):
return param.name.endswith('.b_0')
optimizer = fluid.optimizer.Lamb(learning_rate=0.002,
exclude_from_weight_decay_fn=exclude_fn)
optimizer.minimize(cost)
"""
_moment1_acc_str = "moment1"
_moment2_acc_str = "moment2"
# these two not used in op temporarily
_beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc"
......@@ -2156,6 +2153,7 @@ class LambOptimizer(AdamOptimizer):
beta2=0.999,
epsilon=1e-6,
regularization=None,
exclude_from_weight_decay_fn=None,
name=None):
assert learning_rate is not None
assert lamb_weight_decay is not None
......@@ -2171,6 +2169,7 @@ class LambOptimizer(AdamOptimizer):
name=name)
self.type = "lamb"
self._weight_decay = lamb_weight_decay
self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......@@ -2184,6 +2183,12 @@ class LambOptimizer(AdamOptimizer):
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
if self._exclude_from_weight_decay_fn is not None \
and self._exclude_from_weight_decay_fn(param_and_grad[0]):
weight_decay = 0.0
else:
weight_decay = self._weight_decay
# create the lamb optimize op
lamb_op = block.append_op(
type=self.type,
......@@ -2205,7 +2210,7 @@ class LambOptimizer(AdamOptimizer):
"beta1": self._beta1,
"beta2": self._beta2,
"epsilon": self._epsilon,
"weight_decay": self._weight_decay
"weight_decay": weight_decay
},
stop_gradient=True)
......
......@@ -140,15 +140,12 @@ def lamb_step(inputs, attributes):
moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
mom1_tmp = moment1_out / (1 - beta1_pow)
mom2_tmp = moment2_out / (1 - beta2_pow)
r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(mom1_tmp / np.sqrt(mom2_tmp + epsilon) + weight_decay *
param)
r_2 = np.linalg.norm(moment1_out / (np.sqrt(moment2_out) + epsilon) +
weight_decay * param)
lr_t = lr * r_1 / r_2
param_out = param - lr_t * (mom1_tmp / np.sqrt(mom2_tmp + epsilon) +
param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon) +
weight_decay * param)
return param_out, moment1_out, moment2_out
......@@ -190,16 +187,13 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
1 - beta2) * np.square(update_value)
def update_param():
mom1_tmp = moment1_out / (1 - beta1_pow)
mom2_tmp = moment2_out / (1 - beta2_pow)
r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(mom1_tmp / np.sqrt(mom2_tmp + epsilon) +
r_2 = np.linalg.norm(moment1_out / (np.sqrt(moment2_out) + epsilon) +
weight_decay * param)
lr_t = lr * r_1 / r_2
param_out = param - lr_t * (mom1_tmp / np.sqrt(mom2_tmp + epsilon) +
weight_decay * param)
param_out = param - lr_t * (moment1_out / (
np.sqrt(moment2_out) + epsilon) + weight_decay * param)
for row_id in range(param_out.shape[0]):
update_value = np.zeros(np_grad[0].shape).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册