提交 dbc4fcd7 编写于 作者: T Tomasz Patejko

MKLDNN residual connections fuse pass: unit tests enabled and added

上级 42240893
...@@ -40,7 +40,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -40,7 +40,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetOutput(output.first, {output.second}); op->SetOutput(output.first, {output.second});
} }
struct IsReachable { struct TestIsReachable {
using func = std::function<bool(const std::string&, const std::string&)>; using func = std::function<bool(const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func { auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
...@@ -89,7 +89,9 @@ struct IsReachable { ...@@ -89,7 +89,9 @@ struct IsReachable {
} }
}; };
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) { void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph,
int expected_conv_count,
int expected_elementwise_add_count = 0) {
int conv_count = 0; int conv_count = 0;
int elementwise_add_count = 0; int elementwise_add_count = 0;
...@@ -101,8 +103,8 @@ void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) { ...@@ -101,8 +103,8 @@ void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
++elementwise_add_count; ++elementwise_add_count;
} }
} }
EXPECT_EQ(conv_count, 1); EXPECT_EQ(conv_count, expected_conv_count);
EXPECT_EQ(elementwise_add_count, 0); EXPECT_EQ(elementwise_add_count, expected_elementwise_add_count);
} }
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars, ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
...@@ -127,22 +129,13 @@ ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars, ...@@ -127,22 +129,13 @@ ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
return prog; return prog;
} }
} // namespace
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
const std::string& to, int expected_conv_num) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(*prog));
IsReachable is_reachable; TestIsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu")); EXPECT_TRUE(is_reachable(graph)(from, to));
auto pass = auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
...@@ -150,82 +143,87 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -150,82 +143,87 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu")); EXPECT_TRUE(is_reachable(graph)(from, to));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num); current_nodes_num);
AssertOpsCount(graph); AssertOpsCount(graph, expected_conv_num);
} }
} // namespace
TEST(ConvElementwiseAddMKLDNNFusePass, TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
ConvolutionWithElementwiseAddReluNoBias) { auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable; SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
EXPECT_TRUE(is_reachable(graph)("a", "relu")); SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
auto pass = RunPassAndAssert(&prog, "a", "relu", 1);
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); }
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu")); TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionAsYWithElementwiseAddReluNoBias) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
current_nodes_num); SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
AssertOpsCount(graph); RunPassAndAssert(&prog, "a", "relu", 1);
} }
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"}); auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d", SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}}, {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"}); {"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
IsReachable is_reachable; RunPassAndAssert(&prog, "a", "relu", 1);
EXPECT_TRUE(is_reachable(graph)("a", "d")); }
auto pass = TEST(ConvElementwiseAddMKLDNNFusePass,
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); ConvolutionAsXWithElementwiseAddReluNoBias) {
int original_nodes_num = graph->Nodes().size(); auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_FALSE(is_reachable(graph)("a", "d")); SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, RunPassAndAssert(&prog, "a", "relu", 1);
current_nodes_num);
AssertOpsCount(graph);
} }
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
auto prog = auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"}); BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"}); SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d", SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"}); {"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); SetOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}},
{"Output", "e"});
IsReachable is_reachable; SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, {"Out", "f"});
SetOp(&prog, "relu", {{"X", "f"}}, {"Out", "g"});
EXPECT_TRUE(is_reachable(graph)("a", "f")); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
TestIsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "g"));
auto pass = auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
...@@ -233,11 +231,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { ...@@ -233,11 +231,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "f")); EXPECT_TRUE(is_reachable(graph)("a", "g"));
EXPECT_EQ(original_nodes_num, current_nodes_num);
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, AssertOpsCount(graph, 2, 1);
current_nodes_num);
AssertOpsCount(graph);
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册