From c2c3bd43c6d6ebb88d0da961ae618ab6e6c400de Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Tue, 16 May 2023 16:09:25 +0800 Subject: [PATCH] [AMP] support OD level for static (#53768) --- python/paddle/static/amp/decorator.py | 44 +++++++++------- python/paddle/static/amp/fp16_lists.py | 16 +++--- python/paddle/static/amp/fp16_utils.py | 10 +++- test/amp/test_amp_api.py | 70 +++++++++++++++++++++++++- 4 files changed, 113 insertions(+), 27 deletions(-) diff --git a/python/paddle/static/amp/decorator.py b/python/paddle/static/amp/decorator.py index 43dc849925a..dca2a4e024c 100644 --- a/python/paddle/static/amp/decorator.py +++ b/python/paddle/static/amp/decorator.py @@ -61,12 +61,11 @@ class OptimizerWithMixedPrecision: Args: optimizer (Optimizer): A common Optimizer object. amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object. - level(str): Auto mixed precision level. Accepted values are - "O1" and "O2": O1 represent mixed precision, the input data type - of each operator will be casted by white_list and black_list; - O2 represent Pure fp16 or bf16, all operators parameters and input - data will be casted to fp16 or bf16, except operators in black_list, - don't support fp16 or bf16 kernel and batch_norm. + level(str): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list + will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2 + level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs + will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in + default white list will compute in float16/bfloat16. dtype(str): Whether to use 'float16' or 'bfloat16'. init_loss_scaling (float): The initial loss scaling factor. use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. @@ -123,6 +122,7 @@ class OptimizerWithMixedPrecision: self._learning_rate = optimizer._learning_rate self._learning_rate_map = optimizer._learning_rate_map self._use_pure_fp16 = level == "O2" + self._amp_level = level self._use_fp16_guard = use_amp_guard self._to_fp16_var_names = None if self._use_dynamic_loss_scaling: @@ -241,7 +241,7 @@ class OptimizerWithMixedPrecision: self._amp_lists, use_fp16_guard=False, dest_type=self._amp_vartype, - level='O1', + level=self._amp_level, use_promote=self.use_promote, ) @@ -380,7 +380,7 @@ class OptimizerWithMixedPrecision: self._amp_lists, use_fp16_guard=False, dest_type=self._amp_vartype, - level='O1', + level=self._amp_level, use_promote=self.use_promote, ) @@ -773,12 +773,11 @@ def decorate( amp_lists(CustomOpLists, optional): An CustomOpLists object. The default white_list and black_list will be used for AMP training when it is not set. Default is None. - level(str, optional): Auto mixed precision level. Accepted values are - "O1" and "O2": O1 represent mixed precision, the input data type of - each operator will be casted by white_list and black_list; - O2 represent pure FP16 / BF16 training, all operators parameters - and input data will be casted to FP16 / BF16, except operators in - black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1. + level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list + will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2 + level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs + will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in + default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer @@ -847,15 +846,22 @@ def decorate( """ # check amp_level: O0-O2 level = level.upper() - if not (level in ['O0', 'O1', 'O2']): - raise ValueError( - "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." - ) + if not (level in ['O0', 'OD', 'O1', 'O2']): + raise ValueError("level should be O0, OD, O1 or O2.") amp_dtype = check_amp_dtype(dtype) - if amp_lists is None: + if amp_lists is None or level == 'OD': amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype) + if level == 'OD': + if amp_lists is not None: + warnings.warn( + "If the Amp level is set to OD, the amp list will not be used." + ) + + amp_lists.white_list = {"conv2d", "matmul_v2"} + amp_lists.black_list = amp_lists.all_list - amp_lists.white_list + if use_dynamic_loss_scaling is None: use_dynamic_loss_scaling = dtype == "float16" diff --git a/python/paddle/static/amp/fp16_lists.py b/python/paddle/static/amp/fp16_lists.py index dcf13c0847d..96ad079879a 100644 --- a/python/paddle/static/amp/fp16_lists.py +++ b/python/paddle/static/amp/fp16_lists.py @@ -99,7 +99,7 @@ def _get_sys_unsupported_list(dtype): device = 'XPU' else: device = 'GPU' - _, _, sys_unsupported_list = core.op_supported_infos(device, var_type) + all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type) # sys_unsupported_list will include the following ops. supported_fp16_list = { @@ -114,13 +114,13 @@ def _get_sys_unsupported_list(dtype): } sys_unsupported_list -= supported_fp16_list - return device, sys_unsupported_list + return device, sys_unsupported_list, all_ops def _get_unsupported_list(dtype): # The set of ops that don't support fp16 calculation - _, _sys_unsupported_list = _get_sys_unsupported_list(dtype) - return _sys_unsupported_list + _, _sys_unsupported_list, _sys_all_list = _get_sys_unsupported_list(dtype) + return _sys_unsupported_list, _sys_all_list # The three sets listed below are changed dynamiclly. They don't contain all @@ -200,7 +200,9 @@ class AutoMixedPrecisionLists: self.white_list = copy.copy(_get_white_list(self.amp_dtype)) self.black_list = copy.copy(_get_black_list()) self.gray_list = copy.copy(gray_list) - self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) + unsupported_list, sys_all_list = _get_unsupported_list(self.amp_dtype) + self.unsupported_list = copy.copy(unsupported_list) + self.all_list = copy.copy(sys_all_list) self.black_varnames = copy.copy(custom_black_varnames) self._update_list() @@ -232,7 +234,9 @@ class AutoMixedPrecisionLists: self.gray_list.remove(op_name) self.black_list.add(op_name) self.unsupported_list.add(op_name) - device, sys_unsupported_list = _get_sys_unsupported_list(self.amp_dtype) + device, sys_unsupported_list, _ = _get_sys_unsupported_list( + self.amp_dtype + ) actual_unsupported_list = [] for op_name in sys_unsupported_list: if op_name in self.white_list: diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index 0ba4db12a78..342a2618258 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -426,7 +426,7 @@ def set_var_dst_dtype( def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): keep_fp32_var_names = set() - if level == "O1": + if level == "O1" or level == "OD": return keep_fp32_var_names all_parameters = [] for block in program.blocks: @@ -618,6 +618,14 @@ def cast_model_to_fp16( if level == 'O2': amp_lists.black_list = amp_lists.black_list - black_list + if level == 'OD': + if amp_lists is not None: + dtype = get_low_precision_dtypestr(dest_type) + amp_lists = AutoMixedPrecisionLists(dtype) + + amp_lists.white_list = {"conv2d", "matmul_v2"} + amp_lists.black_list = amp_lists.all_list - amp_lists.white_list + global_block = program.global_block() keep_fp32_ops = set() keep_fp16_ops = set() diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 4120dcce939..ecef047ad6c 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -14,9 +14,11 @@ import unittest -from amp_base_models import AmpTestBase +import numpy as np +from amp_base_models import AmpTestBase, build_conv_model import paddle +from paddle.static import amp class TestAutoCast(AmpTestBase): @@ -37,6 +39,72 @@ class TestAutoCast(AmpTestBase): self.assertEqual(out3.dtype, paddle.float32) +class TestStaticDecorate(AmpTestBase): + def check_results( + self, use_amp, dtype, level, use_promote, expected_op_calls + ): + ( + main_program, + startup_program, + optimizer, + feed_vars, + fetch_vars, + ) = build_conv_model(use_amp, dtype, level, use_promote) + self.assertEqual(main_program.num_blocks, 1) + optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001) + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + level=level, + ) + + amp.debugging.collect_operator_stats(main_program) + op_stats_list = amp.debugging._get_op_stats_list(main_program) + + self._check_op_calls( + op_stats_list[0], expected_fp16_calls=expected_op_calls + ) + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + + max_iters = 2 + x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32") + losses_o1 = self.run_program( + main_program, + startup_program, + optimizer, + feed_vars, + fetch_vars, + place, + exe, + x_fp32, + max_iters, + level, + ) + + def test_static_amp_o1(self): + paddle.enable_static() + expected_fp16_calls = { + "conv2d": 1, + "elementwise_add": 0, + "relu": 0, + "matmul_v2": 1, + "softmax": 0, + "reduce_mean": 0, + "adamw": 0, + } + self.check_results( + True, + 'float16', + 'OD', + use_promote=True, + expected_op_calls=expected_fp16_calls, + ) + paddle.disable_static() + + class TestGradScaler(AmpTestBase): def test_amp_grad_scaler(self): model = paddle.nn.Conv2D(3, 2, 3) -- GitLab