未验证 提交 cc2b4662 编写于 作者: Z zhangbo9674 提交者: GitHub

refine found_inf of loss_scaler (#37770)

上级 08a2d0ba
...@@ -127,6 +127,10 @@ class AmpScaler(object): ...@@ -127,6 +127,10 @@ class AmpScaler(object):
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
self._found_inf = to_variable(np.array([0]).astype(np.bool)) self._found_inf = to_variable(np.array([0]).astype(np.bool))
self._temp_found_inf_fp16 = to_variable(
np.array([0]).astype(np.bool))
self._temp_found_inf_fp32 = to_variable(
np.array([0]).astype(np.bool))
self._scale = to_variable( self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32)) np.array([self._init_loss_scaling]).astype(np.float32))
self._cache_founf_inf = None self._cache_founf_inf = None
...@@ -282,17 +286,20 @@ class AmpScaler(object): ...@@ -282,17 +286,20 @@ class AmpScaler(object):
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32 ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32
) )
] ]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
if len(param_grads_fp16): if len(param_grads_fp16):
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale, _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
param_grads_fp16, param_grads_fp16,
temp_found_inf_fp16) self._temp_found_inf_fp16)
if len(param_grads_fp32): if len(param_grads_fp32):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32, param_grads_fp32,
temp_found_inf_fp32) self._temp_found_inf_fp32)
self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32 if len(param_grads_fp16) and len(param_grads_fp32):
self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32
elif len(param_grads_fp16):
self._found_inf = self._temp_found_inf_fp16
else:
self._found_inf = self._temp_found_inf_fp32
optimizer_state["state"] = OptimizerState.UNSCALED optimizer_state["state"] = OptimizerState.UNSCALED
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册