From 5e1a20bfcdd6e4b08ad5aa90b7e82dabe8f7793b Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Wed, 17 Aug 2022 09:54:29 +0800 Subject: [PATCH] Optimize performance of amp (#45188) --- python/paddle/fluid/dygraph/amp/loss_scaler.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index e1ae4ad9bc5..2ce45086474 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -278,14 +278,12 @@ class AmpScaler(object): if param._grad_ivar() is not None ] param_grads_fp16 = [ - param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and ( - param._grad_ivar().dtype == core.VarDesc.VarType.FP16) + param for param in param_grads + if param.dtype == core.VarDesc.VarType.FP16 ] param_grads_fp32 = [ - param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and ( - param._grad_ivar().dtype == core.VarDesc.VarType.FP32) + param for param in param_grads + if param.dtype == core.VarDesc.VarType.FP32 ] if core.is_compiled_with_npu(): float_status = _C_ops.alloc_float_status() @@ -309,12 +307,7 @@ class AmpScaler(object): param_grads_fp32, self._temp_found_inf_fp32) - if len(param_grads_fp16) and len(param_grads_fp32): - self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32 - elif len(param_grads_fp16): - self._found_inf = self._temp_found_inf_fp16 - else: - self._found_inf = self._temp_found_inf_fp32 + self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32 optimizer_state["state"] = OptimizerState.UNSCALED -- GitLab