diff --git a/paddle/fluid/framework/ir/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/cpu_quantize_pass.cc index edfaf47f018a61d72aa3764185f2c185722b553f..ed80f9cae347cfb2bf23859daea2f1f47dba599b 100644 --- a/paddle/fluid/framework/ir/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/cpu_quantize_pass.cc @@ -224,8 +224,8 @@ std::unique_ptr CPUQuantizePass::ApplyImpl( PADDLE_ENFORCE(param_scope()); + QuantizeConv(graph.get(), false /* with_residual_data */); QuantizeConv(graph.get(), true /* with_residual_data */); - QuantizeConv(graph.get()); QuantizePool(graph.get()); return graph; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index b653e5a521eeb81d1ac3cb5cca1dc86025837ecd..d0d72127f08f4a83cca5daed57ae6d72c33ae1e3 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -599,10 +599,19 @@ bool VarLinksToOp(Node *node, const std::string &op_type) { bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) { PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(op->IsOp()); - if (op->Op()->Input(argument).size() <= nth) return false; + if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth) + return false; return var->Name() == op->Op()->Input(argument)[nth]; } +bool HasInput(Node *op, const std::string &argument) { + PADDLE_ENFORCE(op->IsOp()); + auto const &names = op->Op()->InputNames(); + if (std::find(names.begin(), names.end(), argument) == names.end()) + return false; + return true; +} + bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(op->IsOp()); @@ -1082,8 +1091,15 @@ PDNode *patterns::Conv::operator()() { PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); - if (!with_residual_data) - conv_op->assert_op_attr("fuse_residual_connection", false); + if (!with_residual_data) { + conv_op->assert_more([&](Node *x) { + auto node_names = x->Op()->InputNames(); + if (!HasInput(x, "ResidualData") || + x->Op()->Input("ResidualData").size() == 0) + return true; + return false; + }); + } auto input_var = pattern->NewNode(conv_input_repr()) ->AsInput() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index fc30b5b21c580afdede64421bb4a1f4174bbad03..bac23b651305419a5bcc4fc1efacb721e6e5d0ad 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -305,6 +305,9 @@ bool VarLinksFromOp(Node* node, const std::string& op_type); // Check whether a var node is a op node's nth input. bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth); +// Check whether the op node has input of given name. +bool HasInput(Node* op, const std::string& argument); + // Tell whether a var node is a op node's nth output. bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);