diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 87750d713c6d459c5e62f6a85687da0b4d07a7e0..b49ad4c145d5552cfb7327a953eafac97c02ffbd 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -677,6 +677,26 @@ void ConvertTensorDtype( OpSupportPrecision(op_type, backend, tensor_dtype, blacklist); VLOG(2) << " support low precision " << support_precision; + // if op not has float input, we will not choose the low precision kernel. + { + bool has_float_input{false}; + for (auto in_node : op_node->inputs) { + auto* real_node = + GetRealNode(graphes, block_idx, in_node, vars_in_multi_block_map); + if (real_node->Var()->GetDataType() == proto::VarType::FP16 || + real_node->Var()->GetDataType() == proto::VarType::FP32 || + real_node->Var()->GetDataType() == proto::VarType::FP64 || + real_node->Var()->GetDataType() == proto::VarType::BF16) { + has_float_input = true; + break; + } + } + if (!has_float_input) { + support_precision = false; + VLOG(2) << " op doesn't has float input, just skip."; + } + } + if (support_precision) { HandleSpecialOps(op_node->Op()); ++num_low_precision;