diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 8eb1b64a2763a2b1c91250c1940871b14f996d9b..7951514f8d3bb2c9c4d2d2316ef1de392158b52f 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2069,6 +2069,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, return out_var; } +PDNode *patterns::ElementwiseOp::operator()( + const std::string elementwise_type) { + auto elementwise_op = + pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); + + auto out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output(elementwise_type, "Out"); + + elementwise_op->LinksTo({out_var}); + + return out_var; +} + PDNode *patterns::ResidualElementwise::operator()( PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, bool as_x) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 463b50ac48a8fef8c977585bf4f1eabafc32b74e..45b191ea2299cb044a43ce92c5bf2d88daf54b1c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1071,6 +1071,19 @@ struct Elementwise : public PatternBase { PATTERN_DECL_NODE(elementwise_out); }; +// Elementwise ops +// Forward pass for element-wise operators +// elementwise_out is the result of the operator +struct ElementwiseOp : public PatternBase { + ElementwiseOp(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "elementwise") {} + + PDNode* operator()(const std::string elementwise_type); + + PATTERN_DECL_NODE(elementwise_op); + PATTERN_DECL_NODE(elementwise_out); +}; + // Residual Elementwise ops // This pattern allows operator output to be X or Y // and residual data Y or X, based on as_x flag diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 410dfbd68028627c7b6266a2c0dac00af614adaf..57b262060c5dd47d396a24d6b98ba1a5dccfdfaa 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -819,12 +819,9 @@ void CPUQuantizePass::QuantizeElementwise( Graph* graph, const std::string elementwise_type) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - patterns::Elementwise elementwise_pattern{pattern, name_scope_}; + patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_}; - elementwise_pattern( - pattern->NewNode(elementwise_pattern.elementwise_x_repr()), - pattern->NewNode(elementwise_pattern.elementwise_y_repr()), - elementwise_type); + elementwise_pattern(elementwise_type); int quantize_elementwise_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -839,10 +836,18 @@ void CPUQuantizePass::QuantizeElementwise( return; } - GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x, - elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y, - elementwise_pattern); + auto x_name = elementwise_op->Op()->Input("X"); + auto y_name = elementwise_op->Op()->Input("Y"); + Node *elementwise_x, *elementwise_y; + + for (auto& input : elementwise_op->inputs) { + if (input->Name() == x_name[0]) elementwise_x = input; + if (input->Name() == y_name[0]) elementwise_y = input; + } + if (!elementwise_x || !elementwise_y) { + return; + } + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, elementwise_pattern); diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc index 808d043a4b226c8b49945775ef3987a180fdc029..28363f8f167609db9f893f94e8b78bb9fa34dd69 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -354,10 +354,9 @@ bool QuantDequantMkldnnPass::IsInt8Weight( auto* op_desc = op_node->Op(); auto var_name = op_desc->Input(weight_name)[0]; auto* var = scope->FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::NotFound( - "The input persistable [%s] var of [%s] op is not found.", - var_name, op_desc->Type())); + if (var == nullptr) { + return false; + } auto* weight_tensor = var->GetMutable(); auto* weight_data = weight_tensor->data(); bool is_int8 = true;