提交 16eaaf3f 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: added one more UT, found and corrected bugs in pass

上级 604bad08
......@@ -45,17 +45,13 @@ struct Conv {
->assert_is_op("conv2d");
auto input_var = pattern->new_node(input_name())
->AsInput()
->assert_is_op_input(conv_name());
->assert_is_op_input(conv_name(), input_name());
auto filter_var = pattern->new_node(filter_name())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(conv_name());
->assert_is_op_input(conv_name(), filter_name());
auto output_var = pattern->new_node(output_name())
->AsOutput()
->assert_is_op_output(conv_name());
->assert_is_op_output(conv_name(), output_name());
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
......@@ -77,19 +73,13 @@ struct ElementwiseAdd {
->assert_is_op("elementwise_add");
auto y_var = pattern->new_node(y_name())
->AsInput()
->assert_is_op_input(elementwise_add_name());
->assert_is_op_input(elementwise_add_name(), y_name());
conv_output->assert_is_op_input(pattern->node_name(elementwise_add_name()),
pattern->node_name(x_name()));
// auto y_var = pattern->NewNode(y_name())
// ->AsInput()
// ->assert_is_op_input(elementwise_add_name());
conv_output->assert_is_op_input(elementwise_add_name(), x_name());
auto out_var = pattern->new_node(out_name())
->AsOutput()
->assert_is_op_output(
pattern->node_name(elementwise_add_name()));
->assert_is_op_output(elementwise_add_name(), out_name());
elementwise_add_op->LinksFrom({y_var, conv_output});
elementwise_add_op->LinksTo({out_var});
......@@ -118,16 +108,16 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
auto pattern_ptr = std::make_shared<patterns::Pattern>(pattern, name_scope_);
patterns::Conv conv_pattern;
auto conv_output = conv_pattern(pattern_ptr)();
conv_output->AsIntermediate();
patterns::ElementwiseAdd elementwise_add_pattern;
elementwise_add_pattern(pattern_ptr)(conv_output);
conv_output->AsIntermediate();
auto link_nodes_to = [](Node* a, Node* b) {
a->outputs.push_back(b);
b->inputs.push_back(a);
......@@ -139,7 +129,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetOutput("Ouput", {y->Name()});
op_desc.SetOutput("Output", {y->Name()});
op_desc.SetAttr("fuse_sum", true);
......@@ -155,16 +145,17 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
auto elementwise_add_x = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.x_name());
auto elementwise_add_y = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name());
auto elementwise_add_out = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name());
auto conv_filter = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_op = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.conv_name());
auto conv_input = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.input_name());
auto conv_filter = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_output = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.output_name());
auto elementwise_add_op = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.elementwise_add_name());
auto elementwise_add_y = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name());
auto elementwise_add_out = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
remove_unused_nodes(g, {elementwise_add_x, conv_output, elementwise_add_out});
remove_unused_nodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
};
gpd(graph.get(), handler);
......
......@@ -16,7 +16,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetAttr("use_mkldnn", true);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Output", {outputs});
op->SetOutput("Output", outputs);
} else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
......@@ -24,10 +24,11 @@ void SetOp(ProgramDesc* prog, const std::string& type,
}
}
ProgramDesc BuildProgramDesc() {
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddWithOps) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "d", "weights", "f", "g"})) {
std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f", "g"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
......@@ -37,41 +38,105 @@ ProgramDesc BuildProgramDesc() {
SetOp(&prog, "OP0", {"a"}, {"b"});
SetOp(&prog, "OP1", {"c"}, {"d"});
SetOp(&prog, "conv2d", {"d", "weights"}, {"f"});
SetOp(&prog, "elemenwise_add", {"d", "f"}, {"g"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"e"});
SetOp(&prog, "elementwise_add", {"e", "d"}, {"f"});
SetOp(&prog, "OP3", {"f"}, {"g"});
return prog;
};
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
/*
if (node->Op()->HasAttr("use_mkldnn")) {
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
if (use_mkldnn) {
if (node->Op()->HasAttr("fuse_sum")) {
// bool fuse_sum = boost::get<bool>(node->Op()->GetAttr("fuse_sum"));
if (fuse_sum) {
++conv_elementwise_add_count;
}
}
}
}
}
*/
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
TEST(ConvElementwiseAddMKLDNNFusePass, basic) {
auto prog = BuildProgramDesc();
TEST(ConvElementwiseAddMKLDNNFusePass, OnlyConvolutionElementwiseAdd) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "weights"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
return prog;
};
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_elementwise_add_count = 0;
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
/*
if (node->Op()->HasAttr("use_mkldnn")) {
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
if (use_mkldnn) {
// TODO tpatejko: it is commented because convolution does not support this attribute
if (true/*node->Op()->HasAttr("fuse_sum")*/) {
if (node->Op()->HasAttr("fuse_sum")) {
// bool fuse_sum = boost::get<bool>(node->Op()->GetAttr("fuse_sum"));
if (true /*fuse_sum*/) {
if (fuse_sum) {
++conv_elementwise_add_count;
}
}
}
}
}
*/
}
EXPECT_EQ(conv_elementwise_add_count, 1);
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册