From ca8e21a685d5f2d7f3b21ceddee13d45253452fa Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Tue, 14 Mar 2023 10:57:10 +0800 Subject: [PATCH] polish the amp code (#51020) --- paddle/fluid/eager/amp_utils.h | 112 +++++++---------------- paddle/fluid/imperative/amp_auto_cast.cc | 16 ++++ paddle/fluid/imperative/amp_auto_cast.h | 3 + paddle/fluid/imperative/tracer.h | 2 + 4 files changed, 53 insertions(+), 80 deletions(-) diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index 1579b884a20..dd94b1e10a2 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -122,90 +122,42 @@ inline paddle::experimental::DataType GetAmpDestDtype( const std::string& op_name, const paddle::small_vector, kSlotSmallVectorSize>& amp_tensors_vector) { - auto amp_dtype = - egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype(); auto amp_level = egr::Controller::Instance().GetAMPLevel(); - VLOG(6) << "AMP GetAmpDestDtype:" - << " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level(" - << static_cast(amp_level) << ")."; - auto return_amp_type = paddle::experimental::DataType::FLOAT16; - - if (amp_dtype == "float16") { - if (amp_level == paddle::imperative::AmpLevel::O1) { - if (paddle::imperative::AmpOperators::Instance() - .GetMutableAllowOps() - ->count(op_name)) { - return_amp_type = paddle::experimental::DataType::FLOAT16; - } else if (paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name) || - paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedFp16Ops() - ->count(op_name)) { - return_amp_type = paddle::experimental::DataType::FLOAT32; - } else { - auto dst_type = GetPromoteType(op_name, - amp_tensors_vector, - paddle::experimental::DataType::FLOAT16); - if (dst_type == paddle::experimental::DataType::FLOAT16 && - paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedFp16Ops() - ->count(op_name)) { - dst_type = paddle::experimental::DataType::FLOAT32; - } - return_amp_type = dst_type; - } - } else if (amp_level == paddle::imperative::AmpLevel::O2) { - auto dst_type = paddle::experimental::DataType::FLOAT16; - if (paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedFp16Ops() - ->count(op_name) || - paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name)) { - dst_type = paddle::experimental::DataType::FLOAT32; - } - return_amp_type = dst_type; + auto amp_setting_dtype = + egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype(); + auto dst_type = amp_setting_dtype; + if (amp_level == paddle::imperative::AmpLevel::O1) { + if (paddle::imperative::AmpOperators::Instance() + .GetMutableAllowOps() + ->count(op_name)) { + dst_type = amp_setting_dtype; + } else if (paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(op_name)) { + dst_type = paddle::experimental::DataType::FLOAT32; + } else { + dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); } - } else if (amp_dtype == "bfloat16") { - if (amp_level == paddle::imperative::AmpLevel::O1) { - if (paddle::imperative::AmpOperators::Instance() - .GetMutableAllowOps() - ->count(op_name)) { - return_amp_type = paddle::experimental::DataType::BFLOAT16; - } else if (paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name)) { - return_amp_type = paddle::experimental::DataType::FLOAT32; - } else { - auto dst_type = - GetPromoteType(op_name, - amp_tensors_vector, - paddle::experimental::DataType::BFLOAT16); - if (dst_type == paddle::experimental::DataType::BFLOAT16 && - paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedBf16Ops() - ->count(op_name)) { - dst_type = paddle::experimental::DataType::FLOAT32; - } - return_amp_type = dst_type; - } - } else if (amp_level == paddle::imperative::AmpLevel::O2) { - auto dst_type = paddle::experimental::DataType::BFLOAT16; - if (paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedBf16Ops() - ->count(op_name) || - paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name)) { - dst_type = paddle::experimental::DataType::FLOAT32; - } - return_amp_type = dst_type; + } else if (amp_level == paddle::imperative::AmpLevel::O2) { + if (paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(op_name)) { + dst_type = paddle::experimental::DataType::FLOAT32; } - } else { - return_amp_type = paddle::experimental::DataType::FLOAT32; } - return GetDtypeWithPlace(op_name, amp_tensors_vector, return_amp_type); + + if (dst_type == amp_setting_dtype && + (paddle::imperative::AmpOperators::Instance() + .GetMutableUnsupportedOps(amp_setting_dtype) + ->count(op_name))) { + dst_type = paddle::experimental::DataType::FLOAT32; + } + + dst_type = GetDtypeWithPlace(op_name, amp_tensors_vector, dst_type); + VLOG(6) << "AMP GetAmpDestDtype:" + << " op(" << op_name << ") amp_dtype(" << dst_type << ") amp_level(" + << static_cast(amp_level) << ")."; + return dst_type; } } // namespace egr diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 40e3b12cc4b..48b51849d46 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -200,6 +200,22 @@ AmpOperators::GetMutableBlockOps() { return block_ops_; } +std::shared_ptr> +AmpOperators::GetMutableUnsupportedOps( + const paddle::experimental::DataType& data_type) { + PADDLE_ENFORCE_EQ( + data_type == paddle::experimental::DataType::FLOAT16 || + data_type == paddle::experimental::DataType::BFLOAT16, + true, + phi::errors::InvalidArgument( + "The data_type mismatch. It should be FLOAT16 or BFLOAT16.")); + if (data_type == paddle::experimental::DataType::FLOAT16) { + return unsupported_fp16_ops_; + } else { + return unsupported_bf16_ops_; + } +} + std::shared_ptr> AmpOperators::GetMutableUnsupportedFp16Ops() { return unsupported_fp16_ops_; diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index 3bee2308603..e39190d5c64 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -54,6 +54,9 @@ class AmpOperators { std::shared_ptr> GetMutableBlockOps(); + std::shared_ptr> GetMutableUnsupportedOps( + const paddle::experimental::DataType& data_type); + std::shared_ptr> GetMutableUnsupportedFp16Ops(); diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 2831e007d94..943505955e8 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -184,6 +184,8 @@ class Tracer { } } + phi::DataType GetAmpPhiDtype() const { return amp_dtype_; } + void DisableLayoutAutoTune() { use_layout_autotune_ = false; } void EnableLayoutAutoTune() { use_layout_autotune_ = true; } -- GitLab