未验证 提交 9a058591 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #16322 from wojtuss/wojtuss/fix_cpu_quantize_pass

fix pattern maching conv2d with(out) ResidualData
...@@ -224,8 +224,8 @@ std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl( ...@@ -224,8 +224,8 @@ std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl(
PADDLE_ENFORCE(param_scope()); PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), false /* with_residual_data */);
QuantizeConv(graph.get(), true /* with_residual_data */); QuantizeConv(graph.get(), true /* with_residual_data */);
QuantizeConv(graph.get());
QuantizePool(graph.get()); QuantizePool(graph.get());
return graph; return graph;
......
...@@ -599,10 +599,19 @@ bool VarLinksToOp(Node *node, const std::string &op_type) { ...@@ -599,10 +599,19 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) { bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp()); PADDLE_ENFORCE(op->IsOp());
if (op->Op()->Input(argument).size() <= nth) return false; if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth)
return false;
return var->Name() == op->Op()->Input(argument)[nth]; return var->Name() == op->Op()->Input(argument)[nth];
} }
bool HasInput(Node *op, const std::string &argument) {
PADDLE_ENFORCE(op->IsOp());
auto const &names = op->Op()->InputNames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp()); PADDLE_ENFORCE(op->IsOp());
...@@ -1082,8 +1091,15 @@ PDNode *patterns::Conv::operator()() { ...@@ -1082,8 +1091,15 @@ PDNode *patterns::Conv::operator()() {
PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { PDNode *patterns::ConvResidual::operator()(bool with_residual_data) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
if (!with_residual_data) if (!with_residual_data) {
conv_op->assert_op_attr("fuse_residual_connection", false); conv_op->assert_more([&](Node *x) {
auto node_names = x->Op()->InputNames();
if (!HasInput(x, "ResidualData") ||
x->Op()->Input("ResidualData").size() == 0)
return true;
return false;
});
}
auto input_var = pattern->NewNode(conv_input_repr()) auto input_var = pattern->NewNode(conv_input_repr())
->AsInput() ->AsInput()
......
...@@ -305,6 +305,9 @@ bool VarLinksFromOp(Node* node, const std::string& op_type); ...@@ -305,6 +305,9 @@ bool VarLinksFromOp(Node* node, const std::string& op_type);
// Check whether a var node is a op node's nth input. // Check whether a var node is a op node's nth input.
bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth); bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth);
// Check whether the op node has input of given name.
bool HasInput(Node* op, const std::string& argument);
// Tell whether a var node is a op node's nth output. // Tell whether a var node is a op node's nth output.
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册