From 884011a46a48456f2829a8614373de1303747218 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+lili0826@users.noreply.github.com> Date: Wed, 1 Sep 2021 10:08:38 +0800 Subject: [PATCH] reverse xpu adamw to the combination of ops version. (#35286) --- python/paddle/optimizer/adamw.py | 69 +++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index e6ec91dc41..158d087096 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -162,6 +162,7 @@ class AdamW(Adam): self._params_name = set() self._apply_decay_param_fun = apply_decay_param_fun self._coeff = coeff + self._lr_to_coeff = dict() super(AdamW, self).__init__( learning_rate=learning_rate, @@ -177,6 +178,9 @@ class AdamW(Adam): self.type = "adamw" + if core.is_compiled_with_xpu(): + self.type = "adam" + # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that. self._auxiliary_vars = dict() @@ -189,7 +193,63 @@ class AdamW(Adam): else: return None + def _append_decoupled_weight_decay(self, block, param_and_grad): + """ + Add decoupled weight decay op. + parameter = parameter - parameter * coeff * lr + Args: + block: block in which variable is to be created + param_and_grad: (parameters, gradients) pairs, + the parameters need to decay. + Raises: + Exception: The type of coeff and parameter is not consistent. + """ + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + param, grad = param_and_grad + + if self._apply_decay_param_fun is not None \ + and not self._apply_decay_param_fun(param.name): + return + + if isinstance(self._learning_rate, float): + learning_rate = self._learning_rate + else: + # NOTE. We add this function to the _append_optimize_op(), + # for we must make sure _create_param_lr() be called after + # optimizer._create_global_learning_rate(). + learning_rate = self._create_param_lr(param_and_grad) + + with block.program._optimized_guard( + [param, grad]), framework.name_scope('weight decay'): + self._params_name.add(param.name) + + # If it has been calculated, the result will be reused. + # NOTE(wangxi): In dygraph mode, apply_gradient will be executed + # every step, so need clear _lr_to_coeff every step, + # we do this in _create_optimization_pass + decay_coeff = self._lr_to_coeff.get(learning_rate, None) + if decay_coeff is None: + # NOTE(wangxi): for pipeline to set device:all + with paddle.static.device_guard(None): + decay_coeff = 1.0 - learning_rate * self._coeff + self._lr_to_coeff[learning_rate] = decay_coeff + + find_master = (self._multi_precision and + param.dtype == core.VarDesc.VarType.FP16) + if find_master: + master_weight = self._master_weights[param.name] + scaled_param = master_weight * decay_coeff + paddle.fluid.layers.assign( + input=scaled_param, output=master_weight) + else: + scaled_param = param * decay_coeff + paddle.fluid.layers.assign(input=scaled_param, output=param) + def _append_optimize_op(self, block, param_and_grad): + if paddle.is_compiled_with_xpu(): + self._append_decoupled_weight_decay(block, param_and_grad) + return super(AdamW, self)._append_optimize_op(block, param_and_grad) assert isinstance(block, framework.Block) if isinstance(param_and_grad, dict): @@ -201,8 +261,6 @@ class AdamW(Adam): if self._apply_decay_param_fun is not None \ and not self._apply_decay_param_fun(param.name): with_decay = False - else: - self._params_name.add(param.name) moment1 = self._get_accumulator(self._moment1_acc_str, param_and_grad[0]) @@ -291,6 +349,13 @@ class AdamW(Adam): return adamw_op + def _create_optimization_pass(self, parameters_and_grads): + optimize_ops = super( + AdamW, self)._create_optimization_pass(parameters_and_grads) + # In dygraph mode, clear _lr_to_coeff after applied gradient + self._lr_to_coeff = dict() + return optimize_ops + def __str__(self): return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) -- GitLab