From c80d01bfe88828c3c7ac0304f89f9ce04de12de2 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 8 Sep 2022 23:02:21 +0800 Subject: [PATCH] fix select fp16 kernel. (#45882) --- .../passes/convert_to_mixed_precision.cc | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 87750d713c6..b49ad4c145d 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -677,6 +677,26 @@ void ConvertTensorDtype( OpSupportPrecision(op_type, backend, tensor_dtype, blacklist); VLOG(2) << " support low precision " << support_precision; + // if op not has float input, we will not choose the low precision kernel. + { + bool has_float_input{false}; + for (auto in_node : op_node->inputs) { + auto* real_node = + GetRealNode(graphes, block_idx, in_node, vars_in_multi_block_map); + if (real_node->Var()->GetDataType() == proto::VarType::FP16 || + real_node->Var()->GetDataType() == proto::VarType::FP32 || + real_node->Var()->GetDataType() == proto::VarType::FP64 || + real_node->Var()->GetDataType() == proto::VarType::BF16) { + has_float_input = true; + break; + } + } + if (!has_float_input) { + support_precision = false; + VLOG(2) << " op doesn't has float input, just skip."; + } + } + if (support_precision) { HandleSpecialOps(op_node->Op()); ++num_low_precision; -- GitLab