diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 44b41a89700de9fcbe23c036b52562b65de4d2a8..3d66ed788c6a944dfb2afd083513349a11d89c48 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -78,11 +78,11 @@ inline bool VarNodeHasDtype(Node* var_node) { (type == VarType::VOCAB); } -inline bool IsFloatType(VarType::Type type) { +inline bool IsFP32AndFP64(VarType::Type type) { return (type == VarType::FP64) || (type == VarType::FP32); } -inline bool IsHalfType(VarType::Type type) { +inline bool IsFP16AndBFP16(VarType::Type type) { return (type == VarType::FP16) || (type == VarType::BF16); } @@ -159,26 +159,16 @@ bool OpSupportPrecision(const std::string& op_type, // The set of ops that support fp16 calculation and are considered // numerically-dangerous, slower and whose effects may also be observed in // downstream ops. +// ref to python/paddle/fluid/contrib/mixed_precision/fp16_lists.py void AutoMixedPrecisionPass::SetDefaultBlacklist() const { black_list_.insert({ // numerically-dangerous - "acos", - "asin", - "cosh", - "tan", "exp", - "expm1", "square", "log", - "log2", - "log10", - "log1p", - "logsumexp", "mean", - "rsqrt", "sum", "cos_sim", - "softmax", "softmax_with_cross_entropy", "sigmoid_cross_entropy_with_logits", "c_softmax_with_cross_entropy", @@ -272,6 +262,9 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const { VLOG(4) << "InsertCastOp done"; RestoreOpOriginType(); VLOG(4) << "RestoreOpOriginType done"; + LOG(INFO) << "The number of ops run at low precision [" + << op_run_low_precision_.size() << "/" << op_original_type_.size() + << "]"; } void AutoMixedPrecisionPass::SetOpUniqueType() const { @@ -315,11 +308,26 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const { for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { auto op_type = op_node->Op()->Type(); + + if (op_node->Op()->HasAttr("in_dtype")) { + auto* var_node = op_node->inputs[0]; + auto* real_var_node = real_vars_[var_node->Var()->Name()]; + if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) { + op_node->Op()->SetAttr( + "in_dtype", + static_cast(framework::TransToProtoVarType(low_precision_))); + op_node->Op()->Flush(); + VLOG(4) << "process op with in_dtype attr: " << op_type << " ( " + << static_cast(real_var_node->Var()->GetDataType()) + << " --->" << static_cast(low_precision_) << " )"; + } + } + if (op_run_low_precision_.count(op_type) == 0) continue; if (op_node->Op()->HasAttr("dtype")) { auto dtype = op_node->Op()->GetAttrIfExists("dtype"); - if (IsFloatType(static_cast(dtype))) { + if (IsFP32AndFP64(static_cast(dtype))) { op_node->Op()->SetAttr( "dtype", static_cast(framework::TransToProtoVarType(low_precision_))); @@ -327,10 +335,9 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const { VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype << " --->" << static_cast(low_precision_) << " )"; } - } - if (op_node->Op()->HasAttr("out_dtype")) { + } else if (op_node->Op()->HasAttr("out_dtype")) { auto out_dtype = op_node->Op()->GetAttrIfExists("out_dtype"); - if (IsFloatType(static_cast(out_dtype))) { + if (IsFP32AndFP64(static_cast(out_dtype))) { op_node->Op()->SetAttr( "out_dtype", static_cast(framework::TransToProtoVarType(low_precision_))); @@ -359,37 +366,55 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { if (op_node->Op()->HasAttr("dtype")) { auto dtype = op_node->Op()->GetAttrIfExists("dtype"); - support_low_precision = support_low_precision && - IsFloatType(static_cast(dtype)); + support_low_precision = + support_low_precision && + IsFP32AndFP64(static_cast(dtype)); } else if (op_node->Op()->HasAttr("out_dtype")) { auto out_dtype = op_node->Op()->GetAttrIfExists("out_dtype"); support_low_precision = support_low_precision && - IsFloatType(static_cast(out_dtype)); - } else { - // if op's input var and output var is not dense tensor, the op should - // not run at low precision. - for (auto* in_var_node : op_node->inputs) { - CHECK_EQ(in_var_node->IsVar(), true); - auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; - if (real_in_var_node->Var()->Persistable()) continue; + IsFP32AndFP64(static_cast(out_dtype)); + } + // If scale op's "scale" and "bias" attr value exceed the range of fp16 + // and bf16, it cannot run at low precision. + if (GetOpOriginalType(op_node->Op()->Type()) == "scale") { + auto scale = op_node->Op()->GetAttrIfExists("scale"); + auto bias = op_node->Op()->GetAttrIfExists("bias"); + if (low_precision_ == phi::DataType::FLOAT16) { support_low_precision = support_low_precision && - (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR); - } - - for (auto* out_var_node : op_node->outputs) { - CHECK_EQ(out_var_node->IsVar(), true); - auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; - if (real_out_var_node->Var()->Persistable()) continue; - + phi::dtype::isfinite(static_cast(scale)) && + phi::dtype::isfinite(static_cast(bias)); + } else if (low_precision_ == phi::DataType::BFLOAT16) { support_low_precision = support_low_precision && - (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR); + phi::dtype::isfinite(static_cast(scale)) && + phi::dtype::isfinite(static_cast(bias)); } } + // if op's input var and output var is not dense tensor, the op should + // not run at low precision. + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + if (real_in_var_node->Var()->Persistable()) continue; + + support_low_precision = + support_low_precision && + (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR); + } + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; + if (real_out_var_node->Var()->Persistable()) continue; + + support_low_precision = + support_low_precision && + (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR); + } + if (support_low_precision) { op_run_low_precision_.insert(op_type); VLOG(4) << "support precision: " << op_type << " run at low precision"; @@ -439,7 +464,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { } // when op_1 only support cpu kernel. if op_2's intput var is op_1's - // output var, then op_2 should not run half. + // output var, then op_2 should not run at low precision. if (GetOpOriginalType(op_type) != "feed" && !GpuKernelSupportPrecision(GetOpOriginalType(op_type), phi::DataType::FLOAT32)) { @@ -597,7 +622,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const { auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; auto in_var_name = real_in_var_node->Var()->Name(); - if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue; + if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue; if (!VarNodeHasDtype(real_in_var_node)) continue; if (InputVarsNotConvert(op_node, in_var_name)) continue; @@ -616,7 +641,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const { auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; auto out_var_name = real_out_var_node->Var()->Name(); - if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue; + if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue; if (!VarNodeHasDtype(real_out_var_node)) continue; if (OutputVarsNotConvert(op_node, out_var_name)) continue; @@ -656,7 +681,7 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const { auto var_names = scope->LocalVarNames(); for (const auto& var_name : var_names) { if (vars_convert_to_low_precision_.count(var_name)) { - VLOG(4) << var_name << "'s data type was convert to half"; + VLOG(4) << var_name << "'s data type was convert to low precision"; auto* var = scope->FindLocalVar(var_name); CHECK_EQ(var->IsType(), true); @@ -683,16 +708,18 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const { } } } else if (low_precision_ == phi::DataType::BFLOAT16) { - auto* half_data = + auto* low_precision_data = low_precision_tensor.mutable_data( phi::CPUPlace{}); for (int64_t i = 0; i < origin_tensor->numel(); i++) { if (origin_tensor->dtype() == phi::DataType::FLOAT64) { auto* origin_data = origin_tensor->data(); - half_data[i] = static_cast(origin_data[i]); + low_precision_data[i] = + static_cast(origin_data[i]); } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) { auto* origin_data = origin_tensor->data(); - half_data[i] = static_cast(origin_data[i]); + low_precision_data[i] = + static_cast(origin_data[i]); } } } @@ -732,25 +759,44 @@ void AutoMixedPrecisionPass::InsertCastOp() const { VLOG(4) << "process var: " << real_in_var_node->Var()->Name() << " with type " << in_var_type; - if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) { - DoInsertCastOp(subgraphes_[i], - in_var_node, - op_node, - in_var_type, - framework::TransToProtoVarType(low_precision_), - block_desc, - &suffix, - &cache); - } else if (IsHalfType(in_var_type) && + if (IsFP32AndFP64(in_var_type) && + op_run_low_precision_.count(op_type)) { + auto to_type = framework::TransToProtoVarType(low_precision_); + auto* prev_op = + in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0]; + if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") { + in_var_node->Var()->SetDataType(to_type); + prev_op->Op()->SetAttr("out_dtype", static_cast(to_type)); + prev_op->Op()->Flush(); + } else { + DoInsertCastOp(subgraphes_[i], + in_var_node, + op_node, + in_var_type, + to_type, + block_desc, + &suffix, + &cache); + } + } else if (IsFP16AndBFP16(in_var_type) && op_run_low_precision_.count(op_type) == 0) { - DoInsertCastOp(subgraphes_[i], - in_var_node, - op_node, - in_var_type, - VarType::FP32, - block_desc, - &suffix, - &cache); + auto to_type = VarType::FP32; + auto* prev_op = + in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0]; + if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") { + in_var_node->Var()->SetDataType(to_type); + prev_op->Op()->SetAttr("out_dtype", static_cast(to_type)); + prev_op->Op()->Flush(); + } else { + DoInsertCastOp(subgraphes_[i], + in_var_node, + op_node, + in_var_type, + to_type, + block_desc, + &suffix, + &cache); + } } } diff --git a/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc b/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc index 6354ee47a18f689a41a3c28d542cb8861c2fc1a0..6b83e89a4447d35a035681f8855c6c6eb229cb88 100644 --- a/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc +++ b/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc @@ -164,7 +164,7 @@ TEST(Ernie_gpu_fp16_no_ir, compare_results) { } float *result = reinterpret_cast(output.data.data()); for (size_t j = 0; j < outputs_size; ++j) { - EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2); + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 8e-3); } } } @@ -175,8 +175,6 @@ TEST(Ernie_gpu_fp16_with_ir, compare_results) { config.SetModel(FLAGS_infer_model); config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf); config.SwitchIrOptim(true); - // The fc_fuse_pass has diff, which will be repaired later. - config.pass_builder()->DeletePass("fc_fuse_pass"); // There is a problem with the model itself, which has nothing to do with // constant_folding_pass. config.pass_builder()->DeletePass("constant_folding_pass"); @@ -206,7 +204,7 @@ TEST(Ernie_gpu_fp16_with_ir, compare_results) { } float *result = reinterpret_cast(output.data.data()); for (size_t j = 0; j < outputs_size; ++j) { - EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2); + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 2e-2); } } } @@ -243,7 +241,7 @@ TEST(Ernie_gpu_bf16_no_ir, compare_results) { } float *result = reinterpret_cast(output.data.data()); for (size_t j = 0; j < outputs_size; ++j) { - EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2); + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 1e-2); } } } @@ -254,8 +252,6 @@ TEST(Ernie_gpu_bf16_with_ir, compare_results) { config.SetModel(FLAGS_infer_model); config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16); config.SwitchIrOptim(true); - // The fc_fuse_pass has diff, which will be repaired later. - config.pass_builder()->DeletePass("fc_fuse_pass"); // There is a problem with the model itself, which has nothing to do with // constant_folding_pass. config.pass_builder()->DeletePass("constant_folding_pass"); @@ -285,7 +281,7 @@ TEST(Ernie_gpu_bf16_with_ir, compare_results) { } float *result = reinterpret_cast(output.data.data()); for (size_t j = 0; j < outputs_size; ++j) { - EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2); + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-3); } } } diff --git a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu index 758fb8a23f8f92145d987fd39ffd45f64a8dad27..87811b61306d9eaf565c6e39e8a5dfa94163b49b 100644 --- a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu +++ b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu @@ -223,13 +223,7 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data, // For layer_norm, reduce to calculate mean and std sum_i += static_cast(tmp_3); -#if defined(PADDLE_WITH_CUDA) && __CUDA_ARCH__ >= 530 - square_sum_i += static_cast(__hmul(tmp_3, tmp_3)); -#elif defined(PADDLE_WITH_CUDA) square_sum_i += static_cast(tmp_3) * static_cast(tmp_3); -#else - square_sum_i += static_cast(tmp_3 * tmp_3); -#endif } auto pair = BlockReduce(temp_storage) .Reduce(PairForLayerNorm(sum_i, square_sum_i), @@ -282,9 +276,9 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data, half tmp_0 = __hdiv(__hsub(save_ptr[save_index], mean_i), std_i); half tmp_1 = scale ? __hmul(scale[j], tmp_0) : tmp_0; #else - half tmp_0 = static_cast(static_cast(save_ptr[save_index]) + - static_cast(mean_i) / - static_cast(std_i)); + half tmp_0 = static_cast(static_cast(save_ptr[save_index]) - + static_cast(mean_i) / + static_cast(std_i)); half tmp_1 = scale ? static_cast(static_cast(scale[j]) * static_cast(tmp_0)) : tmp_0; diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index 6015266dde9e7af0ca8523ec97e661efb507317b..a7d4535d6df1adc603790a27ecb4f91ea479cd93 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -149,7 +149,7 @@ __global__ void bias_relu_v2(const int num, #if __CUDA_ARCH__ >= 800 packed_val = __hmax2(__half2(0, 0), packed_val); #elif __CUDA_ARCH__ >= 530 - packed_val = __hmul2(__hgt2(__half2(0, 0), packed_val), packed_val); + packed_val = __hmul2(__hgt2(packed_val, __half2(0, 0)), packed_val); #else packed_val.x = static_cast(static_cast(packed_val.x) > 0) * static_cast(packed_val.x);