diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index ca5a82708c554bfd7a991b9d204b9ce80e9a570f..27444eca5d856d03bdccf468d8cdee8c1b969b8b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2078,6 +2078,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, return out_var; } +PDNode *patterns::ElementwiseOp::operator()( + const std::string elementwise_type) { + auto elementwise_op = + pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); + + auto out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output(elementwise_type, "Out"); + + elementwise_op->LinksTo({out_var}); + + return out_var; +} + PDNode *patterns::ResidualElementwise::operator()( PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, bool as_x) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 3c6b6ce94e23f3501f56d467ace80455eab0e9aa..49d928c41901540c948c3b72f09c66e775c0e9ca 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1072,6 +1072,19 @@ struct Elementwise : public PatternBase { PATTERN_DECL_NODE(elementwise_out); }; +// Elementwise ops +// Forward pass for element-wise operators +// elementwise_out is the result of the operator +struct ElementwiseOp : public PatternBase { + ElementwiseOp(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "elementwise") {} + + PDNode* operator()(const std::string elementwise_type); + + PATTERN_DECL_NODE(elementwise_op); + PATTERN_DECL_NODE(elementwise_out); +}; + // Residual Elementwise ops // This pattern allows operator output to be X or Y // and residual data Y or X, based on as_x flag diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 452212664ec93999e0168d81df2c83e6783ec6fc..35e65eb1ddfef8edde511d602956af2bd7b75aa1 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -858,12 +858,9 @@ void CPUQuantizePass::QuantizeElementwise( Graph* graph, const std::string elementwise_type) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - patterns::Elementwise elementwise_pattern{pattern, name_scope_}; + patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_}; - elementwise_pattern( - pattern->NewNode(elementwise_pattern.elementwise_x_repr()), - pattern->NewNode(elementwise_pattern.elementwise_y_repr()), - elementwise_type); + elementwise_pattern(elementwise_type); int quantize_elementwise_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -878,10 +875,18 @@ void CPUQuantizePass::QuantizeElementwise( return; } - GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x, - elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y, - elementwise_pattern); + auto x_name = elementwise_op->Op()->Input("X"); + auto y_name = elementwise_op->Op()->Input("Y"); + Node *elementwise_x, *elementwise_y; + + for (auto& input : elementwise_op->inputs) { + if (input->Name() == x_name[0]) elementwise_x = input; + if (input->Name() == y_name[0]) elementwise_y = input; + } + if (!elementwise_x || !elementwise_y) { + return; + } + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, elementwise_pattern);