未验证 提交 6959eae5 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unify the static amp codes of fp16 and bf16. Reimplement #52694 in release/2.4. (#52697)

上级 d1e8b1e2
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
from . import decorator from . import decorator
from .decorator import * from .decorator import decorate, amp_decorate
from . import fp16_lists from . import fp16_lists
from .fp16_lists import * from .fp16_lists import *
from . import fp16_utils from . import fp16_utils
......
...@@ -27,8 +27,8 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): ...@@ -27,8 +27,8 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
$$Out = X / scale$$ $$Out = X / scale$$
If any tensor in X contains Inf or Nan, the Out will generate a indicator. If any tensor in X contains Inf or Nan, the Out will generate a indicator.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic. Out should not be used, and its data may not be deterministic.
Otherwise, FoundInfinite will be 0 (False). Otherwise, FoundInfinite will be 0 (False).
Args: Args:
...@@ -38,75 +38,98 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): ...@@ -38,75 +38,98 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
""" """
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x: for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'check_finite_and_unscale') e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'check_finite_and_unscale',
)
helper = LayerHelper("check_finite_and_unscale", **locals()) helper = LayerHelper("check_finite_and_unscale", **locals())
found_inf = helper.create_variable_for_type_inference(dtype='bool') found_inf = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'X': x, 'Scale': scale} inputs = {'X': x, 'Scale': scale}
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
check_variable_and_dtype(float_status, "float_status", check_variable_and_dtype(
['float16', 'float32'], float_status,
'check_finite_and_unscale') "float_status",
['float16', 'float32'],
'check_finite_and_unscale',
)
inputs['FloatStatus'] = float_status inputs['FloatStatus'] = float_status
outputs = {'Out': x, 'FoundInfinite': found_inf} outputs = {'Out': x, 'FoundInfinite': found_inf}
helper.append_op(type='check_finite_and_unscale', helper.append_op(
inputs=inputs, type='check_finite_and_unscale', inputs=inputs, outputs=outputs
outputs=outputs) )
return x, found_inf return x, found_inf
def update_loss_scaling(x, def update_loss_scaling(
found_inf, x,
prev_loss_scaling, found_inf,
num_good_steps, prev_loss_scaling,
num_bad_steps, num_good_steps,
incr_every_n_steps, num_bad_steps,
decr_every_n_nan_or_inf, incr_every_n_steps,
incr_ratio, decr_every_n_nan_or_inf,
decr_ratio, incr_ratio,
stop_update=False, decr_ratio,
name=None): stop_update=False,
name=None,
):
""" """
Update loss scaling according to overall gradients. If all gradients is Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio. finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite. decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args: Args:
x(list|tuple): The input tensors of update_loss_scaling operator. x(list|tuple): The input tensors of update_loss_scaling operator.
found_inf (Variable): A boolean variable indicates whether found_inf (Variable): A boolean variable indicates whether
there is any infinite gradient. there is any infinite gradient.
prev_loss_scaling (Variable): Previous loss scaling. prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite. all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite. some gradients are infinite.
incr_every_n_steps (int): A variable represents increasing loss incr_every_n_steps (int): A variable represents increasing loss
scaling every n consecutive steps with scaling every n consecutive steps with
finite gradients. finite gradients.
decr_every_n_nan_or_inf (int): A variable represents decreasing decr_every_n_nan_or_inf (int): A variable represents decreasing
loss scaling every n accumulated loss scaling every n accumulated
steps with nan or inf gradients. steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss incr_ratio(float): The multiplier to use when increasing the loss
scaling. scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling. loss scaling.
""" """
check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling", check_variable_and_dtype(
['float32', 'float64'], "update_loss_scaling") prev_loss_scaling,
"prev_loss_scaling",
['float32', 'float64'],
"update_loss_scaling",
)
check_type(x, 'x', (tuple, list), 'update_loss_scaling') check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x: for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'update_loss_scaling') e,
if e.dtype == core.VarDesc.VarType.FP16: "x",
assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ ['float16', 'float32', 'float64', 'uint16'],
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." 'update_loss_scaling',
)
if (
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else: else:
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." assert (
prev_loss_scaling.dtype == e.dtype
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
helper = LayerHelper("update_loss_scaling", **locals()) helper = LayerHelper("update_loss_scaling", **locals())
...@@ -115,14 +138,14 @@ def update_loss_scaling(x, ...@@ -115,14 +138,14 @@ def update_loss_scaling(x,
'FoundInfinite': found_inf, 'FoundInfinite': found_inf,
'PrevLossScaling': prev_loss_scaling, 'PrevLossScaling': prev_loss_scaling,
'InGoodSteps': num_good_steps, 'InGoodSteps': num_good_steps,
'InBadSteps': num_bad_steps 'InBadSteps': num_bad_steps,
} }
outputs = { outputs = {
'Out': x, 'Out': x,
'LossScaling': prev_loss_scaling, 'LossScaling': prev_loss_scaling,
'OutGoodSteps': num_good_steps, 'OutGoodSteps': num_good_steps,
'OutBadSteps': num_bad_steps 'OutBadSteps': num_bad_steps,
} }
attrs = { attrs = {
...@@ -137,9 +160,8 @@ def update_loss_scaling(x, ...@@ -137,9 +160,8 @@ def update_loss_scaling(x,
else: else:
attrs['stop_update'] = stop_update attrs['stop_update'] = stop_update
helper.append_op(type='update_loss_scaling', helper.append_op(
inputs=inputs, type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs
outputs=outputs, )
attrs=attrs)
return x return x
...@@ -36,11 +36,11 @@ __all__ = ["decorate"] ...@@ -36,11 +36,11 @@ __all__ = ["decorate"]
class OptimizerWithMixedPrecision(object): class OptimizerWithMixedPrecision(object):
""" """
Optimizer with mixed-precision (MP) training. This is a wrapper of a common Optimizer with mixed-precision (MP) training. This is a wrapper of a common
optimizer, plus the support of mixed-precision pre-training. The object optimizer, plus the support of mixed-precision pre-training. The object
of this class almost has the same behavior as the common optimizer, with the of this class almost has the same behavior as the common optimizer, with the
methods `minimize()`, `backward()`, `apply_gradients()` implemented. methods `minimize()`, `backward()`, `apply_gradients()` implemented.
Additionally, it enables the MP training automatically, i.e, the creation Additionally, it enables the MP training automatically, i.e, the creation
and maintenance of master parameters, scaling of loss, etc. and maintenance of master parameters, scaling of loss, etc.
Args: Args:
...@@ -48,14 +48,14 @@ class OptimizerWithMixedPrecision(object): ...@@ -48,14 +48,14 @@ class OptimizerWithMixedPrecision(object):
amp_lists (CustomOpLists): An CustomOpLists object. amp_lists (CustomOpLists): An CustomOpLists object.
init_loss_scaling (float): The initial loss scaling factor. init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients. steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or accumulated steps with nan or
inf gradients. inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss incr_ratio(float): The multiplier to use when increasing the loss
scaling. scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling. the loss scaling.
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program. use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
...@@ -63,10 +63,20 @@ class OptimizerWithMixedPrecision(object): ...@@ -63,10 +63,20 @@ class OptimizerWithMixedPrecision(object):
""" """
def __init__(self, optimizer, amp_lists, init_loss_scaling, def __init__(
use_dynamic_loss_scaling, incr_every_n_steps, self,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio, use_pure_fp16, optimizer,
use_fp16_guard): amp_lists,
init_loss_scaling,
use_dynamic_loss_scaling,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
use_pure_fp16,
use_fp16_guard,
use_bf16=False,
):
self._optimizer = optimizer self._optimizer = optimizer
self._amp_lists = amp_lists self._amp_lists = amp_lists
self._param_grads = None self._param_grads = None
...@@ -77,11 +87,23 @@ class OptimizerWithMixedPrecision(object): ...@@ -77,11 +87,23 @@ class OptimizerWithMixedPrecision(object):
self._loss_scaling = None self._loss_scaling = None
self._init_loss_scaling = init_loss_scaling self._init_loss_scaling = init_loss_scaling
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
if use_bf16:
if use_dynamic_loss_scaling:
self._use_dynamic_loss_scaling = False
self._init_loss_scaling = 1.0
warnings.warn(
"Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
)
self._amp_dtype = core.VarDesc.VarType.BF16
else:
self._amp_dtype = core.VarDesc.VarType.FP16
self._learning_rate = optimizer._learning_rate self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = use_pure_fp16 self._use_pure_fp16 = use_pure_fp16
self._use_fp16_guard = use_fp16_guard self._use_fp16_guard = use_fp16_guard
self._to_fp16_var_names = None self._to_fp16_var_names = None
self._use_bf16 = use_bf16
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = incr_every_n_steps self._incr_every_n_steps = incr_every_n_steps
self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
...@@ -97,9 +119,10 @@ class OptimizerWithMixedPrecision(object): ...@@ -97,9 +119,10 @@ class OptimizerWithMixedPrecision(object):
self._is_distributed = flag self._is_distributed = flag
def get_loss_scaling(self): def get_loss_scaling(self):
"""Return the real-time loss scaling factor. """Return the real-time loss scaling factor."""
""" assert (
assert self._loss_scaling is not None, 'Please call minimize() before calling get_loss_scaling().' self._loss_scaling is not None
), 'Please call minimize() before calling get_loss_scaling().'
return self._loss_scaling return self._loss_scaling
def get_scaled_loss(self): def get_scaled_loss(self):
...@@ -117,7 +140,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -117,7 +140,8 @@ class OptimizerWithMixedPrecision(object):
shape=[1], shape=[1],
value=self._init_loss_scaling, value=self._init_loss_scaling,
dtype='float32', dtype='float32',
persistable=True) persistable=True,
)
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
self._num_good_steps = layers.create_global_var( self._num_good_steps = layers.create_global_var(
...@@ -125,37 +149,43 @@ class OptimizerWithMixedPrecision(object): ...@@ -125,37 +149,43 @@ class OptimizerWithMixedPrecision(object):
shape=[1], shape=[1],
value=0, value=0,
dtype='int32', dtype='int32',
persistable=True) persistable=True,
)
self._num_bad_steps = layers.create_global_var( self._num_bad_steps = layers.create_global_var(
name=unique_name.generate("num_bad_steps"), name=unique_name.generate("num_bad_steps"),
shape=[1], shape=[1],
value=0, value=0,
dtype='int32', dtype='int32',
persistable=True) persistable=True,
)
# Ensure the data type of learning rate vars is float32 (same as the # Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype) # master parameter dtype)
if isinstance(self._optimizer._learning_rate, float): if isinstance(self._optimizer._learning_rate, float):
self._optimizer._learning_rate_map[default_main_program()] = \ self._optimizer._learning_rate_map[
layers.create_global_var( default_main_program()
name=unique_name.generate("learning_rate"), ] = layers.create_global_var(
shape=[1], name=unique_name.generate("learning_rate"),
value=float(self._optimizer._learning_rate), shape=[1],
dtype='float32', value=float(self._optimizer._learning_rate),
persistable=True) dtype='float32',
persistable=True,
def backward(self, )
loss,
startup_program=None, def backward(
parameter_list=None, self,
no_grad_set=None, loss,
callbacks=None): startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None,
):
""" """
Backward propagation or auto differentiation for gradients' computation. Backward propagation or auto differentiation for gradients' computation.
Args: Args:
loss (Variable): The loss Variable to minimize. loss (Variable): The loss Variable to minimize.
startup_program (Program|None): The startup Program for initializing startup_program (Program|None): The startup Program for initializing
parameters in `parameter_list`. parameters in `parameter_list`.
parameter_list (list|None): A list of Variables to update. parameter_list (list|None): A list of Variables to update.
no_grad_set (set|None): A set of Variables should be ignored. no_grad_set (set|None): A set of Variables should be ignored.
...@@ -163,7 +193,7 @@ class OptimizerWithMixedPrecision(object): ...@@ -163,7 +193,7 @@ class OptimizerWithMixedPrecision(object):
backward operator for one parameter. backward operator for one parameter.
Returns: Returns:
A list of (param, grad), which is a tuple of a parameter and its A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss. gradient respectively, and the scaled loss.
""" """
train_program = loss.block.program train_program = loss.block.program
...@@ -171,9 +201,9 @@ class OptimizerWithMixedPrecision(object): ...@@ -171,9 +201,9 @@ class OptimizerWithMixedPrecision(object):
# NOTE(zhiqiu): _float_status is only used for NPU. # NOTE(zhiqiu): _float_status is only used for NPU.
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
float_status = paddle.static.data(name="float_status", float_status = paddle.static.data(
shape=[8], name="float_status", shape=[8], dtype='float32'
dtype='float32') )
self._train_program.global_block().append_op( self._train_program.global_block().append_op(
type="alloc_float_status", type="alloc_float_status",
outputs={"FloatStatus": float_status}, outputs={"FloatStatus": float_status},
...@@ -192,9 +222,15 @@ class OptimizerWithMixedPrecision(object): ...@@ -192,9 +222,15 @@ class OptimizerWithMixedPrecision(object):
if self._use_pure_fp16: if self._use_pure_fp16:
self._to_fp16_var_names = cast_model_to_fp16( self._to_fp16_var_names = cast_model_to_fp16(
self._train_program, self._amp_lists, self._use_fp16_guard) self._train_program,
self._amp_lists,
self._use_fp16_guard,
self._amp_dtype,
)
else: else:
rewrite_program(self._train_program, self._amp_lists) rewrite_program(
self._train_program, self._amp_lists, self._amp_dtype
)
if loss.dtype != core.VarDesc.VarType.FP32: if loss.dtype != core.VarDesc.VarType.FP32:
loss = loss.astype('float32') loss = loss.astype('float32')
...@@ -205,10 +241,13 @@ class OptimizerWithMixedPrecision(object): ...@@ -205,10 +241,13 @@ class OptimizerWithMixedPrecision(object):
else: else:
self._scaled_loss = loss self._scaled_loss = loss
params_grads = self._optimizer.backward(self._scaled_loss, params_grads = self._optimizer.backward(
startup_program, self._scaled_loss,
parameter_list, no_grad_set, startup_program,
callbacks) parameter_list,
no_grad_set,
callbacks,
)
if self._supports_check_nan_inf(): if self._supports_check_nan_inf():
self._add_cast_ops_to_startup_program(startup_program) self._add_cast_ops_to_startup_program(startup_program)
return params_grads return params_grads
...@@ -216,8 +255,11 @@ class OptimizerWithMixedPrecision(object): ...@@ -216,8 +255,11 @@ class OptimizerWithMixedPrecision(object):
def _add_cast_ops_to_startup_program(self, startup_program): def _add_cast_ops_to_startup_program(self, startup_program):
names = list(self._to_fp16_var_names) if self._to_fp16_var_names else [] names = list(self._to_fp16_var_names) if self._to_fp16_var_names else []
names.sort() names.sort()
startup_program = default_startup_program( startup_program = (
) if startup_program is None else startup_program default_startup_program()
if startup_program is None
else startup_program
)
block = startup_program.global_block() block = startup_program.global_block()
param_names = [p.name for p in block.all_parameters()] param_names = [p.name for p in block.all_parameters()]
for name in names: for name in names:
...@@ -225,28 +267,28 @@ class OptimizerWithMixedPrecision(object): ...@@ -225,28 +267,28 @@ class OptimizerWithMixedPrecision(object):
continue continue
tmp = block.create_var(dtype=core.VarDesc.VarType.FP32) tmp = block.create_var(dtype=core.VarDesc.VarType.FP32)
block.append_op(type='assign', block.append_op(
inputs={'X': [name]}, type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]}
outputs={'Out': [tmp]}) )
block.append_op(type='cast', block.append_op(
inputs={'X': [tmp]}, type='cast',
outputs={'Out': [name]}, inputs={'X': [tmp]},
attrs={ outputs={'Out': [name]},
'in_dtype': core.VarDesc.VarType.FP32, attrs={
'out_dtype': core.VarDesc.VarType.FP16, 'in_dtype': core.VarDesc.VarType.FP32,
}) 'out_dtype': self._amp_dtype,
},
)
self._to_fp16_var_names = None self._to_fp16_var_names = None
def amp_init(self, def amp_init(
place, self, place, scope=None, test_program=None, use_fp16_test=False
scope=None, ):
test_program=None,
use_fp16_test=False):
""" """
Init the amp training, such as cast fp32 parameters to fp16 type. Init the amp training, such as cast fp32 parameters to fp16 type.
Args: Args:
place(CUDAPlace): place is used to initialize place(CUDAPlace): place is used to initialize
fp16 parameters with fp32 values. fp16 parameters with fp32 values.
scope(Scope): The scope is used to find fp32 parameters. scope(Scope): The scope is used to find fp32 parameters.
test_program(Program): The program is used for testing. test_program(Program): The program is used for testing.
...@@ -273,7 +315,7 @@ class OptimizerWithMixedPrecision(object): ...@@ -273,7 +315,7 @@ class OptimizerWithMixedPrecision(object):
loss = paddle.mean(hidden) loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True. # 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy # Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way. # or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type. # 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists( amp_list = paddle.static.amp.CustomOpLists(
...@@ -293,30 +335,40 @@ class OptimizerWithMixedPrecision(object): ...@@ -293,30 +335,40 @@ class OptimizerWithMixedPrecision(object):
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`. # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope()) optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code() run_example_code()
""" """
assert self._train_program is not None, \ assert (
"Please call the minimize method first." self._train_program is not None
), "Please call the minimize method first."
if self._use_pure_fp16: if self._use_pure_fp16:
cast_parameters_to_fp16(place, self._train_program, scope, cast_parameters_to_fp16(
self._to_fp16_var_names) place,
self._train_program,
scope,
self._to_fp16_var_names,
self._amp_dtype,
)
if test_program is not None: if test_program is not None:
if self._use_pure_fp16: if self._use_pure_fp16:
cast_model_to_fp16(test_program, self._amp_lists, cast_model_to_fp16(
self._use_fp16_guard) test_program,
self._amp_lists,
self._use_fp16_guard,
self._amp_dtype,
)
elif use_fp16_test: elif use_fp16_test:
rewrite_program(test_program, self._amp_lists) rewrite_program(test_program, self._amp_lists, self._amp_dtype)
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
""" """
Check scaled gradients to determine whether to update loss scaling and update Check scaled gradients to determine whether to update loss scaling and update
parameters by their scaled gradients. parameters by their scaled gradients.
Args: Args:
params_grads (list): A list of params and scaled grads. params_grads (list): A list of params and scaled grads.
Returns: Returns:
A list of optimize operators. A list of optimize operators.
""" """
...@@ -327,7 +379,10 @@ class OptimizerWithMixedPrecision(object): ...@@ -327,7 +379,10 @@ class OptimizerWithMixedPrecision(object):
# When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
# the model can be optimized. # the model can be optimized.
if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0: if (
not self._use_dynamic_loss_scaling
and self._init_loss_scaling == 1.0
):
return self._optimizer.apply_gradients(params_grads) return self._optimizer.apply_gradients(params_grads)
if self._supports_check_nan_inf(): if self._supports_check_nan_inf():
...@@ -338,7 +393,10 @@ class OptimizerWithMixedPrecision(object): ...@@ -338,7 +393,10 @@ class OptimizerWithMixedPrecision(object):
return optimize_ops return optimize_ops
found_inf = self._check_finite_and_unscale(params_grads) found_inf = self._check_finite_and_unscale(params_grads)
if self._use_dynamic_loss_scaling: if (
self._use_dynamic_loss_scaling
and self._amp_dtype == core.VarDesc.VarType.FP16
):
self._add_dynamic_loss_scaling(params_grads, found_inf) self._add_dynamic_loss_scaling(params_grads, found_inf)
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
...@@ -346,13 +404,16 @@ class OptimizerWithMixedPrecision(object): ...@@ -346,13 +404,16 @@ class OptimizerWithMixedPrecision(object):
real_optimizer = self._optimizer real_optimizer = self._optimizer
while hasattr(real_optimizer, "inner_opt"): while hasattr(real_optimizer, "inner_opt"):
real_optimizer = real_optimizer.inner_opt real_optimizer = real_optimizer.inner_opt
if isinstance(real_optimizer, if isinstance(
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)): real_optimizer,
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
):
# NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
# copy it in advance to avoid multiple time copies. # copy it in advance to avoid multiple time copies.
with self._train_program._optimized_guard([]): with self._train_program._optimized_guard([]):
found_inf = paddle.tensor.creation._memcpy( found_inf = paddle.tensor.creation._memcpy(
found_inf, paddle.CPUPlace()) found_inf, paddle.CPUPlace()
)
real_optimizer._set_auxiliary_var('found_inf', found_inf) real_optimizer._set_auxiliary_var('found_inf', found_inf)
elif hasattr(real_optimizer, "_set_auxiliary_var"): elif hasattr(real_optimizer, "_set_auxiliary_var"):
real_optimizer._set_auxiliary_var('found_inf', found_inf) real_optimizer._set_auxiliary_var('found_inf', found_inf)
...@@ -362,9 +423,10 @@ class OptimizerWithMixedPrecision(object): ...@@ -362,9 +423,10 @@ class OptimizerWithMixedPrecision(object):
def _split_grads(self, params_grads): def _split_grads(self, params_grads):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] fp16_grads = [g for g in grads if g.dtype == self._amp_dtype]
assert len(fp32_grads) + len(fp16_grads) == len(grads), \ assert len(fp32_grads) + len(fp16_grads) == len(
"Data types of all grads must be either fp16 or fp32." grads
), "Data types of all grads must be either fp16/bf16 or fp32."
return grads, fp32_grads, fp16_grads return grads, fp32_grads, fp16_grads
def _check_finite_and_unscale(self, params_grads): def _check_finite_and_unscale(self, params_grads):
...@@ -380,7 +442,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -380,7 +442,8 @@ class OptimizerWithMixedPrecision(object):
grads, grads,
self._loss_scaling, self._loss_scaling,
name="find_infinite_scale", name="find_infinite_scale",
float_status=self._float_status) float_status=self._float_status,
)
found_infs.append(found_inf) found_infs.append(found_inf)
else: else:
for p, g in params_grads: for p, g in params_grads:
...@@ -391,7 +454,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -391,7 +454,8 @@ class OptimizerWithMixedPrecision(object):
], ],
self._loss_scaling, self._loss_scaling,
name="find_infinite_scale", name="find_infinite_scale",
float_status=self._float_status) float_status=self._float_status,
)
found_infs.append(found_inf) found_infs.append(found_inf)
elif self._use_pure_fp16: elif self._use_pure_fp16:
if fp32_grads: if fp32_grads:
...@@ -400,7 +464,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -400,7 +464,8 @@ class OptimizerWithMixedPrecision(object):
fp32_grads, fp32_grads,
self._loss_scaling, self._loss_scaling,
name="find_infinite_scale_fp32", name="find_infinite_scale_fp32",
float_status=self._float_status) float_status=self._float_status,
)
found_infs.append(fp32_found_inf) found_infs.append(fp32_found_inf)
if fp16_grads: if fp16_grads:
with self._train_program._optimized_guard(fp16_grads): with self._train_program._optimized_guard(fp16_grads):
...@@ -408,7 +473,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -408,7 +473,8 @@ class OptimizerWithMixedPrecision(object):
fp16_grads, fp16_grads,
self._loss_scaling, self._loss_scaling,
name="find_infinite_scale_fp16", name="find_infinite_scale_fp16",
float_status=self._float_status) float_status=self._float_status,
)
found_infs.append(fp16_found_inf) found_infs.append(fp16_found_inf)
else: else:
with self._train_program._optimized_guard(grads): with self._train_program._optimized_guard(grads):
...@@ -416,7 +482,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -416,7 +482,8 @@ class OptimizerWithMixedPrecision(object):
grads, grads,
self._loss_scaling, self._loss_scaling,
name="find_infinite_scale", name="find_infinite_scale",
float_status=self._float_status) float_status=self._float_status,
)
if self._is_distributed or self._use_pure_fp16: if self._is_distributed or self._use_pure_fp16:
with self._train_program._optimized_guard([]): with self._train_program._optimized_guard([]):
...@@ -439,7 +506,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -439,7 +506,8 @@ class OptimizerWithMixedPrecision(object):
self._incr_ratio, self._incr_ratio,
self._decr_ratio, self._decr_ratio,
stop_update=self._optimizer._get_stop_update_var(), stop_update=self._optimizer._get_stop_update_var(),
name="update_loss_scaling") name="update_loss_scaling",
)
return return
grads, fp32_grads, fp16_grads = self._split_grads(params_grads) grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
...@@ -447,42 +515,48 @@ class OptimizerWithMixedPrecision(object): ...@@ -447,42 +515,48 @@ class OptimizerWithMixedPrecision(object):
stop_update = False stop_update = False
with self._train_program._optimized_guard([]): with self._train_program._optimized_guard([]):
if fp32_grads: if fp32_grads:
update_loss_scaling(fp32_grads, update_loss_scaling(
found_inf, fp32_grads,
self._loss_scaling, found_inf,
self._num_good_steps, self._loss_scaling,
self._num_bad_steps, self._num_good_steps,
self._incr_every_n_steps, self._num_bad_steps,
self._decr_every_n_nan_or_inf, self._incr_every_n_steps,
self._incr_ratio, self._decr_every_n_nan_or_inf,
self._decr_ratio, self._incr_ratio,
stop_update=stop_update, self._decr_ratio,
name="update_loss_scaling_fp32") stop_update=stop_update,
name="update_loss_scaling_fp32",
)
stop_update = True stop_update = True
if fp16_grads: if fp16_grads:
update_loss_scaling(fp16_grads, update_loss_scaling(
found_inf, fp16_grads,
self._loss_scaling, found_inf,
self._num_good_steps, self._loss_scaling,
self._num_bad_steps, self._num_good_steps,
self._incr_every_n_steps, self._num_bad_steps,
self._decr_every_n_nan_or_inf, self._incr_every_n_steps,
self._incr_ratio, self._decr_every_n_nan_or_inf,
self._decr_ratio, self._incr_ratio,
stop_update=stop_update, self._decr_ratio,
name="update_loss_scaling_fp16") stop_update=stop_update,
name="update_loss_scaling_fp16",
)
else: else:
with self._train_program._optimized_guard([]): with self._train_program._optimized_guard([]):
update_loss_scaling(grads, update_loss_scaling(
found_inf, grads,
self._loss_scaling, found_inf,
self._num_good_steps, self._loss_scaling,
self._num_bad_steps, self._num_good_steps,
self._incr_every_n_steps, self._num_bad_steps,
self._decr_every_n_nan_or_inf, self._incr_every_n_steps,
self._incr_ratio, self._decr_every_n_nan_or_inf,
self._decr_ratio, self._incr_ratio,
name="update_loss_scaling") self._decr_ratio,
name="update_loss_scaling",
)
def apply_optimize(self, loss, startup_program, params_grads): def apply_optimize(self, loss, startup_program, params_grads):
program = loss.block.program program = loss.block.program
...@@ -490,11 +564,9 @@ class OptimizerWithMixedPrecision(object): ...@@ -490,11 +564,9 @@ class OptimizerWithMixedPrecision(object):
optimize_ops = self.apply_gradients(params_grads) optimize_ops = self.apply_gradients(params_grads)
return optimize_ops return optimize_ops
def minimize(self, def minimize(
loss, self, loss, startup_program=None, parameter_list=None, no_grad_set=None
startup_program=None, ):
parameter_list=None,
no_grad_set=None):
""" """
Perform optimization by minimizing the given loss. Perform optimization by minimizing the given loss.
...@@ -511,48 +583,55 @@ class OptimizerWithMixedPrecision(object): ...@@ -511,48 +583,55 @@ class OptimizerWithMixedPrecision(object):
""" """
opt_dict = self._optimizer.__class__.__dict__ opt_dict = self._optimizer.__class__.__dict__
if 'minimize' in opt_dict and isinstance(opt_dict['minimize'], if 'minimize' in opt_dict and isinstance(
types.FunctionType): opt_dict['minimize'], types.FunctionType
):
warnings.warn( warnings.warn(
"The decorated optimizer has its own `minimize` method, but it will not be executed." "The decorated optimizer has its own `minimize` method, but it will not be executed."
) )
scaled_params_grads = self.backward(loss, scaled_params_grads = self.backward(
startup_program=startup_program, loss,
parameter_list=parameter_list, startup_program=startup_program,
no_grad_set=no_grad_set) parameter_list=parameter_list,
no_grad_set=no_grad_set,
)
optimize_ops = self.apply_optimize(loss, startup_program, optimize_ops = self.apply_optimize(
scaled_params_grads) loss, startup_program, scaled_params_grads
)
return optimize_ops, scaled_params_grads return optimize_ops, scaled_params_grads
def decorate(optimizer, def decorate(
amp_lists=None, optimizer,
init_loss_scaling=2**15, amp_lists=None,
incr_every_n_steps=1000, init_loss_scaling=2**15,
decr_every_n_nan_or_inf=2, incr_every_n_steps=1000,
incr_ratio=2.0, decr_every_n_nan_or_inf=2,
decr_ratio=0.8, incr_ratio=2.0,
use_dynamic_loss_scaling=True, decr_ratio=0.8,
use_pure_fp16=False, use_dynamic_loss_scaling=True,
use_fp16_guard=None): use_pure_fp16=False,
""" use_fp16_guard=None,
use_bf16=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
Args: Args:
optimizer(Optimizer): A common Optimizer. optimizer(Optimizer): A common Optimizer.
amp_lists (CustomOpLists): An CustomOpLists object. amp_lists (CustomOpLists): An CustomOpLists object.
init_loss_scaling(float): The initial loss scaling factor. init_loss_scaling(float): The initial loss scaling factor.
incr_every_n_steps(int): Increases loss scaling every n consecutive incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients. steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or accumulated steps with nan or
inf gradients. inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss incr_ratio(float): The multiplier to use when increasing the loss
scaling. scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling. the loss scaling.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
...@@ -560,11 +639,11 @@ def decorate(optimizer, ...@@ -560,11 +639,11 @@ def decorate(optimizer,
Default None, which means that its value equals to `use_pure_fp16`. Default None, which means that its value equals to `use_pure_fp16`.
Returns: Returns:
An optimizer acting like a normal one but with mixed-precision training An optimizer acting like a normal one but with mixed-precision training
enabled. enabled.
Examples 1: Examples 1:
.. code-block:: python .. code-block:: python
# black&white list based strategy example # black&white list based strategy example
import paddle import paddle
...@@ -604,7 +683,7 @@ def decorate(optimizer, ...@@ -604,7 +683,7 @@ def decorate(optimizer,
loss = paddle.mean(hidden) loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True. # 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy # Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way. # or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type. # 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists( amp_list = paddle.static.amp.CustomOpLists(
...@@ -624,19 +703,82 @@ def decorate(optimizer, ...@@ -624,19 +703,82 @@ def decorate(optimizer,
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`. # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope()) optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code() run_example_code()
""" """
dtype = "bfloat16" if use_bf16 else "float16"
if amp_lists is None: if amp_lists is None:
amp_lists = AutoMixedPrecisionLists() amp_lists = AutoMixedPrecisionLists(dtype=dtype)
if use_fp16_guard is None: if use_fp16_guard is None:
use_fp16_guard = use_pure_fp16 use_fp16_guard = use_pure_fp16
mp_optimizer = OptimizerWithMixedPrecision( mp_optimizer = OptimizerWithMixedPrecision(
optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling, optimizer,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, amp_lists,
use_pure_fp16, use_fp16_guard) init_loss_scaling,
use_dynamic_loss_scaling,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
use_pure_fp16,
use_fp16_guard,
use_bf16,
)
return mp_optimizer
def amp_decorate(
optimizer,
amp_lists=None,
level='O1',
dtype='float16',
init_loss_scaling=2**15,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_amp_guard=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
"""
# check amp_dtype: float16 or bfloat16
dtype = dtype.lower()
if not (dtype in ['float16', 'bfloat16']):
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=dtype)
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
use_pure_fp16 = level == "O2"
use_fp16_guard = use_amp_guard
use_bf16 = dtype == "bfloat16"
mp_optimizer = OptimizerWithMixedPrecision(
optimizer,
amp_lists,
init_loss_scaling,
use_dynamic_loss_scaling,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
use_pure_fp16,
use_fp16_guard,
use_bf16,
)
return mp_optimizer return mp_optimizer
...@@ -13,16 +13,47 @@ ...@@ -13,16 +13,47 @@
# limitations under the License. # limitations under the License.
import copy import copy
from ... import core from ... import core
__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"] __all__ = ["CustomOpLists", "AutoMixedPrecisionLists"]
# lookup_table fp16 is slower than fp32, though fp16 is supported. # lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list = { _extra_unsupported_list = {
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' 'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
} }
def _get_unsupported_list(dtype):
if dtype == "float16":
amp_dtype = core.VarDesc.VarType.FP16
elif dtype == "bfloat16":
amp_dtype = core.VarDesc.VarType.BF16
else:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_list = []
# _sys_unsupported_bf16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_list = core.op_supported_infos('XPU', amp_dtype)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_list = core.op_supported_infos('NPU', amp_dtype)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_list = core.op_supported_infos('MLU', amp_dtype)
else:
_, _, _sys_unsupported_list = core.op_supported_infos('GPU', amp_dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
class AutoMixedPrecisionLists(object): class AutoMixedPrecisionLists(object):
""" """
AutoMixedPrecisionLists is a class for black/white list. It can update AutoMixedPrecisionLists is a class for black/white list. It can update
...@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object): ...@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object):
custom_black_varnames (set): Users' custom black varibles' names. custom_black_varnames (set): Users' custom black varibles' names.
""" """
def __init__(self, def __init__(
custom_white_list=None, self,
custom_black_list=None, custom_white_list=None,
custom_black_varnames=None): custom_black_list=None,
custom_black_varnames=None,
dtype="float16",
):
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.amp_dtype = dtype
self.white_list = copy.copy(white_list) self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list) self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames) self.black_varnames = copy.copy(custom_black_varnames)
self._update_list() self._update_list()
...@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object): ...@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object):
if self._custom_white_list and self._custom_black_list: if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self._custom_black_list: if op_name in self._custom_black_list:
raise ValueError("Custom white list overlap " raise ValueError(
"custom black list") "Custom white list overlap " "custom black list"
)
if self._custom_white_list: if self._custom_white_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self.black_list: if op_name in self.black_list:
...@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object): ...@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object):
elif op_name in self.gray_list: elif op_name in self.gray_list:
self.gray_list.remove(op_name) self.gray_list.remove(op_name)
self.white_list.add(op_name) self.white_list.add(op_name)
if op_name in _extra_unsupported_fp16_list: if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name) self.unsupported_list.remove(op_name)
if self._custom_black_list: if self._custom_black_list:
for op_name in self._custom_black_list: for op_name in self._custom_black_list:
...@@ -170,22 +206,4 @@ gray_list = { ...@@ -170,22 +206,4 @@ gray_list = {
'fused_multi_transformer', 'fused_multi_transformer',
} }
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_fp16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'XPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'NPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'MLU', core.VarDesc.VarType.FP16)
else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16)
unsupported_fp16_list = _extra_unsupported_fp16_list | _sys_unsupported_fp16_list
CustomOpLists = AutoMixedPrecisionLists CustomOpLists = AutoMixedPrecisionLists
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import paddle
from ... import core from ... import core
from ... import framework from ... import framework
from ... import layers from ... import layers
...@@ -27,13 +28,14 @@ import numpy as np ...@@ -27,13 +28,14 @@ import numpy as np
__all__ = ["fp16_guard", "cast_model_to_fp16", "cast_parameters_to_fp16"] __all__ = ["fp16_guard", "cast_model_to_fp16", "cast_parameters_to_fp16"]
_logger = get_logger(__name__, _logger = get_logger(
logging.INFO, __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
fmt='%(asctime)s-%(levelname)s: %(message)s') )
_valid_types = [ _valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.LOD_TENSOR_ARRAY core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
] ]
_fp16_guard_pattern = "__use_fp16__" _fp16_guard_pattern = "__use_fp16__"
...@@ -75,7 +77,9 @@ def _dtype_to_str(dtype): ...@@ -75,7 +77,9 @@ def _dtype_to_str(dtype):
Args: Args:
dtype (VarType): Variable type. dtype (VarType): Variable type.
""" """
if dtype == core.VarDesc.VarType.FP16: if dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return 'fp16' return 'fp16'
else: else:
return 'fp32' return 'fp32'
...@@ -108,7 +112,12 @@ def _keep_fp32_input(op, in_name): ...@@ -108,7 +112,12 @@ def _keep_fp32_input(op, in_name):
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'} return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
if op_type in ['fused_attention', 'fused_feedforward']: if op_type in ['fused_attention', 'fused_feedforward']:
return in_name in { return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" 'LnScale',
'LnBias',
'Ln2Scale',
'Ln2Bias',
"Ln1Scale",
"Ln1Bias",
} }
if op_type == 'fused_multi_transformer': if op_type == 'fused_multi_transformer':
return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'} return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'}
...@@ -125,8 +134,12 @@ def _keep_fp32_output(op, out_name): ...@@ -125,8 +134,12 @@ def _keep_fp32_output(op, out_name):
return out_name not in {'Y', 'ConvX', 'ConvZ'} return out_name not in {'Y', 'ConvX', 'ConvZ'}
if op_type in ['fused_attention', 'fused_feedforward']: if op_type in ['fused_attention', 'fused_feedforward']:
return out_name in { return out_name in {
'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean', 'LnMean',
'Ln1Variance' 'LnVariance',
'Ln2Mean',
'Ln2Variance',
'Ln1Mean',
'Ln1Variance',
} }
return False return False
...@@ -149,7 +162,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -149,7 +162,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name): op, in_name
):
continue continue
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name) in_var = block._find_var_recursive(in_var_name)
...@@ -165,11 +179,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -165,11 +179,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
# set cast_op device to `all`, can reduce send cast_var. # set cast_op device to `all`, can reduce send cast_var.
# TODO: need remove this after we unified the dynamic # TODO: need remove this after we unified the dynamic
# and static pipeline interface. # and static pipeline interface.
if src_dtype == core.VarDesc.VarType.FP32 and in_var.stop_gradient: if (
src_dtype == core.VarDesc.VarType.FP32
and in_var.stop_gradient
):
prev_op = None prev_op = None
if in_var.op is op: if in_var.op is op:
prev_op = find_true_prev_op(block.ops, op, prev_op = find_true_prev_op(
in_var_name) block.ops, op, in_var_name
)
elif in_var.op is not None: elif in_var.op is not None:
prev_op = in_var.op prev_op = in_var.op
...@@ -177,33 +195,40 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -177,33 +195,40 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if prev_op is not None: if prev_op is not None:
prev_op_device = prev_op.attr('op_device') prev_op_device = prev_op.attr('op_device')
if prev_op_device is not None and 'all' in prev_op_device: if (
prev_op_device is not None
and 'all' in prev_op_device
):
op_device = prev_op_device op_device = prev_op_device
out_var = block.create_var( out_var = block.create_var(
name=cast_name, name=cast_name,
dtype=dest_dtype, dtype=dest_dtype,
persistable=False, persistable=False,
stop_gradient=in_var.stop_gradient) stop_gradient=in_var.stop_gradient,
)
block._insert_op_without_sync(idx,
type="cast", block._insert_op_without_sync(
inputs={"X": in_var}, idx,
outputs={"Out": out_var}, type="cast",
attrs={ inputs={"X": in_var},
"in_dtype": in_var.dtype, outputs={"Out": out_var},
"out_dtype": attrs={
out_var.dtype, "in_dtype": in_var.dtype,
"op_device": op_device, "out_dtype": out_var.dtype,
"op_role": "op_device": op_device,
op.attr("op_role"), "op_role": op.attr("op_role"),
}) },
)
num_cast_ops += 1 num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name) _rename_arg(op, in_var.name, out_var.name)
else: else:
if op.has_attr('in_dtype'): if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype) op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype in [
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.BF16,
]:
for out_name in op.output_names: for out_name in op.output_names:
if _keep_fp32_output(op, out_name): if _keep_fp32_output(op, out_name):
continue continue
...@@ -212,41 +237,48 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -212,41 +237,48 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if out_var.type not in _valid_types: if out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP32: if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16) out_var.desc.set_dtype(dest_dtype)
if op.has_attr('out_dtype'): if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16) op._set_attr('out_dtype', dest_dtype)
return num_cast_ops return num_cast_ops
def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, def _insert_cast_post_op(
op_var_rename_map): block, op, idx, src_dtype, dest_dtype, target_name, op_var_rename_map
):
num_cast_ops = 0 num_cast_ops = 0
target_var = block.var(target_name) target_var = block.var(target_name)
if target_var.type not in _valid_types or target_var.dtype == dest_dtype: if target_var.type not in _valid_types or target_var.dtype == dest_dtype:
return num_cast_ops return num_cast_ops
assert target_var.dtype == src_dtype, \ assert (
"The real dtype({}) is not equal to the src dtype({})".format( target_var.dtype == src_dtype
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)) ), "The real dtype({}) is not equal to the src dtype({})".format(
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)
)
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype) cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name) cast_var = block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dest_dtype: if cast_var is None or cast_var.dtype != dest_dtype:
cast_var = block.create_var(name=cast_name, cast_var = block.create_var(
dtype=dest_dtype, name=cast_name,
persistable=False, dtype=dest_dtype,
stop_gradient=target_var.stop_gradient) persistable=False,
block._insert_op(idx, stop_gradient=target_var.stop_gradient,
type="cast", )
inputs={"X": target_var}, block._insert_op(
outputs={"Out": cast_var}, idx,
attrs={ type="cast",
"in_dtype": target_var.dtype, inputs={"X": target_var},
"out_dtype": cast_var.dtype, outputs={"Out": cast_var},
"op_device": op.attr("op_device"), attrs={
"op_role": op.attr("op_role"), "in_dtype": target_var.dtype,
}) "out_dtype": cast_var.dtype,
"op_device": op.attr("op_device"),
"op_role": op.attr("op_role"),
},
)
num_cast_ops += 1 num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name op_var_rename_map[block.idx][target_var.name] = cast_var.name
...@@ -272,8 +304,10 @@ def find_true_prev_op(ops, cur_op, var_name): ...@@ -272,8 +304,10 @@ def find_true_prev_op(ops, cur_op, var_name):
prev_op.append(op) prev_op.append(op)
if prev_op: if prev_op:
if not len(prev_op) == 1: if not len(prev_op) == 1:
raise ValueError("There must be only one previous op " raise ValueError(
"that outputs {0} variable".format(var_name)) "There must be only one previous op "
"that outputs {0} variable".format(var_name)
)
else: else:
return prev_op[0] return prev_op[0]
return None return None
...@@ -315,8 +349,7 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False): ...@@ -315,8 +349,7 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False):
def find_op_index(block_desc, cur_op_desc): def find_op_index(block_desc, cur_op_desc):
""" """ """
"""
for idx in range(block_desc.op_size()): for idx in range(block_desc.op_size()):
if cur_op_desc == block_desc.op(idx): if cur_op_desc == block_desc.op(idx):
return idx return idx
...@@ -350,8 +383,9 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard): ...@@ -350,8 +383,9 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
return True return True
if use_fp16_guard: if use_fp16_guard:
if op.has_attr("op_namescope") and \ if op.has_attr("op_namescope") and (
(_fp16_guard_pattern in op.attr("op_namescope")): _fp16_guard_pattern in op.attr("op_namescope")
):
# op in fp16 guard # op in fp16 guard
return False return False
else: else:
...@@ -388,7 +422,12 @@ def fp16_guard(): ...@@ -388,7 +422,12 @@ def fp16_guard():
yield yield
def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): def cast_model_to_fp16(
program,
amp_lists=None,
use_fp16_guard=True,
dest_type=core.VarDesc.VarType.FP16,
):
""" """
Traverse all ops in the whole model and set their inputs and outputs Traverse all ops in the whole model and set their inputs and outputs
to the fp16 data type. This function will do some special process for to the fp16 data type. This function will do some special process for
...@@ -399,6 +438,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -399,6 +438,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object. amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
use_fp16_guard(bool): Determine whether to use `fp16_guard` when use_fp16_guard(bool): Determine whether to use `fp16_guard` when
constructing the program. Default True. constructing the program. Default True.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
""" """
if amp_lists is None: if amp_lists is None:
...@@ -421,7 +461,8 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -421,7 +461,8 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
for in_name in op.input_names: for in_name in op.input_names:
# for ipu, all inputs must be converted to fp16 # for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input( if not core.is_compiled_with_ipu() and _keep_fp32_input(
op, in_name): op, in_name
):
continue continue
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = None in_var = None
...@@ -429,29 +470,36 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -429,29 +470,36 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
in_var = block.var(in_var_name) in_var = block.var(in_var_name)
except ValueError as e: except ValueError as e:
_logger.debug( _logger.debug(
"-- {}, try to get it in the global block --". "-- {}, try to get it in the global block --".format(
format(e)) e
)
)
in_var = global_block.var(in_var_name) in_var = global_block.var(in_var_name)
if in_var is not None: if in_var is not None:
_logger.debug( _logger.debug(
"-- var {} is got in the global block --". "-- var {} is got in the global block --".format(
format(in_var_name)) in_var_name
)
)
if in_var is None or in_var.type not in _valid_types: if in_var is None or in_var.type not in _valid_types:
continue continue
if in_var.dtype == core.VarDesc.VarType.FP32: if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(core.VarDesc.VarType.FP16) in_var.desc.set_dtype(dest_type)
to_fp16_var_names.add(in_var_name) to_fp16_var_names.add(in_var_name)
_logger.debug( _logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --". "-- op type: {}, in var name: {}, in var dtype: {} --".format(
format(op.type, in_var_name, in_var.dtype)) op.type, in_var_name, in_var.dtype
)
)
for out_name in op.output_names: for out_name in op.output_names:
# for ipu, all outputs must be converted to fp16 # for ipu, all outputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_output( if not core.is_compiled_with_ipu() and _keep_fp32_output(
op, out_name): op, out_name
):
continue continue
for out_var_name in op.output(out_name): for out_var_name in op.output(out_name):
out_var = None out_var = None
...@@ -459,32 +507,35 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -459,32 +507,35 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
out_var = block.var(out_var_name) out_var = block.var(out_var_name)
except ValueError as e: except ValueError as e:
_logger.debug( _logger.debug(
"-- {}, try to get it in the global block --". "-- {}, try to get it in the global block --".format(
format(e)) e
)
)
out_var = global_block.var(out_var_name) out_var = global_block.var(out_var_name)
if out_var is not None: if out_var is not None:
_logger.debug( _logger.debug(
"-- var {} is got in the global block --". "-- var {} is got in the global block --".format(
format(out_var_name)) out_var_name
)
)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP32: if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16) out_var.desc.set_dtype(dest_type)
_logger.debug( _logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --" "-- op type: {}, out var name: {}, out var dtype: {} --".format(
.format(op.type, out_var_name, out_var.dtype)) op.type, out_var_name, out_var.dtype
if op.has_attr('in_dtype') and op.attr( )
'in_dtype') == core.VarDesc.VarType.FP32: )
op._set_attr('in_dtype', core.VarDesc.VarType.FP16) for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if op.has_attr('out_dtype') and op.attr( if (
'out_dtype') == core.VarDesc.VarType.FP32: op.has_attr(attr_name)
op._set_attr('out_dtype', core.VarDesc.VarType.FP16) and op.attr(attr_name) == core.VarDesc.VarType.FP32
if op.has_attr('dtype') and op.attr( ):
'dtype') == core.VarDesc.VarType.FP32: op._set_attr(attr_name, dest_type)
op._set_attr('dtype', core.VarDesc.VarType.FP16)
# process ops in keep_fp32_ops # process ops in keep_fp32_ops
op_var_rename_map = [ op_var_rename_map = [
...@@ -497,25 +548,29 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -497,25 +548,29 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
op = ops[idx] op = ops[idx]
num_cast_ops = 0 num_cast_ops = 0
if op in keep_fp32_ops: if op in keep_fp32_ops:
pre_cast_num = _insert_cast_op(block, op, idx, pre_cast_num = _insert_cast_op(
core.VarDesc.VarType.FP16, block, op, idx, dest_type, core.VarDesc.VarType.FP32
core.VarDesc.VarType.FP32) )
num_cast_ops += pre_cast_num num_cast_ops += pre_cast_num
for out_var_name in op.output_arg_names: for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == dest_type:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name) post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops: for post_op in post_ops:
if post_op in keep_fp32_ops: if post_op in keep_fp32_ops:
continue continue
post_cast_num = _insert_cast_post_op( post_cast_num = _insert_cast_post_op(
block, op, idx + pre_cast_num + 1, block,
op,
idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, out_var_name, dest_type,
op_var_rename_map) out_var_name,
op_var_rename_map,
)
num_cast_ops += post_cast_num num_cast_ops += post_cast_num
idx += num_cast_ops + 1 idx += num_cast_ops + 1
...@@ -523,7 +578,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -523,7 +578,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
return to_fp16_var_names return to_fp16_var_names
def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): def _convert_float_to_bfloat16(place, fp32_array):
paddle.disable_static()
framework._set_expected_place(place)
fp32_tensor = paddle.to_tensor(fp32_array)
bf16_array = paddle.cast(fp32_tensor, paddle.bfloat16).numpy()
paddle.enable_static()
return bf16_array
def cast_parameters_to_fp16(
place,
program,
scope=None,
to_fp16_var_names=None,
dest_type=core.VarDesc.VarType.FP16,
):
""" """
Traverse all parameters in the whole model and set them to the FP16 data type. Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32. Whereas, this function will keep parameters of batchnorms in FP32.
...@@ -535,6 +605,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): ...@@ -535,6 +605,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names` to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API. value of `cast_model_to_fp16` API.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
""" """
all_parameters = [] all_parameters = []
for block in program.blocks: for block in program.blocks:
...@@ -544,13 +615,22 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): ...@@ -544,13 +615,22 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
var_scope = scope if scope else global_scope() var_scope = scope if scope else global_scope()
for param in all_parameters: for param in all_parameters:
if param.name in fp16_var_names: if param.name in fp16_var_names:
_logger.debug("---- cast {} to fp16 dtype ----".format(param.name)) _logger.debug(
param_t = var_scope.find_var(param.name).get_tensor() "---- cast {} to fp16/bf16 dtype ----".format(param.name)
data = np.array(param_t) )
param_t.set(np.float16(data), place) if var_scope.find_var(param.name):
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
if dest_type == core.VarDesc.VarType.BF16:
bf16_data = _convert_float_to_bfloat16(place, data)
param_t.set(bf16_data, place)
else:
param_t.set(np.float16(data), place)
else:
_logger.warning(f"Cannot find {param.name}")
def rewrite_program(main_prog, amp_lists): def rewrite_program(main_prog, amp_lists, dest_type=core.VarDesc.VarType.FP16):
""" """
Traverse all ops in current block and insert cast op according to Traverse all ops in current block and insert cast op according to
which set current op belongs to. which set current op belongs to.
...@@ -569,6 +649,7 @@ def rewrite_program(main_prog, amp_lists): ...@@ -569,6 +649,7 @@ def rewrite_program(main_prog, amp_lists):
Args: Args:
main_prog (Program): The main program for training. main_prog (Program): The main program for training.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
""" """
block = main_prog.global_block() block = main_prog.global_block()
block._sync_with_cpp() block._sync_with_cpp()
...@@ -585,7 +666,8 @@ def rewrite_program(main_prog, amp_lists): ...@@ -585,7 +666,8 @@ def rewrite_program(main_prog, amp_lists):
continue continue
if amp_lists.black_varnames is not None and _is_in_black_varnames( if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists): op, amp_lists
):
black_op_set.add(op) black_op_set.add(op)
continue continue
...@@ -611,11 +693,15 @@ def rewrite_program(main_prog, amp_lists): ...@@ -611,11 +693,15 @@ def rewrite_program(main_prog, amp_lists):
else: else:
prev_op = in_var.op prev_op = in_var.op
# if it's one of inputs # if it's one of inputs
if prev_op in black_op_set or \ if (
prev_op.type in amp_lists.black_list: prev_op in black_op_set
or prev_op.type in amp_lists.black_list
):
is_black_op = True is_black_op = True
elif prev_op in white_op_set or \ elif (
prev_op.type in amp_lists.white_list: prev_op in white_op_set
or prev_op.type in amp_lists.white_list
):
is_white_op = True is_white_op = True
if is_black_op: if is_black_op:
black_op_set.add(op) black_op_set.add(op)
...@@ -633,13 +719,13 @@ def rewrite_program(main_prog, amp_lists): ...@@ -633,13 +719,13 @@ def rewrite_program(main_prog, amp_lists):
op = ops[idx] op = ops[idx]
num_cast_ops = 0 num_cast_ops = 0
if op in black_op_set: if op in black_op_set:
num_cast_ops = _insert_cast_op(block, op, idx, num_cast_ops = _insert_cast_op(
core.VarDesc.VarType.FP16, block, op, idx, dest_type, core.VarDesc.VarType.FP32
core.VarDesc.VarType.FP32) )
elif op in white_op_set: elif op in white_op_set:
num_cast_ops = _insert_cast_op(block, op, idx, num_cast_ops = _insert_cast_op(
core.VarDesc.VarType.FP32, block, op, idx, core.VarDesc.VarType.FP32, dest_type
core.VarDesc.VarType.FP16) )
else: else:
pass pass
...@@ -670,13 +756,16 @@ def update_role_var_grad(main_prog, params_grads): ...@@ -670,13 +756,16 @@ def update_role_var_grad(main_prog, params_grads):
if role & int(BACKWARD) and op.has_attr('op_role_var'): if role & int(BACKWARD) and op.has_attr('op_role_var'):
op._remove_attr("op_role_var") op._remove_attr("op_role_var")
else: else:
raise ValueError("The cast op {0} must be in BACKWARD role " raise ValueError(
"and have op_role_var attr.".format(op)) "The cast op {0} must be in BACKWARD role "
"and have op_role_var attr.".format(op)
)
fp16_grad_name = op.input(op.input_names[0])[0] fp16_grad_name = op.input(op.input_names[0])[0]
op_for_fp16_grad = find_true_prev_op(block.ops, op, fp16_grad_name) op_for_fp16_grad = find_true_prev_op(block.ops, op, fp16_grad_name)
op_role_var_attr_name = \ op_role_var_attr_name = (
core.op_proto_and_checker_maker.kOpRoleVarAttrName() core.op_proto_and_checker_maker.kOpRoleVarAttrName()
)
attr_val = [p.name, fp16_grad_name] attr_val = [p.name, fp16_grad_name]
if op_for_fp16_grad.has_attr(op_role_var_attr_name): if op_for_fp16_grad.has_attr(op_role_var_attr_name):
attr_val.extend(op_for_fp16_grad.attr(op_role_var_attr_name)) attr_val.extend(op_for_fp16_grad.attr(op_role_var_attr_name))
...@@ -690,18 +779,22 @@ def update_role_var_grad(main_prog, params_grads): ...@@ -690,18 +779,22 @@ def update_role_var_grad(main_prog, params_grads):
continue continue
post_ops = find_true_post_op(block.ops, op, g.name) post_ops = find_true_post_op(block.ops, op, g.name)
if post_ops: if post_ops:
raise ValueError("The cast op {0}'s output should not be" raise ValueError(
"used by a non-optimize op, however, it" "The cast op {0}'s output should not be"
"is used by {1}".format(op, post_ops[0])) "used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])
)
# add new op in the python and cpp at the same time # add new op in the python and cpp at the same time
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
new_op = framework.Operator(block=block, new_op = framework.Operator(
desc=new_op_desc, block=block,
type=None, desc=new_op_desc,
inputs=None, type=None,
outputs=None, inputs=None,
attrs=None) outputs=None,
attrs=None,
)
block.ops.append(new_op) block.ops.append(new_op)
op_idx = find_op_index(block.desc, op.desc) op_idx = find_op_index(block.desc, op.desc)
if op_idx == -1: if op_idx == -1:
......
...@@ -18,19 +18,32 @@ import six ...@@ -18,19 +18,32 @@ import six
import paddle import paddle
from paddle.fluid import framework, backward, core, program_guard from paddle.fluid import framework, backward, core, program_guard
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor from paddle.fluid.executor import (
_is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor,
)
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM,
)
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.framework import _apply_pass from paddle.fluid.framework import _apply_pass
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists from paddle.fluid.contrib.mixed_precision.decorator import (
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16 AutoMixedPrecisionLists,
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard )
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
rewrite_program,
cast_model_to_fp16,
)
from paddle.fluid.dygraph.amp.auto_cast import (
_in_amp_guard,
_in_pure_fp16_guard,
)
import paddle.compat as cpt import paddle.compat as cpt
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
...@@ -64,7 +77,8 @@ class NestSequence(object): ...@@ -64,7 +77,8 @@ class NestSequence(object):
var_ids = [] var_ids = []
for idx, var in enumerate(self.__input_list): for idx, var in enumerate(self.__input_list):
if isinstance( if isinstance(
var, (framework.Variable, core.VarBase, core.eager.Tensor)): var, (framework.Variable, core.VarBase, core.eager.Tensor)
):
var_ids.append(idx) var_ids.append(idx)
return var_ids return var_ids
...@@ -77,15 +91,17 @@ class NestSequence(object): ...@@ -77,15 +91,17 @@ class NestSequence(object):
warning_types = set() warning_types = set()
for var in self.__input_list: for var in self.__input_list:
if not isinstance( if not isinstance(
var, var, (framework.Variable, core.VarBase, core.eager.Tensor)
(framework.Variable, core.VarBase, core.eager.Tensor)): ):
warning_types.add(type(var)) warning_types.add(type(var))
if warning_types: if warning_types:
logging_utils.warn( logging_utils.warn(
"Output of traced function contains non-tensor type values: {}. " "Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return " "Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.". "what we first saw. Please try to return them as tensor.".format(
format(list(warning_types))) list(warning_types)
)
)
@property @property
def var_ids(self): def var_ids(self):
...@@ -139,12 +155,9 @@ class PartialProgramLayer: ...@@ -139,12 +155,9 @@ class PartialProgramLayer:
Layer: A Layer object that run all ops internally in static mode. Layer: A Layer object that run all ops internally in static mode.
""" """
def __init__(self, def __init__(
main_program, self, main_program, inputs, outputs, parameters=None, **kwargs
inputs, ):
outputs,
parameters=None,
**kwargs):
super(PartialProgramLayer, self).__init__() super(PartialProgramLayer, self).__init__()
self._inputs = NestSequence(inputs) self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True) self._outputs = NestSequence(outputs, need_check=True)
...@@ -160,14 +173,18 @@ class PartialProgramLayer: ...@@ -160,14 +173,18 @@ class PartialProgramLayer:
# Set default mode to train # Set default mode to train
self.training = True self.training = True
custom_white_list, custom_black_list = None, None amp_dtype, custom_white_list, custom_black_list = None, None, None
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if tracer: if tracer:
custom_white_list, custom_black_list = tracer._get_amp_op_list() custom_white_list, custom_black_list = tracer._get_amp_op_list()
# For AMP training amp_dtype = tracer._amp_dtype
self._amp_list = AutoMixedPrecisionLists( if amp_dtype is not None and amp_dtype in ['float16', 'bfloat16']:
custom_white_list=custom_white_list, # For AMP training
custom_black_list=custom_black_list) self._amp_list = AutoMixedPrecisionLists(
custom_white_list=custom_white_list,
custom_black_list=custom_black_list,
dtype=amp_dtype,
)
# program_id -> list(scope) # program_id -> list(scope)
self._scope_cache = {} self._scope_cache = {}
...@@ -203,7 +220,8 @@ class PartialProgramLayer: ...@@ -203,7 +220,8 @@ class PartialProgramLayer:
return self._origin_main_program.clone(for_test=is_infer_mode) return self._origin_main_program.clone(for_test=is_infer_mode)
else: else:
train_program = self._append_backward_desc( train_program = self._append_backward_desc(
self._origin_main_program) self._origin_main_program
)
# Note: Only set grad type once after initializing train program. So we put it here. # Note: Only set grad type once after initializing train program. So we put it here.
self._set_grad_type(self._params, train_program) self._set_grad_type(self._params, train_program)
return train_program return train_program
...@@ -223,16 +241,18 @@ class PartialProgramLayer: ...@@ -223,16 +241,18 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_pure_fp16_program(self, is_infer_mode=False): def _create_pure_fp16_program(self, is_infer_mode=False):
pure_fp16_program = self._origin_main_program.clone( pure_fp16_program = self._origin_main_program.clone(
for_test=is_infer_mode) for_test=is_infer_mode
)
with program_guard(pure_fp16_program): with program_guard(pure_fp16_program):
cast_model_to_fp16(pure_fp16_program, cast_model_to_fp16(
self._amp_list, pure_fp16_program, self._amp_list, use_fp16_guard=False
use_fp16_guard=False) )
if is_infer_mode: if is_infer_mode:
return pure_fp16_program return pure_fp16_program
else: else:
train_pure_fp16_program = self._append_backward_desc( train_pure_fp16_program = self._append_backward_desc(
pure_fp16_program) pure_fp16_program
)
self._set_grad_type(self._params, train_pure_fp16_program) self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program return train_pure_fp16_program
...@@ -240,23 +260,27 @@ class PartialProgramLayer: ...@@ -240,23 +260,27 @@ class PartialProgramLayer:
def _create_forward_backward_train_program(self): def _create_forward_backward_train_program(self):
whole_program = self._create_program() whole_program = self._create_program()
forward_end_op_index = self._infer_program.desc.block(0).op_size() forward_end_op_index = self._infer_program.desc.block(0).op_size()
return self._get_forward_backward_program_form(whole_program, return self._get_forward_backward_program_form(
forward_end_op_index) whole_program, forward_end_op_index
)
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_amp_program(self): def _create_forward_backward_train_amp_program(self):
whole_program = self._create_amp_program() whole_program = self._create_amp_program()
forward_end_op_index = self._infer_amp_program.desc.block(0).op_size() forward_end_op_index = self._infer_amp_program.desc.block(0).op_size()
return self._get_forward_backward_program_form(whole_program, return self._get_forward_backward_program_form(
forward_end_op_index) whole_program, forward_end_op_index
)
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self): def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._create_pure_fp16_program() whole_program = self._create_pure_fp16_program()
forward_end_op_index = self._infer_pure_fp16_program.desc.block( forward_end_op_index = self._infer_pure_fp16_program.desc.block(
0).op_size() 0
return self._get_forward_backward_program_form(whole_program, ).op_size()
forward_end_op_index) return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
@LazyInitialized @LazyInitialized
def _train_program(self): def _train_program(self):
...@@ -352,8 +376,9 @@ class PartialProgramLayer: ...@@ -352,8 +376,9 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _train_program_id(self): def _train_program_id(self):
program_id = _hash_with_id(self._train_program, self) program_id = _hash_with_id(self._train_program, self)
core._set_cached_executor_build_strategy(program_id, core._set_cached_executor_build_strategy(
self._build_strategy) program_id, self._build_strategy
)
return program_id return program_id
@LazyInitialized @LazyInitialized
...@@ -363,8 +388,9 @@ class PartialProgramLayer: ...@@ -363,8 +388,9 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _train_amp_program_id(self): def _train_amp_program_id(self):
program_id = _hash_with_id(self._train_amp_program, self) program_id = _hash_with_id(self._train_amp_program, self)
core._set_cached_executor_build_strategy(program_id, core._set_cached_executor_build_strategy(
self._build_strategy) program_id, self._build_strategy
)
return program_id return program_id
@LazyInitialized @LazyInitialized
...@@ -374,8 +400,9 @@ class PartialProgramLayer: ...@@ -374,8 +400,9 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _train_pure_fp16_program_id(self): def _train_pure_fp16_program_id(self):
program_id = _hash_with_id(self._train_pure_fp16_program, self) program_id = _hash_with_id(self._train_pure_fp16_program, self)
core._set_cached_executor_build_strategy(program_id, core._set_cached_executor_build_strategy(
self._build_strategy) program_id, self._build_strategy
)
return program_id return program_id
@LazyInitialized @LazyInitialized
...@@ -411,8 +438,9 @@ class PartialProgramLayer: ...@@ -411,8 +438,9 @@ class PartialProgramLayer:
return main_program return main_program
def prepare_gradient_aggregation(self, start_idx, main_program, def prepare_gradient_aggregation(
target_program): self, start_idx, main_program, target_program
):
""" """
Why we need add gradient aggregation operation ? Why we need add gradient aggregation operation ?
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
...@@ -420,7 +448,7 @@ class PartialProgramLayer: ...@@ -420,7 +448,7 @@ class PartialProgramLayer:
x = 2 * in # <---- x is a non-leaf node in program. x = 2 * in # <---- x is a non-leaf node in program.
y = x + 3 y = x + 3
return x, y return x, y
loss = forward(in)[0].sum() loss = forward(in)[0].sum()
loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op
""" """
...@@ -430,8 +458,8 @@ class PartialProgramLayer: ...@@ -430,8 +458,8 @@ class PartialProgramLayer:
if exist a op whose inputs is var, then return True if exist a op whose inputs is var, then return True
""" """
if not isinstance(var, framework.Variable) or var.type not in [ if not isinstance(var, framework.Variable) or var.type not in [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS core.VarDesc.VarType.SELECTED_ROWS,
]: ]:
return False return False
if var.dtype not in [paddle.float32, paddle.float64]: if var.dtype not in [paddle.float32, paddle.float64]:
...@@ -448,20 +476,28 @@ class PartialProgramLayer: ...@@ -448,20 +476,28 @@ class PartialProgramLayer:
new_grad_name = var.name + suffix + "@GRAD" new_grad_name = var.name + suffix + "@GRAD"
finded_ops = list( finded_ops = list(
filter( filter(
lambda x: x[0] >= start_idx and any([ lambda x: x[0] >= start_idx
out_arg == var_grad_name and any(
for out_arg in x[1].output_arg_names [
]), enumerate(target_program.block(0).ops))) out_arg == var_grad_name
for out_arg in x[1].output_arg_names
]
),
enumerate(target_program.block(0).ops),
)
)
# len(finded_ops) may equals zero when stop_gradient works. # len(finded_ops) may equals zero when stop_gradient works.
# len(finded_ops) may > 1, because we may have fill_constant op. # len(finded_ops) may > 1, because we may have fill_constant op.
if len(finded_ops) == 0: if len(finded_ops) == 0:
return None return None
# step1: create a new var named var.name@GRAD # step1: create a new var named var.name@GRAD
target_program.block(0).create_var(name=new_grad_name, target_program.block(0).create_var(
type=var.type, name=new_grad_name,
dtype=var.dtype, type=var.type,
shape=var.shape) dtype=var.dtype,
shape=var.shape,
)
# step2: rename the var.name@GRAD to var.name@GRAD@dy2static # step2: rename the var.name@GRAD to var.name@GRAD@dy2static
for idx, op in finded_ops: for idx, op in finded_ops:
op._rename_input(var_grad_name, new_grad_name) op._rename_input(var_grad_name, new_grad_name)
...@@ -472,11 +508,13 @@ class PartialProgramLayer: ...@@ -472,11 +508,13 @@ class PartialProgramLayer:
finded_ops[-1][0] + 1, finded_ops[-1][0] + 1,
type='sum', type='sum',
inputs={'X': [var_grad_name, new_grad_name]}, inputs={'X': [var_grad_name, new_grad_name]},
outputs={"Out": var_grad_name}) outputs={"Out": var_grad_name},
)
return None return None
to_processed_vars = list( to_processed_vars = list(
filter(_need_aggregation, self._outputs.tolist())) filter(_need_aggregation, self._outputs.tolist())
)
for _var in to_processed_vars: for _var in to_processed_vars:
_insert_aggregation_ops_for_var(target_program, _var) _insert_aggregation_ops_for_var(target_program, _var)
...@@ -492,8 +530,9 @@ class PartialProgramLayer: ...@@ -492,8 +530,9 @@ class PartialProgramLayer:
if targets and self._params: if targets and self._params:
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
start_idx = len( start_idx = len(main_program.block(0).ops) + 2 * len(
main_program.block(0).ops) + 2 * len(self._outputs.tolist()) self._outputs.tolist()
)
self.prepare_gradient_aggregation(start_idx, main_program, program) self.prepare_gradient_aggregation(start_idx, main_program, program)
...@@ -512,7 +551,10 @@ class PartialProgramLayer: ...@@ -512,7 +551,10 @@ class PartialProgramLayer:
found_param = False found_param = False
for block in program.blocks: for block in program.blocks:
for op in block.ops: for op in block.ops:
if param.name in op.input_arg_names or param.name in op.output_arg_names: if (
param.name in op.input_arg_names
or param.name in op.output_arg_names
):
required_params.append(param) required_params.append(param)
found_param = True found_param = True
break break
...@@ -529,15 +571,21 @@ class PartialProgramLayer: ...@@ -529,15 +571,21 @@ class PartialProgramLayer:
var_desc = block.vars[name].desc var_desc = block.vars[name].desc
var_base = None var_base = None
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
var_base = core.VarBase(var_desc.dtype(), var_base = core.VarBase(
var_desc.shape(), var_desc.dtype(),
var_desc.name(), var_desc.shape(),
var_desc.type(), False) var_desc.name(),
var_desc.type(),
False,
)
else: else:
var_base = core.eager.Tensor(var_desc.dtype(), var_base = core.eager.Tensor(
var_desc.shape(), var_desc.dtype(),
var_desc.name(), var_desc.shape(),
var_desc.type(), False) var_desc.name(),
var_desc.type(),
False,
)
double_grads.append(var_base) double_grads.append(var_base)
return self._valid_vars(double_grads) return self._valid_vars(double_grads)
...@@ -557,36 +605,62 @@ class PartialProgramLayer: ...@@ -557,36 +605,62 @@ class PartialProgramLayer:
attrs = [ attrs = [
'global_block', 'global_block',
self.program.desc.block(0), 'start_op_index', 0, 'end_op_index', self.program.desc.block(0),
self._get_end_op_index(), 'is_test', not self.training, 'start_op_index',
'program_id', self.program_id 0,
'end_op_index',
self._get_end_op_index(),
'is_test',
not self.training,
'program_id',
self.program_id,
] ]
if self._cuda_graph_capture_mode: if self._cuda_graph_capture_mode:
attrs.extend( attrs.extend(
('cuda_graph_capture_mode', self._cuda_graph_capture_mode, (
'cuda_graph_pool_id', self._cuda_graph_pool_id)) 'cuda_graph_capture_mode',
self._cuda_graph_capture_mode,
use_interpretorcore = _is_enable_standalone_executor( 'cuda_graph_pool_id',
) and _is_dy2st_enable_standalone_executor() self._cuda_graph_pool_id,
)
)
use_interpretorcore = (
_is_enable_standalone_executor()
and _is_dy2st_enable_standalone_executor()
)
attrs.extend(('use_interpretorcore', use_interpretorcore)) attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore: if use_interpretorcore:
attrs.extend( attrs.extend(
('forward_global_block', self.forward_program.desc.block(0), (
'backward_global_block', self.backward_program.desc.block(0))) 'forward_global_block',
self.forward_program.desc.block(0),
'backward_global_block',
self.backward_program.desc.block(0),
)
)
_legacy_C_ops.run_program( _legacy_C_ops.run_program(
self._valid_vars(in_vars), self._valid_vars(self._params), self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars), self._valid_vars(out_vars),
self._create_scope_vec(program_id=self.program_id, self._create_scope_vec(
use_scope_cache=True), program_id=self.program_id, use_scope_cache=True
self._double_grads, self._cuda_graph_vec, *attrs) ),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
else: else:
_legacy_C_ops.run_program(self._valid_vars(in_vars), _legacy_C_ops.run_program(
self._valid_vars(self._params), self._valid_vars(in_vars),
self._valid_vars(out_vars), self._valid_vars(self._params),
self._create_scope_vec(), self._valid_vars(out_vars),
self._double_grads, self._cuda_graph_vec, self._create_scope_vec(),
*attrs) self._double_grads,
self._cuda_graph_vec,
*attrs
)
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
...@@ -594,9 +668,11 @@ class PartialProgramLayer: ...@@ -594,9 +668,11 @@ class PartialProgramLayer:
if _in_pure_fp16_guard(): if _in_pure_fp16_guard():
for i, var in enumerate(in_vars): for i, var in enumerate(in_vars):
name = var.name name = var.name
if (self.program.global_block().has_var(name) if (
and self.program.global_block().var(name).dtype self.program.global_block().has_var(name)
== paddle.float16): and self.program.global_block().var(name).dtype
== paddle.float16
):
in_vars[i] = var.astype('float16') in_vars[i] = var.astype('float16')
in_vars[i].name = name in_vars[i].name = name
...@@ -627,25 +703,32 @@ class PartialProgramLayer: ...@@ -627,25 +703,32 @@ class PartialProgramLayer:
return self._infer_program return self._infer_program
@switch_to_static_graph @switch_to_static_graph
def _get_forward_backward_program_form(self, whole_program, def _get_forward_backward_program_form(
forward_end_op_index): self, whole_program, forward_end_op_index
):
forward_builded_program = add_build_strategy_for( forward_builded_program = add_build_strategy_for(
whole_program, 0, forward_end_op_index, self._build_strategy) whole_program, 0, forward_end_op_index, self._build_strategy
)
backward_start_op_index = forward_end_op_index + 2 * len( backward_start_op_index = forward_end_op_index + 2 * len(
self._outputs.var_ids) self._outputs.var_ids
)
backward_end_op_index = whole_program.desc.block(0).op_size() backward_end_op_index = whole_program.desc.block(0).op_size()
backward_builded_program = add_build_strategy_for( backward_builded_program = add_build_strategy_for(
whole_program, backward_start_op_index, backward_end_op_index, whole_program,
self._build_strategy) backward_start_op_index,
self._apply_inplace_pass(forward_builded_program, backward_end_op_index,
backward_builded_program) self._build_strategy,
)
self._apply_inplace_pass(
forward_builded_program, backward_builded_program
)
return [forward_builded_program, backward_builded_program] return [forward_builded_program, backward_builded_program]
def _apply_inplace_pass(self, forward_program, backward_program): def _apply_inplace_pass(self, forward_program, backward_program):
attr_types = { attr_types = {
"use_cuda": "bool", "use_cuda": "bool",
"mem_opt_skip_vars": "list[str]", "mem_opt_skip_vars": "list[str]",
"for_partial_block": "bool" "for_partial_block": "bool",
} }
empty_startup_program = paddle.static.Program() empty_startup_program = paddle.static.Program()
use_cuda = True if core.is_compiled_with_cuda() else False use_cuda = True if core.is_compiled_with_cuda() else False
...@@ -667,22 +750,33 @@ class PartialProgramLayer: ...@@ -667,22 +750,33 @@ class PartialProgramLayer:
forward_mem_opt_skip_vars.append(var.desc.name()) forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name()) backward_mem_opt_skip_vars.append(var.desc.name())
for var_name in core.parse_safe_eager_deletion_skip_vars( for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc): backward_program.desc
):
forward_mem_opt_skip_vars.append(var_name) forward_mem_opt_skip_vars.append(var_name)
attrs = { attrs = {
"use_cuda": use_cuda, "use_cuda": use_cuda,
"mem_opt_skip_vars": forward_mem_opt_skip_vars, "mem_opt_skip_vars": forward_mem_opt_skip_vars,
"for_partial_block": True "for_partial_block": True,
} }
_apply_pass(forward_program, empty_startup_program, _apply_pass(
"buffer_shared_inplace_pass", attrs, attr_types) forward_program,
empty_startup_program,
"buffer_shared_inplace_pass",
attrs,
attr_types,
)
attrs = { attrs = {
"use_cuda": use_cuda, "use_cuda": use_cuda,
"mem_opt_skip_vars": backward_mem_opt_skip_vars, "mem_opt_skip_vars": backward_mem_opt_skip_vars,
"for_partial_block": True "for_partial_block": True,
} }
_apply_pass(backward_program, empty_startup_program, _apply_pass(
"buffer_shared_inplace_pass", attrs, attr_types) backward_program,
empty_startup_program,
"buffer_shared_inplace_pass",
attrs,
attr_types,
)
def _prepare(self, inputs): def _prepare(self, inputs):
""" """
...@@ -698,23 +792,28 @@ class PartialProgramLayer: ...@@ -698,23 +792,28 @@ class PartialProgramLayer:
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
var = None var = None
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
var = core.VarBase(value=value, var = core.VarBase(
name=self._inputs[i].desc.name(), value=value,
persistable=False, name=self._inputs[i].desc.name(),
place=expected_place, persistable=False,
zero_copy=True) place=expected_place,
zero_copy=True,
)
else: else:
var = core.eager.Tensor(value=value, var = core.eager.Tensor(
name=self._inputs[i].desc.name(), value=value,
persistable=False, name=self._inputs[i].desc.name(),
place=expected_place, persistable=False,
zero_copy=True) place=expected_place,
zero_copy=True,
)
elif isinstance(value, (core.VarBase, core.eager.Tensor)): elif isinstance(value, (core.VarBase, core.eager.Tensor)):
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
# into CUDAPlace when it's as input of multi Ops. so we move it in advance # into CUDAPlace when it's as input of multi Ops. so we move it in advance
# to avoid this problem. # to avoid this problem.
if value.stop_gradient and not value.place._equals( if value.stop_gradient and not value.place._equals(
expected_place): expected_place
):
var = value._copy_to(expected_place, False) var = value._copy_to(expected_place, False)
var.stop_gradient = True var.stop_gradient = True
else: else:
...@@ -737,12 +836,21 @@ class PartialProgramLayer: ...@@ -737,12 +836,21 @@ class PartialProgramLayer:
return out_varbase_map[var_desc.name()] return out_varbase_map[var_desc.name()]
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
var_base = core.VarBase(var_desc.dtype(), var_desc.shape(), var_base = core.VarBase(
var_desc.name(), var_desc.type(), False) var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
else: else:
var_base = core.eager.Tensor(var_desc.dtype(), var_desc.shape(), var_base = core.eager.Tensor(
var_desc.name(), var_desc.type(), var_desc.dtype(),
False) var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
var_base.stop_gradient = var.stop_gradient var_base.stop_gradient = var.stop_gradient
out_varbase_map[var_desc.name()] = var_base out_varbase_map[var_desc.name()] = var_base
return var_base return var_base
...@@ -755,20 +863,30 @@ class PartialProgramLayer: ...@@ -755,20 +863,30 @@ class PartialProgramLayer:
def _create_scope_vec(self, program_id=None, use_scope_cache=False): def _create_scope_vec(self, program_id=None, use_scope_cache=False):
# Hold forward variables # Hold forward variables
tmp_scope_vec = None tmp_scope_vec = None
inner_scope = self._get_scope(program_id=program_id, inner_scope = self._get_scope(
use_scope_cache=use_scope_cache) program_id=program_id, use_scope_cache=use_scope_cache
)
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], tmp_scope_vec = core.VarBase(
"program_out_scope", core.VarDesc.VarType.FP32,
core.VarDesc.VarType.STEP_SCOPES, True) [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES,
True,
)
tmp_scope_vec.value().set_scope(inner_scope) tmp_scope_vec.value().set_scope(inner_scope)
else: else:
tmp_scope_vec = [inner_scope] tmp_scope_vec = [inner_scope]
return tmp_scope_vec return tmp_scope_vec
def _create_cuda_graph_vec(self): def _create_cuda_graph_vec(self):
var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph", var = core.VarBase(
core.VarDesc.VarType.RAW, True) core.VarDesc.VarType.FP32,
[],
"cuda_graph",
core.VarDesc.VarType.RAW,
True,
)
var.stop_gradient = True var.stop_gradient = True
return var return var
...@@ -791,8 +909,9 @@ class PartialProgramLayer: ...@@ -791,8 +909,9 @@ class PartialProgramLayer:
return main_program.clone(for_test=True) return main_program.clone(for_test=True)
def _is_no_value(self, var): def _is_no_value(self, var):
if isinstance(var, if isinstance(var, (core.VarBase, core.eager.Tensor)) and var.shape == [
(core.VarBase, core.eager.Tensor)) and var.shape == [1]: 1
]:
# NOTE: .numpy() will insert MemcpySync operation, it hits performance. # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True return True
...@@ -808,13 +927,14 @@ class PartialProgramLayer: ...@@ -808,13 +927,14 @@ class PartialProgramLayer:
return out_vars return out_vars
elif isinstance(out_vars, (tuple, list)): elif isinstance(out_vars, (tuple, list)):
if isinstance(out_vars, tuple): if isinstance(out_vars, tuple):
res = tuple(var for var in out_vars res = tuple(
if not self._is_no_value(var)) var for var in out_vars if not self._is_no_value(var)
)
else: else:
# isinstance(out_vars, list) # isinstance(out_vars, list)
res = [var for var in out_vars if not self._is_no_value(var)] res = [var for var in out_vars if not self._is_no_value(var)]
has_removed = (len(out_vars) > len(res)) has_removed = len(out_vars) > len(res)
# len(out_vars) > len(res) means we have removed var. This is # len(out_vars) > len(res) means we have removed var. This is
# preventing out_vars is empty or just one element at the beginning # preventing out_vars is empty or just one element at the beginning
if len(res) == 0 and has_removed: if len(res) == 0 and has_removed:
...@@ -835,7 +955,8 @@ class PartialProgramLayer: ...@@ -835,7 +955,8 @@ class PartialProgramLayer:
for param in params: for param in params:
grad_name = param.name + core.grad_var_suffix() grad_name = param.name + core.grad_var_suffix()
grad_var = train_program.desc.block(0).find_var( grad_var = train_program.desc.block(0).find_var(
cpt.to_bytes(grad_name)) cpt.to_bytes(grad_name)
)
# NOTE: cannot find var desc maybe no problem, such as in batch_norm # NOTE: cannot find var desc maybe no problem, such as in batch_norm
if grad_var is None: if grad_var is None:
continue continue
...@@ -864,15 +985,18 @@ class PartialProgramLayer: ...@@ -864,15 +985,18 @@ class PartialProgramLayer:
if not isinstance(self._params, (list, tuple)): if not isinstance(self._params, (list, tuple)):
raise TypeError( raise TypeError(
"Type of self._params in PartialProgramLayer should be list or tuple, but received %s." "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
% type(self._params)) % type(self._params)
)
param_and_buffer_names_set = set() param_and_buffer_names_set = set()
for i, var in enumerate(self._params): for i, var in enumerate(self._params):
# self._params constains parameters and buffers with persistable=True. # self._params constains parameters and buffers with persistable=True.
if not isinstance(var, (core.VarBase, core.eager.Tensor)): if not isinstance(var, (core.VarBase, core.eager.Tensor)):
raise TypeError( raise TypeError(
'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.' 'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.format(
.format(i, type(var))) i, type(var)
)
)
param_and_buffer_names_set.add(var.name) param_and_buffer_names_set.add(var.name)
for block in main_program.blocks: for block in main_program.blocks:
...@@ -886,7 +1010,8 @@ class PartialProgramLayer: ...@@ -886,7 +1010,8 @@ class PartialProgramLayer:
"\n\tRevise suggestion: " "\n\tRevise suggestion: "
"\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer." "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
"\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List" "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
% name) % name
)
def _valid_vars(self, vars): def _valid_vars(self, vars):
""" """
...@@ -903,13 +1028,23 @@ def _create_fake_var(): ...@@ -903,13 +1028,23 @@ def _create_fake_var():
""" """
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
return [ return [
core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var", core.VarBase(
core.VarDesc.VarType.RAW, False) core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
] ]
else: else:
return [ return [
core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var", core.eager.Tensor(
core.VarDesc.VarType.RAW, False) core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
] ]
...@@ -918,23 +1053,27 @@ def partial_program_from(concrete_program): ...@@ -918,23 +1053,27 @@ def partial_program_from(concrete_program):
if inputs and isinstance(inputs[0], layers.Layer): if inputs and isinstance(inputs[0], layers.Layer):
inputs = inputs[1:] inputs = inputs[1:]
return PartialProgramLayer(concrete_program.main_program, inputs, return PartialProgramLayer(
concrete_program.outputs, concrete_program.main_program,
concrete_program.parameters, inputs,
**concrete_program.kwargs) concrete_program.outputs,
concrete_program.parameters,
**concrete_program.kwargs
)
@switch_to_static_graph @switch_to_static_graph
def add_build_strategy_for(program, def add_build_strategy_for(
start_op_index, program, start_op_index, end_op_index, build_strategy=None
end_op_index, ):
build_strategy=None): if start_op_index < end_op_index:
if (start_op_index < end_op_index):
compiled_program = paddle.static.CompiledProgram( compiled_program = paddle.static.CompiledProgram(
core.Graph(program.desc, start_op_index, end_op_index), core.Graph(program.desc, start_op_index, end_op_index),
build_strategy=build_strategy) build_strategy=build_strategy,
compiled_program._compile(core.Scope(), )
framework._current_expected_place()) compiled_program._compile(
core.Scope(), framework._current_expected_place()
)
ir_graph = framework.IrGraph(compiled_program._graph) ir_graph = framework.IrGraph(compiled_program._graph)
builded_program = ir_graph.to_program() builded_program = ir_graph.to_program()
if hasattr(compiled_program._program, 'lr_sheduler'): if hasattr(compiled_program._program, 'lr_sheduler'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册