diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index b6a38e0e28589b072814cf0c7f494cb684d26e6f..db1ad4d5774aad83c8513346873898fd1f3aa883 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -228,9 +228,16 @@ class AmpScaler: optimize_ops, params_grads = (None, None) - optimizer._set_auxiliary_var('found_inf', self._found_inf) - optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = False if self._use_dynamic_loss_scaling: # uopdate the scale @@ -771,9 +778,16 @@ class GradScaler(AmpScaler): if optimizer_state["state"] is OptimizerState.INIT: self._unscale(optimizer) - optimizer._set_auxiliary_var('found_inf', self._found_inf) - optimizer.step() - self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimizer.step() + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimizer.step() + self._cache_founf_inf = False optimizer_state["state"] = OptimizerState.STEPPED diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py index c12843f106562c7167978cbccdda8101d198a61f..4924d523ded05a45e8e9e25e980b7606ff45a048 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py @@ -41,9 +41,16 @@ class HybridParallelGradScaler: optimize_ops, params_grads = (None, None) - optimizer._set_auxiliary_var('found_inf', self._found_inf) - optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = False if self._use_dynamic_loss_scaling: self._update()