未验证 提交 64573f9f 编写于 作者: W wanghuancoder 提交者: GitHub

fix found_inf bug for custom optimizer (#50158)

上级 80310541
......@@ -228,9 +228,16 @@ class AmpScaler:
optimize_ops, params_grads = (None, None)
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if hasattr(optimizer, "_set_auxiliary_var"):
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling:
# uopdate the scale
......@@ -771,9 +778,16 @@ class GradScaler(AmpScaler):
if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer)
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimizer.step()
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if hasattr(optimizer, "_set_auxiliary_var"):
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimizer.step()
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimizer.step()
self._cache_founf_inf = False
optimizer_state["state"] = OptimizerState.STEPPED
......
......@@ -41,9 +41,16 @@ class HybridParallelGradScaler:
optimize_ops, params_grads = (None, None)
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if hasattr(optimizer, "_set_auxiliary_var"):
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling:
self._update()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册