未验证 提交 382e9a06 编写于 作者: W wanghuancoder 提交者: GitHub

refine amp scaler found_inf (#49864)

* refine _found_inf
上级 320958eb
...@@ -18,7 +18,7 @@ from enum import Enum ...@@ -18,7 +18,7 @@ from enum import Enum
import numpy as np import numpy as np
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core, in_dygraph_mode from paddle.fluid import core, in_dygraph_mode
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
...@@ -228,11 +228,9 @@ class AmpScaler: ...@@ -228,11 +228,9 @@ class AmpScaler:
optimize_ops, params_grads = (None, None) optimize_ops, params_grads = (None, None)
if self._found_inf: optimizer._set_auxiliary_var('found_inf', self._found_inf)
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
# uopdate the scale # uopdate the scale
...@@ -330,6 +328,9 @@ class AmpScaler: ...@@ -330,6 +328,9 @@ class AmpScaler:
param_grads_fp16, param_grads_fp16,
self._temp_found_inf_fp16, self._temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp16
)
if len(param_grads_bf16): if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_bf16, param_grads_bf16,
...@@ -338,6 +339,9 @@ class AmpScaler: ...@@ -338,6 +339,9 @@ class AmpScaler:
param_grads_bf16, param_grads_bf16,
self._temp_found_inf_bf16, self._temp_found_inf_bf16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_bf16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -346,6 +350,9 @@ class AmpScaler: ...@@ -346,6 +350,9 @@ class AmpScaler:
param_grads_fp32, param_grads_fp32,
self._temp_found_inf_fp32, self._temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp32
)
else: else:
if len(param_grads_fp16): if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
...@@ -354,6 +361,9 @@ class AmpScaler: ...@@ -354,6 +361,9 @@ class AmpScaler:
param_grads_fp16, param_grads_fp16,
self._temp_found_inf_fp16, self._temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp16
)
if len(param_grads_bf16): if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_bf16, param_grads_bf16,
...@@ -361,6 +371,9 @@ class AmpScaler: ...@@ -361,6 +371,9 @@ class AmpScaler:
param_grads_bf16, param_grads_bf16,
self._temp_found_inf_bf16, self._temp_found_inf_bf16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_bf16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -368,11 +381,8 @@ class AmpScaler: ...@@ -368,11 +381,8 @@ class AmpScaler:
param_grads_fp32, param_grads_fp32,
self._temp_found_inf_fp32, self._temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf = ( self._found_inf, self._temp_found_inf_fp32
self._temp_found_inf_fp16
or self._temp_found_inf_bf16
or self._temp_found_inf_fp32
) )
optimizer_state["state"] = OptimizerState.UNSCALED optimizer_state["state"] = OptimizerState.UNSCALED
...@@ -761,11 +771,9 @@ class GradScaler(AmpScaler): ...@@ -761,11 +771,9 @@ class GradScaler(AmpScaler):
if optimizer_state["state"] is OptimizerState.INIT: if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer) self._unscale(optimizer)
if self._found_inf: optimizer._set_auxiliary_var('found_inf', self._found_inf)
self._cache_founf_inf = True
else:
optimizer.step() optimizer.step()
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
optimizer_state["state"] = OptimizerState.STEPPED optimizer_state["state"] = OptimizerState.STEPPED
......
...@@ -41,11 +41,9 @@ class HybridParallelGradScaler: ...@@ -41,11 +41,9 @@ class HybridParallelGradScaler:
optimize_ops, params_grads = (None, None) optimize_ops, params_grads = (None, None)
if self._found_inf: optimizer._set_auxiliary_var('found_inf', self._found_inf)
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
self._update() self._update()
......
...@@ -19,10 +19,10 @@ from types import MethodType ...@@ -19,10 +19,10 @@ from types import MethodType
import numpy as np import numpy as np
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.common_ops_import import dygraph_only from paddle.common_ops_import import dygraph_only
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.framework import core
from paddle.nn import clip from paddle.nn import clip
...@@ -231,6 +231,9 @@ def GroupShardedScaler(scaler): ...@@ -231,6 +231,9 @@ def GroupShardedScaler(scaler):
param_grads_fp16, param_grads_fp16,
temp_found_inf_fp16, temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -238,15 +241,17 @@ def GroupShardedScaler(scaler): ...@@ -238,15 +241,17 @@ def GroupShardedScaler(scaler):
param_grads_fp32, param_grads_fp32,
temp_found_inf_fp32, temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp32
)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 self._found_inf = self._found_inf.cast("int32")
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.SUM, group=None self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
) )
self._found_inf = is_found_inf.numpy()[0] self._found_inf = self._found_inf.cast("bool")
scaler._unscale = MethodType(unscale_method, scaler) scaler._unscale = MethodType(unscale_method, scaler)
return scaler return scaler
......
...@@ -17,7 +17,7 @@ from types import MethodType ...@@ -17,7 +17,7 @@ from types import MethodType
import numpy as np import numpy as np
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.framework import core from paddle.framework import core
...@@ -73,6 +73,9 @@ def distributed_scaler(scaler): ...@@ -73,6 +73,9 @@ def distributed_scaler(scaler):
param_grads_fp16, param_grads_fp16,
temp_found_inf_fp16, temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -80,17 +83,19 @@ def distributed_scaler(scaler): ...@@ -80,17 +83,19 @@ def distributed_scaler(scaler):
param_grads_fp32, param_grads_fp32,
temp_found_inf_fp32, temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp32
)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 self._found_inf = self._found_inf.cast("int32")
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
# TODO(shenliang03) Since dp allreduce in the optimizer is # TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global # after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed. # information. In the future, we should use check_group to speed.
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
) )
self._found_inf = is_found_inf.numpy()[0] self._found_inf = self._found_inf.cast("bool")
# Only data_parallel doesn't need to modify scaler # Only data_parallel doesn't need to modify scaler
fleet_env = fleet.fleet fleet_env = fleet.fleet
......
...@@ -893,6 +893,13 @@ class Optimizer: ...@@ -893,6 +893,13 @@ class Optimizer:
self._create_global_learning_rate() self._create_global_learning_rate()
if in_dygraph_mode(): if in_dygraph_mode():
found_inf = self._get_auxiliary_var('found_inf')
if found_inf:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', True)
else:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', False)
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None: if param_and_grad[1] is None:
continue continue
......
...@@ -360,8 +360,6 @@ class Adam(Optimizer): ...@@ -360,8 +360,6 @@ class Adam(Optimizer):
# create the adam optimize op # create the adam optimize op
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
found_inf = self._get_auxiliary_var('found_inf')
_beta1 = ( _beta1 = (
self._beta1 self._beta1
if not isinstance(self._beta1, Variable) if not isinstance(self._beta1, Variable)
...@@ -382,7 +380,7 @@ class Adam(Optimizer): ...@@ -382,7 +380,7 @@ class Adam(Optimizer):
beta1_pow_acc, beta1_pow_acc,
beta2_pow_acc, beta2_pow_acc,
master_weight, master_weight,
found_inf, None,
_beta1, _beta1,
_beta2, _beta2,
self._epsilon, self._epsilon,
...@@ -693,6 +691,13 @@ class Adam(Optimizer): ...@@ -693,6 +691,13 @@ class Adam(Optimizer):
if master_weight is not None if master_weight is not None
else None else None
) )
found_inf = self._get_auxiliary_var('found_inf')
if found_inf:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', True)
else:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', False)
_, _, _, _, _, _ = _C_ops.merged_adam_( _, _, _, _, _, _ = _C_ops.merged_adam_(
self._param_dict[key][param_group_idx], self._param_dict[key][param_group_idx],
grad_dict[key], grad_dict[key],
......
...@@ -491,7 +491,6 @@ class AdamW(Optimizer): ...@@ -491,7 +491,6 @@ class AdamW(Optimizer):
else self._beta2.numpy().item(0) else self._beta2.numpy().item(0)
) )
found_inf = self._get_auxiliary_var('found_inf')
_, _, _, _, _, _ = _C_ops.adamw_( _, _, _, _, _, _ = _C_ops.adamw_(
param_and_grad[0], param_and_grad[0],
param_and_grad[1], param_and_grad[1],
...@@ -501,7 +500,7 @@ class AdamW(Optimizer): ...@@ -501,7 +500,7 @@ class AdamW(Optimizer):
beta1_pow_acc, beta1_pow_acc,
beta2_pow_acc, beta2_pow_acc,
master_weight, master_weight,
found_inf, None,
_beta1, _beta1,
_beta2, _beta2,
self._epsilon, self._epsilon,
......
...@@ -293,7 +293,6 @@ class Lamb(Optimizer): ...@@ -293,7 +293,6 @@ class Lamb(Optimizer):
self._used_master_weights[p_name] = master_weight.name self._used_master_weights[p_name] = master_weight.name
else: else:
master_weight = None master_weight = None
found_inf = self._get_auxiliary_var('found_inf')
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
_C_ops.lamb_( _C_ops.lamb_(
...@@ -305,7 +304,7 @@ class Lamb(Optimizer): ...@@ -305,7 +304,7 @@ class Lamb(Optimizer):
beta1_pow_acc, beta1_pow_acc,
beta2_pow_acc, beta2_pow_acc,
master_weight, master_weight,
found_inf, None,
weight_decay, weight_decay,
self._beta1, self._beta1,
self._beta2, self._beta2,
...@@ -343,6 +342,7 @@ class Lamb(Optimizer): ...@@ -343,6 +342,7 @@ class Lamb(Optimizer):
inputs["MasterParam"] = master_weight inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight outputs["MasterParamOut"] = master_weight
found_inf = self._get_auxiliary_var('found_inf')
if found_inf: if found_inf:
inputs["SkipUpdate"] = found_inf inputs["SkipUpdate"] = found_inf
......
...@@ -530,6 +530,13 @@ class Momentum(Optimizer): ...@@ -530,6 +530,13 @@ class Momentum(Optimizer):
) )
if in_dygraph_mode(): if in_dygraph_mode():
found_inf = self._get_auxiliary_var('found_inf')
if found_inf:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', True)
else:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', False)
_, _, _ = _C_ops.merged_momentum_( _, _, _ = _C_ops.merged_momentum_(
self._param_dict[key][param_group_idx], self._param_dict[key][param_group_idx],
grad_dict[key], grad_dict[key],
...@@ -538,8 +545,12 @@ class Momentum(Optimizer): ...@@ -538,8 +545,12 @@ class Momentum(Optimizer):
master_weight, master_weight,
self._momentum, self._momentum,
self._use_nesterov, self._use_nesterov,
self._regularization_method_dict[key][param_group_idx], self._regularization_method_dict[key][
self._regularization_coeff_dict[key][param_group_idx], param_group_idx
],
self._regularization_coeff_dict[key][
param_group_idx
],
find_master, find_master,
self._rescale_grad, self._rescale_grad,
) )
......
...@@ -920,6 +920,13 @@ class Optimizer: ...@@ -920,6 +920,13 @@ class Optimizer:
self._create_accumulators(target_block, params_acc_dict) self._create_accumulators(target_block, params_acc_dict)
if framework._non_static_mode(): if framework._non_static_mode():
found_inf = self._get_auxiliary_var('found_inf')
if found_inf:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', True)
else:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', False)
if isinstance(parameters_and_grads, list): if isinstance(parameters_and_grads, list):
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None: if param_and_grad[1] is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册