From a7e0cdea61c5ec4576c609262588f19ed2430061 Mon Sep 17 00:00:00 2001 From: lidanqing Date: Thu, 23 Jun 2022 10:14:52 +0800 Subject: [PATCH] [cherry-pick] release/2.3 elementwise_mul and matmul mkldnn fix (#43725) * Correct elementwise quantization (#43693) * [Bug fix] Do not quantize weights Y when matmul X and Y both other ops outputs (#43297) * fix some matmul that X and Y both other ops outputs, do not dequantize the Y. * fix CI format * fix according to review Co-authored-by: joanna.wozna.intel --- .../framework/ir/graph_pattern_detector.cc | 14 +++++++++++ .../framework/ir/graph_pattern_detector.h | 13 +++++++++++ .../framework/ir/mkldnn/cpu_quantize_pass.cc | 23 +++++++++++-------- .../ir/mkldnn/quant_dequant_mkldnn_pass.cc | 7 +++--- 4 files changed, 44 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 8eb1b64a276..7951514f8d3 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2069,6 +2069,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, return out_var; } +PDNode *patterns::ElementwiseOp::operator()( + const std::string elementwise_type) { + auto elementwise_op = + pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); + + auto out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output(elementwise_type, "Out"); + + elementwise_op->LinksTo({out_var}); + + return out_var; +} + PDNode *patterns::ResidualElementwise::operator()( PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, bool as_x) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 463b50ac48a..45b191ea229 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1071,6 +1071,19 @@ struct Elementwise : public PatternBase { PATTERN_DECL_NODE(elementwise_out); }; +// Elementwise ops +// Forward pass for element-wise operators +// elementwise_out is the result of the operator +struct ElementwiseOp : public PatternBase { + ElementwiseOp(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "elementwise") {} + + PDNode* operator()(const std::string elementwise_type); + + PATTERN_DECL_NODE(elementwise_op); + PATTERN_DECL_NODE(elementwise_out); +}; + // Residual Elementwise ops // This pattern allows operator output to be X or Y // and residual data Y or X, based on as_x flag diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 410dfbd6802..57b262060c5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -819,12 +819,9 @@ void CPUQuantizePass::QuantizeElementwise( Graph* graph, const std::string elementwise_type) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); - patterns::Elementwise elementwise_pattern{pattern, name_scope_}; + patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_}; - elementwise_pattern( - pattern->NewNode(elementwise_pattern.elementwise_x_repr()), - pattern->NewNode(elementwise_pattern.elementwise_y_repr()), - elementwise_type); + elementwise_pattern(elementwise_type); int quantize_elementwise_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -839,10 +836,18 @@ void CPUQuantizePass::QuantizeElementwise( return; } - GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x, - elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y, - elementwise_pattern); + auto x_name = elementwise_op->Op()->Input("X"); + auto y_name = elementwise_op->Op()->Input("Y"); + Node *elementwise_x, *elementwise_y; + + for (auto& input : elementwise_op->inputs) { + if (input->Name() == x_name[0]) elementwise_x = input; + if (input->Name() == y_name[0]) elementwise_y = input; + } + if (!elementwise_x || !elementwise_y) { + return; + } + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, elementwise_pattern); diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc index 808d043a4b2..28363f8f167 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -354,10 +354,9 @@ bool QuantDequantMkldnnPass::IsInt8Weight( auto* op_desc = op_node->Op(); auto var_name = op_desc->Input(weight_name)[0]; auto* var = scope->FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::NotFound( - "The input persistable [%s] var of [%s] op is not found.", - var_name, op_desc->Type())); + if (var == nullptr) { + return false; + } auto* weight_tensor = var->GetMutable(); auto* weight_data = weight_tensor->data(); bool is_int8 = true; -- GitLab