From 6959eae53a58b29ffca4efc848c55238734456e2 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 12 Apr 2023 11:17:46 +0800 Subject: [PATCH] Unify the static amp codes of fp16 and bf16. Reimplement #52694 in release/2.4. (#52697) --- .../fluid/contrib/mixed_precision/__init__.py | 2 +- .../fluid/contrib/mixed_precision/amp_nn.py | 114 +++-- .../contrib/mixed_precision/decorator.py | 460 ++++++++++++------ .../contrib/mixed_precision/fp16_lists.py | 74 +-- .../contrib/mixed_precision/fp16_utils.py | 329 ++++++++----- .../dygraph_to_static/partial_program.py | 447 +++++++++++------ 6 files changed, 920 insertions(+), 506 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index 1dd5015ec80..52e645f3715 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -15,7 +15,7 @@ from __future__ import print_function from . import decorator -from .decorator import * +from .decorator import decorate, amp_decorate from . import fp16_lists from .fp16_lists import * from . import fp16_utils diff --git a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py index 62b98e75ea1..26a318ca670 100644 --- a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py +++ b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py @@ -27,8 +27,8 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): $$Out = X / scale$$ If any tensor in X contains Inf or Nan, the Out will generate a indicator. - FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of - Out should not be used, and its data may not be deterministic. + FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of + Out should not be used, and its data may not be deterministic. Otherwise, FoundInfinite will be 0 (False). Args: @@ -38,75 +38,98 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): """ check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') for e in x: - check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], - 'check_finite_and_unscale') + check_variable_and_dtype( + e, + "x", + ['float16', 'float32', 'float64', 'uint16'], + 'check_finite_and_unscale', + ) helper = LayerHelper("check_finite_and_unscale", **locals()) found_inf = helper.create_variable_for_type_inference(dtype='bool') inputs = {'X': x, 'Scale': scale} if core.is_compiled_with_npu(): - check_variable_and_dtype(float_status, "float_status", - ['float16', 'float32'], - 'check_finite_and_unscale') + check_variable_and_dtype( + float_status, + "float_status", + ['float16', 'float32'], + 'check_finite_and_unscale', + ) inputs['FloatStatus'] = float_status outputs = {'Out': x, 'FoundInfinite': found_inf} - helper.append_op(type='check_finite_and_unscale', - inputs=inputs, - outputs=outputs) + helper.append_op( + type='check_finite_and_unscale', inputs=inputs, outputs=outputs + ) return x, found_inf -def update_loss_scaling(x, - found_inf, - prev_loss_scaling, - num_good_steps, - num_bad_steps, - incr_every_n_steps, - decr_every_n_nan_or_inf, - incr_ratio, - decr_ratio, - stop_update=False, - name=None): +def update_loss_scaling( + x, + found_inf, + prev_loss_scaling, + num_good_steps, + num_bad_steps, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + stop_update=False, + name=None, +): """ - Update loss scaling according to overall gradients. If all gradients is - finite after incr_every_n_steps, loss scaling will increase by incr_ratio. + Update loss scaling according to overall gradients. If all gradients is + finite after incr_every_n_steps, loss scaling will increase by incr_ratio. Otherwise, loss scaling will decrease by decr_ratio after decr_every_n_nan_or_inf steps and each step some gradients are infinite. Args: x(list|tuple): The input tensors of update_loss_scaling operator. - found_inf (Variable): A boolean variable indicates whether + found_inf (Variable): A boolean variable indicates whether there is any infinite gradient. prev_loss_scaling (Variable): Previous loss scaling. - num_good_steps (Variable): A variable accumulates good steps in which + num_good_steps (Variable): A variable accumulates good steps in which all gradients are finite. - num_bad_steps (Variable): A variable accumulates bad steps in which + num_bad_steps (Variable): A variable accumulates bad steps in which some gradients are infinite. - incr_every_n_steps (int): A variable represents increasing loss - scaling every n consecutive steps with + incr_every_n_steps (int): A variable represents increasing loss + scaling every n consecutive steps with finite gradients. - decr_every_n_nan_or_inf (int): A variable represents decreasing - loss scaling every n accumulated + decr_every_n_nan_or_inf (int): A variable represents decreasing + loss scaling every n accumulated steps with nan or inf gradients. - incr_ratio(float): The multiplier to use when increasing the loss + incr_ratio(float): The multiplier to use when increasing the loss scaling. - decr_ratio(float): The less-than-one-multiplier to use when decreasing + decr_ratio(float): The less-than-one-multiplier to use when decreasing loss scaling. """ - check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling", - ['float32', 'float64'], "update_loss_scaling") + check_variable_and_dtype( + prev_loss_scaling, + "prev_loss_scaling", + ['float32', 'float64'], + "update_loss_scaling", + ) check_type(x, 'x', (tuple, list), 'update_loss_scaling') for e in x: - check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], - 'update_loss_scaling') - if e.dtype == core.VarDesc.VarType.FP16: - assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ - "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." + check_variable_and_dtype( + e, + "x", + ['float16', 'float32', 'float64', 'uint16'], + 'update_loss_scaling', + ) + if ( + e.dtype == core.VarDesc.VarType.FP16 + or e.dtype == core.VarDesc.VarType.BF16 + ): + assert ( + prev_loss_scaling.dtype == core.VarDesc.VarType.FP32 + ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." else: - assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." + assert ( + prev_loss_scaling.dtype == e.dtype + ), "The dtype of prev_loss_scaling should be equal to the dtype of x." helper = LayerHelper("update_loss_scaling", **locals()) @@ -115,14 +138,14 @@ def update_loss_scaling(x, 'FoundInfinite': found_inf, 'PrevLossScaling': prev_loss_scaling, 'InGoodSteps': num_good_steps, - 'InBadSteps': num_bad_steps + 'InBadSteps': num_bad_steps, } outputs = { 'Out': x, 'LossScaling': prev_loss_scaling, 'OutGoodSteps': num_good_steps, - 'OutBadSteps': num_bad_steps + 'OutBadSteps': num_bad_steps, } attrs = { @@ -137,9 +160,8 @@ def update_loss_scaling(x, else: attrs['stop_update'] = stop_update - helper.append_op(type='update_loss_scaling', - inputs=inputs, - outputs=outputs, - attrs=attrs) + helper.append_op( + type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs + ) return x diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 787a4e90a0f..aea65060624 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -36,11 +36,11 @@ __all__ = ["decorate"] class OptimizerWithMixedPrecision(object): """ - Optimizer with mixed-precision (MP) training. This is a wrapper of a common + Optimizer with mixed-precision (MP) training. This is a wrapper of a common optimizer, plus the support of mixed-precision pre-training. The object - of this class almost has the same behavior as the common optimizer, with the - methods `minimize()`, `backward()`, `apply_gradients()` implemented. - Additionally, it enables the MP training automatically, i.e, the creation + of this class almost has the same behavior as the common optimizer, with the + methods `minimize()`, `backward()`, `apply_gradients()` implemented. + Additionally, it enables the MP training automatically, i.e, the creation and maintenance of master parameters, scaling of loss, etc. Args: @@ -48,14 +48,14 @@ class OptimizerWithMixedPrecision(object): amp_lists (CustomOpLists): An CustomOpLists object. init_loss_scaling (float): The initial loss scaling factor. use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. - incr_every_n_steps(int): Increases loss scaling every n consecutive + incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. - decr_every_n_nan_or_inf(int): Decreases loss scaling every n - accumulated steps with nan or + decr_every_n_nan_or_inf(int): Decreases loss scaling every n + accumulated steps with nan or inf gradients. - incr_ratio(float): The multiplier to use when increasing the loss + incr_ratio(float): The multiplier to use when increasing the loss scaling. - decr_ratio(float): The less-than-one-multiplier to use when decreasing + decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program. @@ -63,10 +63,20 @@ class OptimizerWithMixedPrecision(object): """ - def __init__(self, optimizer, amp_lists, init_loss_scaling, - use_dynamic_loss_scaling, incr_every_n_steps, - decr_every_n_nan_or_inf, incr_ratio, decr_ratio, use_pure_fp16, - use_fp16_guard): + def __init__( + self, + optimizer, + amp_lists, + init_loss_scaling, + use_dynamic_loss_scaling, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + use_pure_fp16, + use_fp16_guard, + use_bf16=False, + ): self._optimizer = optimizer self._amp_lists = amp_lists self._param_grads = None @@ -77,11 +87,23 @@ class OptimizerWithMixedPrecision(object): self._loss_scaling = None self._init_loss_scaling = init_loss_scaling self._use_dynamic_loss_scaling = use_dynamic_loss_scaling + if use_bf16: + if use_dynamic_loss_scaling: + self._use_dynamic_loss_scaling = False + self._init_loss_scaling = 1.0 + warnings.warn( + "Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle." + ) + self._amp_dtype = core.VarDesc.VarType.BF16 + else: + self._amp_dtype = core.VarDesc.VarType.FP16 + self._learning_rate = optimizer._learning_rate self._learning_rate_map = optimizer._learning_rate_map self._use_pure_fp16 = use_pure_fp16 self._use_fp16_guard = use_fp16_guard self._to_fp16_var_names = None + self._use_bf16 = use_bf16 if self._use_dynamic_loss_scaling: self._incr_every_n_steps = incr_every_n_steps self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf @@ -97,9 +119,10 @@ class OptimizerWithMixedPrecision(object): self._is_distributed = flag def get_loss_scaling(self): - """Return the real-time loss scaling factor. - """ - assert self._loss_scaling is not None, 'Please call minimize() before calling get_loss_scaling().' + """Return the real-time loss scaling factor.""" + assert ( + self._loss_scaling is not None + ), 'Please call minimize() before calling get_loss_scaling().' return self._loss_scaling def get_scaled_loss(self): @@ -117,7 +140,8 @@ class OptimizerWithMixedPrecision(object): shape=[1], value=self._init_loss_scaling, dtype='float32', - persistable=True) + persistable=True, + ) if self._use_dynamic_loss_scaling: self._num_good_steps = layers.create_global_var( @@ -125,37 +149,43 @@ class OptimizerWithMixedPrecision(object): shape=[1], value=0, dtype='int32', - persistable=True) + persistable=True, + ) self._num_bad_steps = layers.create_global_var( name=unique_name.generate("num_bad_steps"), shape=[1], value=0, dtype='int32', - persistable=True) + persistable=True, + ) # Ensure the data type of learning rate vars is float32 (same as the # master parameter dtype) if isinstance(self._optimizer._learning_rate, float): - self._optimizer._learning_rate_map[default_main_program()] = \ - layers.create_global_var( - name=unique_name.generate("learning_rate"), - shape=[1], - value=float(self._optimizer._learning_rate), - dtype='float32', - persistable=True) - - def backward(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None, - callbacks=None): + self._optimizer._learning_rate_map[ + default_main_program() + ] = layers.create_global_var( + name=unique_name.generate("learning_rate"), + shape=[1], + value=float(self._optimizer._learning_rate), + dtype='float32', + persistable=True, + ) + + def backward( + self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None, + ): """ Backward propagation or auto differentiation for gradients' computation. Args: loss (Variable): The loss Variable to minimize. - startup_program (Program|None): The startup Program for initializing + startup_program (Program|None): The startup Program for initializing parameters in `parameter_list`. parameter_list (list|None): A list of Variables to update. no_grad_set (set|None): A set of Variables should be ignored. @@ -163,7 +193,7 @@ class OptimizerWithMixedPrecision(object): backward operator for one parameter. Returns: - A list of (param, grad), which is a tuple of a parameter and its + A list of (param, grad), which is a tuple of a parameter and its gradient respectively, and the scaled loss. """ train_program = loss.block.program @@ -171,9 +201,9 @@ class OptimizerWithMixedPrecision(object): # NOTE(zhiqiu): _float_status is only used for NPU. if core.is_compiled_with_npu(): - float_status = paddle.static.data(name="float_status", - shape=[8], - dtype='float32') + float_status = paddle.static.data( + name="float_status", shape=[8], dtype='float32' + ) self._train_program.global_block().append_op( type="alloc_float_status", outputs={"FloatStatus": float_status}, @@ -192,9 +222,15 @@ class OptimizerWithMixedPrecision(object): if self._use_pure_fp16: self._to_fp16_var_names = cast_model_to_fp16( - self._train_program, self._amp_lists, self._use_fp16_guard) + self._train_program, + self._amp_lists, + self._use_fp16_guard, + self._amp_dtype, + ) else: - rewrite_program(self._train_program, self._amp_lists) + rewrite_program( + self._train_program, self._amp_lists, self._amp_dtype + ) if loss.dtype != core.VarDesc.VarType.FP32: loss = loss.astype('float32') @@ -205,10 +241,13 @@ class OptimizerWithMixedPrecision(object): else: self._scaled_loss = loss - params_grads = self._optimizer.backward(self._scaled_loss, - startup_program, - parameter_list, no_grad_set, - callbacks) + params_grads = self._optimizer.backward( + self._scaled_loss, + startup_program, + parameter_list, + no_grad_set, + callbacks, + ) if self._supports_check_nan_inf(): self._add_cast_ops_to_startup_program(startup_program) return params_grads @@ -216,8 +255,11 @@ class OptimizerWithMixedPrecision(object): def _add_cast_ops_to_startup_program(self, startup_program): names = list(self._to_fp16_var_names) if self._to_fp16_var_names else [] names.sort() - startup_program = default_startup_program( - ) if startup_program is None else startup_program + startup_program = ( + default_startup_program() + if startup_program is None + else startup_program + ) block = startup_program.global_block() param_names = [p.name for p in block.all_parameters()] for name in names: @@ -225,28 +267,28 @@ class OptimizerWithMixedPrecision(object): continue tmp = block.create_var(dtype=core.VarDesc.VarType.FP32) - block.append_op(type='assign', - inputs={'X': [name]}, - outputs={'Out': [tmp]}) - block.append_op(type='cast', - inputs={'X': [tmp]}, - outputs={'Out': [name]}, - attrs={ - 'in_dtype': core.VarDesc.VarType.FP32, - 'out_dtype': core.VarDesc.VarType.FP16, - }) + block.append_op( + type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]} + ) + block.append_op( + type='cast', + inputs={'X': [tmp]}, + outputs={'Out': [name]}, + attrs={ + 'in_dtype': core.VarDesc.VarType.FP32, + 'out_dtype': self._amp_dtype, + }, + ) self._to_fp16_var_names = None - def amp_init(self, - place, - scope=None, - test_program=None, - use_fp16_test=False): + def amp_init( + self, place, scope=None, test_program=None, use_fp16_test=False + ): """ Init the amp training, such as cast fp32 parameters to fp16 type. - + Args: - place(CUDAPlace): place is used to initialize + place(CUDAPlace): place is used to initialize fp16 parameters with fp32 values. scope(Scope): The scope is used to find fp32 parameters. test_program(Program): The program is used for testing. @@ -273,7 +315,7 @@ class OptimizerWithMixedPrecision(object): loss = paddle.mean(hidden) # 2) Create the optimizer and set `multi_precision` to True. # Setting `multi_precision` to True can avoid the poor accuracy - # or the slow convergence in a way. + # or the slow convergence in a way. optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) # 3) These ops in `custom_black_list` will keep in the float32 computation type. amp_list = paddle.static.amp.CustomOpLists( @@ -293,30 +335,40 @@ class OptimizerWithMixedPrecision(object): # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). # If you want to perform the testing process, you should pass `test_program` into `amp_init`. optimizer.amp_init(place, scope=paddle.static.global_scope()) - + if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: - run_example_code() + run_example_code() """ - assert self._train_program is not None, \ - "Please call the minimize method first." + assert ( + self._train_program is not None + ), "Please call the minimize method first." if self._use_pure_fp16: - cast_parameters_to_fp16(place, self._train_program, scope, - self._to_fp16_var_names) + cast_parameters_to_fp16( + place, + self._train_program, + scope, + self._to_fp16_var_names, + self._amp_dtype, + ) if test_program is not None: if self._use_pure_fp16: - cast_model_to_fp16(test_program, self._amp_lists, - self._use_fp16_guard) + cast_model_to_fp16( + test_program, + self._amp_lists, + self._use_fp16_guard, + self._amp_dtype, + ) elif use_fp16_test: - rewrite_program(test_program, self._amp_lists) + rewrite_program(test_program, self._amp_lists, self._amp_dtype) def apply_gradients(self, params_grads): """ - Check scaled gradients to determine whether to update loss scaling and update + Check scaled gradients to determine whether to update loss scaling and update parameters by their scaled gradients. - + Args: params_grads (list): A list of params and scaled grads. - + Returns: A list of optimize operators. """ @@ -327,7 +379,10 @@ class OptimizerWithMixedPrecision(object): # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, # the model can be optimized. - if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0: + if ( + not self._use_dynamic_loss_scaling + and self._init_loss_scaling == 1.0 + ): return self._optimizer.apply_gradients(params_grads) if self._supports_check_nan_inf(): @@ -338,7 +393,10 @@ class OptimizerWithMixedPrecision(object): return optimize_ops found_inf = self._check_finite_and_unscale(params_grads) - if self._use_dynamic_loss_scaling: + if ( + self._use_dynamic_loss_scaling + and self._amp_dtype == core.VarDesc.VarType.FP16 + ): self._add_dynamic_loss_scaling(params_grads, found_inf) # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow @@ -346,13 +404,16 @@ class OptimizerWithMixedPrecision(object): real_optimizer = self._optimizer while hasattr(real_optimizer, "inner_opt"): real_optimizer = real_optimizer.inner_opt - if isinstance(real_optimizer, - (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)): + if isinstance( + real_optimizer, + (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW), + ): # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we # copy it in advance to avoid multiple time copies. with self._train_program._optimized_guard([]): found_inf = paddle.tensor.creation._memcpy( - found_inf, paddle.CPUPlace()) + found_inf, paddle.CPUPlace() + ) real_optimizer._set_auxiliary_var('found_inf', found_inf) elif hasattr(real_optimizer, "_set_auxiliary_var"): real_optimizer._set_auxiliary_var('found_inf', found_inf) @@ -362,9 +423,10 @@ class OptimizerWithMixedPrecision(object): def _split_grads(self, params_grads): grads = [g for _, g in params_grads] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] - fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] - assert len(fp32_grads) + len(fp16_grads) == len(grads), \ - "Data types of all grads must be either fp16 or fp32." + fp16_grads = [g for g in grads if g.dtype == self._amp_dtype] + assert len(fp32_grads) + len(fp16_grads) == len( + grads + ), "Data types of all grads must be either fp16/bf16 or fp32." return grads, fp32_grads, fp16_grads def _check_finite_and_unscale(self, params_grads): @@ -380,7 +442,8 @@ class OptimizerWithMixedPrecision(object): grads, self._loss_scaling, name="find_infinite_scale", - float_status=self._float_status) + float_status=self._float_status, + ) found_infs.append(found_inf) else: for p, g in params_grads: @@ -391,7 +454,8 @@ class OptimizerWithMixedPrecision(object): ], self._loss_scaling, name="find_infinite_scale", - float_status=self._float_status) + float_status=self._float_status, + ) found_infs.append(found_inf) elif self._use_pure_fp16: if fp32_grads: @@ -400,7 +464,8 @@ class OptimizerWithMixedPrecision(object): fp32_grads, self._loss_scaling, name="find_infinite_scale_fp32", - float_status=self._float_status) + float_status=self._float_status, + ) found_infs.append(fp32_found_inf) if fp16_grads: with self._train_program._optimized_guard(fp16_grads): @@ -408,7 +473,8 @@ class OptimizerWithMixedPrecision(object): fp16_grads, self._loss_scaling, name="find_infinite_scale_fp16", - float_status=self._float_status) + float_status=self._float_status, + ) found_infs.append(fp16_found_inf) else: with self._train_program._optimized_guard(grads): @@ -416,7 +482,8 @@ class OptimizerWithMixedPrecision(object): grads, self._loss_scaling, name="find_infinite_scale", - float_status=self._float_status) + float_status=self._float_status, + ) if self._is_distributed or self._use_pure_fp16: with self._train_program._optimized_guard([]): @@ -439,7 +506,8 @@ class OptimizerWithMixedPrecision(object): self._incr_ratio, self._decr_ratio, stop_update=self._optimizer._get_stop_update_var(), - name="update_loss_scaling") + name="update_loss_scaling", + ) return grads, fp32_grads, fp16_grads = self._split_grads(params_grads) @@ -447,42 +515,48 @@ class OptimizerWithMixedPrecision(object): stop_update = False with self._train_program._optimized_guard([]): if fp32_grads: - update_loss_scaling(fp32_grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - stop_update=stop_update, - name="update_loss_scaling_fp32") + update_loss_scaling( + fp32_grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + stop_update=stop_update, + name="update_loss_scaling_fp32", + ) stop_update = True if fp16_grads: - update_loss_scaling(fp16_grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - stop_update=stop_update, - name="update_loss_scaling_fp16") + update_loss_scaling( + fp16_grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + stop_update=stop_update, + name="update_loss_scaling_fp16", + ) else: with self._train_program._optimized_guard([]): - update_loss_scaling(grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - name="update_loss_scaling") + update_loss_scaling( + grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + name="update_loss_scaling", + ) def apply_optimize(self, loss, startup_program, params_grads): program = loss.block.program @@ -490,11 +564,9 @@ class OptimizerWithMixedPrecision(object): optimize_ops = self.apply_gradients(params_grads) return optimize_ops - def minimize(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None): + def minimize( + self, loss, startup_program=None, parameter_list=None, no_grad_set=None + ): """ Perform optimization by minimizing the given loss. @@ -511,48 +583,55 @@ class OptimizerWithMixedPrecision(object): """ opt_dict = self._optimizer.__class__.__dict__ - if 'minimize' in opt_dict and isinstance(opt_dict['minimize'], - types.FunctionType): + if 'minimize' in opt_dict and isinstance( + opt_dict['minimize'], types.FunctionType + ): warnings.warn( "The decorated optimizer has its own `minimize` method, but it will not be executed." ) - scaled_params_grads = self.backward(loss, - startup_program=startup_program, - parameter_list=parameter_list, - no_grad_set=no_grad_set) + scaled_params_grads = self.backward( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set, + ) - optimize_ops = self.apply_optimize(loss, startup_program, - scaled_params_grads) + optimize_ops = self.apply_optimize( + loss, startup_program, scaled_params_grads + ) return optimize_ops, scaled_params_grads -def decorate(optimizer, - amp_lists=None, - init_loss_scaling=2**15, - incr_every_n_steps=1000, - decr_every_n_nan_or_inf=2, - incr_ratio=2.0, - decr_ratio=0.8, - use_dynamic_loss_scaling=True, - use_pure_fp16=False, - use_fp16_guard=None): - """ +def decorate( + optimizer, + amp_lists=None, + init_loss_scaling=2**15, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=2, + incr_ratio=2.0, + decr_ratio=0.8, + use_dynamic_loss_scaling=True, + use_pure_fp16=False, + use_fp16_guard=None, + use_bf16=False, +): + """ Decorate the given optimizer to adapt to the mixed-precision training. Args: optimizer(Optimizer): A common Optimizer. amp_lists (CustomOpLists): An CustomOpLists object. init_loss_scaling(float): The initial loss scaling factor. - incr_every_n_steps(int): Increases loss scaling every n consecutive + incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. - decr_every_n_nan_or_inf(int): Decreases loss scaling every n - accumulated steps with nan or + decr_every_n_nan_or_inf(int): Decreases loss scaling every n + accumulated steps with nan or inf gradients. - incr_ratio(float): The multiplier to use when increasing the loss + incr_ratio(float): The multiplier to use when increasing the loss scaling. - decr_ratio(float): The less-than-one-multiplier to use when decreasing + decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. @@ -560,11 +639,11 @@ def decorate(optimizer, Default None, which means that its value equals to `use_pure_fp16`. Returns: - An optimizer acting like a normal one but with mixed-precision training + An optimizer acting like a normal one but with mixed-precision training enabled. Examples 1: - .. code-block:: python + .. code-block:: python # black&white list based strategy example import paddle @@ -604,7 +683,7 @@ def decorate(optimizer, loss = paddle.mean(hidden) # 2) Create the optimizer and set `multi_precision` to True. # Setting `multi_precision` to True can avoid the poor accuracy - # or the slow convergence in a way. + # or the slow convergence in a way. optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) # 3) These ops in `custom_black_list` will keep in the float32 computation type. amp_list = paddle.static.amp.CustomOpLists( @@ -624,19 +703,82 @@ def decorate(optimizer, # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). # If you want to perform the testing process, you should pass `test_program` into `amp_init`. optimizer.amp_init(place, scope=paddle.static.global_scope()) - + if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: run_example_code() """ + dtype = "bfloat16" if use_bf16 else "float16" if amp_lists is None: - amp_lists = AutoMixedPrecisionLists() + amp_lists = AutoMixedPrecisionLists(dtype=dtype) if use_fp16_guard is None: use_fp16_guard = use_pure_fp16 mp_optimizer = OptimizerWithMixedPrecision( - optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling, - incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, - use_pure_fp16, use_fp16_guard) + optimizer, + amp_lists, + init_loss_scaling, + use_dynamic_loss_scaling, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + use_pure_fp16, + use_fp16_guard, + use_bf16, + ) + + return mp_optimizer + + +def amp_decorate( + optimizer, + amp_lists=None, + level='O1', + dtype='float16', + init_loss_scaling=2**15, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=2, + incr_ratio=2.0, + decr_ratio=0.8, + use_dynamic_loss_scaling=True, + use_amp_guard=False, +): + """ + Decorate the given optimizer to adapt to the mixed-precision training. + """ + # check amp_dtype: float16 or bfloat16 + dtype = dtype.lower() + if not (dtype in ['float16', 'bfloat16']): + raise ValueError( + "If enable AMP, dtype should be 'float16' or 'bfloat16'." + ) + + if amp_lists is None: + amp_lists = AutoMixedPrecisionLists(dtype=dtype) + + # check amp_level: O0-O2 + level = level.upper() + if not (level in ['O0', 'O1', 'O2']): + raise ValueError( + "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." + ) + + use_pure_fp16 = level == "O2" + use_fp16_guard = use_amp_guard + use_bf16 = dtype == "bfloat16" + mp_optimizer = OptimizerWithMixedPrecision( + optimizer, + amp_lists, + init_loss_scaling, + use_dynamic_loss_scaling, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + use_pure_fp16, + use_fp16_guard, + use_bf16, + ) return mp_optimizer diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index b2767b1dd1c..e9728ed6c16 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -13,16 +13,47 @@ # limitations under the License. import copy + from ... import core __all__ = ["CustomOpLists", "AutoMixedPrecisionLists"] # lookup_table fp16 is slower than fp32, though fp16 is supported. -_extra_unsupported_fp16_list = { - 'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' +_extra_unsupported_list = { + 'lookup_table', + 'lookup_table_v2', + 'scatter', + 'scatter_grad', } +def _get_unsupported_list(dtype): + if dtype == "float16": + amp_dtype = core.VarDesc.VarType.FP16 + elif dtype == "bfloat16": + amp_dtype = core.VarDesc.VarType.BF16 + else: + raise ValueError( + "If enable AMP, dtype should be 'float16' or 'bfloat16'." + ) + + # The set of ops that don't support fp16 calculation + # lookup_table fp16 is slower than fp32, though fp16 is supported. + _sys_unsupported_list = [] + # _sys_unsupported_bf16_list = [] + if core.is_compiled_with_xpu(): + _, _, _sys_unsupported_list = core.op_supported_infos('XPU', amp_dtype) + elif core.is_compiled_with_npu(): + _, _, _sys_unsupported_list = core.op_supported_infos('NPU', amp_dtype) + elif core.is_compiled_with_mlu(): + _, _, _sys_unsupported_list = core.op_supported_infos('MLU', amp_dtype) + else: + _, _, _sys_unsupported_list = core.op_supported_infos('GPU', amp_dtype) + + unsupported_list = _extra_unsupported_list | _sys_unsupported_list + return unsupported_list + + class AutoMixedPrecisionLists(object): """ AutoMixedPrecisionLists is a class for black/white list. It can update @@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object): custom_black_varnames (set): Users' custom black varibles' names. """ - def __init__(self, - custom_white_list=None, - custom_black_list=None, - custom_black_varnames=None): + def __init__( + self, + custom_white_list=None, + custom_black_list=None, + custom_black_varnames=None, + dtype="float16", + ): self._custom_white_list = custom_white_list self._custom_black_list = custom_black_list + self.amp_dtype = dtype self.white_list = copy.copy(white_list) self.black_list = copy.copy(black_list) self.gray_list = copy.copy(gray_list) - self.unsupported_list = copy.copy(unsupported_fp16_list) + self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) self.black_varnames = copy.copy(custom_black_varnames) self._update_list() @@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object): if self._custom_white_list and self._custom_black_list: for op_name in self._custom_white_list: if op_name in self._custom_black_list: - raise ValueError("Custom white list overlap " - "custom black list") + raise ValueError( + "Custom white list overlap " "custom black list" + ) if self._custom_white_list: for op_name in self._custom_white_list: if op_name in self.black_list: @@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object): elif op_name in self.gray_list: self.gray_list.remove(op_name) self.white_list.add(op_name) - if op_name in _extra_unsupported_fp16_list: + if op_name in _extra_unsupported_list: self.unsupported_list.remove(op_name) if self._custom_black_list: for op_name in self._custom_black_list: @@ -170,22 +206,4 @@ gray_list = { 'fused_multi_transformer', } -# The set of ops that don't support fp16 calculation -# lookup_table fp16 is slower than fp32, though fp16 is supported. -_sys_unsupported_fp16_list = [] -if core.is_compiled_with_xpu(): - _, _, _sys_unsupported_fp16_list = core.op_supported_infos( - 'XPU', core.VarDesc.VarType.FP16) -elif core.is_compiled_with_npu(): - _, _, _sys_unsupported_fp16_list = core.op_supported_infos( - 'NPU', core.VarDesc.VarType.FP16) -elif core.is_compiled_with_mlu(): - _, _, _sys_unsupported_fp16_list = core.op_supported_infos( - 'MLU', core.VarDesc.VarType.FP16) -else: - _, _, _sys_unsupported_fp16_list = core.op_supported_infos( - 'GPU', core.VarDesc.VarType.FP16) - -unsupported_fp16_list = _extra_unsupported_fp16_list | _sys_unsupported_fp16_list - CustomOpLists = AutoMixedPrecisionLists diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index b23c94c7e49..5b6fa74ed64 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -14,6 +14,7 @@ from __future__ import print_function +import paddle from ... import core from ... import framework from ... import layers @@ -27,13 +28,14 @@ import numpy as np __all__ = ["fp16_guard", "cast_model_to_fp16", "cast_parameters_to_fp16"] -_logger = get_logger(__name__, - logging.INFO, - fmt='%(asctime)s-%(levelname)s: %(message)s') +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) _valid_types = [ - core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, - core.VarDesc.VarType.LOD_TENSOR_ARRAY + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.SELECTED_ROWS, + core.VarDesc.VarType.LOD_TENSOR_ARRAY, ] _fp16_guard_pattern = "__use_fp16__" @@ -75,7 +77,9 @@ def _dtype_to_str(dtype): Args: dtype (VarType): Variable type. """ - if dtype == core.VarDesc.VarType.FP16: + if dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]: + # TODO(Xreki): change the returned str to "bf16" for BF16 data type. + # Currently too many codes use "cast_fp16" as key. return 'fp16' else: return 'fp32' @@ -108,7 +112,12 @@ def _keep_fp32_input(op, in_name): return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'} if op_type in ['fused_attention', 'fused_feedforward']: return in_name in { - 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" + 'LnScale', + 'LnBias', + 'Ln2Scale', + 'Ln2Bias', + "Ln1Scale", + "Ln1Bias", } if op_type == 'fused_multi_transformer': return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'} @@ -125,8 +134,12 @@ def _keep_fp32_output(op, out_name): return out_name not in {'Y', 'ConvX', 'ConvZ'} if op_type in ['fused_attention', 'fused_feedforward']: return out_name in { - 'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean', - 'Ln1Variance' + 'LnMean', + 'LnVariance', + 'Ln2Mean', + 'Ln2Variance', + 'Ln1Mean', + 'Ln1Variance', } return False @@ -149,7 +162,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( - op, in_name): + op, in_name + ): continue for in_var_name in op.input(in_name): in_var = block._find_var_recursive(in_var_name) @@ -165,11 +179,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): # set cast_op device to `all`, can reduce send cast_var. # TODO: need remove this after we unified the dynamic # and static pipeline interface. - if src_dtype == core.VarDesc.VarType.FP32 and in_var.stop_gradient: + if ( + src_dtype == core.VarDesc.VarType.FP32 + and in_var.stop_gradient + ): prev_op = None if in_var.op is op: - prev_op = find_true_prev_op(block.ops, op, - in_var_name) + prev_op = find_true_prev_op( + block.ops, op, in_var_name + ) elif in_var.op is not None: prev_op = in_var.op @@ -177,33 +195,40 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): if prev_op is not None: prev_op_device = prev_op.attr('op_device') - if prev_op_device is not None and 'all' in prev_op_device: + if ( + prev_op_device is not None + and 'all' in prev_op_device + ): op_device = prev_op_device out_var = block.create_var( name=cast_name, dtype=dest_dtype, persistable=False, - stop_gradient=in_var.stop_gradient) - - block._insert_op_without_sync(idx, - type="cast", - inputs={"X": in_var}, - outputs={"Out": out_var}, - attrs={ - "in_dtype": in_var.dtype, - "out_dtype": - out_var.dtype, - "op_device": op_device, - "op_role": - op.attr("op_role"), - }) + stop_gradient=in_var.stop_gradient, + ) + + block._insert_op_without_sync( + idx, + type="cast", + inputs={"X": in_var}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": in_var.dtype, + "out_dtype": out_var.dtype, + "op_device": op_device, + "op_role": op.attr("op_role"), + }, + ) num_cast_ops += 1 _rename_arg(op, in_var.name, out_var.name) else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dest_dtype) - if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: + if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype in [ + core.VarDesc.VarType.FP16, + core.VarDesc.VarType.BF16, + ]: for out_name in op.output_names: if _keep_fp32_output(op, out_name): continue @@ -212,41 +237,48 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): if out_var.type not in _valid_types: continue if out_var.dtype == core.VarDesc.VarType.FP32: - out_var.desc.set_dtype(core.VarDesc.VarType.FP16) + out_var.desc.set_dtype(dest_dtype) if op.has_attr('out_dtype'): - op._set_attr('out_dtype', core.VarDesc.VarType.FP16) + op._set_attr('out_dtype', dest_dtype) return num_cast_ops -def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, - op_var_rename_map): +def _insert_cast_post_op( + block, op, idx, src_dtype, dest_dtype, target_name, op_var_rename_map +): num_cast_ops = 0 target_var = block.var(target_name) if target_var.type not in _valid_types or target_var.dtype == dest_dtype: return num_cast_ops - assert target_var.dtype == src_dtype, \ - "The real dtype({}) is not equal to the src dtype({})".format( - _dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)) + assert ( + target_var.dtype == src_dtype + ), "The real dtype({}) is not equal to the src dtype({})".format( + _dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype) + ) cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype) cast_var = block.vars.get(cast_name) if cast_var is None or cast_var.dtype != dest_dtype: - cast_var = block.create_var(name=cast_name, - dtype=dest_dtype, - persistable=False, - stop_gradient=target_var.stop_gradient) - block._insert_op(idx, - type="cast", - inputs={"X": target_var}, - outputs={"Out": cast_var}, - attrs={ - "in_dtype": target_var.dtype, - "out_dtype": cast_var.dtype, - "op_device": op.attr("op_device"), - "op_role": op.attr("op_role"), - }) + cast_var = block.create_var( + name=cast_name, + dtype=dest_dtype, + persistable=False, + stop_gradient=target_var.stop_gradient, + ) + block._insert_op( + idx, + type="cast", + inputs={"X": target_var}, + outputs={"Out": cast_var}, + attrs={ + "in_dtype": target_var.dtype, + "out_dtype": cast_var.dtype, + "op_device": op.attr("op_device"), + "op_role": op.attr("op_role"), + }, + ) num_cast_ops += 1 op_var_rename_map[block.idx][target_var.name] = cast_var.name @@ -272,8 +304,10 @@ def find_true_prev_op(ops, cur_op, var_name): prev_op.append(op) if prev_op: if not len(prev_op) == 1: - raise ValueError("There must be only one previous op " - "that outputs {0} variable".format(var_name)) + raise ValueError( + "There must be only one previous op " + "that outputs {0} variable".format(var_name) + ) else: return prev_op[0] return None @@ -315,8 +349,7 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False): def find_op_index(block_desc, cur_op_desc): - """ - """ + """ """ for idx in range(block_desc.op_size()): if cur_op_desc == block_desc.op(idx): return idx @@ -350,8 +383,9 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard): return True if use_fp16_guard: - if op.has_attr("op_namescope") and \ - (_fp16_guard_pattern in op.attr("op_namescope")): + if op.has_attr("op_namescope") and ( + _fp16_guard_pattern in op.attr("op_namescope") + ): # op in fp16 guard return False else: @@ -388,7 +422,12 @@ def fp16_guard(): yield -def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): +def cast_model_to_fp16( + program, + amp_lists=None, + use_fp16_guard=True, + dest_type=core.VarDesc.VarType.FP16, +): """ Traverse all ops in the whole model and set their inputs and outputs to the fp16 data type. This function will do some special process for @@ -399,6 +438,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object. use_fp16_guard(bool): Determine whether to use `fp16_guard` when constructing the program. Default True. + dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16. """ if amp_lists is None: @@ -421,7 +461,8 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): for in_name in op.input_names: # for ipu, all inputs must be converted to fp16 if not core.is_compiled_with_ipu() and _keep_fp32_input( - op, in_name): + op, in_name + ): continue for in_var_name in op.input(in_name): in_var = None @@ -429,29 +470,36 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): in_var = block.var(in_var_name) except ValueError as e: _logger.debug( - "-- {}, try to get it in the global block --". - format(e)) + "-- {}, try to get it in the global block --".format( + e + ) + ) in_var = global_block.var(in_var_name) if in_var is not None: _logger.debug( - "-- var {} is got in the global block --". - format(in_var_name)) + "-- var {} is got in the global block --".format( + in_var_name + ) + ) if in_var is None or in_var.type not in _valid_types: continue if in_var.dtype == core.VarDesc.VarType.FP32: - in_var.desc.set_dtype(core.VarDesc.VarType.FP16) + in_var.desc.set_dtype(dest_type) to_fp16_var_names.add(in_var_name) _logger.debug( - "-- op type: {}, in var name: {}, in var dtype: {} --". - format(op.type, in_var_name, in_var.dtype)) + "-- op type: {}, in var name: {}, in var dtype: {} --".format( + op.type, in_var_name, in_var.dtype + ) + ) for out_name in op.output_names: # for ipu, all outputs must be converted to fp16 if not core.is_compiled_with_ipu() and _keep_fp32_output( - op, out_name): + op, out_name + ): continue for out_var_name in op.output(out_name): out_var = None @@ -459,32 +507,35 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): out_var = block.var(out_var_name) except ValueError as e: _logger.debug( - "-- {}, try to get it in the global block --". - format(e)) + "-- {}, try to get it in the global block --".format( + e + ) + ) out_var = global_block.var(out_var_name) if out_var is not None: _logger.debug( - "-- var {} is got in the global block --". - format(out_var_name)) + "-- var {} is got in the global block --".format( + out_var_name + ) + ) if out_var is None or out_var.type not in _valid_types: continue if out_var.dtype == core.VarDesc.VarType.FP32: - out_var.desc.set_dtype(core.VarDesc.VarType.FP16) + out_var.desc.set_dtype(dest_type) _logger.debug( - "-- op type: {}, out var name: {}, out var dtype: {} --" - .format(op.type, out_var_name, out_var.dtype)) - if op.has_attr('in_dtype') and op.attr( - 'in_dtype') == core.VarDesc.VarType.FP32: - op._set_attr('in_dtype', core.VarDesc.VarType.FP16) - if op.has_attr('out_dtype') and op.attr( - 'out_dtype') == core.VarDesc.VarType.FP32: - op._set_attr('out_dtype', core.VarDesc.VarType.FP16) - if op.has_attr('dtype') and op.attr( - 'dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.FP16) + "-- op type: {}, out var name: {}, out var dtype: {} --".format( + op.type, out_var_name, out_var.dtype + ) + ) + for attr_name in ['in_dtype', 'out_dtype', 'dtype']: + if ( + op.has_attr(attr_name) + and op.attr(attr_name) == core.VarDesc.VarType.FP32 + ): + op._set_attr(attr_name, dest_type) # process ops in keep_fp32_ops op_var_rename_map = [ @@ -497,25 +548,29 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): op = ops[idx] num_cast_ops = 0 if op in keep_fp32_ops: - pre_cast_num = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.FP16, - core.VarDesc.VarType.FP32) + pre_cast_num = _insert_cast_op( + block, op, idx, dest_type, core.VarDesc.VarType.FP32 + ) num_cast_ops += pre_cast_num for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) if out_var is None or out_var.type not in _valid_types: continue - if out_var.dtype == core.VarDesc.VarType.FP16: + if out_var.dtype == dest_type: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) post_ops = find_true_post_op(ops, op, out_var_name) for post_op in post_ops: if post_op in keep_fp32_ops: continue post_cast_num = _insert_cast_post_op( - block, op, idx + pre_cast_num + 1, + block, + op, + idx + pre_cast_num + 1, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, out_var_name, - op_var_rename_map) + dest_type, + out_var_name, + op_var_rename_map, + ) num_cast_ops += post_cast_num idx += num_cast_ops + 1 @@ -523,7 +578,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): return to_fp16_var_names -def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): +def _convert_float_to_bfloat16(place, fp32_array): + paddle.disable_static() + framework._set_expected_place(place) + fp32_tensor = paddle.to_tensor(fp32_array) + bf16_array = paddle.cast(fp32_tensor, paddle.bfloat16).numpy() + paddle.enable_static() + return bf16_array + + +def cast_parameters_to_fp16( + place, + program, + scope=None, + to_fp16_var_names=None, + dest_type=core.VarDesc.VarType.FP16, +): """ Traverse all parameters in the whole model and set them to the FP16 data type. Whereas, this function will keep parameters of batchnorms in FP32. @@ -535,6 +605,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names` will be set to FP16. Usually, it is the returned value of `cast_model_to_fp16` API. + dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16. """ all_parameters = [] for block in program.blocks: @@ -544,13 +615,22 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): var_scope = scope if scope else global_scope() for param in all_parameters: if param.name in fp16_var_names: - _logger.debug("---- cast {} to fp16 dtype ----".format(param.name)) - param_t = var_scope.find_var(param.name).get_tensor() - data = np.array(param_t) - param_t.set(np.float16(data), place) + _logger.debug( + "---- cast {} to fp16/bf16 dtype ----".format(param.name) + ) + if var_scope.find_var(param.name): + param_t = var_scope.find_var(param.name).get_tensor() + data = np.array(param_t) + if dest_type == core.VarDesc.VarType.BF16: + bf16_data = _convert_float_to_bfloat16(place, data) + param_t.set(bf16_data, place) + else: + param_t.set(np.float16(data), place) + else: + _logger.warning(f"Cannot find {param.name}") -def rewrite_program(main_prog, amp_lists): +def rewrite_program(main_prog, amp_lists, dest_type=core.VarDesc.VarType.FP16): """ Traverse all ops in current block and insert cast op according to which set current op belongs to. @@ -569,6 +649,7 @@ def rewrite_program(main_prog, amp_lists): Args: main_prog (Program): The main program for training. + dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16. """ block = main_prog.global_block() block._sync_with_cpp() @@ -585,7 +666,8 @@ def rewrite_program(main_prog, amp_lists): continue if amp_lists.black_varnames is not None and _is_in_black_varnames( - op, amp_lists): + op, amp_lists + ): black_op_set.add(op) continue @@ -611,11 +693,15 @@ def rewrite_program(main_prog, amp_lists): else: prev_op = in_var.op # if it's one of inputs - if prev_op in black_op_set or \ - prev_op.type in amp_lists.black_list: + if ( + prev_op in black_op_set + or prev_op.type in amp_lists.black_list + ): is_black_op = True - elif prev_op in white_op_set or \ - prev_op.type in amp_lists.white_list: + elif ( + prev_op in white_op_set + or prev_op.type in amp_lists.white_list + ): is_white_op = True if is_black_op: black_op_set.add(op) @@ -633,13 +719,13 @@ def rewrite_program(main_prog, amp_lists): op = ops[idx] num_cast_ops = 0 if op in black_op_set: - num_cast_ops = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.FP16, - core.VarDesc.VarType.FP32) + num_cast_ops = _insert_cast_op( + block, op, idx, dest_type, core.VarDesc.VarType.FP32 + ) elif op in white_op_set: - num_cast_ops = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16) + num_cast_ops = _insert_cast_op( + block, op, idx, core.VarDesc.VarType.FP32, dest_type + ) else: pass @@ -670,13 +756,16 @@ def update_role_var_grad(main_prog, params_grads): if role & int(BACKWARD) and op.has_attr('op_role_var'): op._remove_attr("op_role_var") else: - raise ValueError("The cast op {0} must be in BACKWARD role " - "and have op_role_var attr.".format(op)) + raise ValueError( + "The cast op {0} must be in BACKWARD role " + "and have op_role_var attr.".format(op) + ) fp16_grad_name = op.input(op.input_names[0])[0] op_for_fp16_grad = find_true_prev_op(block.ops, op, fp16_grad_name) - op_role_var_attr_name = \ + op_role_var_attr_name = ( core.op_proto_and_checker_maker.kOpRoleVarAttrName() + ) attr_val = [p.name, fp16_grad_name] if op_for_fp16_grad.has_attr(op_role_var_attr_name): attr_val.extend(op_for_fp16_grad.attr(op_role_var_attr_name)) @@ -690,18 +779,22 @@ def update_role_var_grad(main_prog, params_grads): continue post_ops = find_true_post_op(block.ops, op, g.name) if post_ops: - raise ValueError("The cast op {0}'s output should not be" - "used by a non-optimize op, however, it" - "is used by {1}".format(op, post_ops[0])) + raise ValueError( + "The cast op {0}'s output should not be" + "used by a non-optimize op, however, it" + "is used by {1}".format(op, post_ops[0]) + ) # add new op in the python and cpp at the same time new_op_desc = block.desc.append_op() new_op_desc.copy_from(op.desc) - new_op = framework.Operator(block=block, - desc=new_op_desc, - type=None, - inputs=None, - outputs=None, - attrs=None) + new_op = framework.Operator( + block=block, + desc=new_op_desc, + type=None, + inputs=None, + outputs=None, + attrs=None, + ) block.ops.append(new_op) op_idx = find_op_index(block.desc, op.desc) if op_idx == -1: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 28053f00be9..d0a76bd6391 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -18,19 +18,32 @@ import six import paddle from paddle.fluid import framework, backward, core, program_guard -from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor +from paddle.fluid.executor import ( + _is_enable_standalone_executor, + _is_dy2st_enable_standalone_executor, +) from paddle.fluid.dygraph import layers from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static import logging_utils -from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM +from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ( + RETURN_NO_VALUE_MAGIC_NUM, +) from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.compiler import BuildStrategy from paddle.fluid.framework import _apply_pass -from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists -from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16 -from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard +from paddle.fluid.contrib.mixed_precision.decorator import ( + AutoMixedPrecisionLists, +) +from paddle.fluid.contrib.mixed_precision.fp16_utils import ( + rewrite_program, + cast_model_to_fp16, +) +from paddle.fluid.dygraph.amp.auto_cast import ( + _in_amp_guard, + _in_pure_fp16_guard, +) import paddle.compat as cpt from paddle import _C_ops, _legacy_C_ops @@ -64,7 +77,8 @@ class NestSequence(object): var_ids = [] for idx, var in enumerate(self.__input_list): if isinstance( - var, (framework.Variable, core.VarBase, core.eager.Tensor)): + var, (framework.Variable, core.VarBase, core.eager.Tensor) + ): var_ids.append(idx) return var_ids @@ -77,15 +91,17 @@ class NestSequence(object): warning_types = set() for var in self.__input_list: if not isinstance( - var, - (framework.Variable, core.VarBase, core.eager.Tensor)): + var, (framework.Variable, core.VarBase, core.eager.Tensor) + ): warning_types.add(type(var)) if warning_types: logging_utils.warn( "Output of traced function contains non-tensor type values: {}. " "Currently, We don't support to update them while training and will return " - "what we first saw. Please try to return them as tensor.". - format(list(warning_types))) + "what we first saw. Please try to return them as tensor.".format( + list(warning_types) + ) + ) @property def var_ids(self): @@ -139,12 +155,9 @@ class PartialProgramLayer: Layer: A Layer object that run all ops internally in static mode. """ - def __init__(self, - main_program, - inputs, - outputs, - parameters=None, - **kwargs): + def __init__( + self, main_program, inputs, outputs, parameters=None, **kwargs + ): super(PartialProgramLayer, self).__init__() self._inputs = NestSequence(inputs) self._outputs = NestSequence(outputs, need_check=True) @@ -160,14 +173,18 @@ class PartialProgramLayer: # Set default mode to train self.training = True - custom_white_list, custom_black_list = None, None + amp_dtype, custom_white_list, custom_black_list = None, None, None tracer = framework._dygraph_tracer() if tracer: custom_white_list, custom_black_list = tracer._get_amp_op_list() - # For AMP training - self._amp_list = AutoMixedPrecisionLists( - custom_white_list=custom_white_list, - custom_black_list=custom_black_list) + amp_dtype = tracer._amp_dtype + if amp_dtype is not None and amp_dtype in ['float16', 'bfloat16']: + # For AMP training + self._amp_list = AutoMixedPrecisionLists( + custom_white_list=custom_white_list, + custom_black_list=custom_black_list, + dtype=amp_dtype, + ) # program_id -> list(scope) self._scope_cache = {} @@ -203,7 +220,8 @@ class PartialProgramLayer: return self._origin_main_program.clone(for_test=is_infer_mode) else: train_program = self._append_backward_desc( - self._origin_main_program) + self._origin_main_program + ) # Note: Only set grad type once after initializing train program. So we put it here. self._set_grad_type(self._params, train_program) return train_program @@ -223,16 +241,18 @@ class PartialProgramLayer: @switch_to_static_graph def _create_pure_fp16_program(self, is_infer_mode=False): pure_fp16_program = self._origin_main_program.clone( - for_test=is_infer_mode) + for_test=is_infer_mode + ) with program_guard(pure_fp16_program): - cast_model_to_fp16(pure_fp16_program, - self._amp_list, - use_fp16_guard=False) + cast_model_to_fp16( + pure_fp16_program, self._amp_list, use_fp16_guard=False + ) if is_infer_mode: return pure_fp16_program else: train_pure_fp16_program = self._append_backward_desc( - pure_fp16_program) + pure_fp16_program + ) self._set_grad_type(self._params, train_pure_fp16_program) return train_pure_fp16_program @@ -240,23 +260,27 @@ class PartialProgramLayer: def _create_forward_backward_train_program(self): whole_program = self._create_program() forward_end_op_index = self._infer_program.desc.block(0).op_size() - return self._get_forward_backward_program_form(whole_program, - forward_end_op_index) + return self._get_forward_backward_program_form( + whole_program, forward_end_op_index + ) @switch_to_static_graph def _create_forward_backward_train_amp_program(self): whole_program = self._create_amp_program() forward_end_op_index = self._infer_amp_program.desc.block(0).op_size() - return self._get_forward_backward_program_form(whole_program, - forward_end_op_index) + return self._get_forward_backward_program_form( + whole_program, forward_end_op_index + ) @switch_to_static_graph def _create_forward_backward_train_pure_fp16_program(self): whole_program = self._create_pure_fp16_program() forward_end_op_index = self._infer_pure_fp16_program.desc.block( - 0).op_size() - return self._get_forward_backward_program_form(whole_program, - forward_end_op_index) + 0 + ).op_size() + return self._get_forward_backward_program_form( + whole_program, forward_end_op_index + ) @LazyInitialized def _train_program(self): @@ -352,8 +376,9 @@ class PartialProgramLayer: @LazyInitialized def _train_program_id(self): program_id = _hash_with_id(self._train_program, self) - core._set_cached_executor_build_strategy(program_id, - self._build_strategy) + core._set_cached_executor_build_strategy( + program_id, self._build_strategy + ) return program_id @LazyInitialized @@ -363,8 +388,9 @@ class PartialProgramLayer: @LazyInitialized def _train_amp_program_id(self): program_id = _hash_with_id(self._train_amp_program, self) - core._set_cached_executor_build_strategy(program_id, - self._build_strategy) + core._set_cached_executor_build_strategy( + program_id, self._build_strategy + ) return program_id @LazyInitialized @@ -374,8 +400,9 @@ class PartialProgramLayer: @LazyInitialized def _train_pure_fp16_program_id(self): program_id = _hash_with_id(self._train_pure_fp16_program, self) - core._set_cached_executor_build_strategy(program_id, - self._build_strategy) + core._set_cached_executor_build_strategy( + program_id, self._build_strategy + ) return program_id @LazyInitialized @@ -411,8 +438,9 @@ class PartialProgramLayer: return main_program - def prepare_gradient_aggregation(self, start_idx, main_program, - target_program): + def prepare_gradient_aggregation( + self, start_idx, main_program, target_program + ): """ Why we need add gradient aggregation operation ? In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as @@ -420,7 +448,7 @@ class PartialProgramLayer: x = 2 * in # <---- x is a non-leaf node in program. y = x + 3 return x, y - + loss = forward(in)[0].sum() loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op """ @@ -430,8 +458,8 @@ class PartialProgramLayer: if exist a op whose inputs is var, then return True """ if not isinstance(var, framework.Variable) or var.type not in [ - core.VarDesc.VarType.LOD_TENSOR, - core.VarDesc.VarType.SELECTED_ROWS + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.SELECTED_ROWS, ]: return False if var.dtype not in [paddle.float32, paddle.float64]: @@ -448,20 +476,28 @@ class PartialProgramLayer: new_grad_name = var.name + suffix + "@GRAD" finded_ops = list( filter( - lambda x: x[0] >= start_idx and any([ - out_arg == var_grad_name - for out_arg in x[1].output_arg_names - ]), enumerate(target_program.block(0).ops))) + lambda x: x[0] >= start_idx + and any( + [ + out_arg == var_grad_name + for out_arg in x[1].output_arg_names + ] + ), + enumerate(target_program.block(0).ops), + ) + ) # len(finded_ops) may equals zero when stop_gradient works. # len(finded_ops) may > 1, because we may have fill_constant op. if len(finded_ops) == 0: return None # step1: create a new var named var.name@GRAD - target_program.block(0).create_var(name=new_grad_name, - type=var.type, - dtype=var.dtype, - shape=var.shape) + target_program.block(0).create_var( + name=new_grad_name, + type=var.type, + dtype=var.dtype, + shape=var.shape, + ) # step2: rename the var.name@GRAD to var.name@GRAD@dy2static for idx, op in finded_ops: op._rename_input(var_grad_name, new_grad_name) @@ -472,11 +508,13 @@ class PartialProgramLayer: finded_ops[-1][0] + 1, type='sum', inputs={'X': [var_grad_name, new_grad_name]}, - outputs={"Out": var_grad_name}) + outputs={"Out": var_grad_name}, + ) return None to_processed_vars = list( - filter(_need_aggregation, self._outputs.tolist())) + filter(_need_aggregation, self._outputs.tolist()) + ) for _var in to_processed_vars: _insert_aggregation_ops_for_var(target_program, _var) @@ -492,8 +530,9 @@ class PartialProgramLayer: if targets and self._params: backward.gradients(targets=targets, inputs=[]) - start_idx = len( - main_program.block(0).ops) + 2 * len(self._outputs.tolist()) + start_idx = len(main_program.block(0).ops) + 2 * len( + self._outputs.tolist() + ) self.prepare_gradient_aggregation(start_idx, main_program, program) @@ -512,7 +551,10 @@ class PartialProgramLayer: found_param = False for block in program.blocks: for op in block.ops: - if param.name in op.input_arg_names or param.name in op.output_arg_names: + if ( + param.name in op.input_arg_names + or param.name in op.output_arg_names + ): required_params.append(param) found_param = True break @@ -529,15 +571,21 @@ class PartialProgramLayer: var_desc = block.vars[name].desc var_base = None if not framework._in_eager_mode_: - var_base = core.VarBase(var_desc.dtype(), - var_desc.shape(), - var_desc.name(), - var_desc.type(), False) + var_base = core.VarBase( + var_desc.dtype(), + var_desc.shape(), + var_desc.name(), + var_desc.type(), + False, + ) else: - var_base = core.eager.Tensor(var_desc.dtype(), - var_desc.shape(), - var_desc.name(), - var_desc.type(), False) + var_base = core.eager.Tensor( + var_desc.dtype(), + var_desc.shape(), + var_desc.name(), + var_desc.type(), + False, + ) double_grads.append(var_base) return self._valid_vars(double_grads) @@ -557,36 +605,62 @@ class PartialProgramLayer: attrs = [ 'global_block', - self.program.desc.block(0), 'start_op_index', 0, 'end_op_index', - self._get_end_op_index(), 'is_test', not self.training, - 'program_id', self.program_id + self.program.desc.block(0), + 'start_op_index', + 0, + 'end_op_index', + self._get_end_op_index(), + 'is_test', + not self.training, + 'program_id', + self.program_id, ] if self._cuda_graph_capture_mode: attrs.extend( - ('cuda_graph_capture_mode', self._cuda_graph_capture_mode, - 'cuda_graph_pool_id', self._cuda_graph_pool_id)) - - use_interpretorcore = _is_enable_standalone_executor( - ) and _is_dy2st_enable_standalone_executor() + ( + 'cuda_graph_capture_mode', + self._cuda_graph_capture_mode, + 'cuda_graph_pool_id', + self._cuda_graph_pool_id, + ) + ) + + use_interpretorcore = ( + _is_enable_standalone_executor() + and _is_dy2st_enable_standalone_executor() + ) attrs.extend(('use_interpretorcore', use_interpretorcore)) if use_interpretorcore: attrs.extend( - ('forward_global_block', self.forward_program.desc.block(0), - 'backward_global_block', self.backward_program.desc.block(0))) + ( + 'forward_global_block', + self.forward_program.desc.block(0), + 'backward_global_block', + self.backward_program.desc.block(0), + ) + ) _legacy_C_ops.run_program( - self._valid_vars(in_vars), self._valid_vars(self._params), + self._valid_vars(in_vars), + self._valid_vars(self._params), self._valid_vars(out_vars), - self._create_scope_vec(program_id=self.program_id, - use_scope_cache=True), - self._double_grads, self._cuda_graph_vec, *attrs) + self._create_scope_vec( + program_id=self.program_id, use_scope_cache=True + ), + self._double_grads, + self._cuda_graph_vec, + *attrs + ) else: - _legacy_C_ops.run_program(self._valid_vars(in_vars), - self._valid_vars(self._params), - self._valid_vars(out_vars), - self._create_scope_vec(), - self._double_grads, self._cuda_graph_vec, - *attrs) + _legacy_C_ops.run_program( + self._valid_vars(in_vars), + self._valid_vars(self._params), + self._valid_vars(out_vars), + self._create_scope_vec(), + self._double_grads, + self._cuda_graph_vec, + *attrs + ) restored_nest_out = self._restore_out(out_vars) return self._remove_no_value(restored_nest_out) @@ -594,9 +668,11 @@ class PartialProgramLayer: if _in_pure_fp16_guard(): for i, var in enumerate(in_vars): name = var.name - if (self.program.global_block().has_var(name) - and self.program.global_block().var(name).dtype - == paddle.float16): + if ( + self.program.global_block().has_var(name) + and self.program.global_block().var(name).dtype + == paddle.float16 + ): in_vars[i] = var.astype('float16') in_vars[i].name = name @@ -627,25 +703,32 @@ class PartialProgramLayer: return self._infer_program @switch_to_static_graph - def _get_forward_backward_program_form(self, whole_program, - forward_end_op_index): + def _get_forward_backward_program_form( + self, whole_program, forward_end_op_index + ): forward_builded_program = add_build_strategy_for( - whole_program, 0, forward_end_op_index, self._build_strategy) + whole_program, 0, forward_end_op_index, self._build_strategy + ) backward_start_op_index = forward_end_op_index + 2 * len( - self._outputs.var_ids) + self._outputs.var_ids + ) backward_end_op_index = whole_program.desc.block(0).op_size() backward_builded_program = add_build_strategy_for( - whole_program, backward_start_op_index, backward_end_op_index, - self._build_strategy) - self._apply_inplace_pass(forward_builded_program, - backward_builded_program) + whole_program, + backward_start_op_index, + backward_end_op_index, + self._build_strategy, + ) + self._apply_inplace_pass( + forward_builded_program, backward_builded_program + ) return [forward_builded_program, backward_builded_program] def _apply_inplace_pass(self, forward_program, backward_program): attr_types = { "use_cuda": "bool", "mem_opt_skip_vars": "list[str]", - "for_partial_block": "bool" + "for_partial_block": "bool", } empty_startup_program = paddle.static.Program() use_cuda = True if core.is_compiled_with_cuda() else False @@ -667,22 +750,33 @@ class PartialProgramLayer: forward_mem_opt_skip_vars.append(var.desc.name()) backward_mem_opt_skip_vars.append(var.desc.name()) for var_name in core.parse_safe_eager_deletion_skip_vars( - backward_program.desc): + backward_program.desc + ): forward_mem_opt_skip_vars.append(var_name) attrs = { "use_cuda": use_cuda, "mem_opt_skip_vars": forward_mem_opt_skip_vars, - "for_partial_block": True + "for_partial_block": True, } - _apply_pass(forward_program, empty_startup_program, - "buffer_shared_inplace_pass", attrs, attr_types) + _apply_pass( + forward_program, + empty_startup_program, + "buffer_shared_inplace_pass", + attrs, + attr_types, + ) attrs = { "use_cuda": use_cuda, "mem_opt_skip_vars": backward_mem_opt_skip_vars, - "for_partial_block": True + "for_partial_block": True, } - _apply_pass(backward_program, empty_startup_program, - "buffer_shared_inplace_pass", attrs, attr_types) + _apply_pass( + backward_program, + empty_startup_program, + "buffer_shared_inplace_pass", + attrs, + attr_types, + ) def _prepare(self, inputs): """ @@ -698,23 +792,28 @@ class PartialProgramLayer: if isinstance(value, np.ndarray): var = None if not framework._in_eager_mode_: - var = core.VarBase(value=value, - name=self._inputs[i].desc.name(), - persistable=False, - place=expected_place, - zero_copy=True) + var = core.VarBase( + value=value, + name=self._inputs[i].desc.name(), + persistable=False, + place=expected_place, + zero_copy=True, + ) else: - var = core.eager.Tensor(value=value, - name=self._inputs[i].desc.name(), - persistable=False, - place=expected_place, - zero_copy=True) + var = core.eager.Tensor( + value=value, + name=self._inputs[i].desc.name(), + persistable=False, + place=expected_place, + zero_copy=True, + ) elif isinstance(value, (core.VarBase, core.eager.Tensor)): # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times # into CUDAPlace when it's as input of multi Ops. so we move it in advance # to avoid this problem. if value.stop_gradient and not value.place._equals( - expected_place): + expected_place + ): var = value._copy_to(expected_place, False) var.stop_gradient = True else: @@ -737,12 +836,21 @@ class PartialProgramLayer: return out_varbase_map[var_desc.name()] if not framework._in_eager_mode_: - var_base = core.VarBase(var_desc.dtype(), var_desc.shape(), - var_desc.name(), var_desc.type(), False) + var_base = core.VarBase( + var_desc.dtype(), + var_desc.shape(), + var_desc.name(), + var_desc.type(), + False, + ) else: - var_base = core.eager.Tensor(var_desc.dtype(), var_desc.shape(), - var_desc.name(), var_desc.type(), - False) + var_base = core.eager.Tensor( + var_desc.dtype(), + var_desc.shape(), + var_desc.name(), + var_desc.type(), + False, + ) var_base.stop_gradient = var.stop_gradient out_varbase_map[var_desc.name()] = var_base return var_base @@ -755,20 +863,30 @@ class PartialProgramLayer: def _create_scope_vec(self, program_id=None, use_scope_cache=False): # Hold forward variables tmp_scope_vec = None - inner_scope = self._get_scope(program_id=program_id, - use_scope_cache=use_scope_cache) + inner_scope = self._get_scope( + program_id=program_id, use_scope_cache=use_scope_cache + ) if not framework._in_eager_mode_: - tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], - "program_out_scope", - core.VarDesc.VarType.STEP_SCOPES, True) + tmp_scope_vec = core.VarBase( + core.VarDesc.VarType.FP32, + [], + "program_out_scope", + core.VarDesc.VarType.STEP_SCOPES, + True, + ) tmp_scope_vec.value().set_scope(inner_scope) else: tmp_scope_vec = [inner_scope] return tmp_scope_vec def _create_cuda_graph_vec(self): - var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph", - core.VarDesc.VarType.RAW, True) + var = core.VarBase( + core.VarDesc.VarType.FP32, + [], + "cuda_graph", + core.VarDesc.VarType.RAW, + True, + ) var.stop_gradient = True return var @@ -791,8 +909,9 @@ class PartialProgramLayer: return main_program.clone(for_test=True) def _is_no_value(self, var): - if isinstance(var, - (core.VarBase, core.eager.Tensor)) and var.shape == [1]: + if isinstance(var, (core.VarBase, core.eager.Tensor)) and var.shape == [ + 1 + ]: # NOTE: .numpy() will insert MemcpySync operation, it hits performance. if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: return True @@ -808,13 +927,14 @@ class PartialProgramLayer: return out_vars elif isinstance(out_vars, (tuple, list)): if isinstance(out_vars, tuple): - res = tuple(var for var in out_vars - if not self._is_no_value(var)) + res = tuple( + var for var in out_vars if not self._is_no_value(var) + ) else: # isinstance(out_vars, list) res = [var for var in out_vars if not self._is_no_value(var)] - has_removed = (len(out_vars) > len(res)) + has_removed = len(out_vars) > len(res) # len(out_vars) > len(res) means we have removed var. This is # preventing out_vars is empty or just one element at the beginning if len(res) == 0 and has_removed: @@ -835,7 +955,8 @@ class PartialProgramLayer: for param in params: grad_name = param.name + core.grad_var_suffix() grad_var = train_program.desc.block(0).find_var( - cpt.to_bytes(grad_name)) + cpt.to_bytes(grad_name) + ) # NOTE: cannot find var desc maybe no problem, such as in batch_norm if grad_var is None: continue @@ -864,15 +985,18 @@ class PartialProgramLayer: if not isinstance(self._params, (list, tuple)): raise TypeError( "Type of self._params in PartialProgramLayer should be list or tuple, but received %s." - % type(self._params)) + % type(self._params) + ) param_and_buffer_names_set = set() for i, var in enumerate(self._params): # self._params constains parameters and buffers with persistable=True. if not isinstance(var, (core.VarBase, core.eager.Tensor)): raise TypeError( - 'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.' - .format(i, type(var))) + 'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.format( + i, type(var) + ) + ) param_and_buffer_names_set.add(var.name) for block in main_program.blocks: @@ -886,7 +1010,8 @@ class PartialProgramLayer: "\n\tRevise suggestion: " "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer." "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List" - % name) + % name + ) def _valid_vars(self, vars): """ @@ -903,13 +1028,23 @@ def _create_fake_var(): """ if not framework._in_eager_mode_: return [ - core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var", - core.VarDesc.VarType.RAW, False) + core.VarBase( + core.VarDesc.VarType.FP32, + [], + "Fake_var", + core.VarDesc.VarType.RAW, + False, + ) ] else: return [ - core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var", - core.VarDesc.VarType.RAW, False) + core.eager.Tensor( + core.VarDesc.VarType.FP32, + [], + "Fake_var", + core.VarDesc.VarType.RAW, + False, + ) ] @@ -918,23 +1053,27 @@ def partial_program_from(concrete_program): if inputs and isinstance(inputs[0], layers.Layer): inputs = inputs[1:] - return PartialProgramLayer(concrete_program.main_program, inputs, - concrete_program.outputs, - concrete_program.parameters, - **concrete_program.kwargs) + return PartialProgramLayer( + concrete_program.main_program, + inputs, + concrete_program.outputs, + concrete_program.parameters, + **concrete_program.kwargs + ) @switch_to_static_graph -def add_build_strategy_for(program, - start_op_index, - end_op_index, - build_strategy=None): - if (start_op_index < end_op_index): +def add_build_strategy_for( + program, start_op_index, end_op_index, build_strategy=None +): + if start_op_index < end_op_index: compiled_program = paddle.static.CompiledProgram( core.Graph(program.desc, start_op_index, end_op_index), - build_strategy=build_strategy) - compiled_program._compile(core.Scope(), - framework._current_expected_place()) + build_strategy=build_strategy, + ) + compiled_program._compile( + core.Scope(), framework._current_expected_place() + ) ir_graph = framework.IrGraph(compiled_program._graph) builded_program = ir_graph.to_program() if hasattr(compiled_program._program, 'lr_sheduler'): -- GitLab