提交 8aeee2bf 编写于 作者: W Weilong Wu 提交者: GitHub

Revert "fix found_inf bug for custom optimizer (#50158)"

This reverts commit 64573f9f.
上级 a1e96e47
......@@ -228,16 +228,9 @@ class AmpScaler:
optimize_ops, params_grads = (None, None)
if hasattr(optimizer, "_set_auxiliary_var"):
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if self._use_dynamic_loss_scaling:
# uopdate the scale
......@@ -778,16 +771,9 @@ class GradScaler(AmpScaler):
if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer)
if hasattr(optimizer, "_set_auxiliary_var"):
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimizer.step()
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimizer.step()
self._cache_founf_inf = False
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimizer.step()
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
optimizer_state["state"] = OptimizerState.STEPPED
......
......@@ -41,16 +41,9 @@ class HybridParallelGradScaler:
optimize_ops, params_grads = (None, None)
if hasattr(optimizer, "_set_auxiliary_var"):
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if self._use_dynamic_loss_scaling:
self._update()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册