未验证 提交 880fb833 编写于 作者: Y Yibing Liu 提交者: GitHub

[cherry-pick] Update lamb optimizer (#18333) (#18380)

* Update lamb optimizer (#18333)

* Update lamb optimizer

* Regenerate api spec

test=release/1.5

* Give an experimental warning

test=release/1.5
上级 5b103c24
...@@ -861,7 +861,7 @@ paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'los ...@@ -861,7 +861,7 @@ paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'los
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.load (ArgSpec(args=['self', 'stat_dict'], varargs=None, keywords=None, defaults=None), ('document', '649a92cf7f1ea28666fd00c4ea01acde')) paddle.fluid.optimizer.DGCMomentumOptimizer.load (ArgSpec(args=['self', 'stat_dict'], varargs=None, keywords=None, defaults=None), ('document', '649a92cf7f1ea28666fd00c4ea01acde'))
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b'))
paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'exclude_from_weight_decay_fn', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LambOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.LambOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
paddle.fluid.optimizer.LambOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.LambOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae'))
paddle.fluid.optimizer.LambOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.LambOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
......
...@@ -60,23 +60,13 @@ correction. For more information, please refer to https://arxiv.org/abs/1904.009 ...@@ -60,23 +60,13 @@ correction. For more information, please refer to https://arxiv.org/abs/1904.009
The updating of parameters follows: The updating of parameters follows:
$$ $$
m_t^l &= \beta_1 m_{t - 1}^l + (1 - \beta_1)g_t^l \\ m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t \\
v_t^l &= \beta_2 v_{t - 1}^l + (1 - \beta_2)g_t^l \odot g_t^l \\ v_t &= \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2 \\
\widehat{m}_t^l &= m_t^l/(1 - \beta_1^t) \\ r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon} \\
\widehat{v}_t^l &= v_t^l/(1 - \beta_2^t) \\ w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
r_1 &= \left \| w_{t-1}^l \right \|_2 \\
r_2 &= \left \| \frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l \right \|_2 \\
r &= r_1 / r_2 \\
\eta^l &= r \times \eta \\
w_t^l &= w_{t-1}^l -\eta ^l \times (\frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l)
$$ $$
where $m$ is the 1st moment, and $v$ the 2nd moment, $\eta$ the where $m$ is the 1st moment, and $v$ the 2nd moment, $\eta$ the
......
...@@ -66,19 +66,14 @@ struct LambMomentUpdateFunctor { ...@@ -66,19 +66,14 @@ struct LambMomentUpdateFunctor {
T g = grad_[i]; T g = grad_[i];
T mom1 = moment1_[i]; T mom1 = moment1_[i];
T mom2 = moment2_[i]; T mom2 = moment2_[i];
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i]; T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
T mom1_h = mom1 / (1 - beta1_pow);
T mom2_h = mom2 / (1 - beta2_pow);
moment1_out_[i] = mom1; moment1_out_[i] = mom1;
moment2_out_[i] = mom2; moment2_out_[i] = mom2;
trust_ratio_div_[i] = mom1_h / sqrt(mom2_h + epsilon_) + weight_decay_ * p; trust_ratio_div_[i] = mom1 / (sqrt(mom2) + epsilon_) + weight_decay_ * p;
} }
}; };
...@@ -130,19 +125,14 @@ struct SparseLambMomentUpdateFunctor { ...@@ -130,19 +125,14 @@ struct SparseLambMomentUpdateFunctor {
// The following code is same as dense // The following code is same as dense
T mom1 = moment1_[i]; T mom1 = moment1_[i];
T mom2 = moment2_[i]; T mom2 = moment2_[i];
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i]; T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
T mom1_h = mom1 / (1 - beta1_pow);
T mom2_h = mom2 / (1 - beta2_pow);
moment1_out_[i] = mom1; moment1_out_[i] = mom1;
moment2_out_[i] = mom2; moment2_out_[i] = mom2;
trust_ratio_div_[i] = mom1_h / sqrt(mom2_h + epsilon_) + weight_decay_ * p; trust_ratio_div_[i] = mom1 / (sqrt(mom2) + epsilon_) + weight_decay_ * p;
} }
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
......
...@@ -2077,30 +2077,20 @@ class LambOptimizer(AdamOptimizer): ...@@ -2077,30 +2077,20 @@ class LambOptimizer(AdamOptimizer):
LAMB Optimizer is designed to scale up the batch size of training without losing LAMB Optimizer is designed to scale up the batch size of training without losing
accuracy, which supports adaptive element-wise updating and accurate layer-wise accuracy, which supports adaptive element-wise updating and accurate layer-wise
correction. For more information, please refer to `Reducing BERT Pre-Training correction. For more information, please refer to `Large Batch Optimization for
Time from 3 Days to 76 Minutes <https://arxiv.org/abs/1904.00962>`_ . Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .
The updating of parameters follows: The updating of parameters follows:
.. math:: .. math::
m_t^l & = \\beta_1 m_{t - 1}^l + (1 - \\beta_1)g_t^l m_t &= \\beta_1 m_{t - 1}+ (1 - \\beta_1)g_t \\
v_t^l & = \\beta_2 v_{t - 1}^l + (1 - \\beta_2)g_t^l \odot g_t^l v_t &= \\beta_2 v_{t - 1} + (1 - \\beta_2)g_t^2 \\
\\widehat{m}_t^l & = m_t^l/(1 - \\beta_1^t) r_t &= \\frac{m_t}{\\sqrt{v_t}+\\epsilon} \\
\\widehat{v}_t^l & = v_t^l/(1 - \\beta_2^t) w_t &= w_{t-1} -\\eta_t \\frac{\\left \| w_{t-1}\\right \|}{\\left \| r_t + \\lambda w_{t-1}\\right \|} (r_t + \\lambda w_{t-1})
r_1 & = \\left \| w_{t-1}^l \\right \|_2
r_2 & = \\left \| \\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l \\right \|_2
r & = r_1 / r_2
\\eta^l & = r \\times \\eta
w_t^l & = w_{t-1}^l -\\eta ^l \\times (\\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l)
where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
...@@ -2114,8 +2104,10 @@ class LambOptimizer(AdamOptimizer): ...@@ -2114,8 +2104,10 @@ class LambOptimizer(AdamOptimizer):
beta1 (float): The exponential decay rate for the 1st moment estimates. beta1 (float): The exponential decay rate for the 1st moment estimates.
beta2 (float): The exponential decay rate for the 2nd moment estimates. beta2 (float): The exponential decay rate for the 2nd moment estimates.
epsilon (float): A small float value for numerical stability. epsilon (float): A small float value for numerical stability.
regularization: A Regularizer, such as regularization (Regularizer): A Regularizer, such as
fluid.regularizer.L1DecayRegularizer. fluid.regularizer.L1DecayRegularizer.
exclude_from_weight_decay_fn (function): Exclude a parameter from weight
decay when **exclude_from_weight_decay_fn(parameter)** returns true.
name (str|None): An optional name prefix. name (str|None): An optional name prefix.
Examples: Examples:
...@@ -2127,11 +2119,16 @@ class LambOptimizer(AdamOptimizer): ...@@ -2127,11 +2119,16 @@ class LambOptimizer(AdamOptimizer):
hidden = fluid.layers.fc(input=data, size=10) hidden = fluid.layers.fc(input=data, size=10)
cost = fluid.layers.mean(hidden) cost = fluid.layers.mean(hidden)
optimizer = fluid.optimizer.Lamb(learning_rate=0.002) def exclude_fn(param):
return param.name.endswith('.b_0')
optimizer = fluid.optimizer.Lamb(learning_rate=0.002,
exclude_from_weight_decay_fn=exclude_fn)
optimizer.minimize(cost) optimizer.minimize(cost)
""" """
_moment1_acc_str = "moment1" _moment1_acc_str = "moment1"
_moment2_acc_str = "moment2" _moment2_acc_str = "moment2"
# these two not used in op temporarily
_beta1_pow_acc_str = "beta1_pow_acc" _beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc" _beta2_pow_acc_str = "beta2_pow_acc"
...@@ -2142,6 +2139,7 @@ class LambOptimizer(AdamOptimizer): ...@@ -2142,6 +2139,7 @@ class LambOptimizer(AdamOptimizer):
beta2=0.999, beta2=0.999,
epsilon=1e-6, epsilon=1e-6,
regularization=None, regularization=None,
exclude_from_weight_decay_fn=None,
name=None): name=None):
assert learning_rate is not None assert learning_rate is not None
assert lamb_weight_decay is not None assert lamb_weight_decay is not None
...@@ -2157,6 +2155,10 @@ class LambOptimizer(AdamOptimizer): ...@@ -2157,6 +2155,10 @@ class LambOptimizer(AdamOptimizer):
name=name) name=name)
self.type = "lamb" self.type = "lamb"
self._weight_decay = lamb_weight_decay self._weight_decay = lamb_weight_decay
self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
print(
"WARNING: The LAMB optimizer doesn't have official implementation "
"yet and is still in experimental.")
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -2170,6 +2172,12 @@ class LambOptimizer(AdamOptimizer): ...@@ -2170,6 +2172,12 @@ class LambOptimizer(AdamOptimizer):
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0]) param_and_grad[0])
if self._exclude_from_weight_decay_fn is not None \
and self._exclude_from_weight_decay_fn(param_and_grad[0]):
weight_decay = 0.0
else:
weight_decay = self._weight_decay
# create the lamb optimize op # create the lamb optimize op
lamb_op = block.append_op( lamb_op = block.append_op(
type=self.type, type=self.type,
...@@ -2191,7 +2199,7 @@ class LambOptimizer(AdamOptimizer): ...@@ -2191,7 +2199,7 @@ class LambOptimizer(AdamOptimizer):
"beta1": self._beta1, "beta1": self._beta1,
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon, "epsilon": self._epsilon,
"weight_decay": self._weight_decay "weight_decay": weight_decay
}, },
stop_gradient=True) stop_gradient=True)
......
...@@ -140,15 +140,12 @@ def lamb_step(inputs, attributes): ...@@ -140,15 +140,12 @@ def lamb_step(inputs, attributes):
moment1_out = beta1 * moment1 + (1 - beta1) * grad moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
mom1_tmp = moment1_out / (1 - beta1_pow)
mom2_tmp = moment2_out / (1 - beta2_pow)
r_1 = np.linalg.norm(param) r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(mom1_tmp / np.sqrt(mom2_tmp + epsilon) + weight_decay * r_2 = np.linalg.norm(moment1_out / (np.sqrt(moment2_out) + epsilon) +
param) weight_decay * param)
lr_t = lr * r_1 / r_2 lr_t = lr * r_1 / r_2
param_out = param - lr_t * (mom1_tmp / np.sqrt(mom2_tmp + epsilon) + param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon) +
weight_decay * param) weight_decay * param)
return param_out, moment1_out, moment2_out return param_out, moment1_out, moment2_out
...@@ -190,16 +187,13 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): ...@@ -190,16 +187,13 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
1 - beta2) * np.square(update_value) 1 - beta2) * np.square(update_value)
def update_param(): def update_param():
mom1_tmp = moment1_out / (1 - beta1_pow)
mom2_tmp = moment2_out / (1 - beta2_pow)
r_1 = np.linalg.norm(param) r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(mom1_tmp / np.sqrt(mom2_tmp + epsilon) + r_2 = np.linalg.norm(moment1_out / (np.sqrt(moment2_out) + epsilon) +
weight_decay * param) weight_decay * param)
lr_t = lr * r_1 / r_2 lr_t = lr * r_1 / r_2
param_out = param - lr_t * (mom1_tmp / np.sqrt(mom2_tmp + epsilon) + param_out = param - lr_t * (moment1_out / (
weight_decay * param) np.sqrt(moment2_out) + epsilon) + weight_decay * param)
for row_id in range(param_out.shape[0]): for row_id in range(param_out.shape[0]):
update_value = np.zeros(np_grad[0].shape).astype("float32") update_value = np.zeros(np_grad[0].shape).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册