diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index 1579b884a209765356d95398369138a0f8a08260..dd94b1e10a232d929e651cdd6649bc6d4c29d60f 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -122,90 +122,42 @@ inline paddle::experimental::DataType GetAmpDestDtype( const std::string& op_name, const paddle::small_vector, kSlotSmallVectorSize>& amp_tensors_vector) { - auto amp_dtype = - egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype(); auto amp_level = egr::Controller::Instance().GetAMPLevel(); - VLOG(6) << "AMP GetAmpDestDtype:" - << " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level(" - << static_cast(amp_level) << ")."; - auto return_amp_type = paddle::experimental::DataType::FLOAT16; - - if (amp_dtype == "float16") { - if (amp_level == paddle::imperative::AmpLevel::O1) { - if (paddle::imperative::AmpOperators::Instance() - .GetMutableAllowOps() - ->count(op_name)) { - return_amp_type = paddle::experimental::DataType::FLOAT16; - } else if (paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name) || - paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedFp16Ops() - ->count(op_name)) { - return_amp_type = paddle::experimental::DataType::FLOAT32; - } else { - auto dst_type = GetPromoteType(op_name, - amp_tensors_vector, - paddle::experimental::DataType::FLOAT16); - if (dst_type == paddle::experimental::DataType::FLOAT16 && - paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedFp16Ops() - ->count(op_name)) { - dst_type = paddle::experimental::DataType::FLOAT32; - } - return_amp_type = dst_type; - } - } else if (amp_level == paddle::imperative::AmpLevel::O2) { - auto dst_type = paddle::experimental::DataType::FLOAT16; - if (paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedFp16Ops() - ->count(op_name) || - paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name)) { - dst_type = paddle::experimental::DataType::FLOAT32; - } - return_amp_type = dst_type; + auto amp_setting_dtype = + egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype(); + auto dst_type = amp_setting_dtype; + if (amp_level == paddle::imperative::AmpLevel::O1) { + if (paddle::imperative::AmpOperators::Instance() + .GetMutableAllowOps() + ->count(op_name)) { + dst_type = amp_setting_dtype; + } else if (paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(op_name)) { + dst_type = paddle::experimental::DataType::FLOAT32; + } else { + dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); } - } else if (amp_dtype == "bfloat16") { - if (amp_level == paddle::imperative::AmpLevel::O1) { - if (paddle::imperative::AmpOperators::Instance() - .GetMutableAllowOps() - ->count(op_name)) { - return_amp_type = paddle::experimental::DataType::BFLOAT16; - } else if (paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name)) { - return_amp_type = paddle::experimental::DataType::FLOAT32; - } else { - auto dst_type = - GetPromoteType(op_name, - amp_tensors_vector, - paddle::experimental::DataType::BFLOAT16); - if (dst_type == paddle::experimental::DataType::BFLOAT16 && - paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedBf16Ops() - ->count(op_name)) { - dst_type = paddle::experimental::DataType::FLOAT32; - } - return_amp_type = dst_type; - } - } else if (amp_level == paddle::imperative::AmpLevel::O2) { - auto dst_type = paddle::experimental::DataType::BFLOAT16; - if (paddle::imperative::AmpOperators::Instance() - .GetMutableUnsupportedBf16Ops() - ->count(op_name) || - paddle::imperative::AmpOperators::Instance() - .GetMutableBlockOps() - ->count(op_name)) { - dst_type = paddle::experimental::DataType::FLOAT32; - } - return_amp_type = dst_type; + } else if (amp_level == paddle::imperative::AmpLevel::O2) { + if (paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(op_name)) { + dst_type = paddle::experimental::DataType::FLOAT32; } - } else { - return_amp_type = paddle::experimental::DataType::FLOAT32; } - return GetDtypeWithPlace(op_name, amp_tensors_vector, return_amp_type); + + if (dst_type == amp_setting_dtype && + (paddle::imperative::AmpOperators::Instance() + .GetMutableUnsupportedOps(amp_setting_dtype) + ->count(op_name))) { + dst_type = paddle::experimental::DataType::FLOAT32; + } + + dst_type = GetDtypeWithPlace(op_name, amp_tensors_vector, dst_type); + VLOG(6) << "AMP GetAmpDestDtype:" + << " op(" << op_name << ") amp_dtype(" << dst_type << ") amp_level(" + << static_cast(amp_level) << ")."; + return dst_type; } } // namespace egr diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 40e3b12cc4b9311a393ff6280552150184fdde97..48b51849d462a9bc425df3d5b5aaceac637fd528 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -200,6 +200,22 @@ AmpOperators::GetMutableBlockOps() { return block_ops_; } +std::shared_ptr> +AmpOperators::GetMutableUnsupportedOps( + const paddle::experimental::DataType& data_type) { + PADDLE_ENFORCE_EQ( + data_type == paddle::experimental::DataType::FLOAT16 || + data_type == paddle::experimental::DataType::BFLOAT16, + true, + phi::errors::InvalidArgument( + "The data_type mismatch. It should be FLOAT16 or BFLOAT16.")); + if (data_type == paddle::experimental::DataType::FLOAT16) { + return unsupported_fp16_ops_; + } else { + return unsupported_bf16_ops_; + } +} + std::shared_ptr> AmpOperators::GetMutableUnsupportedFp16Ops() { return unsupported_fp16_ops_; diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index 3bee2308603f9ed0bf98acab0a3c700ea8b737ca..e39190d5c64770ddb49534a7dd55219e7454ab6e 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -54,6 +54,9 @@ class AmpOperators { std::shared_ptr> GetMutableBlockOps(); + std::shared_ptr> GetMutableUnsupportedOps( + const paddle::experimental::DataType& data_type); + std::shared_ptr> GetMutableUnsupportedFp16Ops(); diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 2831e007d94c46828727093490a30b21a4ac8934..943505955e83847110a12486cecef6ac287a7cae 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -184,6 +184,8 @@ class Tracer { } } + phi::DataType GetAmpPhiDtype() const { return amp_dtype_; } + void DisableLayoutAutoTune() { use_layout_autotune_ = false; } void EnableLayoutAutoTune() { use_layout_autotune_ = true; }