提交 d3a66473 编写于 作者: W Wojciech Uss 提交者: Tao Luo

improve placement pass tests code coverage (#22197)

上级 f5262865
......@@ -27,11 +27,11 @@ namespace ir {
*/
class CUDNNPlacementPass : public PlacementPassBase {
private:
const std::string GetPlacementName() const { return "cuDNN"; }
const std::string GetPlacementName() const override { return "cuDNN"; }
const std::string GetAttrName() const { return "use_cudnn"; }
const std::string GetAttrName() const override { return "use_cudnn"; }
const std::unordered_set<std::string> GetOpTypesList() const {
const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types");
}
};
......
......@@ -22,94 +22,108 @@ namespace paddle {
namespace framework {
namespace ir {
void RegisterOpKernel() {
static bool is_registered = false;
if (!is_registered) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
platform::CUDAPlace place = platform::CUDAPlace(0);
OpKernelType plain_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
LibraryType::kPlain);
OpKernelType cudnn_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
LibraryType::kCUDNN);
auto fake_kernel_func = [](const ExecutionContext&) -> void {
static int num_calls = 0;
num_calls++;
};
all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func;
all_kernels["relu"][plain_kernel_type] = fake_kernel_func;
is_registered = true;
class PlacementPassTest {
private:
void RegisterOpKernel() {
static bool is_registered = false;
if (!is_registered) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
platform::CUDAPlace place = platform::CUDAPlace(0);
OpKernelType plain_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
LibraryType::kPlain);
OpKernelType cudnn_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
LibraryType::kCUDNN);
auto fake_kernel_func = [](const ExecutionContext&) -> void {
static int num_calls = 0;
num_calls++;
};
all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func;
all_kernels["relu"][plain_kernel_type] = fake_kernel_func;
is_registered = true;
}
}
}
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
unsigned expected_use_cudnn_true_count) {
// operator use_cudnn
// --------------------------------------------------
// (a,b)->concat->c -
// (c,weights,bias)->conv2d->f false
// f->relu->g -
// g->pool2d->h false
// (h,weights2,bias2)->depthwise_conv2d->k false
// k->relu->l -
Layers layers;
VarDesc* a = layers.data("a");
VarDesc* b = layers.data("b");
VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b}));
VarDesc* weights_0 = layers.data("weights_0");
VarDesc* bias_0 = layers.data("bias_0");
VarDesc* f = layers.conv2d(c, weights_0, bias_0, false);
VarDesc* g = layers.relu(f);
VarDesc* h = layers.pool2d(g, false);
VarDesc* weights_1 = layers.data("weights_1");
VarDesc* bias_1 = layers.data("bias_1");
VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false);
layers.relu(k);
RegisterOpKernel();
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
pass->Set("cudnn_enabled_op_types",
new std::unordered_set<std::string>(cudnn_enabled_op_types));
graph.reset(pass->Apply(graph.release()));
unsigned use_cudnn_true_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) {
auto* op = node->Op();
if (op->HasAttr("use_cudnn") &&
boost::get<bool>(op->GetAttr("use_cudnn"))) {
++use_cudnn_true_count;
public:
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
unsigned expected_use_cudnn_true_count) {
// operator use_cudnn
// --------------------------------------------------
// (a,b)->concat->c -
// (c,weights,bias)->conv2d->f false
// f->relu->g -
// g->pool2d->h false
// (h,weights2,bias2)->depthwise_conv2d->k false
// k->relu->l -
Layers layers;
VarDesc* a = layers.data("a");
VarDesc* b = layers.data("b");
VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b}));
VarDesc* weights_0 = layers.data("weights_0");
VarDesc* bias_0 = layers.data("bias_0");
VarDesc* f = layers.conv2d(c, weights_0, bias_0, false);
VarDesc* g = layers.relu(f);
VarDesc* h = layers.pool2d(g, false);
VarDesc* weights_1 = layers.data("weights_1");
VarDesc* bias_1 = layers.data("bias_1");
VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false);
layers.relu(k);
RegisterOpKernel();
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
pass->Set("cudnn_enabled_op_types",
new std::unordered_set<std::string>(cudnn_enabled_op_types));
graph.reset(pass->Apply(graph.release()));
unsigned use_cudnn_true_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) {
auto* op = node->Op();
if (op->HasAttr("use_cudnn") &&
boost::get<bool>(op->GetAttr("use_cudnn"))) {
++use_cudnn_true_count;
}
}
}
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
}
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
}
void PlacementNameTest() {
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"cuDNN");
}
};
TEST(CUDNNPlacementPass, enable_conv2d) {
// 1 conv2d
MainTest({"conv2d"}, 1);
PlacementPassTest().MainTest({"conv2d"}, 1);
}
TEST(CUDNNPlacementPass, enable_relu_pool) {
// 1 conv2d + 1 pool2d
MainTest({"conv2d", "pool2d"}, 2);
PlacementPassTest().MainTest({"conv2d", "pool2d"}, 2);
}
TEST(CUDNNPlacementPass, enable_all) {
// 1 conv2d + 1 pool2d
// depthwise_conv2d doesnot have CUDNN kernel.
MainTest({}, 2);
PlacementPassTest().MainTest({}, 2);
}
TEST(CUDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
}
} // namespace ir
......
......@@ -27,11 +27,11 @@ namespace ir {
*/
class MKLDNNPlacementPass : public PlacementPassBase {
private:
const std::string GetPlacementName() const { return "MKLDNN"; }
const std::string GetPlacementName() const override { return "MKLDNN"; }
const std::string GetAttrName() const { return "use_mkldnn"; }
const std::string GetAttrName() const override { return "use_mkldnn"; }
const std::unordered_set<std::string> GetOpTypesList() const {
const std::unordered_set<std::string> GetOpTypesList() const override {
return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
}
};
......
......@@ -21,112 +21,129 @@ namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, boost::tribool use_mkldnn) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (!boost::indeterminate(use_mkldnn)) op->SetAttr("use_mkldnn", use_mkldnn);
if (type == "conv2d") {
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") {
op->SetInput("X", inputs);
} else if (type == "concat") {
op->SetAttr("axis", 1);
op->SetInput("X", {inputs[0], inputs[1]});
} else if (type == "pool2d") {
op->SetInput("X", {inputs[0]});
} else {
FAIL() << "Unexpected operator type.";
class PlacementPassTest {
private:
void SetOp(ProgramDesc* prog, const std::string& type,
const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
boost::tribool use_mkldnn) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (!boost::indeterminate(use_mkldnn))
op->SetAttr("use_mkldnn", use_mkldnn);
if (type == "conv2d") {
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") {
op->SetInput("X", inputs);
} else if (type == "concat") {
op->SetAttr("axis", 1);
op->SetInput("X", {inputs[0], inputs[1]});
} else if (type == "pool2d") {
op->SetInput("X", {inputs[0]});
} else {
FAIL() << "Unexpected operator type.";
}
op->SetOutput("Out", {outputs[0]});
}
op->SetOutput("Out", {outputs[0]});
}
// operator use_mkldnn
// ---------------------------------------
// (a,b)->concat->c none
// (c,weights,bias)->conv->f none
// f->relu->g false
// g->pool->h false
// (h,weights2,bias2)->conv->k true
// k->relu->l true
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
// operator use_mkldnn
// ---------------------------------------
// (a,b)->concat->c none
// (c,weights,bias)->conv->f none
// f->relu->g false
// g->pool->h false
// (h,weights2,bias2)->conv->k true
// k->relu->l true
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "concat", "concat1", std::vector<std::string>({"a", "b"}),
std::vector<std::string>({"c"}), boost::indeterminate);
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), boost::indeterminate);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), false);
SetOp(&prog, "pool2d", "pool1", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}), false);
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}), true);
return prog;
}
SetOp(&prog, "concat", "concat1", std::vector<std::string>({"a", "b"}),
std::vector<std::string>({"c"}), boost::indeterminate);
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), boost::indeterminate);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), false);
SetOp(&prog, "pool2d", "pool1", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}), false);
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}), true);
return prog;
}
public:
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
unsigned expected_use_mkldnn_true_count) {
auto prog = BuildProgramDesc();
void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
unsigned expected_use_mkldnn_true_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types));
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types));
graph.reset(pass->Apply(graph.release()));
graph.reset(pass->Apply(graph.release()));
unsigned use_mkldnn_true_count = 0;
unsigned use_mkldnn_true_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->HasAttr("use_mkldnn") &&
boost::get<bool>(op->GetAttr("use_mkldnn"))) {
++use_mkldnn_true_count;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->HasAttr("use_mkldnn") &&
boost::get<bool>(op->GetAttr("use_mkldnn"))) {
++use_mkldnn_true_count;
}
}
}
EXPECT_EQ(use_mkldnn_true_count, expected_use_mkldnn_true_count);
}
EXPECT_EQ(use_mkldnn_true_count, expected_use_mkldnn_true_count);
}
void PlacementNameTest() {
auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass");
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
"MKLDNN");
}
};
TEST(MKLDNNPlacementPass, enable_conv_relu) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
MainTest({"conv2d", "relu"}, 3);
PlacementPassTest().MainTest({"conv2d", "relu"}, 3);
}
TEST(MKLDNNPlacementPass, enable_relu_pool) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest({"relu", "pool2d"}, 4);
PlacementPassTest().MainTest({"relu", "pool2d"}, 4);
}
TEST(MKLDNNPlacementPass, enable_all) {
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
MainTest({}, 4);
PlacementPassTest().MainTest({}, 4);
}
TEST(MKLDNNPlacementPass, placement_name) {
PlacementPassTest().PlacementNameTest();
}
} // namespace ir
......
......@@ -35,6 +35,10 @@ class PlacementPassBase : public Pass {
private:
bool IsSupport(const std::string& op_type) const;
#if PADDLE_WITH_TESTING
friend class PlacementPassTest;
#endif
};
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册