diff --git a/paddle/fluid/eager/amp_auto_cast.h b/paddle/fluid/eager/amp_auto_cast.h index 3f02d68b2aa2e911c7b35cde834136645f90c451..c9cf3e2ee282398d75cf25b7e70825c2d398195f 100644 --- a/paddle/fluid/eager/amp_auto_cast.h +++ b/paddle/fluid/eager/amp_auto_cast.h @@ -69,15 +69,16 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name, VLOG(6) << "AMP AmpAutoCasts:" << " input(" << input_name << ") dst_dtype(" << phi::DataTypeToString(dst_dtype) << ")."; + + if ((op_name == "batch_norm" || op_name == "layer_norm" || + op_name == "sync_batch_norm") && + input_name != "X") { + return input; + } if (dst_dtype == phi::DataType::FLOAT16) { if (op_name == "run_program") { return input; } - if ((op_name == "batch_norm" || op_name == "layer_norm" || - op_name == "sync_batch_norm") && - input_name != "X") { - return input; - } if ((op_name == "fused_attention" || op_name == "fused_feedforward")) { if (input_name == "LnScale" || input_name == "LnBias" || input_name == "Ln2Scale" || input_name == "Ln2Bias" || @@ -86,6 +87,7 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name, } } } + if (NeedCast(input, dst_dtype)) { paddle::framework::AttributeMap cast_attrs = { {"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())}, diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index bfa58512eb238b97be757df338c9564761581938..d77293444e6f5021eb24bb1d3f0cb092ba315d6a 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -26,19 +26,24 @@ static inline phi::DataType GetPromoteType( kSlotSmallVectorSize>& amp_tensors_vector, const phi::DataType& amp_dtype) { auto dst_type = amp_dtype; + // only consider the dtype of input(X). + if (op_name == "batch_norm" || op_name == "layer_norm" || + op_name == "sync_batch_norm" || + op_name == "moving_average_abs_max_scale") { + if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) { + dst_type = phi::DataType::FLOAT32; + } + return dst_type; + } + if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() == "float16") { - if (op_name == "batch_norm" || op_name == "layer_norm" || - op_name == "sync_batch_norm") { - if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) { - dst_type = phi::DataType::FLOAT32; - } - } else if (op_name == "fused_attention") { + if (op_name == "fused_attention") { for (size_t i = 0; i < amp_tensors_vector.size(); i++) { if (i != 3 || i != 4 || i != 9 || i != 10) { if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) { dst_type = phi::DataType::FLOAT32; - break; + return dst_type; } } } @@ -47,37 +52,22 @@ static inline phi::DataType GetPromoteType( if (i != 7 || i != 8 || i != 9 || i != 10) { if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) { dst_type = phi::DataType::FLOAT32; - break; - } - } - } - } else { - for (const auto& tensors : amp_tensors_vector) { - for (const auto& tensor : tensors) { - if (tensor.dtype() == phi::DataType::FLOAT32) { - dst_type = tensor.dtype(); - break; + return dst_type; } } } } - } else { - for (const auto& tensors : amp_tensors_vector) { - for (const auto& tensor : tensors) { - if (tensor.dtype() == phi::DataType::FLOAT32) { - dst_type = tensor.dtype(); - break; - } - } - } } - // NOTE(juncai): moving_average_abs_max_scale only consider the dtype of - // input(X) - if (op_name == "moving_average_abs_max_scale") { - if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT16) { - dst_type = phi::DataType::FLOAT16; + + for (const auto& tensors : amp_tensors_vector) { + for (const auto& tensor : tensors) { + if (tensor.dtype() == phi::DataType::FLOAT32) { + dst_type = tensor.dtype(); + break; + } } } + return dst_type; } diff --git a/paddle/fluid/eager/eager_amp_auto_cast.h b/paddle/fluid/eager/eager_amp_auto_cast.h index 8753a95069fc542a6841cde3adc14885030dd50f..a612a84d2ae1c673a25b8e70677836b918a61389 100644 --- a/paddle/fluid/eager/eager_amp_auto_cast.h +++ b/paddle/fluid/eager/eager_amp_auto_cast.h @@ -89,15 +89,16 @@ inline paddle::Tensor EagerAmpAutoCast(const std::string& input_name, VLOG(6) << "AMP AmpAutoCasts:" << " input(" << egr::EagerUtils::TensorStr(input) << " to dst_dtype(" << phi::DataTypeToString(dst_dtype) << ")."; + if ((op_name == "batch_norm" || op_name == "layer_norm" || + op_name == "sync_batch_norm") && + input_name != "x") { + return input; + } + if (dst_dtype == phi::DataType::FLOAT16) { if (op_name == "run_program") { return input; } - if ((op_name == "batch_norm" || op_name == "layer_norm" || - op_name == "sync_batch_norm") && - input_name != "x") { - return input; - } if ((op_name == "fused_attention" || op_name == "fused_feedforward")) { if (input_name == "LnScale" || input_name == "LnBias" || input_name == "Ln2Scale" || input_name == "Ln2Bias" ||