未验证 提交 1e232e27 编写于 作者: W wanghuancoder 提交者: GitHub

refine amp scaler (#51340)

* refine _found_inf
上级 870c0837
...@@ -18,7 +18,7 @@ from enum import Enum ...@@ -18,7 +18,7 @@ from enum import Enum
import numpy as np import numpy as np
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core, in_dygraph_mode from paddle.fluid import core, in_dygraph_mode
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
...@@ -131,6 +131,9 @@ class AmpScaler: ...@@ -131,6 +131,9 @@ class AmpScaler:
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
self._found_inf = to_variable(np.array([0]).astype(np.bool_)) self._found_inf = to_variable(np.array([0]).astype(np.bool_))
self._temp_found_inf_value_false = to_variable(
np.array([0]).astype(np.bool_)
)
self._temp_found_inf_fp16 = to_variable( self._temp_found_inf_fp16 = to_variable(
np.array([0]).astype(np.bool_) np.array([0]).astype(np.bool_)
) )
...@@ -228,11 +231,16 @@ class AmpScaler: ...@@ -228,11 +231,16 @@ class AmpScaler:
optimize_ops, params_grads = (None, None) optimize_ops, params_grads = (None, None)
if self._found_inf: if hasattr(optimizer, "_set_auxiliary_var"):
self._cache_founf_inf = True optimizer._set_auxiliary_var('found_inf', self._found_inf)
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
# uopdate the scale # uopdate the scale
...@@ -318,6 +326,7 @@ class AmpScaler: ...@@ -318,6 +326,7 @@ class AmpScaler:
for param in param_grads for param in param_grads
if param.dtype == core.VarDesc.VarType.FP32 if param.dtype == core.VarDesc.VarType.FP32
] ]
self._found_inf = self._temp_found_inf_value_false
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
float_status = _legacy_C_ops.alloc_float_status() float_status = _legacy_C_ops.alloc_float_status()
_legacy_C_ops.clear_float_status(float_status, float_status) _legacy_C_ops.clear_float_status(float_status, float_status)
...@@ -330,6 +339,9 @@ class AmpScaler: ...@@ -330,6 +339,9 @@ class AmpScaler:
param_grads_fp16, param_grads_fp16,
self._temp_found_inf_fp16, self._temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp16
)
if len(param_grads_bf16): if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_bf16, param_grads_bf16,
...@@ -338,6 +350,9 @@ class AmpScaler: ...@@ -338,6 +350,9 @@ class AmpScaler:
param_grads_bf16, param_grads_bf16,
self._temp_found_inf_bf16, self._temp_found_inf_bf16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_bf16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -346,6 +361,9 @@ class AmpScaler: ...@@ -346,6 +361,9 @@ class AmpScaler:
param_grads_fp32, param_grads_fp32,
self._temp_found_inf_fp32, self._temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp32
)
else: else:
if len(param_grads_fp16): if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
...@@ -354,6 +372,9 @@ class AmpScaler: ...@@ -354,6 +372,9 @@ class AmpScaler:
param_grads_fp16, param_grads_fp16,
self._temp_found_inf_fp16, self._temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp16
)
if len(param_grads_bf16): if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_bf16, param_grads_bf16,
...@@ -361,6 +382,9 @@ class AmpScaler: ...@@ -361,6 +382,9 @@ class AmpScaler:
param_grads_bf16, param_grads_bf16,
self._temp_found_inf_bf16, self._temp_found_inf_bf16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_bf16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -368,12 +392,9 @@ class AmpScaler: ...@@ -368,12 +392,9 @@ class AmpScaler:
param_grads_fp32, param_grads_fp32,
self._temp_found_inf_fp32, self._temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf = ( self._found_inf, self._temp_found_inf_fp32
self._temp_found_inf_fp16 )
or self._temp_found_inf_bf16
or self._temp_found_inf_fp32
)
optimizer_state["state"] = OptimizerState.UNSCALED optimizer_state["state"] = OptimizerState.UNSCALED
...@@ -761,11 +782,16 @@ class GradScaler(AmpScaler): ...@@ -761,11 +782,16 @@ class GradScaler(AmpScaler):
if optimizer_state["state"] is OptimizerState.INIT: if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer) self._unscale(optimizer)
if self._found_inf: if hasattr(optimizer, "_set_auxiliary_var"):
self._cache_founf_inf = True optimizer._set_auxiliary_var('found_inf', self._found_inf)
else:
optimizer.step() optimizer.step()
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimizer.step()
self._cache_founf_inf = False
optimizer_state["state"] = OptimizerState.STEPPED optimizer_state["state"] = OptimizerState.STEPPED
......
...@@ -236,6 +236,10 @@ class AscendOptimizer(Optimizer): ...@@ -236,6 +236,10 @@ class AscendOptimizer(Optimizer):
ret_list.append(var) ret_list.append(var)
return ret_list return ret_list
def _set_auxiliary_var(self, key, val):
super()._set_auxiliary_var(key, val)
self.inner_opt._set_auxiliary_var(key, val)
def minimize( def minimize(
self, self,
loss, loss,
......
...@@ -41,11 +41,16 @@ class HybridParallelGradScaler: ...@@ -41,11 +41,16 @@ class HybridParallelGradScaler:
optimize_ops, params_grads = (None, None) optimize_ops, params_grads = (None, None)
if self._found_inf: if hasattr(optimizer, "_set_auxiliary_var"):
self._cache_founf_inf = True optimizer._set_auxiliary_var('found_inf', self._found_inf)
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
else:
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
self._update() self._update()
......
...@@ -25,6 +25,10 @@ class MetaOptimizerBase(Optimizer): ...@@ -25,6 +25,10 @@ class MetaOptimizerBase(Optimizer):
self.meta_optimizers_white_list = [] self.meta_optimizers_white_list = []
self.meta_optimizers_black_list = [] self.meta_optimizers_black_list = []
def _set_auxiliary_var(self, key, val):
super()._set_auxiliary_var(key, val)
self.inner_opt._set_auxiliary_var(key, val)
def _set_basic_info( def _set_basic_info(
self, loss, role_maker, user_defined_optimizer, user_defined_strategy self, loss, role_maker, user_defined_optimizer, user_defined_strategy
): ):
......
...@@ -203,6 +203,10 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -203,6 +203,10 @@ class GroupShardedOptimizerStage2(Optimizer):
# Update optimizer parameters and adjust parameter storage and use according to rank. # Update optimizer parameters and adjust parameter storage and use according to rank.
self._update_opt_status() self._update_opt_status()
def _set_auxiliary_var(self, key, val):
super()._set_auxiliary_var(key, val)
self._optim._set_auxiliary_var(key, val)
@paddle.autograd.no_grad() @paddle.autograd.no_grad()
def _sync_params_and_buffers(self): def _sync_params_and_buffers(self):
""" """
......
...@@ -19,10 +19,10 @@ from types import MethodType ...@@ -19,10 +19,10 @@ from types import MethodType
import numpy as np import numpy as np
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.common_ops_import import dygraph_only from paddle.common_ops_import import dygraph_only
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.framework import core
from paddle.nn import clip from paddle.nn import clip
...@@ -262,6 +262,7 @@ def GroupShardedScaler(scaler): ...@@ -262,6 +262,7 @@ def GroupShardedScaler(scaler):
0 if device == "cpu" else int(paddle.get_device().split(":")[1]) 0 if device == "cpu" else int(paddle.get_device().split(":")[1])
) )
self._found_inf = self._temp_found_inf_value_false
with device_guard(dev_id, device): with device_guard(dev_id, device):
if len(param_grads_bfp16): if len(param_grads_bfp16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
...@@ -270,6 +271,9 @@ def GroupShardedScaler(scaler): ...@@ -270,6 +271,9 @@ def GroupShardedScaler(scaler):
param_grads_bfp16, param_grads_bfp16,
temp_found_inf_bfp16, temp_found_inf_bfp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_bfp16
)
if len(param_grads_fp16): if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp16, param_grads_fp16,
...@@ -277,6 +281,9 @@ def GroupShardedScaler(scaler): ...@@ -277,6 +281,9 @@ def GroupShardedScaler(scaler):
param_grads_fp16, param_grads_fp16,
temp_found_inf_fp16, temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -284,21 +291,17 @@ def GroupShardedScaler(scaler): ...@@ -284,21 +291,17 @@ def GroupShardedScaler(scaler):
param_grads_fp32, param_grads_fp32,
temp_found_inf_fp32, temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp32
)
self._found_inf = ( self._found_inf = self._found_inf.cast("int32")
1
if temp_found_inf_bfp16
or temp_found_inf_fp16
or temp_found_inf_fp32
else 0
)
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.SUM, group=None self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
) )
self._found_inf = is_found_inf.numpy()[0] self._found_inf = self._found_inf.cast("bool")
scaler._unscale = MethodType(unscale_method, scaler) scaler._unscale = MethodType(unscale_method, scaler)
return scaler return scaler
......
...@@ -17,7 +17,7 @@ from types import MethodType ...@@ -17,7 +17,7 @@ from types import MethodType
import numpy as np import numpy as np
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.framework import core from paddle.framework import core
...@@ -66,6 +66,7 @@ def distributed_scaler(scaler): ...@@ -66,6 +66,7 @@ def distributed_scaler(scaler):
] ]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_))
self._found_inf = self._temp_found_inf_value_false
if len(param_grads_fp16): if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp16, param_grads_fp16,
...@@ -73,6 +74,9 @@ def distributed_scaler(scaler): ...@@ -73,6 +74,9 @@ def distributed_scaler(scaler):
param_grads_fp16, param_grads_fp16,
temp_found_inf_fp16, temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -80,17 +84,19 @@ def distributed_scaler(scaler): ...@@ -80,17 +84,19 @@ def distributed_scaler(scaler):
param_grads_fp32, param_grads_fp32,
temp_found_inf_fp32, temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp32
)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 self._found_inf = self._found_inf.cast("int32")
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
# TODO(shenliang03) Since dp allreduce in the optimizer is # TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global # after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed. # information. In the future, we should use check_group to speed.
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
) )
self._found_inf = is_found_inf.numpy()[0] self._found_inf = self._found_inf.cast("bool")
# Only data_parallel doesn't need to modify scaler # Only data_parallel doesn't need to modify scaler
fleet_env = fleet.fleet fleet_env = fleet.fleet
......
...@@ -970,11 +970,18 @@ class Optimizer: ...@@ -970,11 +970,18 @@ class Optimizer:
self._create_global_learning_rate() self._create_global_learning_rate()
if in_dygraph_mode(): if in_dygraph_mode():
for param_and_grad in parameters_and_grads: found_inf = self._get_auxiliary_var('found_inf')
if param_and_grad[1] is None: if found_inf:
continue if isinstance(found_inf, core.eager.Tensor):
if param_and_grad[0].trainable is True: self._set_auxiliary_var('found_inf', True)
self._append_optimize_op(target_block, param_and_grad) else:
if isinstance(found_inf, core.eager.Tensor):
self._set_auxiliary_var('found_inf', False)
for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None:
continue
if param_and_grad[0].trainable is True:
self._append_optimize_op(target_block, param_and_grad)
else: else:
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None: if param_and_grad[1] is None:
......
...@@ -144,6 +144,10 @@ class LookAhead(Optimizer): ...@@ -144,6 +144,10 @@ class LookAhead(Optimizer):
self._global_step_var = None self._global_step_var = None
self._k_var = None self._k_var = None
def _set_auxiliary_var(self, key, val):
super()._set_auxiliary_var(key, val)
self.inner_optimizer._set_auxiliary_var(key, val)
@framework.dygraph_only @framework.dygraph_only
@imperative_base.no_grad @imperative_base.no_grad
def step(self): def step(self):
......
...@@ -149,12 +149,15 @@ class Adadelta(Optimizer): ...@@ -149,12 +149,15 @@ class Adadelta(Optimizer):
parameters = parameters.get('params') parameters = parameters.get('params')
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_accumulator(self._avg_squared_grad_acc_str, master_p) self._add_accumulator(self._avg_squared_grad_acc_str, master_p)
self._add_accumulator( self._add_accumulator(
self._avg_squared_update_acc_str, master_p self._avg_squared_update_acc_str, master_p
) )
self._already_create_accumulater.add(p.name)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
...@@ -166,6 +169,7 @@ class Adadelta(Optimizer): ...@@ -166,6 +169,7 @@ class Adadelta(Optimizer):
) )
self._add_accumulator(self._avg_squared_grad_acc_str, p) self._add_accumulator(self._avg_squared_grad_acc_str, p)
self._add_accumulator(self._avg_squared_update_acc_str, p) self._add_accumulator(self._avg_squared_update_acc_str, p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
if isinstance(param_and_grad, dict): if isinstance(param_and_grad, dict):
......
...@@ -142,9 +142,12 @@ class Adagrad(Optimizer): ...@@ -142,9 +142,12 @@ class Adagrad(Optimizer):
parameters = self._update_param_group(parameters) parameters = self._update_param_group(parameters)
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_accumulator(self._moment_acc_str, master_p) self._add_accumulator(self._moment_acc_str, master_p)
self._already_create_accumulater.add(p.name)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
...@@ -159,6 +162,7 @@ class Adagrad(Optimizer): ...@@ -159,6 +162,7 @@ class Adagrad(Optimizer):
p, p,
fill_value=self.initial_accumulator_value, fill_value=self.initial_accumulator_value,
) )
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
......
...@@ -260,9 +260,12 @@ class Adam(Optimizer): ...@@ -260,9 +260,12 @@ class Adam(Optimizer):
# Create accumulator tensors for first and second moments # Create accumulator tensors for first and second moments
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_moments_pows(master_p) self._add_moments_pows(master_p)
self._already_create_accumulater.add(p.name)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
...@@ -273,6 +276,7 @@ class Adam(Optimizer): ...@@ -273,6 +276,7 @@ class Adam(Optimizer):
"Consider using multi_precision=True option of the Adam optimizer." "Consider using multi_precision=True option of the Adam optimizer."
) )
self._add_moments_pows(p) self._add_moments_pows(p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -303,8 +307,6 @@ class Adam(Optimizer): ...@@ -303,8 +307,6 @@ class Adam(Optimizer):
# create the adam optimize op # create the adam optimize op
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
found_inf = self._get_auxiliary_var('found_inf')
_beta1 = ( _beta1 = (
self._beta1 self._beta1
if not isinstance(self._beta1, Variable) if not isinstance(self._beta1, Variable)
...@@ -325,7 +327,7 @@ class Adam(Optimizer): ...@@ -325,7 +327,7 @@ class Adam(Optimizer):
beta1_pow_acc, beta1_pow_acc,
beta2_pow_acc, beta2_pow_acc,
master_weight, master_weight,
found_inf, None,
_beta1, _beta1,
_beta2, _beta2,
self._epsilon, self._epsilon,
...@@ -636,21 +638,28 @@ class Adam(Optimizer): ...@@ -636,21 +638,28 @@ class Adam(Optimizer):
if master_weight is not None if master_weight is not None
else None else None
) )
_, _, _, _, _, _ = _C_ops.merged_adam_( found_inf = self._get_auxiliary_var('found_inf')
self._param_dict[key][param_group_idx], if found_inf:
grad_dict[key], if isinstance(found_inf, core.eager.Tensor):
lr_dict[key], self._set_auxiliary_var('found_inf', True)
self._moment1_dict[key][param_group_idx], else:
self._moment2_dict[key][param_group_idx], if isinstance(found_inf, core.eager.Tensor):
self._beta1_pow_acc_dict[key][param_group_idx], self._set_auxiliary_var('found_inf', False)
self._beta2_pow_acc_dict[key][param_group_idx], _, _, _, _, _, _ = _C_ops.merged_adam_(
master_weight, self._param_dict[key][param_group_idx],
_beta1, grad_dict[key],
_beta2, lr_dict[key],
self._epsilon, self._moment1_dict[key][param_group_idx],
find_master, self._moment2_dict[key][param_group_idx],
False, self._beta1_pow_acc_dict[key][param_group_idx],
) self._beta2_pow_acc_dict[key][param_group_idx],
master_weight,
_beta1,
_beta2,
self._epsilon,
find_master,
False,
)
else: else:
inputs = { inputs = {
"Param": self._param_dict[key][param_group_idx], "Param": self._param_dict[key][param_group_idx],
......
...@@ -195,9 +195,12 @@ class Adamax(Optimizer): ...@@ -195,9 +195,12 @@ class Adamax(Optimizer):
# Create accumulator tensors for first moment and infinity norm # Create accumulator tensors for first moment and infinity norm
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_moments_pows(master_p) self._add_moments_pows(master_p)
self._already_create_accumulater.add(p.name)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
...@@ -208,6 +211,7 @@ class Adamax(Optimizer): ...@@ -208,6 +211,7 @@ class Adamax(Optimizer):
"Consider using multi_precision=True option of the Adam optimizer." "Consider using multi_precision=True option of the Adam optimizer."
) )
self._add_moments_pows(p) self._add_moments_pows(p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
......
...@@ -283,6 +283,7 @@ class AdamW(Optimizer): ...@@ -283,6 +283,7 @@ class AdamW(Optimizer):
self._use_multi_tensor = None self._use_multi_tensor = None
self.regularization = None self.regularization = None
self._auxiliary_vars = {} self._auxiliary_vars = {}
self._already_create_accumulater = set()
def _set_auxiliary_var(self, key, val): def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val self._auxiliary_vars[key] = val
...@@ -368,9 +369,12 @@ class AdamW(Optimizer): ...@@ -368,9 +369,12 @@ class AdamW(Optimizer):
# Create accumulator tensors for first and second moments # Create accumulator tensors for first and second moments
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_moments_pows(master_p) self._add_moments_pows(master_p)
self._already_create_accumulater.add(p.name)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
...@@ -381,6 +385,7 @@ class AdamW(Optimizer): ...@@ -381,6 +385,7 @@ class AdamW(Optimizer):
"Consider using multi_precision=True option of the Adam optimizer." "Consider using multi_precision=True option of the Adam optimizer."
) )
self._add_moments_pows(p) self._add_moments_pows(p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -437,7 +442,6 @@ class AdamW(Optimizer): ...@@ -437,7 +442,6 @@ class AdamW(Optimizer):
else self._beta2.numpy().item(0) else self._beta2.numpy().item(0)
) )
found_inf = self._get_auxiliary_var('found_inf')
_, _, _, _, _, _ = _C_ops.adamw_( _, _, _, _, _, _ = _C_ops.adamw_(
param_and_grad[0], param_and_grad[0],
param_and_grad[1], param_and_grad[1],
...@@ -447,7 +451,7 @@ class AdamW(Optimizer): ...@@ -447,7 +451,7 @@ class AdamW(Optimizer):
beta1_pow_acc, beta1_pow_acc,
beta2_pow_acc, beta2_pow_acc,
master_weight, master_weight,
found_inf, None,
_beta1, _beta1,
_beta2, _beta2,
self._epsilon, self._epsilon,
......
...@@ -159,11 +159,15 @@ class Lamb(Optimizer): ...@@ -159,11 +159,15 @@ class Lamb(Optimizer):
# Create accumulator tensors for first and second moments # Create accumulator tensors for first and second moments
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_moments_pows(master_p) self._add_moments_pows(master_p)
self._already_create_accumulater.add(p.name)
else: else:
self._add_moments_pows(p) self._add_moments_pows(p)
self._already_create_accumulater.add(p.name)
def _add_moments_pows(self, p): def _add_moments_pows(self, p):
acc_dtype = p.dtype acc_dtype = p.dtype
...@@ -233,7 +237,6 @@ class Lamb(Optimizer): ...@@ -233,7 +237,6 @@ class Lamb(Optimizer):
self._used_master_weights[p_name] = master_weight.name self._used_master_weights[p_name] = master_weight.name
else: else:
master_weight = None master_weight = None
found_inf = self._get_auxiliary_var('found_inf')
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
_C_ops.lamb_( _C_ops.lamb_(
...@@ -245,7 +248,7 @@ class Lamb(Optimizer): ...@@ -245,7 +248,7 @@ class Lamb(Optimizer):
beta1_pow_acc, beta1_pow_acc,
beta2_pow_acc, beta2_pow_acc,
master_weight, master_weight,
found_inf, None,
weight_decay, weight_decay,
self._beta1, self._beta1,
self._beta2, self._beta2,
...@@ -283,6 +286,7 @@ class Lamb(Optimizer): ...@@ -283,6 +286,7 @@ class Lamb(Optimizer):
inputs["MasterParam"] = master_weight inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight outputs["MasterParamOut"] = master_weight
found_inf = self._get_auxiliary_var('found_inf')
if found_inf: if found_inf:
inputs["SkipUpdate"] = found_inf inputs["SkipUpdate"] = found_inf
......
...@@ -211,9 +211,12 @@ class Momentum(Optimizer): ...@@ -211,9 +211,12 @@ class Momentum(Optimizer):
parameters = self._update_param_group(parameters) parameters = self._update_param_group(parameters)
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_accumulator(self._velocity_acc_str, master_p) self._add_accumulator(self._velocity_acc_str, master_p)
self._already_create_accumulater.add(p.name)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
...@@ -224,6 +227,7 @@ class Momentum(Optimizer): ...@@ -224,6 +227,7 @@ class Momentum(Optimizer):
"Consider using multi_precision=True option of the Momentum optimizer." "Consider using multi_precision=True option of the Momentum optimizer."
) )
self._add_accumulator(self._velocity_acc_str, p) self._add_accumulator(self._velocity_acc_str, p)
self._already_create_accumulater.add(p.name)
def _create_regularization_of_grad(self, param, grad, regularization=None): def _create_regularization_of_grad(self, param, grad, regularization=None):
"""Create and add backward regularization Operators """Create and add backward regularization Operators
...@@ -472,19 +476,30 @@ class Momentum(Optimizer): ...@@ -472,19 +476,30 @@ class Momentum(Optimizer):
) )
if in_dygraph_mode(): if in_dygraph_mode():
_, _, _ = _C_ops.merged_momentum_( found_inf = self._get_auxiliary_var('found_inf')
self._param_dict[key][param_group_idx], if found_inf:
grad_dict[key], if isinstance(found_inf, core.eager.Tensor):
self._velocity_dict[key][param_group_idx], self._set_auxiliary_var('found_inf', True)
lr_dict[key], else:
master_weight, if isinstance(found_inf, core.eager.Tensor):
self._momentum, self._set_auxiliary_var('found_inf', False)
self._use_nesterov, _, _, _ = _C_ops.merged_momentum_(
self._regularization_method_dict[key][param_group_idx], self._param_dict[key][param_group_idx],
self._regularization_coeff_dict[key][param_group_idx], grad_dict[key],
find_master, self._velocity_dict[key][param_group_idx],
self._rescale_grad, lr_dict[key],
) master_weight,
self._momentum,
self._use_nesterov,
self._regularization_method_dict[key][
param_group_idx
],
self._regularization_coeff_dict[key][
param_group_idx
],
find_master,
self._rescale_grad,
)
else: else:
inputs = { inputs = {
"Param": self._param_dict[key][param_group_idx], "Param": self._param_dict[key][param_group_idx],
......
...@@ -275,6 +275,7 @@ class Optimizer: ...@@ -275,6 +275,7 @@ class Optimizer:
self._param_dict = self._create_multi_tensor_dict() self._param_dict = self._create_multi_tensor_dict()
self._auxiliary_vars = {} self._auxiliary_vars = {}
self._already_create_accumulater = set()
def _set_auxiliary_var(self, key, val): def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val self._auxiliary_vars[key] = val
...@@ -979,31 +980,38 @@ class Optimizer: ...@@ -979,31 +980,38 @@ class Optimizer:
self._create_accumulators(target_block, params_acc_dict) self._create_accumulators(target_block, params_acc_dict)
if framework._non_static_mode(): if framework._non_static_mode():
if isinstance(parameters_and_grads, list): found_inf = self._get_auxiliary_var('found_inf')
for param_and_grad in parameters_and_grads: if found_inf:
if param_and_grad[1] is None: if isinstance(found_inf, core.eager.Tensor):
continue self._set_auxiliary_var('found_inf', True)
if param_and_grad[0].stop_gradient is False:
self._append_optimize_op(
target_block, param_and_grad
)
else: else:
for param_and_grad in parameters_and_grads['params']: if isinstance(found_inf, core.eager.Tensor):
if param_and_grad[1] is None: self._set_auxiliary_var('found_inf', False)
continue if isinstance(parameters_and_grads, list):
if param_and_grad[0].stop_gradient is False: for param_and_grad in parameters_and_grads:
param_grad_dict = dict() if param_and_grad[1] is None:
param_grad_dict['params'] = param_and_grad continue
param_grad_dict.update( if param_and_grad[0].stop_gradient is False:
{ self._append_optimize_op(
k: v target_block, param_and_grad
for k, v in parameters_and_grads.items() )
if k != 'params' else:
} for param_and_grad in parameters_and_grads['params']:
) if param_and_grad[1] is None:
self._append_optimize_op( continue
target_block, param_grad_dict if param_and_grad[0].stop_gradient is False:
) param_grad_dict = dict()
param_grad_dict['params'] = param_and_grad
param_grad_dict.update(
{
k: v
for k, v in parameters_and_grads.items()
if k != 'params'
}
)
self._append_optimize_op(
target_block, param_grad_dict
)
else: else:
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None: if param_and_grad[1] is None:
......
...@@ -203,11 +203,15 @@ class RMSProp(Optimizer): ...@@ -203,11 +203,15 @@ class RMSProp(Optimizer):
parameters = parameters.get('params') parameters = parameters.get('params')
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_accumulator(self._momentum_acc_str, master_p) self._add_accumulator(self._momentum_acc_str, master_p)
self._add_accumulator(self._mean_square_acc_str, master_p) self._add_accumulator(self._mean_square_acc_str, master_p)
self._add_accumulator(self._mean_grad_acc_str, master_p) self._add_accumulator(self._mean_grad_acc_str, master_p)
self._already_create_accumulater.add(p.name)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
...@@ -220,6 +224,7 @@ class RMSProp(Optimizer): ...@@ -220,6 +224,7 @@ class RMSProp(Optimizer):
self._add_accumulator(self._momentum_acc_str, p) self._add_accumulator(self._momentum_acc_str, p)
self._add_accumulator(self._mean_square_acc_str, p) self._add_accumulator(self._mean_square_acc_str, p)
self._add_accumulator(self._mean_grad_acc_str, p) self._add_accumulator(self._mean_grad_acc_str, p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
if not isinstance(block, framework.Block): if not isinstance(block, framework.Block):
......
...@@ -99,8 +99,11 @@ class SGD(Optimizer): ...@@ -99,8 +99,11 @@ class SGD(Optimizer):
# Create accumulator tensors for first and second moments # Create accumulator tensors for first and second moments
for p in parameters: for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._already_create_accumulater.add(p.name)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册