未验证 提交 ca8e21a6 编写于 作者: Z Zhang Ting 提交者: GitHub

polish the amp code (#51020)

上级 871d2d36
...@@ -122,90 +122,42 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -122,90 +122,42 @@ inline paddle::experimental::DataType GetAmpDestDtype(
const std::string& op_name, const std::string& op_name,
const paddle::small_vector<std::vector<paddle::Tensor>, const paddle::small_vector<std::vector<paddle::Tensor>,
kSlotSmallVectorSize>& amp_tensors_vector) { kSlotSmallVectorSize>& amp_tensors_vector) {
auto amp_dtype =
egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype();
auto amp_level = egr::Controller::Instance().GetAMPLevel(); auto amp_level = egr::Controller::Instance().GetAMPLevel();
VLOG(6) << "AMP GetAmpDestDtype:" auto amp_setting_dtype =
<< " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level(" egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype();
<< static_cast<int>(amp_level) << ")."; auto dst_type = amp_setting_dtype;
auto return_amp_type = paddle::experimental::DataType::FLOAT16;
if (amp_dtype == "float16") {
if (amp_level == paddle::imperative::AmpLevel::O1) { if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance() if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps() .GetMutableAllowOps()
->count(op_name)) { ->count(op_name)) {
return_amp_type = paddle::experimental::DataType::FLOAT16; dst_type = amp_setting_dtype;
} else if (paddle::imperative::AmpOperators::Instance() } else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name) ||
paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops()
->count(op_name)) {
return_amp_type = paddle::experimental::DataType::FLOAT32;
} else {
auto dst_type = GetPromoteType(op_name,
amp_tensors_vector,
paddle::experimental::DataType::FLOAT16);
if (dst_type == paddle::experimental::DataType::FLOAT16 &&
paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return_amp_type = dst_type;
}
} else if (amp_level == paddle::imperative::AmpLevel::O2) {
auto dst_type = paddle::experimental::DataType::FLOAT16;
if (paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops()
->count(op_name) ||
paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps() .GetMutableBlockOps()
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
}
return_amp_type = dst_type;
}
} else if (amp_dtype == "bfloat16") {
if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
return_amp_type = paddle::experimental::DataType::BFLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
return_amp_type = paddle::experimental::DataType::FLOAT32;
} else { } else {
auto dst_type = dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
GetPromoteType(op_name,
amp_tensors_vector,
paddle::experimental::DataType::BFLOAT16);
if (dst_type == paddle::experimental::DataType::BFLOAT16 &&
paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedBf16Ops()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return_amp_type = dst_type;
} }
} else if (amp_level == paddle::imperative::AmpLevel::O2) { } else if (amp_level == paddle::imperative::AmpLevel::O2) {
auto dst_type = paddle::experimental::DataType::BFLOAT16;
if (paddle::imperative::AmpOperators::Instance() if (paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedBf16Ops()
->count(op_name) ||
paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps() .GetMutableBlockOps()
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} }
return_amp_type = dst_type;
} }
} else {
return_amp_type = paddle::experimental::DataType::FLOAT32; if (dst_type == amp_setting_dtype &&
(paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedOps(amp_setting_dtype)
->count(op_name))) {
dst_type = paddle::experimental::DataType::FLOAT32;
} }
return GetDtypeWithPlace(op_name, amp_tensors_vector, return_amp_type);
dst_type = GetDtypeWithPlace(op_name, amp_tensors_vector, dst_type);
VLOG(6) << "AMP GetAmpDestDtype:"
<< " op(" << op_name << ") amp_dtype(" << dst_type << ") amp_level("
<< static_cast<int>(amp_level) << ").";
return dst_type;
} }
} // namespace egr } // namespace egr
...@@ -200,6 +200,22 @@ AmpOperators::GetMutableBlockOps() { ...@@ -200,6 +200,22 @@ AmpOperators::GetMutableBlockOps() {
return block_ops_; return block_ops_;
} }
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedOps(
const paddle::experimental::DataType& data_type) {
PADDLE_ENFORCE_EQ(
data_type == paddle::experimental::DataType::FLOAT16 ||
data_type == paddle::experimental::DataType::BFLOAT16,
true,
phi::errors::InvalidArgument(
"The data_type mismatch. It should be FLOAT16 or BFLOAT16."));
if (data_type == paddle::experimental::DataType::FLOAT16) {
return unsupported_fp16_ops_;
} else {
return unsupported_bf16_ops_;
}
}
std::shared_ptr<std::unordered_set<std::string>> std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedFp16Ops() { AmpOperators::GetMutableUnsupportedFp16Ops() {
return unsupported_fp16_ops_; return unsupported_fp16_ops_;
......
...@@ -54,6 +54,9 @@ class AmpOperators { ...@@ -54,6 +54,9 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>> GetMutableBlockOps(); std::shared_ptr<std::unordered_set<std::string>> GetMutableBlockOps();
std::shared_ptr<std::unordered_set<std::string>> GetMutableUnsupportedOps(
const paddle::experimental::DataType& data_type);
std::shared_ptr<std::unordered_set<std::string>> std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedFp16Ops(); GetMutableUnsupportedFp16Ops();
......
...@@ -184,6 +184,8 @@ class Tracer { ...@@ -184,6 +184,8 @@ class Tracer {
} }
} }
phi::DataType GetAmpPhiDtype() const { return amp_dtype_; }
void DisableLayoutAutoTune() { use_layout_autotune_ = false; } void DisableLayoutAutoTune() { use_layout_autotune_ = false; }
void EnableLayoutAutoTune() { use_layout_autotune_ = true; } void EnableLayoutAutoTune() { use_layout_autotune_ = true; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册