From 9aa89b99e1c13f4d5a82dda60a9eb4de242a3388 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Tue, 21 Jun 2022 07:04:45 +0200 Subject: [PATCH] Correct elementwise quantization (#43693) --- .../framework/ir/graph_pattern_detector.cc | 14 +++++++++++ .../framework/ir/graph_pattern_detector.h | 13 +++++++++++ .../framework/ir/mkldnn/cpu_quantize_pass.cc | 23 +++++++++++-------- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index ca5a82708c..27444eca5d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2078,6 +2078,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, return out_var; } +PDNode *patterns::ElementwiseOp::operator()( + const std::string elementwise_type) { + auto elementwise_op = + pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); + + auto out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output(elementwise_type, "Out"); + + elementwise_op->LinksTo({out_var}); + + return out_var; +} + PDNode *patterns::ResidualElementwise::operator()( PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, bool as_x) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 3c6b6ce94e..49d928c419 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1072,6 +1072,19 @@ struct Elementwise : public PatternBase { PATTERN_DECL_NODE(elementwise_out); }; +// Elementwise ops +// Forward pass for element-wise operators +// elementwise_out is the result of the operator +struct ElementwiseOp : public PatternBase { + ElementwiseOp(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "elementwise") {} + + PDNode* operator()(const std::string elementwise_type); + + PATTERN_DECL_NODE(elementwise_op); + PATTERN_DECL_NODE(elementwise_out); +}; + // Residual Elementwise ops // This pattern allows operator output to be X or Y // and residual data Y or X, based on as_x flag diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 452212664e..35e65eb1dd 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -858,12 +858,9 @@ void CPUQuantizePass::QuantizeElementwise( Graph* graph, const std::string elementwise_type) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - patterns::Elementwise elementwise_pattern{pattern, name_scope_}; + patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_}; - elementwise_pattern( - pattern->NewNode(elementwise_pattern.elementwise_x_repr()), - pattern->NewNode(elementwise_pattern.elementwise_y_repr()), - elementwise_type); + elementwise_pattern(elementwise_type); int quantize_elementwise_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -878,10 +875,18 @@ void CPUQuantizePass::QuantizeElementwise( return; } - GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x, - elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y, - elementwise_pattern); + auto x_name = elementwise_op->Op()->Input("X"); + auto y_name = elementwise_op->Op()->Input("Y"); + Node *elementwise_x, *elementwise_y; + + for (auto& input : elementwise_op->inputs) { + if (input->Name() == x_name[0]) elementwise_x = input; + if (input->Name() == y_name[0]) elementwise_y = input; + } + if (!elementwise_x || !elementwise_y) { + return; + } + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, elementwise_pattern); -- GitLab