未验证 提交 3d5faa88 编写于 作者: N niuliling123 提交者: GitHub

Add Cpu tensor cast when amp_type isn't float32 (#50401)

上级 bf38175e
...@@ -85,6 +85,39 @@ static inline paddle::experimental::DataType GetPromoteType( ...@@ -85,6 +85,39 @@ static inline paddle::experimental::DataType GetPromoteType(
return dst_type; return dst_type;
} }
inline paddle::experimental::DataType GetDtypeWithPlace(
const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>& amp_tensors_vector,
const paddle::experimental::DataType amp_dtype) {
if (amp_dtype == paddle::experimental::DataType::FLOAT32) {
return amp_dtype;
}
bool is_right_place = false;
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
auto place = tensor.place();
is_right_place = (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place) ||
paddle::platform::is_custom_place(place));
if (is_right_place) {
break;
}
}
}
if (!is_right_place) {
VLOG(6) << "Change " << op_name << "'s AMP type from " << amp_dtype
<< " to FP32";
return paddle::experimental::DataType::FLOAT32;
}
return amp_dtype;
}
inline paddle::experimental::DataType GetAmpDestDtype( inline paddle::experimental::DataType GetAmpDestDtype(
const std::string& op_name, const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>, const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
...@@ -95,19 +128,21 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -95,19 +128,21 @@ inline paddle::experimental::DataType GetAmpDestDtype(
VLOG(6) << "AMP GetAmpDestDtype:" VLOG(6) << "AMP GetAmpDestDtype:"
<< " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level(" << " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level("
<< static_cast<int>(amp_level) << ")."; << static_cast<int>(amp_level) << ").";
auto return_amp_type = paddle::experimental::DataType::FLOAT16;
if (amp_dtype == "float16") { if (amp_dtype == "float16") {
if (amp_level == paddle::imperative::AmpLevel::O1) { if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance() if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps() .GetMutableAllowOps()
->count(op_name)) { ->count(op_name)) {
return paddle::experimental::DataType::FLOAT16; return_amp_type = paddle::experimental::DataType::FLOAT16;
} else if (paddle::imperative::AmpOperators::Instance() } else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps() .GetMutableBlockOps()
->count(op_name) || ->count(op_name) ||
paddle::imperative::AmpOperators::Instance() paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops() .GetMutableUnsupportedFp16Ops()
->count(op_name)) { ->count(op_name)) {
return paddle::experimental::DataType::FLOAT32; return_amp_type = paddle::experimental::DataType::FLOAT32;
} else { } else {
auto dst_type = GetPromoteType(op_name, auto dst_type = GetPromoteType(op_name,
amp_tensors_vector, amp_tensors_vector,
...@@ -118,7 +153,7 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -118,7 +153,7 @@ inline paddle::experimental::DataType GetAmpDestDtype(
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} }
return dst_type; return_amp_type = dst_type;
} }
} else if (amp_level == paddle::imperative::AmpLevel::O2) { } else if (amp_level == paddle::imperative::AmpLevel::O2) {
auto dst_type = paddle::experimental::DataType::FLOAT16; auto dst_type = paddle::experimental::DataType::FLOAT16;
...@@ -130,18 +165,18 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -130,18 +165,18 @@ inline paddle::experimental::DataType GetAmpDestDtype(
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} }
return dst_type; return_amp_type = dst_type;
} }
} else if (amp_dtype == "bfloat16") { } else if (amp_dtype == "bfloat16") {
if (amp_level == paddle::imperative::AmpLevel::O1) { if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance() if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps() .GetMutableAllowOps()
->count(op_name)) { ->count(op_name)) {
return paddle::experimental::DataType::BFLOAT16; return_amp_type = paddle::experimental::DataType::BFLOAT16;
} else if (paddle::imperative::AmpOperators::Instance() } else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps() .GetMutableBlockOps()
->count(op_name)) { ->count(op_name)) {
return paddle::experimental::DataType::FLOAT32; return_amp_type = paddle::experimental::DataType::FLOAT32;
} else { } else {
auto dst_type = auto dst_type =
GetPromoteType(op_name, GetPromoteType(op_name,
...@@ -153,7 +188,7 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -153,7 +188,7 @@ inline paddle::experimental::DataType GetAmpDestDtype(
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} }
return dst_type; return_amp_type = dst_type;
} }
} else if (amp_level == paddle::imperative::AmpLevel::O2) { } else if (amp_level == paddle::imperative::AmpLevel::O2) {
auto dst_type = paddle::experimental::DataType::BFLOAT16; auto dst_type = paddle::experimental::DataType::BFLOAT16;
...@@ -165,10 +200,12 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -165,10 +200,12 @@ inline paddle::experimental::DataType GetAmpDestDtype(
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} }
return dst_type; return_amp_type = dst_type;
} }
} else {
return_amp_type = paddle::experimental::DataType::FLOAT32;
} }
return paddle::experimental::DataType::FLOAT32; return GetDtypeWithPlace(op_name, amp_tensors_vector, return_amp_type);
} }
} // namespace egr } // namespace egr
...@@ -22,14 +22,19 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor, ...@@ -22,14 +22,19 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
const paddle::experimental::DataType& dst_dtype) { const paddle::experimental::DataType& dst_dtype) {
auto place = tensor.place(); auto place = tensor.place();
auto data_type = tensor.dtype(); auto data_type = tensor.dtype();
// Except CPU judgment, other conditions should be consistent with
// amp_utils.h's judgment
if (paddle::platform::is_gpu_place(place) || if (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) || paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place) || paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) || paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) || paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place) || paddle::platform::is_npu_pinned_place(place) ||
paddle::platform::is_custom_place(place)) { paddle::platform::is_custom_place(place) ||
paddle::platform::is_cpu_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader // CudaPinndePlace is added for varbase created by dataloader
// Cpu place is for differnt place tensor, when input1 is cpu and input2 is
// gpu
if ((data_type == paddle::experimental::DataType::FLOAT32 || if ((data_type == paddle::experimental::DataType::FLOAT32 ||
data_type == paddle::experimental::DataType::FLOAT16 || data_type == paddle::experimental::DataType::FLOAT16 ||
data_type == paddle::experimental::DataType::BFLOAT16) && data_type == paddle::experimental::DataType::BFLOAT16) &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册