未验证 提交 e93e48ef 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] fix bf16 amp training error (#54571) (#54643)

fix bf16 amp training error
cherry pick #54571
上级 6b778b97
......@@ -69,15 +69,16 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
VLOG(6) << "AMP AmpAutoCasts:"
<< " input(" << input_name << ") dst_dtype("
<< phi::DataTypeToString(dst_dtype) << ").";
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") &&
input_name != "X") {
return input;
}
if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
}
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") &&
input_name != "X") {
return input;
}
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
if (input_name == "LnScale" || input_name == "LnBias" ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
......@@ -86,6 +87,7 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
}
}
}
if (NeedCast(input, dst_dtype)) {
paddle::framework::AttributeMap cast_attrs = {
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
......
......@@ -26,19 +26,24 @@ static inline phi::DataType GetPromoteType(
kSlotSmallVectorSize>& amp_tensors_vector,
const phi::DataType& amp_dtype) {
auto dst_type = amp_dtype;
// only consider the dtype of input(X).
if (op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm" ||
op_name == "moving_average_abs_max_scale") {
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32;
}
return dst_type;
}
if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() ==
"float16") {
if (op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") {
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32;
}
} else if (op_name == "fused_attention") {
if (op_name == "fused_attention") {
for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
if (i != 3 || i != 4 || i != 9 || i != 10) {
if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32;
break;
return dst_type;
}
}
}
......@@ -47,37 +52,22 @@ static inline phi::DataType GetPromoteType(
if (i != 7 || i != 8 || i != 9 || i != 10) {
if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32;
break;
}
}
}
} else {
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == phi::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
return dst_type;
}
}
}
}
} else {
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == phi::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
}
}
}
}
// NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
// input(X)
if (op_name == "moving_average_abs_max_scale") {
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT16) {
dst_type = phi::DataType::FLOAT16;
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == phi::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
}
}
}
return dst_type;
}
......
......@@ -89,15 +89,16 @@ inline paddle::Tensor EagerAmpAutoCast(const std::string& input_name,
VLOG(6) << "AMP AmpAutoCasts:"
<< " input(" << egr::EagerUtils::TensorStr(input) << " to dst_dtype("
<< phi::DataTypeToString(dst_dtype) << ").";
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") &&
input_name != "x") {
return input;
}
if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
}
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") &&
input_name != "x") {
return input;
}
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
if (input_name == "LnScale" || input_name == "LnBias" ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册