From cc2b466295b95e24d6e1eb007bba733d2d512a7b Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 2 Dec 2021 19:14:24 +0800 Subject: [PATCH] refine found_inf of loss_scaler (#37770) --- python/paddle/fluid/dygraph/amp/loss_scaler.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index 432b178ea67..f7c2d6be574 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -127,6 +127,10 @@ class AmpScaler(object): self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._found_inf = to_variable(np.array([0]).astype(np.bool)) + self._temp_found_inf_fp16 = to_variable( + np.array([0]).astype(np.bool)) + self._temp_found_inf_fp32 = to_variable( + np.array([0]).astype(np.bool)) self._scale = to_variable( np.array([self._init_loss_scaling]).astype(np.float32)) self._cache_founf_inf = None @@ -282,17 +286,20 @@ class AmpScaler(object): ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32 ) ] - temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) - temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) if len(param_grads_fp16): _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale, param_grads_fp16, - temp_found_inf_fp16) + self._temp_found_inf_fp16) if len(param_grads_fp32): _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, param_grads_fp32, - temp_found_inf_fp32) - self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32 + self._temp_found_inf_fp32) + if len(param_grads_fp16) and len(param_grads_fp32): + self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32 + elif len(param_grads_fp16): + self._found_inf = self._temp_found_inf_fp16 + else: + self._found_inf = self._temp_found_inf_fp32 optimizer_state["state"] = OptimizerState.UNSCALED -- GitLab