diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index 432b178ea6706608f5eb029662ef15d6d7517b63..f7c2d6be574c4e17fa5ce6fa44ca4ecc55a5eb95 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -127,6 +127,10 @@ class AmpScaler(object): self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._found_inf = to_variable(np.array([0]).astype(np.bool)) + self._temp_found_inf_fp16 = to_variable( + np.array([0]).astype(np.bool)) + self._temp_found_inf_fp32 = to_variable( + np.array([0]).astype(np.bool)) self._scale = to_variable( np.array([self._init_loss_scaling]).astype(np.float32)) self._cache_founf_inf = None @@ -282,17 +286,20 @@ class AmpScaler(object): ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32 ) ] - temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) - temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) if len(param_grads_fp16): _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale, param_grads_fp16, - temp_found_inf_fp16) + self._temp_found_inf_fp16) if len(param_grads_fp32): _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, param_grads_fp32, - temp_found_inf_fp32) - self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32 + self._temp_found_inf_fp32) + if len(param_grads_fp16) and len(param_grads_fp32): + self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32 + elif len(param_grads_fp16): + self._found_inf = self._temp_found_inf_fp16 + else: + self._found_inf = self._temp_found_inf_fp32 optimizer_state["state"] = OptimizerState.UNSCALED