diff --git a/python/paddle/static/amp/debugging.py b/python/paddle/static/amp/debugging.py index 28abe84c39b2e6bcf3f7f4f4199b887fb44fd3dd..5a894495d98f5e679e288d248d10744075af4d5e 100644 --- a/python/paddle/static/amp/debugging.py +++ b/python/paddle/static/amp/debugging.py @@ -13,8 +13,14 @@ # limitations under the License. import copy +import logging import paddle +from paddle.fluid.log_helper import get_logger + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) class OperatorStatsUnit: @@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input): var = block._var_recursive(var_name) return var.dtype except: - print( + _logger.warning( "Operator < {} > gets {} < {} : {} > error!".format( op.type, "input" if is_input else "output", arg_name, var_name ) @@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block): if _is_floating_point(compute_dtype) and _is_floating_point( var_dtype ): - print( + _logger.warning( "Operator < {} > has different input data types, input_names = {}, output_names = {}.".format( op.type, op.input_names, op.output_names ) @@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block): if _is_floating_point(compute_dtype) and _is_floating_point( var_dtype ): - print( + _logger.warning( "Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format( op.type, op.input_names, op.output_names ) @@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list): def _get_op_stats_list(program): + def _is_special_ops_with_input_x(op_type): + # operators have input X and have inputs different dtypes. + special_op_list = ['cast', 'batch_norm', 'instance_norm', 'layer_norm'] + if op_type in special_op_list: + return True + if op_type.replace("_grad", "") in special_op_list: + return True + return False + op_stats_list = [] for block in program.blocks: block_op_stats_dict = {} @@ -161,13 +176,7 @@ def _get_op_stats_list(program): 'create_double_buffer_reader', ]: compute_dtype = None - elif op.type in [ - 'cast', - 'layer_norm', - 'layer_norm_grad', - 'batch_norm', - 'batch_norm_grad', - ]: + elif _is_special_ops_with_input_x(op.type): # Not check the input and output dtype difference for this operators. compute_dtype = _get_var_dtype_from_block(block, op, 'X', True) elif "Param" in op.input_names: @@ -183,6 +192,78 @@ def _get_op_stats_list(program): def collect_operator_stats(program=None, print_subblocks=False): + """ + Collect the number of operators for different data types through parsing + the program. The statistical data are categorized according to four data + types, namely float32, float16, bfloat16 and others. + + Args: + program(Program, optional): The program to parse. Default None, and the default main_program will be parsed. + print_subblocks(bool, optional): Whether to print the operator stats for each subblock. Default False. + + Examples: + + .. code-block:: python + + import paddle + + paddle.enable_static() + + class SimpleConvNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3) + self.linear = paddle.nn.Linear(in_features=26, out_features=10) + + def forward(self, x): + out = self.conv(x) + out = paddle.nn.functional.relu(out) + out = self.linear(out) + out = paddle.nn.functional.softmax(out) + return out + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.utils.unique_name.guard(): + with paddle.static.program_guard(main_program, startup_program): + model = SimpleConvNet() + x = paddle.static.data( + name='input', shape=[None, 1, 28, 28], dtype='float32' + ) + out = model(x) + loss = paddle.mean(out) + optimizer = paddle.optimizer.AdamW() + optimizer = paddle.static.amp.decorate(optimizer) + optimizer.minimize(loss) + paddle.static.amp.debugging.collect_operator_stats(main_program) + # <------------------------------------------------ op list of all blocks -------------------------------------------------> + # <------------------------------------------------------- op list --------------------------------------------------------> + # <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls --> + # adamw | 0 | 0 | 4 | 0 + # cast | 5 | 0 | 6 | 0 + # check_finite_and_unscale | 0 | 0 | 1 | 0 + # conv2d | 1 | 0 | 0 | 0 + # conv2d_grad | 1 | 0 | 0 | 0 + # elementwise_add | 2 | 0 | 0 | 0 + # elementwise_add_grad | 2 | 0 | 0 | 0 + # elementwise_mul | 0 | 0 | 1 | 0 + # elementwise_mul_grad | 0 | 0 | 1 | 0 + # fill_constant | 0 | 0 | 1 | 0 + # matmul_v2 | 1 | 0 | 0 | 0 + # matmul_v2_grad | 1 | 0 | 0 | 0 + # memcpy | 0 | 0 | 0 | 1 + # reduce_mean | 0 | 0 | 1 | 0 + # reduce_mean_grad | 0 | 0 | 1 | 0 + # relu | 1 | 0 | 0 | 0 + # relu_grad | 1 | 0 | 0 | 0 + # reshape2 | 0 | 0 | 1 | 0 + # reshape2_grad | 0 | 0 | 1 | 0 + # softmax | 0 | 0 | 1 | 0 + # softmax_grad | 0 | 0 | 1 | 0 + # update_loss_scaling | 0 | 0 | 1 | 0 + # <----------------------------------------------------- op count: 22 -----------------------------------------------------> + """ + def _convert_to_list(op_stats_unit_dict): for key, value in op_stats_unit_dict.items(): op_stats_unit_dict[key] = value.convert_to_list() diff --git a/python/paddle/static/amp/decorator.py b/python/paddle/static/amp/decorator.py index 1ef13e5bcd0c5ca936f9b25936d3f69bd44d89ca..fc0aaac92bf46c5bd6ab9a1174a9589537bd044f 100644 --- a/python/paddle/static/amp/decorator.py +++ b/python/paddle/static/amp/decorator.py @@ -34,6 +34,21 @@ from .fp16_utils import ( from .function_overload import FunctionType, overload +def _set_multi_precision(optimizer, multi_precision): + if not isinstance( + optimizer, + (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer), + ): + raise RuntimeError( + "Current AMP training level is O2, optimizer is expected to be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format( + type(optimizer) + ) + ) + + if multi_precision and hasattr(optimizer, "_multi_precision"): + optimizer._multi_precision = multi_precision + + class OptimizerWithMixedPrecision: """ Optimizer with mixed-precision (MP) training. This is a wrapper of a common @@ -767,22 +782,96 @@ def decorate( amp_lists=None, level='O1', dtype='float16', + master_weight=None, init_loss_scaling=2**15, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, incr_ratio=2.0, decr_ratio=0.8, - use_dynamic_loss_scaling=True, + use_dynamic_loss_scaling=None, use_amp_guard=False, use_promote=False, ): """ Decorate the given optimizer to adapt to the mixed-precision training. - """ - amp_dtype = check_amp_dtype(dtype) - if amp_lists is None: - amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype) + Args: + optimizer(Optimizer): A common Optimizer. + amp_lists(CustomOpLists, optional): An CustomOpLists object. The default + white_list and black_list will be used for AMP training when it is + not set. Default is None. + level(str, optional): Auto mixed precision level. Accepted values are + "O1" and "O2": O1 represent mixed precision, the input data type of + each operator will be casted by white_list and black_list; + O2 represent pure FP16 / BF16 training, all operators parameters + and input data will be casted to FP16 / BF16, except operators in + black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1. + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. + master_weight(bool, optinal): For level='O2', whether to use multi-precision + during weight updating. If master_weight is None, in O2 level optimizer + will use multi-precision. Default is None. + init_loss_scaling(float, optional): The initial loss scaling factor. + Default is 32768. + incr_every_n_steps(int, optional): Increases loss scaling every n + consecutive steps with finite gradients. Default is 1000. + decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n + accumulated steps with nan or inf gradients. Default is 2. + incr_ratio(float, optional): The multiplier to use when increasing the + loss scaling. Default is 2. + decr_ratio(float, optional): The less-than-one-multiplier to use when + decreasing the loss scaling. Default is 0.8. + use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss + scaling. Default is None, which means True for float16, and False + for bfloat16. + + Returns: + An optimizer acting like a normal one but with mixed-precision training + + Examples: + + .. code-block:: python + + import paddle + + paddle.enable_static() + + class SimpleConvNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3) + self.linear = paddle.nn.Linear(in_features=26, out_features=10) + + def forward(self, x): + out = self.conv(x) + out = paddle.nn.functional.relu(out) + out = self.linear(out) + out = paddle.nn.functional.softmax(out) + return out + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.utils.unique_name.guard(): + with paddle.static.program_guard(main_program, startup_program): + model = SimpleConvNet() + x = paddle.static.data( + name='input', shape=[None, 1, 28, 28], dtype='float32' + ) + out = model(x) + loss = paddle.mean(out) + optimizer = paddle.optimizer.AdamW() + optimizer = paddle.static.amp.decorate(optimizer, level="O2", dtype="float16") + optimizer.minimize(loss) + + if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(startup_program) + + # Call `amp_init` after FP32 parameters initialization, such as `exe.run(startup_program)`, + # to convert FP32 parameters to low precision FP16 / BF16. + optimizer.amp_init(place, scope=paddle.static.global_scope()) + + """ # check amp_level: O0-O2 level = level.upper() if not (level in ['O0', 'O1', 'O2']): @@ -790,6 +879,18 @@ def decorate( "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." ) + amp_dtype = check_amp_dtype(dtype) + if amp_lists is None: + amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype) + + if use_dynamic_loss_scaling is None: + use_dynamic_loss_scaling = dtype == "float16" + + if optimizer is not None: + # support master_weight + multi_precision = not (master_weight is False) + _set_multi_precision(optimizer, multi_precision) + mp_optimizer = OptimizerWithMixedPrecision( optimizer, amp_lists, diff --git a/test/amp/amp_base_models.py b/test/amp/amp_base_models.py index 23c4b018b67d5394e8066dc63c9901e4d68c7616..8b63b2391c0200cbe19463e16092be6fcb39a5a6 100644 --- a/test/amp/amp_base_models.py +++ b/test/amp/amp_base_models.py @@ -42,7 +42,6 @@ def _build_optimizer( beta2=0.836, epsilon=1e-4, weight_decay=0.01, - multi_precision=True, ) if use_amp: optimizer = paddle.static.amp.decorate( diff --git a/test/amp/test_model_cast_to_bf16.py b/test/amp/test_model_cast_to_bf16.py index 1a58a2905ec66983b75af7b9f286674e422ec19f..3002b623b18af24ee3630ab674b562eb2900ae19 100644 --- a/test/amp/test_model_cast_to_bf16.py +++ b/test/amp/test_model_cast_to_bf16.py @@ -221,11 +221,29 @@ class TestModelCastBF16(unittest.TestCase): class TestProgramBF16(AmpTestBase): + def _check_optimizer(self, program, expected_num_mp): + optimizers = [] + for block in program.blocks: + for op in block.ops: + if "Param" in op.input_names and "Grad" in op.input_names: + optimizers.append(op) + + actual_num_mp = 0 + for op in optimizers: + if op.has_attr("multi_precision") and op.attr("multi_precision"): + actual_num_mp += 1 + self.assertEqual( + actual_num_mp, + expected_num_mp, + f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.", + ) + def test_amp_bf16_o1(self): main_program, startup_program = build_embedding_model( True, "bfloat16", "O1" ) self.assertEqual(main_program.num_blocks, 1) + self._check_optimizer(main_program, 0) amp.debugging.collect_operator_stats(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program) @@ -255,6 +273,11 @@ class TestProgramBF16(AmpTestBase): "squared_l2_norm": 2, "adamw": 2, } + self._check_optimizer( + main_program, + expected_bf16_calls["matmul_v2"] + + expected_bf16_calls["elementwise_add"], + ) self._check_op_calls(op_stats_list[0], expected_bf16_calls)