提交 d3a66473 编写于 作者: W Wojciech Uss 提交者: Tao Luo

improve placement pass tests code coverage (#22197)

上级 f5262865
...@@ -27,11 +27,11 @@ namespace ir { ...@@ -27,11 +27,11 @@ namespace ir {
*/ */
class CUDNNPlacementPass : public PlacementPassBase { class CUDNNPlacementPass : public PlacementPassBase {
private: private:
const std::string GetPlacementName() const { return "cuDNN"; } const std::string GetPlacementName() const override { return "cuDNN"; }
const std::string GetAttrName() const { return "use_cudnn"; } const std::string GetAttrName() const override { return "use_cudnn"; }
const std::unordered_set<std::string> GetOpTypesList() const { const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types"); return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types");
} }
}; };
......
...@@ -22,94 +22,108 @@ namespace paddle { ...@@ -22,94 +22,108 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void RegisterOpKernel() { class PlacementPassTest {
static bool is_registered = false; private:
if (!is_registered) { void RegisterOpKernel() {
auto& all_kernels = OperatorWithKernel::AllOpKernels(); static bool is_registered = false;
if (!is_registered) {
platform::CUDAPlace place = platform::CUDAPlace(0); auto& all_kernels = OperatorWithKernel::AllOpKernels();
OpKernelType plain_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout, platform::CUDAPlace place = platform::CUDAPlace(0);
LibraryType::kPlain); OpKernelType plain_kernel_type =
OpKernelType cudnn_kernel_type = OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout, LibraryType::kPlain);
LibraryType::kCUDNN); OpKernelType cudnn_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
auto fake_kernel_func = [](const ExecutionContext&) -> void { LibraryType::kCUDNN);
static int num_calls = 0;
num_calls++; auto fake_kernel_func = [](const ExecutionContext&) -> void {
}; static int num_calls = 0;
num_calls++;
all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func; };
all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func; all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["relu"][plain_kernel_type] = fake_kernel_func; all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func;
is_registered = true; all_kernels["relu"][plain_kernel_type] = fake_kernel_func;
is_registered = true;
}
} }
}
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types, public:
unsigned expected_use_cudnn_true_count) { void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
// operator use_cudnn unsigned expected_use_cudnn_true_count) {
// -------------------------------------------------- // operator use_cudnn
// (a,b)->concat->c - // --------------------------------------------------
// (c,weights,bias)->conv2d->f false // (a,b)->concat->c -
// f->relu->g - // (c,weights,bias)->conv2d->f false
// g->pool2d->h false // f->relu->g -
// (h,weights2,bias2)->depthwise_conv2d->k false // g->pool2d->h false
// k->relu->l - // (h,weights2,bias2)->depthwise_conv2d->k false
Layers layers; // k->relu->l -
VarDesc* a = layers.data("a"); Layers layers;
VarDesc* b = layers.data("b"); VarDesc* a = layers.data("a");
VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b})); VarDesc* b = layers.data("b");
VarDesc* weights_0 = layers.data("weights_0"); VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b}));
VarDesc* bias_0 = layers.data("bias_0"); VarDesc* weights_0 = layers.data("weights_0");
VarDesc* f = layers.conv2d(c, weights_0, bias_0, false); VarDesc* bias_0 = layers.data("bias_0");
VarDesc* g = layers.relu(f); VarDesc* f = layers.conv2d(c, weights_0, bias_0, false);
VarDesc* h = layers.pool2d(g, false); VarDesc* g = layers.relu(f);
VarDesc* weights_1 = layers.data("weights_1"); VarDesc* h = layers.pool2d(g, false);
VarDesc* bias_1 = layers.data("bias_1"); VarDesc* weights_1 = layers.data("weights_1");
VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false); VarDesc* bias_1 = layers.data("bias_1");
layers.relu(k); VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false);
layers.relu(k);
RegisterOpKernel();
RegisterOpKernel();
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass"); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
pass->Set("cudnn_enabled_op_types", auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
new std::unordered_set<std::string>(cudnn_enabled_op_types)); pass->Set("cudnn_enabled_op_types",
new std::unordered_set<std::string>(cudnn_enabled_op_types));
graph.reset(pass->Apply(graph.release()));
graph.reset(pass->Apply(graph.release()));
unsigned use_cudnn_true_count = 0;
for (auto* node : graph->Nodes()) { unsigned use_cudnn_true_count = 0;
if (node->IsOp() && node->Op()) { for (auto* node : graph->Nodes()) {
auto* op = node->Op(); if (node->IsOp() && node->Op()) {
if (op->HasAttr("use_cudnn") && auto* op = node->Op();
boost::get<bool>(op->GetAttr("use_cudnn"))) { if (op->HasAttr("use_cudnn") &&
++use_cudnn_true_count; boost::get<bool>(op->GetAttr("use_cudnn"))) {
++use_cudnn_true_count;
}
} }
} }
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
} }
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count); void PlacementNameTest() {
} auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"cuDNN");
}
};
TEST(CUDNNPlacementPass, enable_conv2d) { TEST(CUDNNPlacementPass, enable_conv2d) {
// 1 conv2d // 1 conv2d
MainTest({"conv2d"}, 1); PlacementPassTest().MainTest({"conv2d"}, 1);
} }
TEST(CUDNNPlacementPass, enable_relu_pool) { TEST(CUDNNPlacementPass, enable_relu_pool) {
// 1 conv2d + 1 pool2d // 1 conv2d + 1 pool2d
MainTest({"conv2d", "pool2d"}, 2); PlacementPassTest().MainTest({"conv2d", "pool2d"}, 2);
} }
TEST(CUDNNPlacementPass, enable_all) { TEST(CUDNNPlacementPass, enable_all) {
// 1 conv2d + 1 pool2d // 1 conv2d + 1 pool2d
// depthwise_conv2d doesnot have CUDNN kernel. // depthwise_conv2d doesnot have CUDNN kernel.
MainTest({}, 2); PlacementPassTest().MainTest({}, 2);
}
TEST(CUDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
} }
} // namespace ir } // namespace ir
......
...@@ -27,11 +27,11 @@ namespace ir { ...@@ -27,11 +27,11 @@ namespace ir {
*/ */
class MKLDNNPlacementPass : public PlacementPassBase { class MKLDNNPlacementPass : public PlacementPassBase {
private: private:
const std::string GetPlacementName() const { return "MKLDNN"; } const std::string GetPlacementName() const override { return "MKLDNN"; }
const std::string GetAttrName() const { return "use_mkldnn"; } const std::string GetAttrName() const override { return "use_mkldnn"; }
const std::unordered_set<std::string> GetOpTypesList() const { const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types"); return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
} }
}; };
......
...@@ -21,112 +21,129 @@ namespace paddle { ...@@ -21,112 +21,129 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, class PlacementPassTest {
const std::vector<std::string>& inputs, private:
const std::vector<std::string>& outputs, boost::tribool use_mkldnn) { void SetOp(ProgramDesc* prog, const std::string& type,
auto* op = prog->MutableBlock(0)->AppendOp(); const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
op->SetType(type); boost::tribool use_mkldnn) {
auto* op = prog->MutableBlock(0)->AppendOp();
if (!boost::indeterminate(use_mkldnn)) op->SetAttr("use_mkldnn", use_mkldnn);
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("name", name); if (!boost::indeterminate(use_mkldnn))
op->SetInput("Input", {inputs[0]}); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]}); if (type == "conv2d") {
} else if (type == "relu") { op->SetAttr("name", name);
op->SetInput("X", inputs); op->SetInput("Input", {inputs[0]});
} else if (type == "concat") { op->SetInput("Filter", {inputs[1]});
op->SetAttr("axis", 1); op->SetInput("Bias", {inputs[2]});
op->SetInput("X", {inputs[0], inputs[1]}); } else if (type == "relu") {
} else if (type == "pool2d") { op->SetInput("X", inputs);
op->SetInput("X", {inputs[0]}); } else if (type == "concat") {
} else { op->SetAttr("axis", 1);
FAIL() << "Unexpected operator type."; op->SetInput("X", {inputs[0], inputs[1]});
} else if (type == "pool2d") {
op->SetInput("X", {inputs[0]});
} else {
FAIL() << "Unexpected operator type.";
}
op->SetOutput("Out", {outputs[0]});
} }
op->SetOutput("Out", {outputs[0]});
}
// operator use_mkldnn // operator use_mkldnn
// --------------------------------------- // ---------------------------------------
// (a,b)->concat->c none // (a,b)->concat->c none
// (c,weights,bias)->conv->f none // (c,weights,bias)->conv->f none
// f->relu->g false // f->relu->g false
// g->pool->h false // g->pool->h false
// (h,weights2,bias2)->conv->k true // (h,weights2,bias2)->conv->k true
// k->relu->l true // k->relu->l true
ProgramDesc BuildProgramDesc() { ProgramDesc BuildProgramDesc() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g", std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) { "h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS); var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias") {
var->SetPersistable(true); var->SetPersistable(true);
}
} }
SetOp(&prog, "concat", "concat1", std::vector<std::string>({"a", "b"}),
std::vector<std::string>({"c"}), boost::indeterminate);
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), boost::indeterminate);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), false);
SetOp(&prog, "pool2d", "pool1", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}), false);
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}), true);
return prog;
} }
SetOp(&prog, "concat", "concat1", std::vector<std::string>({"a", "b"}), public:
std::vector<std::string>({"c"}), boost::indeterminate); void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
SetOp(&prog, "conv2d", "conv1", unsigned expected_use_mkldnn_true_count) {
std::vector<std::string>({"c", "weights", "bias"}), auto prog = BuildProgramDesc();
std::vector<std::string>({"f"}), boost::indeterminate);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), false);
SetOp(&prog, "pool2d", "pool1", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}), false);
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}), true);
return prog;
}
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types, std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
unsigned expected_use_mkldnn_true_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass"); pass->Set("mkldnn_enabled_op_types",
pass->Set("mkldnn_enabled_op_types", new std::unordered_set<std::string>(mkldnn_enabled_op_types));
new std::unordered_set<std::string>(mkldnn_enabled_op_types));
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
unsigned use_mkldnn_true_count = 0; unsigned use_mkldnn_true_count = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
auto* op = node->Op(); auto* op = node->Op();
if (op->HasAttr("use_mkldnn") && if (op->HasAttr("use_mkldnn") &&
boost::get<bool>(op->GetAttr("use_mkldnn"))) { boost::get<bool>(op->GetAttr("use_mkldnn"))) {
++use_mkldnn_true_count; ++use_mkldnn_true_count;
}
} }
} }
EXPECT_EQ(use_mkldnn_true_count, expected_use_mkldnn_true_count);
} }
EXPECT_EQ(use_mkldnn_true_count, expected_use_mkldnn_true_count); void PlacementNameTest() {
} auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"MKLDNN");
}
};
TEST(MKLDNNPlacementPass, enable_conv_relu) { TEST(MKLDNNPlacementPass, enable_conv_relu) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
MainTest({"conv2d", "relu"}, 3); PlacementPassTest().MainTest({"conv2d", "relu"}, 3);
} }
TEST(MKLDNNPlacementPass, enable_relu_pool) { TEST(MKLDNNPlacementPass, enable_relu_pool) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest({"relu", "pool2d"}, 4); PlacementPassTest().MainTest({"relu", "pool2d"}, 4);
} }
TEST(MKLDNNPlacementPass, enable_all) { TEST(MKLDNNPlacementPass, enable_all) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest({}, 4); PlacementPassTest().MainTest({}, 4);
}
TEST(MKLDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
} }
} // namespace ir } // namespace ir
......
...@@ -35,6 +35,10 @@ class PlacementPassBase : public Pass { ...@@ -35,6 +35,10 @@ class PlacementPassBase : public Pass {
private: private:
bool IsSupport(const std::string& op_type) const; bool IsSupport(const std::string& op_type) const;
#if PADDLE_WITH_TESTING
friend class PlacementPassTest;
#endif
}; };
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册