提交 d3a66473 编写于 作者: W Wojciech Uss 提交者: Tao Luo

improve placement pass tests code coverage (#22197)

上级 f5262865
......@@ -27,11 +27,11 @@ namespace ir {
*/
class CUDNNPlacementPass : public PlacementPassBase {
private:
const std::string GetPlacementName() const { return "cuDNN"; }
const std::string GetPlacementName() const override { return "cuDNN"; }
const std::string GetAttrName() const { return "use_cudnn"; }
const std::string GetAttrName() const override { return "use_cudnn"; }
const std::unordered_set<std::string> GetOpTypesList() const {
const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types");
}
};
......
......@@ -22,7 +22,9 @@ namespace paddle {
namespace framework {
namespace ir {
void RegisterOpKernel() {
class PlacementPassTest {
private:
void RegisterOpKernel() {
static bool is_registered = false;
if (!is_registered) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
......@@ -47,9 +49,10 @@ void RegisterOpKernel() {
is_registered = true;
}
}
}
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
public:
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
unsigned expected_use_cudnn_true_count) {
// operator use_cudnn
// --------------------------------------------------
......@@ -94,22 +97,33 @@ void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
}
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
}
}
void PlacementNameTest() {
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"cuDNN");
}
};
TEST(CUDNNPlacementPass, enable_conv2d) {
// 1 conv2d
MainTest({"conv2d"}, 1);
PlacementPassTest().MainTest({"conv2d"}, 1);
}
TEST(CUDNNPlacementPass, enable_relu_pool) {
// 1 conv2d + 1 pool2d
MainTest({"conv2d", "pool2d"}, 2);
PlacementPassTest().MainTest({"conv2d", "pool2d"}, 2);
}
TEST(CUDNNPlacementPass, enable_all) {
// 1 conv2d + 1 pool2d
// depthwise_conv2d doesnot have CUDNN kernel.
MainTest({}, 2);
PlacementPassTest().MainTest({}, 2);
}
TEST(CUDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
}
} // namespace ir
......
......@@ -27,11 +27,11 @@ namespace ir {
*/
class MKLDNNPlacementPass : public PlacementPassBase {
private:
const std::string GetPlacementName() const { return "MKLDNN"; }
const std::string GetPlacementName() const override { return "MKLDNN"; }
const std::string GetAttrName() const { return "use_mkldnn"; }
const std::string GetAttrName() const override { return "use_mkldnn"; }
const std::unordered_set<std::string> GetOpTypesList() const {
const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
}
};
......
......@@ -21,14 +21,18 @@ namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, boost::tribool use_mkldnn) {
class PlacementPassTest {
private:
void SetOp(ProgramDesc* prog, const std::string& type,
const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
boost::tribool use_mkldnn) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (!boost::indeterminate(use_mkldnn)) op->SetAttr("use_mkldnn", use_mkldnn);
if (!boost::indeterminate(use_mkldnn))
op->SetAttr("use_mkldnn", use_mkldnn);
if (type == "conv2d") {
op->SetAttr("name", name);
......@@ -46,17 +50,17 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
FAIL() << "Unexpected operator type.";
}
op->SetOutput("Out", {outputs[0]});
}
}
// operator use_mkldnn
// ---------------------------------------
// (a,b)->concat->c none
// (c,weights,bias)->conv->f none
// f->relu->g false
// g->pool->h false
// (h,weights2,bias2)->conv->k true
// k->relu->l true
ProgramDesc BuildProgramDesc() {
// operator use_mkldnn
// ---------------------------------------
// (a,b)->concat->c none
// (c,weights,bias)->conv->f none
// f->relu->g false
// g->pool->h false
// (h,weights2,bias2)->conv->k true
// k->relu->l true
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
......@@ -85,15 +89,17 @@ ProgramDesc BuildProgramDesc() {
std::vector<std::string>({"l"}), true);
return prog;
}
}
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
public:
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
unsigned expected_use_mkldnn_true_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types));
......@@ -112,21 +118,32 @@ void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
}
EXPECT_EQ(use_mkldnn_true_count, expected_use_mkldnn_true_count);
}
}
void PlacementNameTest() {
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"MKLDNN");
}
};
TEST(MKLDNNPlacementPass, enable_conv_relu) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
MainTest({"conv2d", "relu"}, 3);
PlacementPassTest().MainTest({"conv2d", "relu"}, 3);
}
TEST(MKLDNNPlacementPass, enable_relu_pool) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest({"relu", "pool2d"}, 4);
PlacementPassTest().MainTest({"relu", "pool2d"}, 4);
}
TEST(MKLDNNPlacementPass, enable_all) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest({}, 4);
PlacementPassTest().MainTest({}, 4);
}
TEST(MKLDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
}
} // namespace ir
......
......@@ -35,6 +35,10 @@ class PlacementPassBase : public Pass {
private:
bool IsSupport(const std::string& op_type) const;
#if PADDLE_WITH_TESTING
friend class PlacementPassTest;
#endif
};
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册