From 8eaff62d795a0804993a470eb549ebed8b8d41cd Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Mon, 10 Oct 2022 10:08:57 +0200 Subject: [PATCH] add function FindInputNameByVarName (#46759) * Add methods that find input or output name by var name * kind of bugfix - initialize variables * ci fix * review fixed --- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 2 +- .../ir/mkldnn/cpu_quantize_squash_pass.cc | 28 ++++++++----------- .../ir/mkldnn/scale_matmul_fuse_pass.cc | 7 ++--- paddle/fluid/platform/mkldnn_helper.h | 18 ++++++++++++ 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 92351d5067f..f40eeea1736 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -873,7 +873,7 @@ void CPUQuantizePass::QuantizeElementwise( auto x_name = elementwise_op->Op()->Input("X"); auto y_name = elementwise_op->Op()->Input("Y"); - Node *elementwise_x, *elementwise_y; + Node *elementwise_x{nullptr}, *elementwise_y{nullptr}; for (auto& input : elementwise_op->inputs) { if (input->Name() == x_name[0]) elementwise_x = input; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 7c23976d3c6..933d60b0a27 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -19,6 +19,7 @@ #include #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -262,10 +263,8 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, op_requant_pattern); if (requant_in->outputs.size() == 1) { - std::string any_op_output_name; - for (auto name : any_op->Op()->OutputNames()) - for (auto output_name : any_op->Op()->Output(name)) - if (output_name == requant_in->Name()) any_op_output_name = name; + std::string any_op_output_name = + FindOutputNameByVarName(any_op->Op(), requant_in->Name()); PADDLE_ENFORCE_NE( any_op_output_name.empty(), @@ -308,10 +307,8 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern); if (requant_out->outputs.size() == 1) { - std::string any_op_input_name; - for (auto name : any_op->Op()->InputNames()) - for (auto input_name : any_op->Op()->Input(name)) - if (input_name == requant_out->Name()) any_op_input_name = name; + std::string any_op_input_name = + FindInputNameByVarName(any_op->Op(), requant_out->Name()); PADDLE_ENFORCE_NE(any_op_input_name.empty(), true, @@ -374,10 +371,8 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { return; } // Find the name of the output linking any_op to dequant_in - std::string output_name; - for (auto name : any_op->Op()->OutputNames()) - for (auto out_name : any_op->Op()->Output(name)) - if (out_name == dequant_in->Name()) output_name = name; + std::string output_name = + FindOutputNameByVarName(any_op->Op(), dequant_in->Name()); if (output_name.empty()) return; @@ -430,17 +425,16 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { quant_op->Op()->GetAttrIfExists("Shift") == shift) { auto quant_out = quant_op->outputs[0]; auto last_op = quant_out->outputs[0]; + auto last_op_op = last_op->Op(); - std::string last_op_input_name; - for (auto name : last_op->Op()->InputNames()) - for (auto input_name : last_op->Op()->Input(name)) - if (input_name == quant_out->Name()) last_op_input_name = name; + std::string last_op_input_name = + FindInputNameByVarName(last_op_op, quant_out->Name()); PADDLE_ENFORCE_NE( last_op_input_name.empty(), true, platform::errors::NotFound("Operator after quantize operator(%s) " - "should has quantize output as input.", + "should have quantize output as input.", quant_out->Name())); // update the next operator input, diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc index 20b9698bda3..a968af26bd2 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -110,10 +111,8 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { "Scale(%f) of scale op should have positive value.", scale_scale)); - std::string matmul_op_input_name; - for (auto name : matmul_op->Op()->InputNames()) - for (auto input_name : matmul_op->Op()->Input(name)) - if (input_name == scale_out->Name()) matmul_op_input_name = name; + std::string matmul_op_input_name = + FindInputNameByVarName(matmul_op->Op(), scale_out->Name()); PADDLE_ENFORCE_NE( matmul_op_input_name.empty(), diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 07f5f3408a3..bba813b7f26 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -632,4 +632,22 @@ bool constexpr is_int8() { } } // namespace platform + +inline std::string FindInputNameByVarName(framework::OpDesc* op, + const std::string& searched_name) { + std::string ret; + for (const auto& name : op->InputNames()) + for (const auto& input_name : op->Input(name)) + if (input_name == searched_name) ret = name; + return ret; +} + +inline std::string FindOutputNameByVarName(framework::OpDesc* op, + const std::string& searched_name) { + std::string ret; + for (const auto& name : op->OutputNames()) + for (const auto& output_name : op->Output(name)) + if (output_name == searched_name) ret = name; + return ret; +} } // namespace paddle -- GitLab