diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index 115811f6a3d8e4ace1e39a769e3259448fc4c766..7b8071ee6015c5e080f651d02af9693cae2078e0 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -85,6 +85,39 @@ static inline paddle::experimental::DataType GetPromoteType( return dst_type; } +inline paddle::experimental::DataType GetDtypeWithPlace( + const std::string& op_name, + const paddle::small_vector, + kSlotSmallVectorSize>& amp_tensors_vector, + const paddle::experimental::DataType amp_dtype) { + if (amp_dtype == paddle::experimental::DataType::FLOAT32) { + return amp_dtype; + } + bool is_right_place = false; + for (const auto& tensors : amp_tensors_vector) { + for (const auto& tensor : tensors) { + auto place = tensor.place(); + is_right_place = (paddle::platform::is_gpu_place(place) || + paddle::platform::is_cuda_pinned_place(place) || + paddle::platform::is_xpu_place(place) || + paddle::platform::is_mlu_place(place) || + paddle::platform::is_npu_place(place) || + paddle::platform::is_npu_pinned_place(place) || + paddle::platform::is_custom_place(place)); + if (is_right_place) { + break; + } + } + } + + if (!is_right_place) { + VLOG(6) << "Change " << op_name << "'s AMP type from " << amp_dtype + << " to FP32"; + return paddle::experimental::DataType::FLOAT32; + } + return amp_dtype; +} + inline paddle::experimental::DataType GetAmpDestDtype( const std::string& op_name, const paddle::small_vector, @@ -95,19 +128,21 @@ inline paddle::experimental::DataType GetAmpDestDtype( VLOG(6) << "AMP GetAmpDestDtype:" << " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level(" << static_cast(amp_level) << ")."; + auto return_amp_type = paddle::experimental::DataType::FLOAT16; + if (amp_dtype == "float16") { if (amp_level == paddle::imperative::AmpLevel::O1) { if (paddle::imperative::AmpOperators::Instance() .GetMutableAllowOps() ->count(op_name)) { - return paddle::experimental::DataType::FLOAT16; + return_amp_type = paddle::experimental::DataType::FLOAT16; } else if (paddle::imperative::AmpOperators::Instance() .GetMutableBlockOps() ->count(op_name) || paddle::imperative::AmpOperators::Instance() .GetMutableUnsupportedFp16Ops() ->count(op_name)) { - return paddle::experimental::DataType::FLOAT32; + return_amp_type = paddle::experimental::DataType::FLOAT32; } else { auto dst_type = GetPromoteType(op_name, amp_tensors_vector, @@ -118,7 +153,7 @@ inline paddle::experimental::DataType GetAmpDestDtype( ->count(op_name)) { dst_type = paddle::experimental::DataType::FLOAT32; } - return dst_type; + return_amp_type = dst_type; } } else if (amp_level == paddle::imperative::AmpLevel::O2) { auto dst_type = paddle::experimental::DataType::FLOAT16; @@ -130,18 +165,18 @@ inline paddle::experimental::DataType GetAmpDestDtype( ->count(op_name)) { dst_type = paddle::experimental::DataType::FLOAT32; } - return dst_type; + return_amp_type = dst_type; } } else if (amp_dtype == "bfloat16") { if (amp_level == paddle::imperative::AmpLevel::O1) { if (paddle::imperative::AmpOperators::Instance() .GetMutableAllowOps() ->count(op_name)) { - return paddle::experimental::DataType::BFLOAT16; + return_amp_type = paddle::experimental::DataType::BFLOAT16; } else if (paddle::imperative::AmpOperators::Instance() .GetMutableBlockOps() ->count(op_name)) { - return paddle::experimental::DataType::FLOAT32; + return_amp_type = paddle::experimental::DataType::FLOAT32; } else { auto dst_type = GetPromoteType(op_name, @@ -153,7 +188,7 @@ inline paddle::experimental::DataType GetAmpDestDtype( ->count(op_name)) { dst_type = paddle::experimental::DataType::FLOAT32; } - return dst_type; + return_amp_type = dst_type; } } else if (amp_level == paddle::imperative::AmpLevel::O2) { auto dst_type = paddle::experimental::DataType::BFLOAT16; @@ -165,10 +200,12 @@ inline paddle::experimental::DataType GetAmpDestDtype( ->count(op_name)) { dst_type = paddle::experimental::DataType::FLOAT32; } - return dst_type; + return_amp_type = dst_type; } + } else { + return_amp_type = paddle::experimental::DataType::FLOAT32; } - return paddle::experimental::DataType::FLOAT32; + return GetDtypeWithPlace(op_name, amp_tensors_vector, return_amp_type); } } // namespace egr diff --git a/paddle/fluid/eager/eager_amp_auto_cast.h b/paddle/fluid/eager/eager_amp_auto_cast.h index ea3e53b972d99516de463f37df31a98d471d84a0..e80daf69d9c9e5fde3fc4744cd351c5042f569f0 100644 --- a/paddle/fluid/eager/eager_amp_auto_cast.h +++ b/paddle/fluid/eager/eager_amp_auto_cast.h @@ -22,14 +22,19 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor, const paddle::experimental::DataType& dst_dtype) { auto place = tensor.place(); auto data_type = tensor.dtype(); + // Except CPU judgment, other conditions should be consistent with + // amp_utils.h's judgment if (paddle::platform::is_gpu_place(place) || paddle::platform::is_cuda_pinned_place(place) || paddle::platform::is_xpu_place(place) || paddle::platform::is_mlu_place(place) || paddle::platform::is_npu_place(place) || paddle::platform::is_npu_pinned_place(place) || - paddle::platform::is_custom_place(place)) { + paddle::platform::is_custom_place(place) || + paddle::platform::is_cpu_place(place)) { // CudaPinndePlace is added for varbase created by dataloader + // Cpu place is for differnt place tensor, when input1 is cpu and input2 is + // gpu if ((data_type == paddle::experimental::DataType::FLOAT32 || data_type == paddle::experimental::DataType::FLOAT16 || data_type == paddle::experimental::DataType::BFLOAT16) &&