未验证 提交 5e1a20bf 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of amp (#45188)

上级 b4f67757
...@@ -278,14 +278,12 @@ class AmpScaler(object): ...@@ -278,14 +278,12 @@ class AmpScaler(object):
if param._grad_ivar() is not None if param._grad_ivar() is not None
] ]
param_grads_fp16 = [ param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list param for param in param_grads
if (param._grad_ivar() is not None) and ( if param.dtype == core.VarDesc.VarType.FP16
param._grad_ivar().dtype == core.VarDesc.VarType.FP16)
] ]
param_grads_fp32 = [ param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list param for param in param_grads
if (param._grad_ivar() is not None) and ( if param.dtype == core.VarDesc.VarType.FP32
param._grad_ivar().dtype == core.VarDesc.VarType.FP32)
] ]
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
float_status = _C_ops.alloc_float_status() float_status = _C_ops.alloc_float_status()
...@@ -309,12 +307,7 @@ class AmpScaler(object): ...@@ -309,12 +307,7 @@ class AmpScaler(object):
param_grads_fp32, param_grads_fp32,
self._temp_found_inf_fp32) self._temp_found_inf_fp32)
if len(param_grads_fp16) and len(param_grads_fp32): self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32
self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32
elif len(param_grads_fp16):
self._found_inf = self._temp_found_inf_fp16
else:
self._found_inf = self._temp_found_inf_fp32
optimizer_state["state"] = OptimizerState.UNSCALED optimizer_state["state"] = OptimizerState.UNSCALED
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册