未验证 提交 3d5faa88 编写于 作者: N niuliling123 提交者: GitHub

Add Cpu tensor cast when amp_type isn't float32 (#50401)

上级 bf38175e
......@@ -85,6 +85,39 @@ static inline paddle::experimental::DataType GetPromoteType(
return dst_type;
}
inline paddle::experimental::DataType GetDtypeWithPlace(
const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>& amp_tensors_vector,
const paddle::experimental::DataType amp_dtype) {
if (amp_dtype == paddle::experimental::DataType::FLOAT32) {
return amp_dtype;
}
bool is_right_place = false;
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
auto place = tensor.place();
is_right_place = (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place) ||
paddle::platform::is_custom_place(place));
if (is_right_place) {
break;
}
}
}
if (!is_right_place) {
VLOG(6) << "Change " << op_name << "'s AMP type from " << amp_dtype
<< " to FP32";
return paddle::experimental::DataType::FLOAT32;
}
return amp_dtype;
}
inline paddle::experimental::DataType GetAmpDestDtype(
const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
......@@ -95,19 +128,21 @@ inline paddle::experimental::DataType GetAmpDestDtype(
VLOG(6) << "AMP GetAmpDestDtype:"
<< " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level("
<< static_cast<int>(amp_level) << ").";
auto return_amp_type = paddle::experimental::DataType::FLOAT16;
if (amp_dtype == "float16") {
if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
return paddle::experimental::DataType::FLOAT16;
return_amp_type = paddle::experimental::DataType::FLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name) ||
paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops()
->count(op_name)) {
return paddle::experimental::DataType::FLOAT32;
return_amp_type = paddle::experimental::DataType::FLOAT32;
} else {
auto dst_type = GetPromoteType(op_name,
amp_tensors_vector,
......@@ -118,7 +153,7 @@ inline paddle::experimental::DataType GetAmpDestDtype(
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return dst_type;
return_amp_type = dst_type;
}
} else if (amp_level == paddle::imperative::AmpLevel::O2) {
auto dst_type = paddle::experimental::DataType::FLOAT16;
......@@ -130,18 +165,18 @@ inline paddle::experimental::DataType GetAmpDestDtype(
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return dst_type;
return_amp_type = dst_type;
}
} else if (amp_dtype == "bfloat16") {
if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
return paddle::experimental::DataType::BFLOAT16;
return_amp_type = paddle::experimental::DataType::BFLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
return paddle::experimental::DataType::FLOAT32;
return_amp_type = paddle::experimental::DataType::FLOAT32;
} else {
auto dst_type =
GetPromoteType(op_name,
......@@ -153,7 +188,7 @@ inline paddle::experimental::DataType GetAmpDestDtype(
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return dst_type;
return_amp_type = dst_type;
}
} else if (amp_level == paddle::imperative::AmpLevel::O2) {
auto dst_type = paddle::experimental::DataType::BFLOAT16;
......@@ -165,10 +200,12 @@ inline paddle::experimental::DataType GetAmpDestDtype(
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return dst_type;
return_amp_type = dst_type;
}
} else {
return_amp_type = paddle::experimental::DataType::FLOAT32;
}
return paddle::experimental::DataType::FLOAT32;
return GetDtypeWithPlace(op_name, amp_tensors_vector, return_amp_type);
}
} // namespace egr
......@@ -22,14 +22,19 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
const paddle::experimental::DataType& dst_dtype) {
auto place = tensor.place();
auto data_type = tensor.dtype();
// Except CPU judgment, other conditions should be consistent with
// amp_utils.h's judgment
if (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place) ||
paddle::platform::is_custom_place(place)) {
paddle::platform::is_custom_place(place) ||
paddle::platform::is_cpu_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
// Cpu place is for differnt place tensor, when input1 is cpu and input2 is
// gpu
if ((data_type == paddle::experimental::DataType::FLOAT32 ||
data_type == paddle::experimental::DataType::FLOAT16 ||
data_type == paddle::experimental::DataType::BFLOAT16) &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册