未验证 提交 d0000082 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #13552 from sfraczek/sfraczek/conv-relu-update

little update to conv relu fuse pass (fix)
...@@ -26,8 +26,6 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -26,8 +26,6 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph.get());
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get()); FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
std::unordered_set<Node*> nodes2delete;
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern() auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_relu_mkldnn_fuse/conv_input") ->NewNode("conv_relu_mkldnn_fuse/conv_input")
...@@ -42,36 +40,20 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -42,36 +40,20 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
Graph* g) { Graph* g) {
VLOG(4) << "handle ConvReLU fuse"; VLOG(4) << "handle ConvReLU fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_relu_pattern); // Filter conv_relu_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
// Create an ConvReLU Node. // Transform Conv node into ConvReLU node.
OpDesc desc; OpDesc* desc = conv->Op();
std::string conv_relu_i_in = subgraph.at(conv_input)->Name(); desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
std::string conv_relu_w_in = conv_weight->Name(); desc->SetAttr("fuse_relu", true);
std::string conv_relu_b_in = conv_bias->Name(); GraphSafeRemoveNodes(graph.get(), {relu, conv_out});
std::string conv_relu_out = relu_out->Name();
desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in}));
desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in}));
desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in}));
desc.SetOutput("Output", std::vector<std::string>({conv_relu_out}));
desc.SetType("conv2d");
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
desc.SetAttr("fuse_relu", true);
auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input)); PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node); IR_NODE_LINK_TO(conv, relu_out);
IR_NODE_LINK_TO(conv_weight, conv_relu_node);
IR_NODE_LINK_TO(conv_bias, conv_relu_node);
IR_NODE_LINK_TO(conv_relu_node, relu_out);
found_conv_relu_count++; found_conv_relu_count++;
}; };
......
...@@ -85,16 +85,13 @@ TEST(ConvReLUFusePass, basic) { ...@@ -85,16 +85,13 @@ TEST(ConvReLUFusePass, basic) {
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") { if (node->IsOp() && node->Op()->Type() == "conv2d") {
if (node->Op()->HasAttr("use_mkldnn")) { auto* op = node->Op();
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
if (use_mkldnn) { EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
if (node->Op()->HasAttr("fuse_relu")) { ASSERT_TRUE(op->HasAttr("fuse_relu"));
bool fuse_relu = boost::get<bool>(node->Op()->GetAttr("fuse_relu")); bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
if (fuse_relu) { if (fuse_relu) {
++conv_relu_count; ++conv_relu_count;
}
}
}
} }
} }
} }
......
...@@ -638,11 +638,6 @@ PDNode *patterns::ConvReLU::operator()( ...@@ -638,11 +638,6 @@ PDNode *patterns::ConvReLU::operator()(
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter"); ->assert_is_op_input("conv2d", "Filter");
// Bias
auto *conv_bias_var = pattern->NewNode(conv_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Bias");
// intermediate variable, will be removed in the IR after fuse. // intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr()) auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate() ->AsIntermediate()
...@@ -653,8 +648,7 @@ PDNode *patterns::ConvReLU::operator()( ...@@ -653,8 +648,7 @@ PDNode *patterns::ConvReLU::operator()(
->AsOutput() ->AsOutput()
->assert_is_op_output("relu"); ->assert_is_op_output("relu");
conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var}) conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
.LinksTo({conv_out_var});
relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var}); relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var});
return relu_out_var; return relu_out_var;
} }
......
...@@ -379,7 +379,7 @@ struct PatternBase { ...@@ -379,7 +379,7 @@ struct PatternBase {
// op: conv + relu // op: conv + relu
// named nodes: // named nodes:
// conv_input, conv_weight, // conv_input, conv_weight,
// conv_bias, conv_out, conv, // conv_out, conv,
// relu_out, relu // relu_out, relu
struct ConvReLU : public PatternBase { struct ConvReLU : public PatternBase {
ConvReLU(PDPattern* pattern, const std::string& name_scope) ConvReLU(PDPattern* pattern, const std::string& name_scope)
...@@ -392,7 +392,6 @@ struct ConvReLU : public PatternBase { ...@@ -392,7 +392,6 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE(relu); PATTERN_DECL_NODE(relu);
// declare variable node's name // declare variable node's name
PATTERN_DECL_NODE(conv_weight); PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_bias);
PATTERN_DECL_NODE(conv_out); PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(relu_out); PATTERN_DECL_NODE(relu_out);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册