提交 104a9f1e 编写于 作者: W Wojciech Uss

fix pattern maching conv2d with(out) ResidualData

test=develop
上级 a3b8028d
......@@ -224,8 +224,8 @@ std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl(
PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), false /* with_residual_data */);
QuantizeConv(graph.get(), true /* with_residual_data */);
QuantizeConv(graph.get());
QuantizePool(graph.get());
return graph;
......
......@@ -599,10 +599,19 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->Op()->Input(argument).size() <= nth) return false;
if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth)
return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
bool HasInput(Node *op, const std::string &argument) {
PADDLE_ENFORCE(op->IsOp());
auto const &names = op->Op()->InputNames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
......@@ -1082,8 +1091,15 @@ PDNode *patterns::Conv::operator()() {
PDNode *patterns::ConvResidual::operator()(bool with_residual_data) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
if (!with_residual_data)
conv_op->assert_op_attr("fuse_residual_connection", false);
if (!with_residual_data) {
conv_op->assert_more([&](Node *x) {
auto node_names = x->Op()->InputNames();
if (!HasInput(x, "ResidualData") ||
x->Op()->Input("ResidualData").size() == 0)
return true;
return false;
});
}
auto input_var = pattern->NewNode(conv_input_repr())
->AsInput()
......
......@@ -305,6 +305,9 @@ bool VarLinksFromOp(Node* node, const std::string& op_type);
// Check whether a var node is a op node's nth input.
bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth);
// Check whether the op node has input of given name.
bool HasInput(Node* op, const std::string& argument);
// Tell whether a var node is a op node's nth output.
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册