From a3bc3bcd4854057079f2f9447d8872c25ed3af28 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Mon, 16 Nov 2020 14:32:58 +0800 Subject: [PATCH] Fix scaled_params append error in AdamW. (#28633) Fix no_grad setting in AdamW. test=develop --- python/paddle/optimizer/adamw.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 2cf3881d046..0ffff675903 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -15,6 +15,7 @@ from .optimizer import Optimizer from .adam import Adam from ..fluid import framework +from ..fluid.dygraph import base as imperative_base import paddle from paddle.fluid.dygraph.parallel import apply_collective_grads @@ -171,13 +172,14 @@ class AdamW(Adam): learning_rate = self._learning_rate() with param.block.program._optimized_guard( [param, grad]), framework.name_scope('weight decay'): + scaled_params.append( + (param, grad, param * self._coeff * learning_rate)) if param.name not in self._params_name: - scaled_params.append( - (param, grad, param * self._coeff * learning_rate)) self._params_name.add(param.name) param = param * self._coeff return scaled_params + @imperative_base.no_grad def minimize(self, loss, startup_program=None, @@ -207,6 +209,7 @@ class AdamW(Adam): return optimize_ops, params_grads @framework.dygraph_only + @imperative_base.no_grad def step(self): if paddle.distributed.get_world_size() > 1: apply_collective_grads(self._parameter_list) @@ -227,7 +230,7 @@ class AdamW(Adam): [param, grad]), framework.name_scope('weight decay'): updated_param = paddle.fluid.layers.elementwise_sub( x=param, y=scaled_param) - param.set_value(updated_param.numpy()) + paddle.fluid.layers.assign(input=updated_param, output=param) self._apply_optimize( loss=None, startup_program=None, params_grads=params_grads) -- GitLab