提交 07a62ddc 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: inputs in pass modified. Support for new...

MKLDNN conv + elementwise_add fusion: inputs in pass modified. Support for new conv parameter. UTs corrected
上级 56528531
...@@ -73,19 +73,19 @@ struct ElementwiseAdd { ...@@ -73,19 +73,19 @@ struct ElementwiseAdd {
auto elementwise_add_op = pattern->new_node(op_name()) auto elementwise_add_op = pattern->new_node(op_name())
->assert_is_op("elementwise_add"); ->assert_is_op("elementwise_add");
auto y_var = pattern->new_node(y_name()) auto x_var = pattern->new_node(x_name())
->assert_is_op_input(op_name(), ->assert_is_op_input(op_name(),
y_name()); x_name());
conv_output->assert_is_op_input(op_name(), conv_output->assert_is_op_input(op_name(),
x_name()); y_name());
auto out_var = pattern->new_node(out_name()) auto out_var = pattern->new_node(out_name())
->AsOutput() ->AsOutput()
->assert_is_op_output(op_name(), ->assert_is_op_output(op_name(),
out_name()); out_name());
elementwise_add_op->LinksFrom({y_var, conv_output}); elementwise_add_op->LinksFrom({x_var, conv_output});
elementwise_add_op->LinksTo({out_var}); elementwise_add_op->LinksTo({out_var});
return out_var; return out_var;
...@@ -139,13 +139,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -139,13 +139,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* y) { auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* conv_output, Node* elementwise_add_x) {
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("conv2d"); op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()}); op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()}); op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetOutput("Output", {y->Name()}); op_desc.SetInput("ElementwiseParameter", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
op_desc.SetAttr("use_mkldnn", true); op_desc.SetAttr("use_mkldnn", true);
op_desc.SetAttr("fuse_eltwise", true); op_desc.SetAttr("fuse_eltwise", true);
...@@ -154,7 +155,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -154,7 +155,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
patterns::LinkNodes(conv_input, fused_conv_op); patterns::LinkNodes(conv_input, fused_conv_op);
patterns::LinkNodes(conv_filter, fused_conv_op); patterns::LinkNodes(conv_filter, fused_conv_op);
patterns::LinkNodes(fused_conv_op, y); patterns::LinkNodes(fused_conv_op, conv_output);
}; };
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
...@@ -169,14 +170,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -169,14 +170,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.op_name()); elementwise_add_pattern.op_name());
auto elementwise_add_y = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto elementwise_add_x = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.y_name()); elementwise_add_pattern.x_name());
auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.out_name()); elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_filter, elementwise_add_y); fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x);
patterns::CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y); patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op}); GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
......
...@@ -8,6 +8,9 @@ namespace paddle { ...@@ -8,6 +8,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
constexpr int nodes_removed = 3;
constexpr int nodes_added = 1;
void SetOp(ProgramDesc* prog, const std::string& type, void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) { const std::vector<std::string>& outputs) {
...@@ -93,7 +96,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -93,7 +96,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
} }
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"}); SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"}); SetOp(&prog, "relu", {"d"}, {"e"});
return prog; return prog;
...@@ -113,7 +116,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -113,7 +116,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
EXPECT_TRUE(is_reachable(graph)("a", "relu")); EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num); EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
// Assert conv_relu op in newly generated graph // Assert conv_relu op in newly generated graph
int conv_count = 0; int conv_count = 0;
int elementwise_add_count = 0; int elementwise_add_count = 0;
...@@ -143,7 +146,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { ...@@ -143,7 +146,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
} }
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"}); SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
return prog; return prog;
}; };
...@@ -161,7 +164,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { ...@@ -161,7 +164,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
EXPECT_FALSE(is_reachable(graph)("a", "d")); EXPECT_FALSE(is_reachable(graph)("a", "d"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num); EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
// Assert conv_relu op in newly generated graph // Assert conv_relu op in newly generated graph
int conv_count = 0; int conv_count = 0;
int elementwise_add_count = 0; int elementwise_add_count = 0;
...@@ -192,7 +195,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { ...@@ -192,7 +195,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
SetOp(&prog, "sigmoid", {"a"}, {"b"}); SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"}); SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"}); SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"}); SetOp(&prog, "relu", {"e"}, {"f"});
return prog; return prog;
...@@ -212,7 +215,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { ...@@ -212,7 +215,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
EXPECT_TRUE(is_reachable(graph)("a", "f")); EXPECT_TRUE(is_reachable(graph)("a", "f"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num); EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
// Assert conv_relu op in newly generated graph // Assert conv_relu op in newly generated graph
int conv_count = 0; int conv_count = 0;
int elementwise_add_count = 0; int elementwise_add_count = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册