提交 8aeee2bf 编写于 作者: W Weilong Wu 提交者: GitHub

Revert "fix found_inf bug for custom optimizer (#50158)"

This reverts commit 64573f9f.
上级 a1e96e47
...@@ -228,16 +228,9 @@ class AmpScaler: ...@@ -228,16 +228,9 @@ class AmpScaler:
optimize_ops, params_grads = (None, None) optimize_ops, params_grads = (None, None)
if hasattr(optimizer, "_set_auxiliary_var"): optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimizer._set_auxiliary_var('found_inf', self._found_inf) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
# uopdate the scale # uopdate the scale
...@@ -778,16 +771,9 @@ class GradScaler(AmpScaler): ...@@ -778,16 +771,9 @@ class GradScaler(AmpScaler):
if optimizer_state["state"] is OptimizerState.INIT: if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer) self._unscale(optimizer)
if hasattr(optimizer, "_set_auxiliary_var"): optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimizer._set_auxiliary_var('found_inf', self._found_inf) optimizer.step()
optimizer.step() self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimizer.step()
self._cache_founf_inf = False
optimizer_state["state"] = OptimizerState.STEPPED optimizer_state["state"] = OptimizerState.STEPPED
......
...@@ -41,16 +41,9 @@ class HybridParallelGradScaler: ...@@ -41,16 +41,9 @@ class HybridParallelGradScaler:
optimize_ops, params_grads = (None, None) optimize_ops, params_grads = (None, None)
if hasattr(optimizer, "_set_auxiliary_var"): optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimizer._set_auxiliary_var('found_inf', self._found_inf) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
self._update() self._update()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册