From 77bae9a45b4870006b1f3b12ee9ffdc319864a89 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Wed, 1 Jun 2022 11:36:24 +0800 Subject: [PATCH] fix the bug of adamw which set the attribute in param group not working (#43013) * fix the bug of adamw which set the attribute in param group not working * fix undefined variable * fix api example typo * add unittest * fix unittest typo --- .../fluid/tests/unittests/test_adamw_op.py | 109 +++++ python/paddle/optimizer/adamw.py | 412 ++++++++++++++---- 2 files changed, 431 insertions(+), 90 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index 3e2f112e96..225bd35a8e 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -271,6 +271,115 @@ class TestAdamWOpGroup(TestAdamWOp): adam.clear_gradients() +class TestAdamWOpMultiPrecison(unittest.TestCase): + def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False): + paddle.disable_static() + paddle.seed(10) + paddle.set_device(place) + + input = paddle.randn((5, 5)) + + model = paddle.nn.Linear(5, 5) + + optimizer = paddle.optimizer.AdamW( + parameters=[{ + 'params': model.parameters(), + 'weight_decay': 0.001, + 'beta1': 0.1, + 'beta2': 0.99 + }], + multi_precision=use_amp) + + for idx in range(2): + if place == 'gpu' and use_amp == True: + model = paddle.amp.decorate(models=model, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + if place == 'gpu' and use_amp == True: + with paddle.amp.auto_cast(level='O2'): + output = model(input) + loss = paddle.mean(output) + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + optimizer.clear_grad() + else: + output = model(input) + loss = paddle.mean(output) + loss.backward() + optimizer.step() + optimizer.clear_grad() + + def _get_places(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + return places + + def test_main(self): + for place in self._get_places(): + use_amp_list = [True, False] + for use_amp in use_amp_list: + self._test_adamw_op_dygraph_place_amp(place, use_amp) + + +class TestAdamWOpError(unittest.TestCase): + def test_api_errors(self): + def test_weight_decay_dtype(): + linear = paddle.nn.Linear(13, 5) + adam = paddle.optimizer.AdamW( + learning_rate=0.01, + parameters=linear.parameters(), + weight_decay=1) + + def test_parameters_dtype1(): + adam = paddle.optimizer.AdamW( + learning_rate=0.01, + parameters=paddle.randn((5, 5)), + weight_decay=0.1) + + def test_parameters_dtype2(): + linear = paddle.nn.Linear(13, 5) + adam = paddle.optimizer.AdamW( + learning_rate=0.01, + parameters={'params': linear.parameters()}, + weight_decay=0.1) + + def test_parameters_dtype3(): + adam = paddle.optimizer.AdamW( + learning_rate=0.01, parameters=None, weight_decay=0.1) + + def test_parameters_dtype4(): + linear = paddle.nn.Linear(13, 5) + adam = paddle.optimizer.AdamW( + learning_rate=0.01, + parameters={'params': set(linear.parameters())}, + weight_decay=0.1) + + def test_learning_rate_dtype(): + linear = paddle.nn.Linear(13, 5) + adam = paddle.optimizer.AdamW( + learning_rate=1, + parameters=linear.parameters(), + weight_decay=0.1) + + def test_grad_clip_dtype(): + linear = paddle.nn.Linear(13, 5) + adam = paddle.optimizer.AdamW( + learning_rate=0.01, + parameters=linear.parameters(), + weight_decay=0.1, + grad_clip=0.1) + + self.assertRaises(TypeError, test_weight_decay_dtype) + self.assertRaises(TypeError, test_parameters_dtype1) + self.assertRaises(TypeError, test_parameters_dtype2) + self.assertRaises(AttributeError, test_parameters_dtype3) + self.assertRaises(TypeError, test_parameters_dtype4) + self.assertRaises(TypeError, test_learning_rate_dtype) + self.assertRaises(TypeError, test_grad_clip_dtype) + + class TestAdamWOpGroupWithLR(TestAdamWOp): def test_adamw_op_dygraph(self): paddle.disable_static() diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 0fa49745a9..0b61f3cb9a 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from collections import defaultdict from .optimizer import Optimizer -from .adam import Adam +from .lr import LRScheduler from ..fluid import core from ..fluid import framework -from ..fluid.framework import Variable +from ..fluid.framework import Variable, Parameter +from ..fluid import unique_name +from ..fluid import layers +from ..fluid.layer_helper import LayerHelper +from ..fluid.clip import GradientClipBase from ..fluid.dygraph import base as imperative_base from collections.abc import Callable from .. import _C_ops @@ -25,7 +31,7 @@ import paddle __all__ = [] -class AdamW(Adam): +class AdamW(Optimizer): r""" The AdamW optimizer is implemented based on the AdamW Optimization in paper `DECOUPLED WEIGHT DECAY REGULARIZATION `_. @@ -102,14 +108,14 @@ class AdamW(Adam): beta1 = paddle.to_tensor([0.9], dtype="float32") beta2 = paddle.to_tensor([0.99], dtype="float32") - adam = paddle.optimizer.AdamW(learning_rate=0.1, + opt = paddle.optimizer.AdamW(learning_rate=0.1, parameters=linear.parameters(), beta1=beta1, beta2=beta2, weight_decay=0.01) out.backward() - adam.step() - adam.clear_grad() + opt.step() + opt.clear_grad() #Note that the learning_rate of linear_2 is 0.01. @@ -119,7 +125,7 @@ class AdamW(Adam): out = linear_1(inp) out = linear_2(out) loss = paddle.mean(out) - adam = paddle.optimizer.AdamW( + opt = paddle.optimizer.AdamW( learning_rate=0.1, parameters=[{ 'params': linear_1.parameters() @@ -132,11 +138,16 @@ class AdamW(Adam): weight_decay=0.01, beta1=0.9) out.backward() - adam.step() - adam.clear_grad() + opt.step() + opt.clear_grad() """ + _moment1_acc_str = "moment1" + _moment2_acc_str = "moment2" + _beta1_pow_acc_str = "beta1_pow_acc" + _beta2_pow_acc_str = "beta2_pow_acc" + def __init__(self, learning_rate=0.001, beta1=0.9, @@ -160,37 +171,108 @@ class AdamW(Adam): raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") if not 0 <= epsilon: raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") - coeff = weight_decay - if not isinstance(coeff, float) and \ - not isinstance(coeff, framework.Variable): - raise TypeError("coeff should be float or Tensor.") - self._params_name = set() - self._apply_decay_param_fun = apply_decay_param_fun - self._coeff = coeff - self._lr_to_coeff = dict() + if not isinstance(weight_decay, float) and \ + not isinstance(weight_decay, framework.Variable): + raise TypeError("weight_decay should be float or Tensor.") if lr_ratio is not None: assert isinstance(lr_ratio, Callable) if not core.is_compiled_with_cuda(): raise NotImplementedError( "'lr_ratio' is unimplemented in CPU, XPU and NPU") - self._lr_ratio = lr_ratio - super(AdamW, self).__init__( - learning_rate=learning_rate, - parameters=parameters, - beta1=beta1, - beta2=beta2, - epsilon=epsilon, - grad_clip=grad_clip, - name=name, - lazy_mode=lazy_mode, - multi_precision=multi_precision) - self._default_dict = {'coeff': coeff} + if parameters is not None: + # paddle.Tensor is also iterable, so here we don't check whether + # the input is iterable, if the input is paddle.Tensor, the + # list(paddle.Tensor) will be a error value + if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)): + raise TypeError( + "`parameters` argument given to the optimizer should be " + "an iterable of paddle Tensors, but got argument type is `{}`.". + format(type(parameters))) + if isinstance(parameters, dict): + raise TypeError( + "`parameters` argument should not get dict type, " + "if parameter groups is needed, please set `parameters`" + " as list of dict") + self._parameter_list = list(parameters) + else: + self._parameter_list = None + + self._name = name + if framework._non_static_mode(): + if self._parameter_list is None: + raise AttributeError( + "parameters argument given to the Optimizer should not be None in dygraph mode." + ) + + if not isinstance(learning_rate, (float, LRScheduler)): + raise TypeError( + "learning rate should be float or LRScheduler, got %s here" % + type(learning_rate)) + if grad_clip is not None: + if not isinstance(grad_clip, GradientClipBase): + raise TypeError( + "'grad_clip' should be an instance of GradientClipBase's derived class" + ) + + self._dtype = None + # Infer the dtype form parameter + if self._parameter_list: + if isinstance(self._parameter_list[0], dict): + for param_group in self._parameter_list: + assert 'params' in param_group, \ + 'params should be set in parameters if parameter groups are optimized in different options' + self._dtype = self._parameter_list[0]['params'][0].dtype + else: + self._dtype = self._parameter_list[0].dtype + + # each program should have a independent learning rate + # program -> tensor(learning_rate) + self._learning_rate_map = dict() + # Dictionary of accumulators. Some optimizer subclasses need to + # allocate and manage extra tensors associated with the parameters + # to train. These tensors are called accumulators. + # {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...} + self._accumulators = defaultdict(lambda: dict()) + self.helper = None + self._opti_name_list = [] + self._accumulators_holder = {} + self._param_device_map = dict() + self.clear_gradients = self.clear_grad self.type = "adamw" + self._learning_rate = learning_rate + self._params_name = set() + self._apply_decay_param_fun = apply_decay_param_fun + self._weight_decay = weight_decay + self._grad_clip = grad_clip + self._lr_ratio = lr_ratio + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + self._lazy_mode = lazy_mode + self._multi_precision = multi_precision + self._master_weights = {} + + self._default_dict = { + 'weight_decay': weight_decay, + 'beta1': beta1, + 'beta2': beta2, + 'epsilon': epsilon, + 'lazy_mode': lazy_mode, + 'grad_clip': grad_clip + } + + self._param_groups = [] + if self._parameter_list and isinstance(self._parameter_list[0], dict): + for param_group in self._parameter_list: + self._add_param_group(param_group.copy()) + else: + self._param_groups = self._parameter_list - # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that. - self._auxiliary_vars = dict() + self._use_multi_tensor = None + self.regularization = None + self._auxiliary_vars = {} def _set_auxiliary_var(self, key, val): self._auxiliary_vars[key] = val @@ -201,58 +283,128 @@ class AdamW(Adam): else: return None - def _append_decoupled_weight_decay(self, block, param_and_grad): + def _add_param_group(self, param_group): """ - Add decoupled weight decay op. - parameter = parameter - parameter * coeff * lr + Add a param group to parameter_list. + Args: - block: block in which variable is to be created - param_and_grad: (parameters, gradients) pairs, - the parameters need to decay. - Raises: - Exception: The type of coeff and parameter is not consistent. + param_group (dict): The group of Tensors to be optimzed with + different optimization options. """ - if isinstance(param_and_grad, dict): - param_and_grad = self._update_param_group(param_and_grad) - param, grad = param_and_grad + params = param_group['params'] + if isinstance(params, Parameter): + param_group['params'] = [params] + elif isinstance(params, set): + raise TypeError( + "optimizer parameters should be in ordered collections," + "but received set, please use list instead.") + else: + param_group['params'] = list(params) - if self._apply_decay_param_fun is not None \ - and not self._apply_decay_param_fun(param.name): - return + # Update optimization options for each groups + for k, v in self._default_dict.items(): + param_group.setdefault(k, v) + + param_set = set() + for group in self._param_groups: + param_set.update(set(group['params'])) + + if not param_set.isdisjoint(set(param_group['params'])): + raise ValueError( + "some parameters appear in more than one parameter group") - if isinstance(self._learning_rate, float): - learning_rate = self._learning_rate + for param in param_group['params']: + param.optimize_attr['learning_rate'] = param_group.get( + 'learning_rate', 1.) + + self._param_groups.append(param_group) + + def _create_master_weight(self, param): + if param.name in self._master_weights: + var = self._master_weights[param.name] else: - # NOTE. We add this function to the _append_optimize_op(), - # for we must make sure _create_param_lr() be called after - # optimizer._create_global_learning_rate(). - learning_rate = self._create_param_lr(param_and_grad) - - with block.program._optimized_guard( - [param, grad]), framework.name_scope('weight decay'): - self._params_name.add(param.name) - - # If it has been calculated, the result will be reused. - # NOTE(wangxi): In dygraph mode, apply_gradient will be executed - # every step, so need clear _lr_to_coeff every step, - # we do this in _create_optimization_pass - decay_coeff = self._lr_to_coeff.get(learning_rate, None) - if decay_coeff is None: - # NOTE(wangxi): for pipeline to set device:all - with paddle.static.device_guard(None): - decay_coeff = 1.0 - learning_rate * self._coeff - self._lr_to_coeff[learning_rate] = decay_coeff - - find_master = (self._multi_precision and - param.dtype == core.VarDesc.VarType.FP16) - if find_master: - master_weight = self._master_weights[param.name] - scaled_param = master_weight * decay_coeff - paddle.fluid.layers.assign( - input=scaled_param, output=master_weight) - else: - scaled_param = param * decay_coeff - paddle.fluid.layers.assign(input=scaled_param, output=param) + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32 + }) + self._master_weights[param.name] = var + return var + + def _get_accumulator(self, name, param): + """Utility function to fetch an accumulator for a parameter + Args: + name: name of the accumulator + param: parameter variable for which accumulator is to be fetched + Returns: + accumulator variable for the parameter + """ + if self._name is not None: + name = self._name + "_" + name + find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 + target_param = self._master_weights[ + param.name] if find_master else param + target_name = target_param.name + if (name not in self._accumulators or + target_name not in self._accumulators[name]): + raise Exception("Accumulator {} does not exist for parameter {}". + format(name, target_name)) + return self._accumulators[name][target_name] + + def _add_moments_pows(self, p): + acc_dtype = p.dtype + if acc_dtype == core.VarDesc.VarType.FP16: + acc_dtype = core.VarDesc.VarType.FP32 + self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) + self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) + self._add_accumulator( + name=self._beta1_pow_acc_str, + param=p, + dtype=acc_dtype, + fill_value=0.9 if isinstance(self._beta1, Variable) \ + else self._beta1, + shape=[1], + type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + self._add_accumulator( + name=self._beta2_pow_acc_str, + param=p, + dtype=acc_dtype, + fill_value=0.999 if isinstance(self._beta2, Variable) \ + else self._beta2, + shape=[1], + type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) + + # Create accumulator tensors for first and second moments + for p in parameters: + if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + master_p = self._create_master_weight(p) + self._add_moments_pows(master_p) + continue + if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: + warnings.warn( + "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Consider using multi_precision=True option of the Adam optimizer." + ) + self._add_moments_pows(p) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) @@ -295,8 +447,9 @@ class AdamW(Adam): _, _, _, _, _, _ = _C_ops.final_state_adamw( param_and_grad[0], param_and_grad[1], lr, moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, - _beta1, _beta2, self._epsilon, lr_ratio_, self._coeff, - with_decay, self._lazy_mode, 1000, find_master, False) + _beta1, _beta2, self._epsilon, lr_ratio_, + self._weight_decay, with_decay, self._lazy_mode, 1000, + find_master, False) else: _, _, _, _, _, _ = _C_ops.adamw( param_and_grad[0], param_and_grad[1], lr, moment1, moment2, @@ -306,8 +459,8 @@ class AdamW(Adam): 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, 'beta2', _beta2, "with_decay", with_decay, 'coeff', - self._coeff, 'multi_precision', find_master, 'lr_ratio', - lr_ratio_) + self._weight_decay, 'multi_precision', find_master, + 'lr_ratio', lr_ratio_) return None inputs = { @@ -338,7 +491,7 @@ class AdamW(Adam): "min_row_size_to_use_multithread": 1000, "multi_precision": find_master, "with_decay": with_decay, - "coeff": self._coeff, + "coeff": self._weight_decay, "lr_ratio": 1. if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) } @@ -369,17 +522,96 @@ class AdamW(Adam): return adamw_op - def _create_optimization_pass(self, parameters_and_grads): - optimize_ops = super( - AdamW, self)._create_optimization_pass(parameters_and_grads) - # In dygraph mode, clear _lr_to_coeff after applied gradient - self._lr_to_coeff = dict() - return optimize_ops - def __str__(self): return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) + @imperative_base.no_grad + @framework.dygraph_only + def step(self): + """ + Execute the optimizer and update parameters once. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + + a = paddle.rand([2,13], dtype="float32") + linear = paddle.nn.Linear(13, 5) + # This can be any optimizer supported by dygraph. + opt = paddle.optimizer.AdamW(learning_rate = 0.01, + parameters = linear.parameters()) + out = linear(a) + out.backward() + opt.step() + opt.clear_grad() + """ + if not isinstance(self._parameter_list[0], dict): + params_grads = [] + for param in self._parameter_list: + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + if framework.in_dygraph_mode(): + if hasattr(grad_var, "is_selected_rows" + ) and grad_var.is_selected_rows( + ) and self.regularization is not None: + raise RuntimeError( + "AdamW don't support weight_decay with sparse parameters, please set it to None." + ) + else: + if hasattr(grad_var, + "_is_sparse") and grad_var._is_sparse( + ) and self.regularization is not None: + raise RuntimeError( + "AdamW don't support weight_decay with sparse parameters, please set it to None." + ) + params_grads.append((param, grad_var)) + + optimize_ops = self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads) + else: + # optimize parameters in groups + for param_group in self._param_groups: + params_grads = defaultdict(lambda: list()) + for param in param_group['params']: + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + if framework.in_dygraph_mode(): + if hasattr(grad_var, "is_selected_rows" + ) and grad_var.is_selected_rows( + ) and self.regularization is not None: + raise RuntimeError( + "AdamW don't support weight_decay with sparse parameters, please set it to None." + ) + else: + if hasattr(grad_var, + "_is_sparse") and grad_var._is_sparse( + ) and self.regularization is not None: + raise RuntimeError( + "AdamW don't support weight_decay with sparse parameters, please set it to None." + ) + params_grads['params'].append((param, grad_var)) + params_grads.update( + {k: v + for k, v in param_group.items() if k != 'params'}) + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads) + def _update_param_group(self, parameters): - self._coeff = parameters.get('coeff', self._default_dict['coeff']) + self._beta1 = parameters.get('beta1', self._default_dict['beta1']) + self._beta2 = parameters.get('beta2', self._default_dict['beta2']) + self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) + self._lazy_mode = parameters.get('lazy_mode', + self._default_dict['lazy_mode']) + self._weight_decay = parameters.get('weight_decay', + self._default_dict['weight_decay']) parameters = parameters.get('params') + return parameters -- GitLab