未验证 提交 619c62bb 编写于 作者: W WangXi 提交者: GitHub

fix adamw apply gradient (#30130)

上级 7564d43b
...@@ -29,10 +29,12 @@ class TestAdamWOp(unittest.TestCase): ...@@ -29,10 +29,12 @@ class TestAdamWOp(unittest.TestCase):
parameters=linear.parameters(), parameters=linear.parameters(),
apply_decay_param_fun=lambda name: True, apply_decay_param_fun=lambda name: True,
weight_decay=0.01) weight_decay=0.01)
out = linear(a)
out.backward() for _ in range(2):
adam.step() out = linear(a)
adam.clear_gradients() out.backward()
adam.step()
adam.clear_gradients()
def test_adamw_op_coverage(self): def test_adamw_op_coverage(self):
paddle.disable_static() paddle.disable_static()
......
...@@ -16,6 +16,7 @@ from .optimizer import Optimizer ...@@ -16,6 +16,7 @@ from .optimizer import Optimizer
from ..fluid import core from ..fluid import core
from ..fluid import framework from ..fluid import framework
from ..fluid.framework import Variable from ..fluid.framework import Variable
from ..fluid.dygraph import base as imperative_base
import paddle import paddle
...@@ -247,6 +248,7 @@ class Adam(Optimizer): ...@@ -247,6 +248,7 @@ class Adam(Optimizer):
return adam_op return adam_op
@imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
def step(self): def step(self):
""" """
......
...@@ -129,6 +129,7 @@ class AdamW(Adam): ...@@ -129,6 +129,7 @@ class AdamW(Adam):
self._params_name = set() self._params_name = set()
self._apply_decay_param_fun = apply_decay_param_fun self._apply_decay_param_fun = apply_decay_param_fun
self._coeff = coeff self._coeff = coeff
self._lr_to_coeff = dict()
super(AdamW, self).__init__( super(AdamW, self).__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
parameters=parameters, parameters=parameters,
...@@ -139,96 +140,48 @@ class AdamW(Adam): ...@@ -139,96 +140,48 @@ class AdamW(Adam):
name=name, name=name,
lazy_mode=lazy_mode) lazy_mode=lazy_mode)
def _scale_parameters(self, params_and_grads): def _append_decoupled_weight_decay(self, block, param_and_grad):
""" """
Adds weight decay ops. Add decoupled weight decay op.
scaled_parameter = parameter * coeff parameter = parameter - parameter * coeff * lr
Args: Args:
params_and_grads: A list of (parameters, gradients) pairs, block: block in which variable is to be created
param_and_grad: (parameters, gradients) pairs,
the parameters need to decay. the parameters need to decay.
Raises: Raises:
Exception: The type of coeff and parameter is not consistent. Exception: The type of coeff and parameter is not consistent.
""" """
param, grad = param_and_grad
scaled_params = []
for param, grad in params_and_grads: if self._apply_decay_param_fun is not None \
# If no gradient then we don't need to do anything and not self._apply_decay_param_fun(param.name):
if grad is None: return
continue
if self._apply_decay_param_fun is not None \ if isinstance(self._learning_rate, float):
and not self._apply_decay_param_fun(param.name): learning_rate = self._learning_rate
continue else:
# NOTE. We add this function to the _append_optimize_op(),
if isinstance(self._coeff, float): # for we must make sure _create_param_lr() be called after
assert param.dtype is not paddle.fluid.core.VarDesc.VarType.FP32, \ # optimizer._create_global_learning_rate().
"the type of coeff(float) and parameter(%s) is not consistent."%(self._coeff.dtype) learning_rate = self._create_param_lr(param_and_grad)
else:
assert self._coeff.dtype == param.dtype, \ with block.program._optimized_guard(
"the type of coeff(%s) and parameter(%s) is not consistent."%(self._coeff.dtype, param.dtype) [param, grad]), framework.name_scope('weight decay'):
if isinstance(self._learning_rate, float): self._params_name.add(param.name)
learning_rate = self._learning_rate
else: # If it has been calculated, the result will be reused
learning_rate = self._learning_rate() decay_coeff = self._lr_to_coeff.get(learning_rate, None)
with param.block.program._optimized_guard( if decay_coeff is None:
[param, grad]), framework.name_scope('weight decay'): decay_coeff = 1.0 - learning_rate * self._coeff
scaled_params.append( self._lr_to_coeff[learning_rate] = decay_coeff
(param, grad, param * self._coeff * learning_rate))
if param.name not in self._params_name: scaled_param = param * decay_coeff
self._params_name.add(param.name) paddle.fluid.layers.assign(input=scaled_param, output=param)
param = param * self._coeff
return scaled_params def _append_optimize_op(self, block, param_and_grad):
self._append_decoupled_weight_decay(block, param_and_grad)
@imperative_base.no_grad return super(AdamW, self)._append_optimize_op(block, param_and_grad)
def minimize(self,
loss,
startup_program=None,
parameters=None,
no_grad_set=None):
parameters = parameters if parameters \
else self._parameter_list
params_grads = self.backward(
loss=loss,
startup_program=startup_program,
parameters=parameters,
no_grad_set=no_grad_set)
scaled_params = self._scale_parameters(params_grads)
for p_grad_sgrad in scaled_params:
param, grad, scaled_param = p_grad_sgrad
with param.block.program._optimized_guard(
[param, grad]), framework.name_scope('weight decay'):
updated_param = paddle.fluid.layers.elementwise_sub(
x=param, y=scaled_param)
paddle.fluid.layers.assign(input=updated_param, output=param)
optimize_ops = self._apply_optimize(
loss=loss,
params_grads=params_grads,
startup_program=startup_program)
return optimize_ops, params_grads
@framework.dygraph_only
@imperative_base.no_grad
def step(self):
params_grads = []
for param in self._parameter_list:
if not param.trainable:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))
scaled_params = self._scale_parameters(params_grads)
for p_grad_sgrad in scaled_params:
param, grad, scaled_param = p_grad_sgrad
with param.block.program._optimized_guard(
[param, grad]), framework.name_scope('weight decay'):
updated_param = paddle.fluid.layers.elementwise_sub(
x=param, y=scaled_param)
paddle.fluid.layers.assign(input=updated_param, output=param)
self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads)
def __str__(self): def __str__(self):
return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册