From 104a9f1e270e9b42efcad0d24217a76475f7faf2 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Wed, 20 Mar 2019 03:06:48 +0100 Subject: [PATCH] fix pattern maching conv2d with(out) ResidualData test=develop --- .../fluid/framework/ir/cpu_quantize_pass.cc | 2 +- .../framework/ir/graph_pattern_detector.cc | 22 ++++++++++++++++--- .../framework/ir/graph_pattern_detector.h | 3 +++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/cpu_quantize_pass.cc index edfaf47f01..ed80f9cae3 100644 --- a/paddle/fluid/framework/ir/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/cpu_quantize_pass.cc @@ -224,8 +224,8 @@ std::unique_ptr CPUQuantizePass::ApplyImpl( PADDLE_ENFORCE(param_scope()); + QuantizeConv(graph.get(), false /* with_residual_data */); QuantizeConv(graph.get(), true /* with_residual_data */); - QuantizeConv(graph.get()); QuantizePool(graph.get()); return graph; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index b653e5a521..d0d72127f0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -599,10 +599,19 @@ bool VarLinksToOp(Node *node, const std::string &op_type) { bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) { PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(op->IsOp()); - if (op->Op()->Input(argument).size() <= nth) return false; + if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth) + return false; return var->Name() == op->Op()->Input(argument)[nth]; } +bool HasInput(Node *op, const std::string &argument) { + PADDLE_ENFORCE(op->IsOp()); + auto const &names = op->Op()->InputNames(); + if (std::find(names.begin(), names.end(), argument) == names.end()) + return false; + return true; +} + bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(op->IsOp()); @@ -1082,8 +1091,15 @@ PDNode *patterns::Conv::operator()() { PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); - if (!with_residual_data) - conv_op->assert_op_attr("fuse_residual_connection", false); + if (!with_residual_data) { + conv_op->assert_more([&](Node *x) { + auto node_names = x->Op()->InputNames(); + if (!HasInput(x, "ResidualData") || + x->Op()->Input("ResidualData").size() == 0) + return true; + return false; + }); + } auto input_var = pattern->NewNode(conv_input_repr()) ->AsInput() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index fc30b5b21c..bac23b6513 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -305,6 +305,9 @@ bool VarLinksFromOp(Node* node, const std::string& op_type); // Check whether a var node is a op node's nth input. bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth); +// Check whether the op node has input of given name. +bool HasInput(Node* op, const std::string& argument); + // Tell whether a var node is a op node's nth output. bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); -- GitLab