From db407bf0531ea8f3deb183c110a6e79cc2389344 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 16 May 2023 13:27:15 +0800 Subject: [PATCH] [AMP] Allow to switch whether to use promote strategy to choose kernel for O2 training. (#53742) * Allow to switch whether to use promote strategy to choose kernel for O2 training. * Fix comparing error and add unittest. --- paddle/fluid/eager/amp_utils.h | 37 ++++-- paddle/fluid/eager/api/utils/global_utils.h | 3 + paddle/fluid/imperative/tracer.cc | 2 + paddle/fluid/imperative/tracer.h | 8 ++ paddle/fluid/pybind/imperative.cc | 3 + python/paddle/amp/auto_cast.py | 14 ++- test/amp/amp_base_models.py | 69 ++++++++++-- test/amp/test_amp_api.py | 12 +- test/amp/test_amp_o2_embedding_model.py | 1 + test/amp/test_amp_promote.py | 118 +++++++++++++++++++- test/amp/test_model_cast_to_bf16.py | 1 + 11 files changed, 234 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index 4bfb2fdb4f6..bfa58512eb2 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -120,18 +120,35 @@ inline phi::DataType GetAmpDestDtype( egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype(); auto dst_type = amp_setting_dtype; - if (paddle::imperative::AmpOperators::Instance().GetMutableAllowOps()->count( - op_name)) { - dst_type = amp_setting_dtype; - } else if (paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name)) { - dst_type = phi::DataType::FLOAT32; - } else { - if (amp_level == paddle::imperative::AmpLevel::OD) { + bool use_promote = true; + if (amp_level == paddle::imperative::AmpLevel::O2) { + use_promote = + egr::Controller::Instance().GetCurrentTracer()->GetUsePromote(); + } + + if (use_promote) { + if (paddle::imperative::AmpOperators::Instance() + .GetMutableAllowOps() + ->count(op_name)) { + dst_type = amp_setting_dtype; + } else if (paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(op_name)) { dst_type = phi::DataType::FLOAT32; } else { - dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); + if (amp_level == paddle::imperative::AmpLevel::OD) { + dst_type = phi::DataType::FLOAT32; + } else { + dst_type = + GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); + } + } + } else { + // use_promote can be set to false only for O2 training. + if (paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(op_name)) { + dst_type = phi::DataType::FLOAT32; } } diff --git a/paddle/fluid/eager/api/utils/global_utils.h b/paddle/fluid/eager/api/utils/global_utils.h index ba79a63648d..75eb28a6af9 100644 --- a/paddle/fluid/eager/api/utils/global_utils.h +++ b/paddle/fluid/eager/api/utils/global_utils.h @@ -58,6 +58,9 @@ class Controller { return tracer_->GetAmpLevel(); } + void SetUsePromote(bool use_promote) { tracer_->SetUsePromote(use_promote); } + bool GetUsePromote() const { return tracer_->GetUsePromote(); } + bool UseLayoutAutoTune() { bool use_autotune = false; #if defined(PADDLE_WITH_CUDA) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 2d4e6622c05..ccb58d32022 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -44,6 +44,8 @@ thread_local bool Tracer::enable_program_desc_tracing_ = false; thread_local bool Tracer::has_grad_ = true; +thread_local bool Tracer::use_promote_ = true; + thread_local bool Tracer::use_layout_autotune_ = false; thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0; diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 7355cec776e..87e0edd036f 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -156,6 +156,13 @@ class Tracer { void SetHasGrad(bool has_grad) { has_grad_ = has_grad; } + void SetUsePromote(bool use_promote) { + VLOG(4) << "set use_promote to " << use_promote; + use_promote_ = use_promote; + } + + bool GetUsePromote() const { return use_promote_; } + void SetAmpLevel(AmpLevel level) { VLOG(4) << "set amp_level to " << static_cast(level); amp_level_ = level; @@ -220,6 +227,7 @@ class Tracer { static thread_local bool enable_program_desc_tracing_; static thread_local bool use_layout_autotune_; static thread_local bool has_grad_; + static thread_local bool use_promote_; static thread_local AmpLevel amp_level_; static thread_local phi::DataType amp_dtype_; }; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 5bbd66fd09c..67c8afbafbe 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2156,6 +2156,9 @@ void BindImperative(py::module *m_ptr) { .def_property("_enable_program_desc_tracing", &imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::SetEnableProgramDescTracing) + .def_property("_use_promote", + &imperative::Tracer::GetUsePromote, + &imperative::Tracer::SetUsePromote) .def_property("_amp_level", &imperative::Tracer::GetAmpLevel, &imperative::Tracer::SetAmpLevel) diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 2cb0af3530e..c31b0bbef8e 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -274,6 +274,7 @@ def amp_guard( custom_black_list=None, level='O1', dtype='float16', + use_promote=True, ): """ Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. @@ -438,6 +439,11 @@ def amp_guard( original_amp_dtype = tracer._amp_dtype tracer._amp_dtype = amp_dtype + # switch promote + if amp_level == AMP_LEVEL.O2: + original_use_promote = tracer._use_promote + tracer._use_promote = use_promote + # restore status try: yield @@ -448,6 +454,8 @@ def amp_guard( tracer._set_amp_op_list(original_white_list, original_black_list) # set_flags(original_flags) tracer._amp_dtype = original_amp_dtype + if amp_level == AMP_LEVEL.O2: + tracer._use_promote = original_use_promote class StateDictHook: @@ -641,6 +649,7 @@ def auto_cast( custom_black_list=None, level='O1', dtype='float16', + use_promote=True, ): """ Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. @@ -663,6 +672,7 @@ def auto_cast( will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. + use_promote(bool, optional): Whether to promotes to fp32 when op has any float32 inputs. It is only supported when amp level is O2. Default is True. Examples: @@ -696,7 +706,9 @@ def auto_cast( print(d.dtype) # paddle.float16 """ - return amp_guard(enable, custom_white_list, custom_black_list, level, dtype) + return amp_guard( + enable, custom_white_list, custom_black_list, level, dtype, use_promote + ) def decorate( diff --git a/test/amp/amp_base_models.py b/test/amp/amp_base_models.py index 8b63b2391c0..aabf9e82f7b 100644 --- a/test/amp/amp_base_models.py +++ b/test/amp/amp_base_models.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import unittest import numpy as np @@ -19,6 +20,7 @@ import numpy as np import paddle from paddle import nn from paddle.fluid import core +from paddle.fluid.framework import _non_static_mode _fixed_add_param = np.random.random(size=[16, 16]).astype("float32") @@ -30,20 +32,27 @@ def _build_optimizer( amp_lists=None, use_grad_clip=False, use_promote=False, + model=None, ): if use_grad_clip: grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) else: grad_clip = None + if _non_static_mode(): + assert model is not None + parameters = model.parameters() + else: + parameters = None optimizer = paddle.optimizer.AdamW( learning_rate=0.01, + parameters=parameters, grad_clip=grad_clip, beta1=0.78, beta2=0.836, epsilon=1e-4, weight_decay=0.01, ) - if use_amp: + if not _non_static_mode() and use_amp: optimizer = paddle.static.amp.decorate( optimizer, amp_lists, @@ -118,7 +127,7 @@ class SimpleConvNet(nn.Layer): def forward(self, x): out = self.conv(x) - out = nn.functional.relu(out) + out = nn.functional.relu(out.cast("float32")) out = out.flatten(start_axis=1, stop_axis=3) out = self.linear(out) out = nn.functional.softmax(out) @@ -128,6 +137,22 @@ class SimpleConvNet(nn.Layer): def build_conv_model( use_amp, amp_dtype="float16", amp_level="O1", use_promote=False ): + if _non_static_mode(): + model = SimpleConvNet() + optimizer = _build_optimizer(use_amp=False, model=model) + if use_amp and amp_dtype == "float16": + scaler = paddle.amp.GradScaler() + else: + scaler = None + if use_amp and amp_level == "O2": + model, optimizer = paddle.amp.decorate( + models=model, + optimizers=optimizer, + level=amp_level, + dtype=amp_dtype, + ) + return model, optimizer, scaler + main_program = paddle.static.Program() startup_program = paddle.static.Program() with paddle.utils.unique_name.guard(): @@ -237,19 +262,36 @@ class AmpTestBase(unittest.TestCase): self.amp_level = None def _check_op_calls( - self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={} + self, + op_stats_dict, + expected_bf16_calls={}, + expected_fp16_calls={}, + debug_info=None, ): - for op_type, value in expected_bf16_calls.items(): + def _extract_op_call(op_calls_str, pos): + return int(copy.copy(op_calls_str).split(",")[pos]) + + for op_type, expected_value in expected_bf16_calls.items(): + # print(f"[BF16] op_type={op_type}, value={value}") + if isinstance(op_stats_dict[op_type], str): + actual_value = _extract_op_call(op_stats_dict[op_type], 1) + else: + actual_value = op_stats_dict[op_type].bf16_calls self.assertEqual( - op_stats_dict[op_type].bf16_calls, - value, - f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.", + actual_value, + expected_value, + f"[{debug_info}] The number of bf16 calls of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.", ) - for op_type, value in expected_fp16_calls.items(): + for op_type, expected_value in expected_fp16_calls.items(): + # print(f"[FP16] op_type={op_type}, value={value}") + if isinstance(op_stats_dict[op_type], str): + actual_value = _extract_op_call(op_stats_dict[op_type], 0) + else: + actual_value = op_stats_dict[op_type].fp16_calls self.assertEqual( - op_stats_dict[op_type].fp16_calls, - value, - f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.", + actual_value, + expected_value, + f"[debug_info] The number of fp16 calls of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.", ) def run_program( @@ -263,6 +305,7 @@ class AmpTestBase(unittest.TestCase): exe, x_np, max_iters, + dtype, level, ): losses = [] @@ -277,6 +320,8 @@ class AmpTestBase(unittest.TestCase): feed={feed_vars[0].name: x_np}, fetch_list=fetch_vars, ) - print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}") + print( + f"-- [AMP {dtype} {level}] iter={iter_id}, loss={results[0]}" + ) losses.append(results[0]) return losses diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 7d397c70432..4120dcce939 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -20,15 +20,17 @@ import paddle class TestAutoCast(AmpTestBase): - def test_amp_OD_level(self): - conv = paddle.nn.Conv2D( + def setUp(self): + self._conv = paddle.nn.Conv2D( in_channels=1, out_channels=6, kernel_size=3, bias_attr=False ) - linear = paddle.nn.Linear(in_features=4, out_features=4) + self._linear = paddle.nn.Linear(in_features=4, out_features=4) + + def test_amp_OD_level(self): with paddle.amp.auto_cast(level='OD'): - out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32')) + out1 = self._conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32')) out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16') - out3 = linear(out2) + out3 = self._linear(out2) self.assertEqual(out1.dtype, paddle.float16) self.assertEqual(out2.dtype, paddle.float32) diff --git a/test/amp/test_amp_o2_embedding_model.py b/test/amp/test_amp_o2_embedding_model.py index b1af4bde8d3..237ca1120d6 100644 --- a/test/amp/test_amp_o2_embedding_model.py +++ b/test/amp/test_amp_o2_embedding_model.py @@ -131,6 +131,7 @@ class TestUnittedEmbedding(AmpTestBase): exe, x_np, max_iters, + "float16", level, ) return losses diff --git a/test/amp/test_amp_promote.py b/test/amp/test_amp_promote.py index e75799b39c3..9f8395df413 100644 --- a/test/amp/test_amp_promote.py +++ b/test/amp/test_amp_promote.py @@ -20,13 +20,12 @@ from amp_base_models import AmpTestBase, build_conv_model import paddle from paddle.static import amp -paddle.enable_static() - -class TestAMPPromote(AmpTestBase): +class TestStaticAmpPromoteStats(AmpTestBase): def check_promote_results( - self, use_amp, dtype, level, use_promote, expected_op_calls + self, use_amp, dtype, level, use_promote, expected_op_calls, debug_info ): + paddle.enable_static() ( main_program, startup_program, @@ -40,7 +39,9 @@ class TestAMPPromote(AmpTestBase): op_stats_list = amp.debugging._get_op_stats_list(main_program) self._check_op_calls( - op_stats_list[0], expected_fp16_calls=expected_op_calls + op_stats_list[0], + expected_fp16_calls=expected_op_calls, + debug_info=debug_info, ) place = paddle.CUDAPlace(0) @@ -58,8 +59,10 @@ class TestAMPPromote(AmpTestBase): exe, x_fp32, max_iters, + dtype, level, ) + paddle.disable_static() def test_static_amp_o1(self): expected_fp16_calls = { @@ -77,13 +80,14 @@ class TestAMPPromote(AmpTestBase): 'O1', use_promote=True, expected_op_calls=expected_fp16_calls, + debug_info="TestStaticAmpPromoteStats/test_static_amp_o1", ) def test_static_amp_o2(self): expected_fp16_calls = { "conv2d": 1, "elementwise_add": 2, - "relu": 1, + "relu": 0, "matmul_v2": 1, "softmax": 1, "reduce_mean": 1, @@ -95,7 +99,109 @@ class TestAMPPromote(AmpTestBase): 'O2', use_promote=True, expected_op_calls=expected_fp16_calls, + debug_info="TestStaticAmpPromoteStats/test_static_amp_o2", + ) + + +class TestEagerAmpPromoteStats(AmpTestBase): + def check_promote_results( + self, dtype, level, use_promote, expected_op_calls, debug_info + ): + model, optimizer, scaler = build_conv_model( + use_amp=True, + amp_dtype=dtype, + amp_level=level, + use_promote=use_promote, + ) + model.train() + + paddle.amp.debugging.enable_operator_stats_collection() + with paddle.amp.auto_cast( + enable=True, dtype=dtype, level=level, use_promote=use_promote + ): + x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32') + out = model(x) + loss = paddle.mean(out) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + optimizer.clear_grad() + paddle.amp.debugging.disable_operator_stats_collection() + op_stats = paddle.fluid.core.get_low_precision_op_list() + + self._check_op_calls( + op_stats, + expected_fp16_calls=expected_op_calls, + debug_info=debug_info, + ) + + def test_o2_promote_on(self): + expected_fp16_calls = { + "conv2d": 1, + "elementwise_add": 2, + "relu": 0, + "matmul_v2": 1, + "softmax": 1, + "reduce_mean": 1, + "adamw_": 4, + } + self.check_promote_results( + 'float16', + 'O2', + use_promote=True, + expected_op_calls=expected_fp16_calls, + debug_info="TestEagerAmpPromoteStats/test_o2_promote_on", + ) + + def test_o2_promote_off(self): + expected_fp16_calls = { + "conv2d": 1, + "elementwise_add": 2, + "relu": 1, + "matmul_v2": 1, + "softmax": 1, + "reduce_mean": 1, + "adamw_": 4, + } + self.check_promote_results( + 'float16', + 'O2', + use_promote=False, + expected_op_calls=expected_fp16_calls, + debug_info="TestEagerAmpPromoteStats/test_o2_promote_off", + ) + + +class TestEagerAmpPromoteSimple(AmpTestBase): + def setUp(self): + self._conv = paddle.nn.Conv2D( + in_channels=1, out_channels=6, kernel_size=3, bias_attr=False ) + self._linear = paddle.nn.Linear(in_features=4, out_features=4) + + def test_o2_use_promote_on(self): + with paddle.amp.auto_cast(level='O2'): + x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32') + conv_out = self._conv(x) + y = paddle.rand(shape=conv_out.shape, dtype='float16') + add_out = conv_out + y + linear_out = self._linear(add_out) + + self.assertEqual(conv_out.dtype, paddle.float16) + self.assertEqual(add_out.dtype, paddle.float16) + self.assertEqual(linear_out.dtype, paddle.float32) + + def test_o2_use_promote_off(self): + with paddle.amp.auto_cast(level='O2', use_promote=False): + x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32') + conv_out = self._conv(x) + y = paddle.rand(shape=conv_out.shape, dtype='float16') + add_out = conv_out + y + linear_out = self._linear(add_out) + + self.assertEqual(conv_out.dtype, paddle.float16) + self.assertEqual(add_out.dtype, paddle.float16) + self.assertEqual(linear_out.dtype, paddle.float16) if __name__ == '__main__': diff --git a/test/amp/test_model_cast_to_bf16.py b/test/amp/test_model_cast_to_bf16.py index 7e4de2630d4..373a5ed06ac 100644 --- a/test/amp/test_model_cast_to_bf16.py +++ b/test/amp/test_model_cast_to_bf16.py @@ -310,6 +310,7 @@ class TestStaticBF16(AmpTestBase): exe, x_np, max_iters, + "bfloat16", level, ) return losses -- GitLab