未验证 提交 77bae9a4 编写于 作者: G Guoxia Wang 提交者: GitHub

fix the bug of adamw which set the attribute in param group not working (#43013)

* fix the bug of adamw which set the attribute in param group not working

* fix undefined variable

* fix api example typo

* add unittest

* fix unittest typo
上级 81622708
...@@ -271,6 +271,115 @@ class TestAdamWOpGroup(TestAdamWOp): ...@@ -271,6 +271,115 @@ class TestAdamWOpGroup(TestAdamWOp):
adam.clear_gradients() adam.clear_gradients()
class TestAdamWOpMultiPrecison(unittest.TestCase):
def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False):
paddle.disable_static()
paddle.seed(10)
paddle.set_device(place)
input = paddle.randn((5, 5))
model = paddle.nn.Linear(5, 5)
optimizer = paddle.optimizer.AdamW(
parameters=[{
'params': model.parameters(),
'weight_decay': 0.001,
'beta1': 0.1,
'beta2': 0.99
}],
multi_precision=use_amp)
for idx in range(2):
if place == 'gpu' and use_amp == True:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
if place == 'gpu' and use_amp == True:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
optimizer.clear_grad()
else:
output = model(input)
loss = paddle.mean(output)
loss.backward()
optimizer.step()
optimizer.clear_grad()
def _get_places(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places
def test_main(self):
for place in self._get_places():
use_amp_list = [True, False]
for use_amp in use_amp_list:
self._test_adamw_op_dygraph_place_amp(place, use_amp)
class TestAdamWOpError(unittest.TestCase):
def test_api_errors(self):
def test_weight_decay_dtype():
linear = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=linear.parameters(),
weight_decay=1)
def test_parameters_dtype1():
adam = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=paddle.randn((5, 5)),
weight_decay=0.1)
def test_parameters_dtype2():
linear = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters={'params': linear.parameters()},
weight_decay=0.1)
def test_parameters_dtype3():
adam = paddle.optimizer.AdamW(
learning_rate=0.01, parameters=None, weight_decay=0.1)
def test_parameters_dtype4():
linear = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters={'params': set(linear.parameters())},
weight_decay=0.1)
def test_learning_rate_dtype():
linear = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.AdamW(
learning_rate=1,
parameters=linear.parameters(),
weight_decay=0.1)
def test_grad_clip_dtype():
linear = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=linear.parameters(),
weight_decay=0.1,
grad_clip=0.1)
self.assertRaises(TypeError, test_weight_decay_dtype)
self.assertRaises(TypeError, test_parameters_dtype1)
self.assertRaises(TypeError, test_parameters_dtype2)
self.assertRaises(AttributeError, test_parameters_dtype3)
self.assertRaises(TypeError, test_parameters_dtype4)
self.assertRaises(TypeError, test_learning_rate_dtype)
self.assertRaises(TypeError, test_grad_clip_dtype)
class TestAdamWOpGroupWithLR(TestAdamWOp): class TestAdamWOpGroupWithLR(TestAdamWOp):
def test_adamw_op_dygraph(self): def test_adamw_op_dygraph(self):
paddle.disable_static() paddle.disable_static()
......
...@@ -12,11 +12,17 @@ ...@@ -12,11 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from collections import defaultdict
from .optimizer import Optimizer from .optimizer import Optimizer
from .adam import Adam from .lr import LRScheduler
from ..fluid import core from ..fluid import core
from ..fluid import framework from ..fluid import framework
from ..fluid.framework import Variable from ..fluid.framework import Variable, Parameter
from ..fluid import unique_name
from ..fluid import layers
from ..fluid.layer_helper import LayerHelper
from ..fluid.clip import GradientClipBase
from ..fluid.dygraph import base as imperative_base from ..fluid.dygraph import base as imperative_base
from collections.abc import Callable from collections.abc import Callable
from .. import _C_ops from .. import _C_ops
...@@ -25,7 +31,7 @@ import paddle ...@@ -25,7 +31,7 @@ import paddle
__all__ = [] __all__ = []
class AdamW(Adam): class AdamW(Optimizer):
r""" r"""
The AdamW optimizer is implemented based on the AdamW Optimization The AdamW optimizer is implemented based on the AdamW Optimization
in paper `DECOUPLED WEIGHT DECAY REGULARIZATION <https://arxiv.org/pdf/1711.05101.pdf>`_. in paper `DECOUPLED WEIGHT DECAY REGULARIZATION <https://arxiv.org/pdf/1711.05101.pdf>`_.
...@@ -102,14 +108,14 @@ class AdamW(Adam): ...@@ -102,14 +108,14 @@ class AdamW(Adam):
beta1 = paddle.to_tensor([0.9], dtype="float32") beta1 = paddle.to_tensor([0.9], dtype="float32")
beta2 = paddle.to_tensor([0.99], dtype="float32") beta2 = paddle.to_tensor([0.99], dtype="float32")
adam = paddle.optimizer.AdamW(learning_rate=0.1, opt = paddle.optimizer.AdamW(learning_rate=0.1,
parameters=linear.parameters(), parameters=linear.parameters(),
beta1=beta1, beta1=beta1,
beta2=beta2, beta2=beta2,
weight_decay=0.01) weight_decay=0.01)
out.backward() out.backward()
adam.step() opt.step()
adam.clear_grad() opt.clear_grad()
#Note that the learning_rate of linear_2 is 0.01. #Note that the learning_rate of linear_2 is 0.01.
...@@ -119,7 +125,7 @@ class AdamW(Adam): ...@@ -119,7 +125,7 @@ class AdamW(Adam):
out = linear_1(inp) out = linear_1(inp)
out = linear_2(out) out = linear_2(out)
loss = paddle.mean(out) loss = paddle.mean(out)
adam = paddle.optimizer.AdamW( opt = paddle.optimizer.AdamW(
learning_rate=0.1, learning_rate=0.1,
parameters=[{ parameters=[{
'params': linear_1.parameters() 'params': linear_1.parameters()
...@@ -132,11 +138,16 @@ class AdamW(Adam): ...@@ -132,11 +138,16 @@ class AdamW(Adam):
weight_decay=0.01, weight_decay=0.01,
beta1=0.9) beta1=0.9)
out.backward() out.backward()
adam.step() opt.step()
adam.clear_grad() opt.clear_grad()
""" """
_moment1_acc_str = "moment1"
_moment2_acc_str = "moment2"
_beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc"
def __init__(self, def __init__(self,
learning_rate=0.001, learning_rate=0.001,
beta1=0.9, beta1=0.9,
...@@ -160,37 +171,108 @@ class AdamW(Adam): ...@@ -160,37 +171,108 @@ class AdamW(Adam):
raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
if not 0 <= epsilon: if not 0 <= epsilon:
raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
coeff = weight_decay if not isinstance(weight_decay, float) and \
if not isinstance(coeff, float) and \ not isinstance(weight_decay, framework.Variable):
not isinstance(coeff, framework.Variable): raise TypeError("weight_decay should be float or Tensor.")
raise TypeError("coeff should be float or Tensor.")
self._params_name = set()
self._apply_decay_param_fun = apply_decay_param_fun
self._coeff = coeff
self._lr_to_coeff = dict()
if lr_ratio is not None: if lr_ratio is not None:
assert isinstance(lr_ratio, Callable) assert isinstance(lr_ratio, Callable)
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
raise NotImplementedError( raise NotImplementedError(
"'lr_ratio' is unimplemented in CPU, XPU and NPU") "'lr_ratio' is unimplemented in CPU, XPU and NPU")
self._lr_ratio = lr_ratio
super(AdamW, self).__init__( if parameters is not None:
learning_rate=learning_rate, # paddle.Tensor is also iterable, so here we don't check whether
parameters=parameters, # the input is iterable, if the input is paddle.Tensor, the
beta1=beta1, # list(paddle.Tensor) will be a error value
beta2=beta2, if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)):
epsilon=epsilon, raise TypeError(
grad_clip=grad_clip, "`parameters` argument given to the optimizer should be "
name=name, "an iterable of paddle Tensors, but got argument type is `{}`.".
lazy_mode=lazy_mode, format(type(parameters)))
multi_precision=multi_precision) if isinstance(parameters, dict):
self._default_dict = {'coeff': coeff} raise TypeError(
"`parameters` argument should not get dict type, "
"if parameter groups is needed, please set `parameters`"
" as list of dict")
self._parameter_list = list(parameters)
else:
self._parameter_list = None
self._name = name
if framework._non_static_mode():
if self._parameter_list is None:
raise AttributeError(
"parameters argument given to the Optimizer should not be None in dygraph mode."
)
if not isinstance(learning_rate, (float, LRScheduler)):
raise TypeError(
"learning rate should be float or LRScheduler, got %s here" %
type(learning_rate))
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
raise TypeError(
"'grad_clip' should be an instance of GradientClipBase's derived class"
)
self._dtype = None
# Infer the dtype form parameter
if self._parameter_list:
if isinstance(self._parameter_list[0], dict):
for param_group in self._parameter_list:
assert 'params' in param_group, \
'params should be set in parameters if parameter groups are optimized in different options'
self._dtype = self._parameter_list[0]['params'][0].dtype
else:
self._dtype = self._parameter_list[0].dtype
# each program should have a independent learning rate
# program -> tensor(learning_rate)
self._learning_rate_map = dict()
# Dictionary of accumulators. Some optimizer subclasses need to
# allocate and manage extra tensors associated with the parameters
# to train. These tensors are called accumulators.
# {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...}
self._accumulators = defaultdict(lambda: dict())
self.helper = None
self._opti_name_list = []
self._accumulators_holder = {}
self._param_device_map = dict()
self.clear_gradients = self.clear_grad
self.type = "adamw" self.type = "adamw"
self._learning_rate = learning_rate
self._params_name = set()
self._apply_decay_param_fun = apply_decay_param_fun
self._weight_decay = weight_decay
self._grad_clip = grad_clip
self._lr_ratio = lr_ratio
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon
self._lazy_mode = lazy_mode
self._multi_precision = multi_precision
self._master_weights = {}
self._default_dict = {
'weight_decay': weight_decay,
'beta1': beta1,
'beta2': beta2,
'epsilon': epsilon,
'lazy_mode': lazy_mode,
'grad_clip': grad_clip
}
self._param_groups = []
if self._parameter_list and isinstance(self._parameter_list[0], dict):
for param_group in self._parameter_list:
self._add_param_group(param_group.copy())
else:
self._param_groups = self._parameter_list
# Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that. self._use_multi_tensor = None
self._auxiliary_vars = dict() self.regularization = None
self._auxiliary_vars = {}
def _set_auxiliary_var(self, key, val): def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val self._auxiliary_vars[key] = val
...@@ -201,58 +283,128 @@ class AdamW(Adam): ...@@ -201,58 +283,128 @@ class AdamW(Adam):
else: else:
return None return None
def _append_decoupled_weight_decay(self, block, param_and_grad): def _add_param_group(self, param_group):
""" """
Add decoupled weight decay op. Add a param group to parameter_list.
parameter = parameter - parameter * coeff * lr
Args: Args:
block: block in which variable is to be created param_group (dict): The group of Tensors to be optimzed with
param_and_grad: (parameters, gradients) pairs, different optimization options.
the parameters need to decay.
Raises:
Exception: The type of coeff and parameter is not consistent.
""" """
if isinstance(param_and_grad, dict): params = param_group['params']
param_and_grad = self._update_param_group(param_and_grad) if isinstance(params, Parameter):
param, grad = param_and_grad param_group['params'] = [params]
elif isinstance(params, set):
raise TypeError(
"optimizer parameters should be in ordered collections,"
"but received set, please use list instead.")
else:
param_group['params'] = list(params)
if self._apply_decay_param_fun is not None \ # Update optimization options for each groups
and not self._apply_decay_param_fun(param.name): for k, v in self._default_dict.items():
return param_group.setdefault(k, v)
param_set = set()
for group in self._param_groups:
param_set.update(set(group['params']))
if not param_set.isdisjoint(set(param_group['params'])):
raise ValueError(
"some parameters appear in more than one parameter group")
if isinstance(self._learning_rate, float): for param in param_group['params']:
learning_rate = self._learning_rate param.optimize_attr['learning_rate'] = param_group.get(
'learning_rate', 1.)
self._param_groups.append(param_group)
def _create_master_weight(self, param):
if param.name in self._master_weights:
var = self._master_weights[param.name]
else: else:
# NOTE. We add this function to the _append_optimize_op(), assert isinstance(self.helper, LayerHelper)
# for we must make sure _create_param_lr() be called after
# optimizer._create_global_learning_rate(). var_name = param.name + "_fp32_master"
learning_rate = self._create_param_lr(param_and_grad) var_name = unique_name.generate(var_name)
var = layers.create_global_var(
with block.program._optimized_guard( name=var_name,
[param, grad]), framework.name_scope('weight decay'): shape=param.shape,
self._params_name.add(param.name) value=0,
dtype='float32',
# If it has been calculated, the result will be reused. persistable=True)
# NOTE(wangxi): In dygraph mode, apply_gradient will be executed block = self.helper.startup_program.global_block()
# every step, so need clear _lr_to_coeff every step, block.append_op(
# we do this in _create_optimization_pass type="cast",
decay_coeff = self._lr_to_coeff.get(learning_rate, None) inputs={"X": [param]},
if decay_coeff is None: outputs={"Out": [var]},
# NOTE(wangxi): for pipeline to set device:all attrs={
with paddle.static.device_guard(None): "in_dtype": param.dtype,
decay_coeff = 1.0 - learning_rate * self._coeff "out_dtype": core.VarDesc.VarType.FP32
self._lr_to_coeff[learning_rate] = decay_coeff })
self._master_weights[param.name] = var
find_master = (self._multi_precision and return var
param.dtype == core.VarDesc.VarType.FP16)
if find_master: def _get_accumulator(self, name, param):
master_weight = self._master_weights[param.name] """Utility function to fetch an accumulator for a parameter
scaled_param = master_weight * decay_coeff Args:
paddle.fluid.layers.assign( name: name of the accumulator
input=scaled_param, output=master_weight) param: parameter variable for which accumulator is to be fetched
else: Returns:
scaled_param = param * decay_coeff accumulator variable for the parameter
paddle.fluid.layers.assign(input=scaled_param, output=param) """
if self._name is not None:
name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
target_param = self._master_weights[
param.name] if find_master else param
target_name = target_param.name
if (name not in self._accumulators or
target_name not in self._accumulators[name]):
raise Exception("Accumulator {} does not exist for parameter {}".
format(name, target_name))
return self._accumulators[name][target_name]
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16:
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
if isinstance(parameters, dict):
parameters = self._update_param_group(parameters)
# Create accumulator tensors for first and second moments
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
self._add_moments_pows(p)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -295,8 +447,9 @@ class AdamW(Adam): ...@@ -295,8 +447,9 @@ class AdamW(Adam):
_, _, _, _, _, _ = _C_ops.final_state_adamw( _, _, _, _, _, _ = _C_ops.final_state_adamw(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, beta1_pow_acc, beta2_pow_acc, master_weight, found_inf,
_beta1, _beta2, self._epsilon, lr_ratio_, self._coeff, _beta1, _beta2, self._epsilon, lr_ratio_,
with_decay, self._lazy_mode, 1000, find_master, False) self._weight_decay, with_decay, self._lazy_mode, 1000,
find_master, False)
else: else:
_, _, _, _, _, _ = _C_ops.adamw( _, _, _, _, _, _ = _C_ops.adamw(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
...@@ -306,8 +459,8 @@ class AdamW(Adam): ...@@ -306,8 +459,8 @@ class AdamW(Adam):
'lazy_mode', self._lazy_mode, 'lazy_mode', self._lazy_mode,
'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1,
'beta2', _beta2, "with_decay", with_decay, 'coeff', 'beta2', _beta2, "with_decay", with_decay, 'coeff',
self._coeff, 'multi_precision', find_master, 'lr_ratio', self._weight_decay, 'multi_precision', find_master,
lr_ratio_) 'lr_ratio', lr_ratio_)
return None return None
inputs = { inputs = {
...@@ -338,7 +491,7 @@ class AdamW(Adam): ...@@ -338,7 +491,7 @@ class AdamW(Adam):
"min_row_size_to_use_multithread": 1000, "min_row_size_to_use_multithread": 1000,
"multi_precision": find_master, "multi_precision": find_master,
"with_decay": with_decay, "with_decay": with_decay,
"coeff": self._coeff, "coeff": self._weight_decay,
"lr_ratio": 1. "lr_ratio": 1.
if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])
} }
...@@ -369,17 +522,96 @@ class AdamW(Adam): ...@@ -369,17 +522,96 @@ class AdamW(Adam):
return adamw_op return adamw_op
def _create_optimization_pass(self, parameters_and_grads):
optimize_ops = super(
AdamW, self)._create_optimization_pass(parameters_and_grads)
# In dygraph mode, clear _lr_to_coeff after applied gradient
self._lr_to_coeff = dict()
return optimize_ops
def __str__(self): def __str__(self):
return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
@imperative_base.no_grad
@framework.dygraph_only
def step(self):
"""
Execute the optimizer and update parameters once.
Returns:
None
Examples:
.. code-block:: python
import paddle
a = paddle.rand([2,13], dtype="float32")
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
opt = paddle.optimizer.AdamW(learning_rate = 0.01,
parameters = linear.parameters())
out = linear(a)
out.backward()
opt.step()
opt.clear_grad()
"""
if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
if framework.in_dygraph_mode():
if hasattr(grad_var, "is_selected_rows"
) and grad_var.is_selected_rows(
) and self.regularization is not None:
raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None."
)
else:
if hasattr(grad_var,
"_is_sparse") and grad_var._is_sparse(
) and self.regularization is not None:
raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None."
)
params_grads.append((param, grad_var))
optimize_ops = self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads)
else:
# optimize parameters in groups
for param_group in self._param_groups:
params_grads = defaultdict(lambda: list())
for param in param_group['params']:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
if framework.in_dygraph_mode():
if hasattr(grad_var, "is_selected_rows"
) and grad_var.is_selected_rows(
) and self.regularization is not None:
raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None."
)
else:
if hasattr(grad_var,
"_is_sparse") and grad_var._is_sparse(
) and self.regularization is not None:
raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None."
)
params_grads['params'].append((param, grad_var))
params_grads.update(
{k: v
for k, v in param_group.items() if k != 'params'})
self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads)
def _update_param_group(self, parameters): def _update_param_group(self, parameters):
self._coeff = parameters.get('coeff', self._default_dict['coeff']) self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
self._lazy_mode = parameters.get('lazy_mode',
self._default_dict['lazy_mode'])
self._weight_decay = parameters.get('weight_decay',
self._default_dict['weight_decay'])
parameters = parameters.get('params') parameters = parameters.get('params')
return parameters return parameters
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册