From 1e232e27fc18d51c084d992ac51563e6dea4bb52 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 15 Mar 2023 10:26:18 +0800 Subject: [PATCH] refine amp scaler (#51340) * refine _found_inf --- python/paddle/amp/grad_scaler.py | 56 ++++++++++++++----- .../ascend/ascend_optimizer.py | 4 ++ .../hybrid_parallel_gradscaler.py | 13 +++-- .../meta_optimizers/meta_optimizer_base.py | 4 ++ .../group_sharded_optimizer_stage2.py | 4 ++ .../sharding/group_sharded_utils.py | 27 +++++---- python/paddle/distributed/fleet/scaler.py | 16 ++++-- python/paddle/fluid/optimizer.py | 17 ++++-- python/paddle/incubate/optimizer/lookahead.py | 4 ++ python/paddle/optimizer/adadelta.py | 4 ++ python/paddle/optimizer/adagrad.py | 4 ++ python/paddle/optimizer/adam.py | 45 +++++++++------ python/paddle/optimizer/adamax.py | 4 ++ python/paddle/optimizer/adamw.py | 8 ++- python/paddle/optimizer/lamb.py | 8 ++- python/paddle/optimizer/momentum.py | 41 +++++++++----- python/paddle/optimizer/optimizer.py | 56 +++++++++++-------- python/paddle/optimizer/rmsprop.py | 5 ++ python/paddle/optimizer/sgd.py | 3 + 19 files changed, 223 insertions(+), 100 deletions(-) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 85e6f6efc6b..96c3c671262 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -18,7 +18,7 @@ from enum import Enum import numpy as np -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.fluid import core, in_dygraph_mode from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import to_variable @@ -131,6 +131,9 @@ class AmpScaler: self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._found_inf = to_variable(np.array([0]).astype(np.bool_)) + self._temp_found_inf_value_false = to_variable( + np.array([0]).astype(np.bool_) + ) self._temp_found_inf_fp16 = to_variable( np.array([0]).astype(np.bool_) ) @@ -228,11 +231,16 @@ class AmpScaler: optimize_ops, params_grads = (None, None) - if self._found_inf: - self._cache_founf_inf = True - else: + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = False + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = False if self._use_dynamic_loss_scaling: # uopdate the scale @@ -318,6 +326,7 @@ class AmpScaler: for param in param_grads if param.dtype == core.VarDesc.VarType.FP32 ] + self._found_inf = self._temp_found_inf_value_false if core.is_compiled_with_npu(): float_status = _legacy_C_ops.alloc_float_status() _legacy_C_ops.clear_float_status(float_status, float_status) @@ -330,6 +339,9 @@ class AmpScaler: param_grads_fp16, self._temp_found_inf_fp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_fp16 + ) if len(param_grads_bf16): _legacy_C_ops.check_finite_and_unscale( param_grads_bf16, @@ -338,6 +350,9 @@ class AmpScaler: param_grads_bf16, self._temp_found_inf_bf16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_bf16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, @@ -346,6 +361,9 @@ class AmpScaler: param_grads_fp32, self._temp_found_inf_fp32, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_fp32 + ) else: if len(param_grads_fp16): _legacy_C_ops.check_finite_and_unscale( @@ -354,6 +372,9 @@ class AmpScaler: param_grads_fp16, self._temp_found_inf_fp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_fp16 + ) if len(param_grads_bf16): _legacy_C_ops.check_finite_and_unscale( param_grads_bf16, @@ -361,6 +382,9 @@ class AmpScaler: param_grads_bf16, self._temp_found_inf_bf16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_bf16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, @@ -368,12 +392,9 @@ class AmpScaler: param_grads_fp32, self._temp_found_inf_fp32, ) - - self._found_inf = ( - self._temp_found_inf_fp16 - or self._temp_found_inf_bf16 - or self._temp_found_inf_fp32 - ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, self._temp_found_inf_fp32 + ) optimizer_state["state"] = OptimizerState.UNSCALED @@ -761,11 +782,16 @@ class GradScaler(AmpScaler): if optimizer_state["state"] is OptimizerState.INIT: self._unscale(optimizer) - if self._found_inf: - self._cache_founf_inf = True - else: + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) optimizer.step() - self._cache_founf_inf = False + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimizer.step() + self._cache_founf_inf = False optimizer_state["state"] = OptimizerState.STEPPED diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index 6db42eb47be..b7d22882c82 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -236,6 +236,10 @@ class AscendOptimizer(Optimizer): ret_list.append(var) return ret_list + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_opt._set_auxiliary_var(key, val) + def minimize( self, loss, diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py index 144dc8b6586..4924d523ded 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py @@ -41,11 +41,16 @@ class HybridParallelGradScaler: optimize_ops, params_grads = (None, None) - if self._found_inf: - self._cache_founf_inf = True - else: + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = False + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = False if self._use_dynamic_loss_scaling: self._update() diff --git a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py index 87085a322c3..9a7660ebd7d 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py +++ b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py @@ -25,6 +25,10 @@ class MetaOptimizerBase(Optimizer): self.meta_optimizers_white_list = [] self.meta_optimizers_black_list = [] + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_opt._set_auxiliary_var(key, val) + def _set_basic_info( self, loss, role_maker, user_defined_optimizer, user_defined_strategy ): diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 00ec12a523f..639bdf79ac9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -203,6 +203,10 @@ class GroupShardedOptimizerStage2(Optimizer): # Update optimizer parameters and adjust parameter storage and use according to rank. self._update_opt_status() + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self._optim._set_auxiliary_var(key, val) + @paddle.autograd.no_grad() def _sync_params_and_buffers(self): """ diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 5acdd51af90..27c508b279f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -19,10 +19,10 @@ from types import MethodType import numpy as np import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.common_ops_import import dygraph_only +from paddle.fluid import core from paddle.fluid.dygraph import to_variable -from paddle.framework import core from paddle.nn import clip @@ -262,6 +262,7 @@ def GroupShardedScaler(scaler): 0 if device == "cpu" else int(paddle.get_device().split(":")[1]) ) + self._found_inf = self._temp_found_inf_value_false with device_guard(dev_id, device): if len(param_grads_bfp16): _legacy_C_ops.check_finite_and_unscale( @@ -270,6 +271,9 @@ def GroupShardedScaler(scaler): param_grads_bfp16, temp_found_inf_bfp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_bfp16 + ) if len(param_grads_fp16): _legacy_C_ops.check_finite_and_unscale( param_grads_fp16, @@ -277,6 +281,9 @@ def GroupShardedScaler(scaler): param_grads_fp16, temp_found_inf_fp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_fp16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, @@ -284,21 +291,17 @@ def GroupShardedScaler(scaler): param_grads_fp32, temp_found_inf_fp32, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_fp32 + ) - self._found_inf = ( - 1 - if temp_found_inf_bfp16 - or temp_found_inf_fp16 - or temp_found_inf_fp32 - else 0 - ) - is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") + self._found_inf = self._found_inf.cast("int32") paddle.distributed.all_reduce( - is_found_inf, op=paddle.distributed.ReduceOp.SUM, group=None + self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None ) - self._found_inf = is_found_inf.numpy()[0] + self._found_inf = self._found_inf.cast("bool") scaler._unscale = MethodType(unscale_method, scaler) return scaler diff --git a/python/paddle/distributed/fleet/scaler.py b/python/paddle/distributed/fleet/scaler.py index 003265a8612..e210e495bee 100755 --- a/python/paddle/distributed/fleet/scaler.py +++ b/python/paddle/distributed/fleet/scaler.py @@ -17,7 +17,7 @@ from types import MethodType import numpy as np import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.distributed import fleet from paddle.fluid.dygraph import to_variable from paddle.framework import core @@ -66,6 +66,7 @@ def distributed_scaler(scaler): ] temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) + self._found_inf = self._temp_found_inf_value_false if len(param_grads_fp16): _legacy_C_ops.check_finite_and_unscale( param_grads_fp16, @@ -73,6 +74,9 @@ def distributed_scaler(scaler): param_grads_fp16, temp_found_inf_fp16, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_fp16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, @@ -80,17 +84,19 @@ def distributed_scaler(scaler): param_grads_fp32, temp_found_inf_fp32, ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_fp32 + ) - self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 - is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") + self._found_inf = self._found_inf.cast("int32") # TODO(shenliang03) Since dp allreduce in the optimizer is # after the gradscaler, check_finite needs to synchronize global # information. In the future, we should use check_group to speed. paddle.distributed.all_reduce( - is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None + self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None ) - self._found_inf = is_found_inf.numpy()[0] + self._found_inf = self._found_inf.cast("bool") # Only data_parallel doesn't need to modify scaler fleet_env = fleet.fleet diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 35b677f49ef..e1051db52b4 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -970,11 +970,18 @@ class Optimizer: self._create_global_learning_rate() if in_dygraph_mode(): - for param_and_grad in parameters_and_grads: - if param_and_grad[1] is None: - continue - if param_and_grad[0].trainable is True: - self._append_optimize_op(target_block, param_and_grad) + found_inf = self._get_auxiliary_var('found_inf') + if found_inf: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', True) + else: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', False) + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + if param_and_grad[0].trainable is True: + self._append_optimize_op(target_block, param_and_grad) else: for param_and_grad in parameters_and_grads: if param_and_grad[1] is None: diff --git a/python/paddle/incubate/optimizer/lookahead.py b/python/paddle/incubate/optimizer/lookahead.py index b1ad5f3ecb0..bfa08c40556 100644 --- a/python/paddle/incubate/optimizer/lookahead.py +++ b/python/paddle/incubate/optimizer/lookahead.py @@ -144,6 +144,10 @@ class LookAhead(Optimizer): self._global_step_var = None self._k_var = None + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_optimizer._set_auxiliary_var(key, val) + @framework.dygraph_only @imperative_base.no_grad def step(self): diff --git a/python/paddle/optimizer/adadelta.py b/python/paddle/optimizer/adadelta.py index d0d8d917e70..1cdb61f698e 100644 --- a/python/paddle/optimizer/adadelta.py +++ b/python/paddle/optimizer/adadelta.py @@ -149,12 +149,15 @@ class Adadelta(Optimizer): parameters = parameters.get('params') for p in parameters: + if p.name in self._already_create_accumulater: + continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._avg_squared_grad_acc_str, master_p) self._add_accumulator( self._avg_squared_update_acc_str, master_p ) + self._already_create_accumulater.add(p.name) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) @@ -166,6 +169,7 @@ class Adadelta(Optimizer): ) self._add_accumulator(self._avg_squared_grad_acc_str, p) self._add_accumulator(self._avg_squared_update_acc_str, p) + self._already_create_accumulater.add(p.name) def _append_optimize_op(self, block, param_and_grad): if isinstance(param_and_grad, dict): diff --git a/python/paddle/optimizer/adagrad.py b/python/paddle/optimizer/adagrad.py index 1052eefb22c..3d2935c7407 100644 --- a/python/paddle/optimizer/adagrad.py +++ b/python/paddle/optimizer/adagrad.py @@ -142,9 +142,12 @@ class Adagrad(Optimizer): parameters = self._update_param_group(parameters) for p in parameters: + if p.name in self._already_create_accumulater: + continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._moment_acc_str, master_p) + self._already_create_accumulater.add(p.name) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) @@ -159,6 +162,7 @@ class Adagrad(Optimizer): p, fill_value=self.initial_accumulator_value, ) + self._already_create_accumulater.add(p.name) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index ff1ff74e398..3fd3677c317 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -260,9 +260,12 @@ class Adam(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: + if p.name in self._already_create_accumulater: + continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) + self._already_create_accumulater.add(p.name) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) @@ -273,6 +276,7 @@ class Adam(Optimizer): "Consider using multi_precision=True option of the Adam optimizer." ) self._add_moments_pows(p) + self._already_create_accumulater.add(p.name) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) @@ -303,8 +307,6 @@ class Adam(Optimizer): # create the adam optimize op if framework.in_dygraph_mode(): - found_inf = self._get_auxiliary_var('found_inf') - _beta1 = ( self._beta1 if not isinstance(self._beta1, Variable) @@ -325,7 +327,7 @@ class Adam(Optimizer): beta1_pow_acc, beta2_pow_acc, master_weight, - found_inf, + None, _beta1, _beta2, self._epsilon, @@ -636,21 +638,28 @@ class Adam(Optimizer): if master_weight is not None else None ) - _, _, _, _, _, _ = _C_ops.merged_adam_( - self._param_dict[key][param_group_idx], - grad_dict[key], - lr_dict[key], - self._moment1_dict[key][param_group_idx], - self._moment2_dict[key][param_group_idx], - self._beta1_pow_acc_dict[key][param_group_idx], - self._beta2_pow_acc_dict[key][param_group_idx], - master_weight, - _beta1, - _beta2, - self._epsilon, - find_master, - False, - ) + found_inf = self._get_auxiliary_var('found_inf') + if found_inf: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', True) + else: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', False) + _, _, _, _, _, _ = _C_ops.merged_adam_( + self._param_dict[key][param_group_idx], + grad_dict[key], + lr_dict[key], + self._moment1_dict[key][param_group_idx], + self._moment2_dict[key][param_group_idx], + self._beta1_pow_acc_dict[key][param_group_idx], + self._beta2_pow_acc_dict[key][param_group_idx], + master_weight, + _beta1, + _beta2, + self._epsilon, + find_master, + False, + ) else: inputs = { "Param": self._param_dict[key][param_group_idx], diff --git a/python/paddle/optimizer/adamax.py b/python/paddle/optimizer/adamax.py index e7a7c6d0d2e..bc33c392a4c 100644 --- a/python/paddle/optimizer/adamax.py +++ b/python/paddle/optimizer/adamax.py @@ -195,9 +195,12 @@ class Adamax(Optimizer): # Create accumulator tensors for first moment and infinity norm for p in parameters: + if p.name in self._already_create_accumulater: + continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) + self._already_create_accumulater.add(p.name) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) @@ -208,6 +211,7 @@ class Adamax(Optimizer): "Consider using multi_precision=True option of the Adam optimizer." ) self._add_moments_pows(p) + self._already_create_accumulater.add(p.name) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 8177f43ee58..0233ad1c972 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -283,6 +283,7 @@ class AdamW(Optimizer): self._use_multi_tensor = None self.regularization = None self._auxiliary_vars = {} + self._already_create_accumulater = set() def _set_auxiliary_var(self, key, val): self._auxiliary_vars[key] = val @@ -368,9 +369,12 @@ class AdamW(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: + if p.name in self._already_create_accumulater: + continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) + self._already_create_accumulater.add(p.name) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) @@ -381,6 +385,7 @@ class AdamW(Optimizer): "Consider using multi_precision=True option of the Adam optimizer." ) self._add_moments_pows(p) + self._already_create_accumulater.add(p.name) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) @@ -437,7 +442,6 @@ class AdamW(Optimizer): else self._beta2.numpy().item(0) ) - found_inf = self._get_auxiliary_var('found_inf') _, _, _, _, _, _ = _C_ops.adamw_( param_and_grad[0], param_and_grad[1], @@ -447,7 +451,7 @@ class AdamW(Optimizer): beta1_pow_acc, beta2_pow_acc, master_weight, - found_inf, + None, _beta1, _beta2, self._epsilon, diff --git a/python/paddle/optimizer/lamb.py b/python/paddle/optimizer/lamb.py index e7aeede370d..ea01765f170 100644 --- a/python/paddle/optimizer/lamb.py +++ b/python/paddle/optimizer/lamb.py @@ -159,11 +159,15 @@ class Lamb(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: + if p.name in self._already_create_accumulater: + continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) + self._already_create_accumulater.add(p.name) else: self._add_moments_pows(p) + self._already_create_accumulater.add(p.name) def _add_moments_pows(self, p): acc_dtype = p.dtype @@ -233,7 +237,6 @@ class Lamb(Optimizer): self._used_master_weights[p_name] = master_weight.name else: master_weight = None - found_inf = self._get_auxiliary_var('found_inf') if framework.in_dygraph_mode(): _C_ops.lamb_( @@ -245,7 +248,7 @@ class Lamb(Optimizer): beta1_pow_acc, beta2_pow_acc, master_weight, - found_inf, + None, weight_decay, self._beta1, self._beta2, @@ -283,6 +286,7 @@ class Lamb(Optimizer): inputs["MasterParam"] = master_weight outputs["MasterParamOut"] = master_weight + found_inf = self._get_auxiliary_var('found_inf') if found_inf: inputs["SkipUpdate"] = found_inf diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index cf14efb8525..65bd7836aaf 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -211,9 +211,12 @@ class Momentum(Optimizer): parameters = self._update_param_group(parameters) for p in parameters: + if p.name in self._already_create_accumulater: + continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._velocity_acc_str, master_p) + self._already_create_accumulater.add(p.name) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) @@ -224,6 +227,7 @@ class Momentum(Optimizer): "Consider using multi_precision=True option of the Momentum optimizer." ) self._add_accumulator(self._velocity_acc_str, p) + self._already_create_accumulater.add(p.name) def _create_regularization_of_grad(self, param, grad, regularization=None): """Create and add backward regularization Operators @@ -472,19 +476,30 @@ class Momentum(Optimizer): ) if in_dygraph_mode(): - _, _, _ = _C_ops.merged_momentum_( - self._param_dict[key][param_group_idx], - grad_dict[key], - self._velocity_dict[key][param_group_idx], - lr_dict[key], - master_weight, - self._momentum, - self._use_nesterov, - self._regularization_method_dict[key][param_group_idx], - self._regularization_coeff_dict[key][param_group_idx], - find_master, - self._rescale_grad, - ) + found_inf = self._get_auxiliary_var('found_inf') + if found_inf: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', True) + else: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', False) + _, _, _ = _C_ops.merged_momentum_( + self._param_dict[key][param_group_idx], + grad_dict[key], + self._velocity_dict[key][param_group_idx], + lr_dict[key], + master_weight, + self._momentum, + self._use_nesterov, + self._regularization_method_dict[key][ + param_group_idx + ], + self._regularization_coeff_dict[key][ + param_group_idx + ], + find_master, + self._rescale_grad, + ) else: inputs = { "Param": self._param_dict[key][param_group_idx], diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 14ca85c9503..50b9858dea7 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -275,6 +275,7 @@ class Optimizer: self._param_dict = self._create_multi_tensor_dict() self._auxiliary_vars = {} + self._already_create_accumulater = set() def _set_auxiliary_var(self, key, val): self._auxiliary_vars[key] = val @@ -979,31 +980,38 @@ class Optimizer: self._create_accumulators(target_block, params_acc_dict) if framework._non_static_mode(): - if isinstance(parameters_and_grads, list): - for param_and_grad in parameters_and_grads: - if param_and_grad[1] is None: - continue - if param_and_grad[0].stop_gradient is False: - self._append_optimize_op( - target_block, param_and_grad - ) + found_inf = self._get_auxiliary_var('found_inf') + if found_inf: + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', True) else: - for param_and_grad in parameters_and_grads['params']: - if param_and_grad[1] is None: - continue - if param_and_grad[0].stop_gradient is False: - param_grad_dict = dict() - param_grad_dict['params'] = param_and_grad - param_grad_dict.update( - { - k: v - for k, v in parameters_and_grads.items() - if k != 'params' - } - ) - self._append_optimize_op( - target_block, param_grad_dict - ) + if isinstance(found_inf, core.eager.Tensor): + self._set_auxiliary_var('found_inf', False) + if isinstance(parameters_and_grads, list): + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + self._append_optimize_op( + target_block, param_and_grad + ) + else: + for param_and_grad in parameters_and_grads['params']: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + param_grad_dict = dict() + param_grad_dict['params'] = param_and_grad + param_grad_dict.update( + { + k: v + for k, v in parameters_and_grads.items() + if k != 'params' + } + ) + self._append_optimize_op( + target_block, param_grad_dict + ) else: for param_and_grad in parameters_and_grads: if param_and_grad[1] is None: diff --git a/python/paddle/optimizer/rmsprop.py b/python/paddle/optimizer/rmsprop.py index 266e771647d..7efcd51af42 100644 --- a/python/paddle/optimizer/rmsprop.py +++ b/python/paddle/optimizer/rmsprop.py @@ -203,11 +203,15 @@ class RMSProp(Optimizer): parameters = parameters.get('params') for p in parameters: + if p.name in self._already_create_accumulater: + continue + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._momentum_acc_str, master_p) self._add_accumulator(self._mean_square_acc_str, master_p) self._add_accumulator(self._mean_grad_acc_str, master_p) + self._already_create_accumulater.add(p.name) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) @@ -220,6 +224,7 @@ class RMSProp(Optimizer): self._add_accumulator(self._momentum_acc_str, p) self._add_accumulator(self._mean_square_acc_str, p) self._add_accumulator(self._mean_grad_acc_str, p) + self._already_create_accumulater.add(p.name) def _append_optimize_op(self, block, param_and_grad): if not isinstance(block, framework.Block): diff --git a/python/paddle/optimizer/sgd.py b/python/paddle/optimizer/sgd.py index ffb091ae5a5..8bea047c42d 100644 --- a/python/paddle/optimizer/sgd.py +++ b/python/paddle/optimizer/sgd.py @@ -99,8 +99,11 @@ class SGD(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: + if p.name in self._already_create_accumulater: + continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) + self._already_create_accumulater.add(p.name) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) -- GitLab