未验证 提交 4ee3815e 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] fix bf16 amp training error (#54571)

上级 4277f61f
...@@ -69,15 +69,16 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name, ...@@ -69,15 +69,16 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
VLOG(6) << "AMP AmpAutoCasts:" VLOG(6) << "AMP AmpAutoCasts:"
<< " input(" << input_name << ") dst_dtype(" << " input(" << input_name << ") dst_dtype("
<< phi::DataTypeToString(dst_dtype) << ")."; << phi::DataTypeToString(dst_dtype) << ").";
if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
}
if ((op_name == "batch_norm" || op_name == "layer_norm" || if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") && op_name == "sync_batch_norm") &&
input_name != "X") { input_name != "X") {
return input; return input;
} }
if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
}
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) { if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
if (input_name == "LnScale" || input_name == "LnBias" || if (input_name == "LnScale" || input_name == "LnBias" ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" || input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
...@@ -86,6 +87,7 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name, ...@@ -86,6 +87,7 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
} }
} }
} }
if (NeedCast(input, dst_dtype)) { if (NeedCast(input, dst_dtype)) {
paddle::framework::AttributeMap cast_attrs = { paddle::framework::AttributeMap cast_attrs = {
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())}, {"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
......
...@@ -26,19 +26,24 @@ static inline phi::DataType GetPromoteType( ...@@ -26,19 +26,24 @@ static inline phi::DataType GetPromoteType(
kSlotSmallVectorSize>& amp_tensors_vector, kSlotSmallVectorSize>& amp_tensors_vector,
const phi::DataType& amp_dtype) { const phi::DataType& amp_dtype) {
auto dst_type = amp_dtype; auto dst_type = amp_dtype;
if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() == // only consider the dtype of input(X).
"float16") {
if (op_name == "batch_norm" || op_name == "layer_norm" || if (op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") { op_name == "sync_batch_norm" ||
op_name == "moving_average_abs_max_scale") {
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) { if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32; dst_type = phi::DataType::FLOAT32;
} }
} else if (op_name == "fused_attention") { return dst_type;
}
if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() ==
"float16") {
if (op_name == "fused_attention") {
for (size_t i = 0; i < amp_tensors_vector.size(); i++) { for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
if (i != 3 || i != 4 || i != 9 || i != 10) { if (i != 3 || i != 4 || i != 9 || i != 10) {
if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) { if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32; dst_type = phi::DataType::FLOAT32;
break; return dst_type;
} }
} }
} }
...@@ -47,21 +52,13 @@ static inline phi::DataType GetPromoteType( ...@@ -47,21 +52,13 @@ static inline phi::DataType GetPromoteType(
if (i != 7 || i != 8 || i != 9 || i != 10) { if (i != 7 || i != 8 || i != 9 || i != 10) {
if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) { if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32; dst_type = phi::DataType::FLOAT32;
break; return dst_type;
} }
} }
} }
} else {
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == phi::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
}
}
} }
} }
} else {
for (const auto& tensors : amp_tensors_vector) { for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) { for (const auto& tensor : tensors) {
if (tensor.dtype() == phi::DataType::FLOAT32) { if (tensor.dtype() == phi::DataType::FLOAT32) {
...@@ -70,14 +67,7 @@ static inline phi::DataType GetPromoteType( ...@@ -70,14 +67,7 @@ static inline phi::DataType GetPromoteType(
} }
} }
} }
}
// NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
// input(X)
if (op_name == "moving_average_abs_max_scale") {
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT16) {
dst_type = phi::DataType::FLOAT16;
}
}
return dst_type; return dst_type;
} }
......
...@@ -89,15 +89,16 @@ inline paddle::Tensor EagerAmpAutoCast(const std::string& input_name, ...@@ -89,15 +89,16 @@ inline paddle::Tensor EagerAmpAutoCast(const std::string& input_name,
VLOG(6) << "AMP AmpAutoCasts:" VLOG(6) << "AMP AmpAutoCasts:"
<< " input(" << egr::EagerUtils::TensorStr(input) << " to dst_dtype(" << " input(" << egr::EagerUtils::TensorStr(input) << " to dst_dtype("
<< phi::DataTypeToString(dst_dtype) << ")."; << phi::DataTypeToString(dst_dtype) << ").";
if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
}
if ((op_name == "batch_norm" || op_name == "layer_norm" || if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") && op_name == "sync_batch_norm") &&
input_name != "x") { input_name != "x") {
return input; return input;
} }
if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
}
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) { if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
if (input_name == "LnScale" || input_name == "LnBias" || if (input_name == "LnScale" || input_name == "LnBias" ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" || input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册