未验证 提交 a3bc3bcd 编写于 作者: G Guo Sheng 提交者: GitHub

Fix scaled_params append error in AdamW. (#28633)

Fix no_grad setting in AdamW.
test=develop
上级 c4d22c84
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from .optimizer import Optimizer from .optimizer import Optimizer
from .adam import Adam from .adam import Adam
from ..fluid import framework from ..fluid import framework
from ..fluid.dygraph import base as imperative_base
import paddle import paddle
from paddle.fluid.dygraph.parallel import apply_collective_grads from paddle.fluid.dygraph.parallel import apply_collective_grads
...@@ -171,13 +172,14 @@ class AdamW(Adam): ...@@ -171,13 +172,14 @@ class AdamW(Adam):
learning_rate = self._learning_rate() learning_rate = self._learning_rate()
with param.block.program._optimized_guard( with param.block.program._optimized_guard(
[param, grad]), framework.name_scope('weight decay'): [param, grad]), framework.name_scope('weight decay'):
scaled_params.append(
(param, grad, param * self._coeff * learning_rate))
if param.name not in self._params_name: if param.name not in self._params_name:
scaled_params.append(
(param, grad, param * self._coeff * learning_rate))
self._params_name.add(param.name) self._params_name.add(param.name)
param = param * self._coeff param = param * self._coeff
return scaled_params return scaled_params
@imperative_base.no_grad
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -207,6 +209,7 @@ class AdamW(Adam): ...@@ -207,6 +209,7 @@ class AdamW(Adam):
return optimize_ops, params_grads return optimize_ops, params_grads
@framework.dygraph_only @framework.dygraph_only
@imperative_base.no_grad
def step(self): def step(self):
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
apply_collective_grads(self._parameter_list) apply_collective_grads(self._parameter_list)
...@@ -227,7 +230,7 @@ class AdamW(Adam): ...@@ -227,7 +230,7 @@ class AdamW(Adam):
[param, grad]), framework.name_scope('weight decay'): [param, grad]), framework.name_scope('weight decay'):
updated_param = paddle.fluid.layers.elementwise_sub( updated_param = paddle.fluid.layers.elementwise_sub(
x=param, y=scaled_param) x=param, y=scaled_param)
param.set_value(updated_param.numpy()) paddle.fluid.layers.assign(input=updated_param, output=param)
self._apply_optimize( self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads) loss=None, startup_program=None, params_grads=params_grads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册