From 64573f9fc7cc5badd1e8b461bd02aa6a43c14be2 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Fri, 3 Feb 2023 10:51:01 +0800 Subject: [PATCH] fix found_inf bug for custom optimizer (#50158) --- python/paddle/amp/grad_scaler.py | 26 ++++++++++++++----- .../hybrid_parallel_gradscaler.py | 13 +++++++--- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index b6a38e0e28..db1ad4d577 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -228,9 +228,16 @@ class AmpScaler: optimize_ops, params_grads = (None, None) - optimizer._set_auxiliary_var('found_inf', self._found_inf) - optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = False if self._use_dynamic_loss_scaling: # uopdate the scale @@ -771,9 +778,16 @@ class GradScaler(AmpScaler): if optimizer_state["state"] is OptimizerState.INIT: self._unscale(optimizer) - optimizer._set_auxiliary_var('found_inf', self._found_inf) - optimizer.step() - self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimizer.step() + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimizer.step() + self._cache_founf_inf = False optimizer_state["state"] = OptimizerState.STEPPED diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py index c12843f106..4924d523de 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_gradscaler.py @@ -41,9 +41,16 @@ class HybridParallelGradScaler: optimize_ops, params_grads = (None, None) - optimizer._set_auxiliary_var('found_inf', self._found_inf) - optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + if hasattr(optimizer, "_set_auxiliary_var"): + optimizer._set_auxiliary_var('found_inf', self._found_inf) + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf') + else: + if self._found_inf: + self._cache_founf_inf = True + else: + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = False if self._use_dynamic_loss_scaling: self._update() -- GitLab