提交 d3a66473 编写于 作者: W Wojciech Uss 提交者: Tao Luo

improve placement pass tests code coverage (#22197)

上级 f5262865
...@@ -27,11 +27,11 @@ namespace ir { ...@@ -27,11 +27,11 @@ namespace ir {
*/ */
class CUDNNPlacementPass : public PlacementPassBase { class CUDNNPlacementPass : public PlacementPassBase {
private: private:
const std::string GetPlacementName() const { return "cuDNN"; } const std::string GetPlacementName() const override { return "cuDNN"; }
const std::string GetAttrName() const { return "use_cudnn"; } const std::string GetAttrName() const override { return "use_cudnn"; }
const std::unordered_set<std::string> GetOpTypesList() const { const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types"); return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types");
} }
}; };
......
...@@ -22,7 +22,9 @@ namespace paddle { ...@@ -22,7 +22,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void RegisterOpKernel() { class PlacementPassTest {
private:
void RegisterOpKernel() {
static bool is_registered = false; static bool is_registered = false;
if (!is_registered) { if (!is_registered) {
auto& all_kernels = OperatorWithKernel::AllOpKernels(); auto& all_kernels = OperatorWithKernel::AllOpKernels();
...@@ -47,9 +49,10 @@ void RegisterOpKernel() { ...@@ -47,9 +49,10 @@ void RegisterOpKernel() {
is_registered = true; is_registered = true;
} }
} }
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types, public:
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
unsigned expected_use_cudnn_true_count) { unsigned expected_use_cudnn_true_count) {
// operator use_cudnn // operator use_cudnn
// -------------------------------------------------- // --------------------------------------------------
...@@ -94,22 +97,33 @@ void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types, ...@@ -94,22 +97,33 @@ void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
} }
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count); EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
} }
void PlacementNameTest() {
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"cuDNN");
}
};
TEST(CUDNNPlacementPass, enable_conv2d) { TEST(CUDNNPlacementPass, enable_conv2d) {
// 1 conv2d // 1 conv2d
MainTest({"conv2d"}, 1); PlacementPassTest().MainTest({"conv2d"}, 1);
} }
TEST(CUDNNPlacementPass, enable_relu_pool) { TEST(CUDNNPlacementPass, enable_relu_pool) {
// 1 conv2d + 1 pool2d // 1 conv2d + 1 pool2d
MainTest({"conv2d", "pool2d"}, 2); PlacementPassTest().MainTest({"conv2d", "pool2d"}, 2);
} }
TEST(CUDNNPlacementPass, enable_all) { TEST(CUDNNPlacementPass, enable_all) {
// 1 conv2d + 1 pool2d // 1 conv2d + 1 pool2d
// depthwise_conv2d doesnot have CUDNN kernel. // depthwise_conv2d doesnot have CUDNN kernel.
MainTest({}, 2); PlacementPassTest().MainTest({}, 2);
}
TEST(CUDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
} }
} // namespace ir } // namespace ir
......
...@@ -27,11 +27,11 @@ namespace ir { ...@@ -27,11 +27,11 @@ namespace ir {
*/ */
class MKLDNNPlacementPass : public PlacementPassBase { class MKLDNNPlacementPass : public PlacementPassBase {
private: private:
const std::string GetPlacementName() const { return "MKLDNN"; } const std::string GetPlacementName() const override { return "MKLDNN"; }
const std::string GetAttrName() const { return "use_mkldnn"; } const std::string GetAttrName() const override { return "use_mkldnn"; }
const std::unordered_set<std::string> GetOpTypesList() const { const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types"); return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
} }
}; };
......
...@@ -21,14 +21,18 @@ namespace paddle { ...@@ -21,14 +21,18 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, class PlacementPassTest {
const std::vector<std::string>& inputs, private:
const std::vector<std::string>& outputs, boost::tribool use_mkldnn) { void SetOp(ProgramDesc* prog, const std::string& type,
const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
boost::tribool use_mkldnn) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
if (!boost::indeterminate(use_mkldnn)) op->SetAttr("use_mkldnn", use_mkldnn); if (!boost::indeterminate(use_mkldnn))
op->SetAttr("use_mkldnn", use_mkldnn);
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("name", name); op->SetAttr("name", name);
...@@ -46,17 +50,17 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -46,17 +50,17 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
FAIL() << "Unexpected operator type."; FAIL() << "Unexpected operator type.";
} }
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
} }
// operator use_mkldnn // operator use_mkldnn
// --------------------------------------- // ---------------------------------------
// (a,b)->concat->c none // (a,b)->concat->c none
// (c,weights,bias)->conv->f none // (c,weights,bias)->conv->f none
// f->relu->g false // f->relu->g false
// g->pool->h false // g->pool->h false
// (h,weights2,bias2)->conv->k true // (h,weights2,bias2)->conv->k true
// k->relu->l true // k->relu->l true
ProgramDesc BuildProgramDesc() { ProgramDesc BuildProgramDesc() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v :
...@@ -85,15 +89,17 @@ ProgramDesc BuildProgramDesc() { ...@@ -85,15 +89,17 @@ ProgramDesc BuildProgramDesc() {
std::vector<std::string>({"l"}), true); std::vector<std::string>({"l"}), true);
return prog; return prog;
} }
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types, public:
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
unsigned expected_use_mkldnn_true_count) { unsigned expected_use_mkldnn_true_count) {
auto prog = BuildProgramDesc(); auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass"); auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
pass->Set("mkldnn_enabled_op_types", pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types)); new std::unordered_set<std::string>(mkldnn_enabled_op_types));
...@@ -112,21 +118,32 @@ void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types, ...@@ -112,21 +118,32 @@ void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
} }
EXPECT_EQ(use_mkldnn_true_count, expected_use_mkldnn_true_count); EXPECT_EQ(use_mkldnn_true_count, expected_use_mkldnn_true_count);
} }
void PlacementNameTest() {
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"MKLDNN");
}
};
TEST(MKLDNNPlacementPass, enable_conv_relu) { TEST(MKLDNNPlacementPass, enable_conv_relu) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
MainTest({"conv2d", "relu"}, 3); PlacementPassTest().MainTest({"conv2d", "relu"}, 3);
} }
TEST(MKLDNNPlacementPass, enable_relu_pool) { TEST(MKLDNNPlacementPass, enable_relu_pool) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest({"relu", "pool2d"}, 4); PlacementPassTest().MainTest({"relu", "pool2d"}, 4);
} }
TEST(MKLDNNPlacementPass, enable_all) { TEST(MKLDNNPlacementPass, enable_all) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest({}, 4); PlacementPassTest().MainTest({}, 4);
}
TEST(MKLDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
} }
} // namespace ir } // namespace ir
......
...@@ -35,6 +35,10 @@ class PlacementPassBase : public Pass { ...@@ -35,6 +35,10 @@ class PlacementPassBase : public Pass {
private: private:
bool IsSupport(const std::string& op_type) const; bool IsSupport(const std::string& op_type) const;
#if PADDLE_WITH_TESTING
friend class PlacementPassTest;
#endif
}; };
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册