未验证 提交 5e1a20bf 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of amp (#45188)

上级 b4f67757
......@@ -278,14 +278,12 @@ class AmpScaler(object):
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (
param._grad_ivar().dtype == core.VarDesc.VarType.FP16)
param for param in param_grads
if param.dtype == core.VarDesc.VarType.FP16
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (
param._grad_ivar().dtype == core.VarDesc.VarType.FP32)
param for param in param_grads
if param.dtype == core.VarDesc.VarType.FP32
]
if core.is_compiled_with_npu():
float_status = _C_ops.alloc_float_status()
......@@ -309,12 +307,7 @@ class AmpScaler(object):
param_grads_fp32,
self._temp_found_inf_fp32)
if len(param_grads_fp16) and len(param_grads_fp32):
self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32
elif len(param_grads_fp16):
self._found_inf = self._temp_found_inf_fp16
else:
self._found_inf = self._temp_found_inf_fp32
optimizer_state["state"] = OptimizerState.UNSCALED
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册