提交 07a62ddc 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: inputs in pass modified. Support for new...

MKLDNN conv + elementwise_add fusion: inputs in pass modified. Support for new conv parameter. UTs corrected
上级 56528531
......@@ -73,19 +73,19 @@ struct ElementwiseAdd {
auto elementwise_add_op = pattern->new_node(op_name())
->assert_is_op("elementwise_add");
auto y_var = pattern->new_node(y_name())
auto x_var = pattern->new_node(x_name())
->assert_is_op_input(op_name(),
y_name());
x_name());
conv_output->assert_is_op_input(op_name(),
x_name());
y_name());
auto out_var = pattern->new_node(out_name())
->AsOutput()
->assert_is_op_output(op_name(),
out_name());
elementwise_add_op->LinksFrom({y_var, conv_output});
elementwise_add_op->LinksFrom({x_var, conv_output});
elementwise_add_op->LinksTo({out_var});
return out_var;
......@@ -139,13 +139,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate();
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* conv_output, Node* elementwise_add_x) {
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetOutput("Output", {y->Name()});
op_desc.SetInput("ElementwiseParameter", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
op_desc.SetAttr("use_mkldnn", true);
op_desc.SetAttr("fuse_eltwise", true);
......@@ -154,7 +155,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
patterns::LinkNodes(conv_input, fused_conv_op);
patterns::LinkNodes(conv_filter, fused_conv_op);
patterns::LinkNodes(fused_conv_op, y);
patterns::LinkNodes(fused_conv_op, conv_output);
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
......@@ -169,14 +170,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.op_name());
auto elementwise_add_y = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.y_name());
auto elementwise_add_x = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.x_name());
auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
patterns::CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x);
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
};
gpd(graph.get(), handler);
......
......@@ -8,6 +8,9 @@ namespace paddle {
namespace framework {
namespace ir {
constexpr int nodes_removed = 3;
constexpr int nodes_added = 1;
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
......@@ -93,7 +96,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
}
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"});
return prog;
......@@ -113,7 +116,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
......@@ -143,7 +146,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
}
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
return prog;
};
......@@ -161,7 +164,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
EXPECT_FALSE(is_reachable(graph)("a", "d"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
......@@ -192,7 +195,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"});
return prog;
......@@ -212,7 +215,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
EXPECT_TRUE(is_reachable(graph)("a", "f"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册