From 624ffdf2da35b5ecab5340bb90a63960bf452680 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Mon, 7 Nov 2022 12:54:22 +0800 Subject: [PATCH] [Paddle inference] fix mixed precision (#47654) --- .../passes/convert_to_mixed_precision.cc | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 9d0e6ecf49a..1f7b83504d0 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -110,7 +110,14 @@ class ConvertToMixedPrecisionPass { keep_io_types_(keep_io_types), black_list_(black_list), place_(paddle::CPUPlace()), - executor_(place_) {} + executor_(place_) { + black_list_.insert("assign"); + black_list_.insert("fill_constant"); + black_list_.insert("assign_value"); + black_list_.insert("eye"); + black_list_.insert("fill_any_like"); + black_list_.insert("fill_constant_batch_size_like"); + } void Run(); @@ -587,10 +594,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { bool support_precision = OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_); - // If the op has no input and output of float type, we will not choose the + // If the op has no input of float type, we will not choose the // low precision kernel. { - bool has_float_input_and_output{false}; + bool has_float_input{false}; for (auto* in_node : op_node->inputs) { if (!in_node->IsVar()) continue; auto* real_node = GetRealVarNode(block_idx, in_node); @@ -598,22 +605,12 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { real_node->Var()->GetDataType() == VarType::FP32 || real_node->Var()->GetDataType() == VarType::FP64 || real_node->Var()->GetDataType() == VarType::BF16) { - has_float_input_and_output = true; - break; - } - } - for (auto* out_node : op_node->outputs) { - if (!out_node->IsVar()) continue; - auto* real_node = GetRealVarNode(block_idx, out_node); - if (real_node->Var()->GetDataType() == VarType::FP16 || - real_node->Var()->GetDataType() == VarType::FP32 || - real_node->Var()->GetDataType() == VarType::FP64 || - real_node->Var()->GetDataType() == VarType::BF16) { - has_float_input_and_output = true; + has_float_input = true; break; } } - if (!has_float_input_and_output) { + + if (!has_float_input) { support_precision = false; VLOG(2) << " op doesn't has float input and output, just skip."; } @@ -727,7 +724,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { if (vars_in_multi_block_with_pair_.count(real_node->Name()) && vars_in_multi_block_with_pair_.at(real_node->Name()).second == - block_idx) { + block_idx && + vars_in_multi_block_with_pair_.at(real_node->Name()).first == + VarType::Type()) { vars_in_multi_block_with_pair_.at(real_node->Name()).first = real_node->Var()->GetDataType(); } -- GitLab