未验证 提交 b4b16946 编写于 作者: T Tao Luo 提交者: GitHub

add fc_mkldnn_pass in compare_mkldnn (#17712)

test=develop
上级 70a887af
...@@ -184,6 +184,7 @@ void compare(bool use_mkldnn = false) { ...@@ -184,6 +184,7 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> inputs; std::vector<std::vector<PaddleTensor>> inputs;
......
...@@ -252,6 +252,7 @@ void compare(bool use_mkldnn = false) { ...@@ -252,6 +252,7 @@ void compare(bool use_mkldnn = false) {
std::unordered_set<std::string> op_list = {"softmax", "elementwise_add", std::unordered_set<std::string> op_list = {"softmax", "elementwise_add",
"relu"}; "relu"};
cfg.SetMKLDNNOp(op_list); cfg.SetMKLDNNOp(op_list);
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -147,6 +147,7 @@ void compare(bool use_mkldnn = false) { ...@@ -147,6 +147,7 @@ void compare(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -220,6 +220,7 @@ void compare(bool use_mkldnn = false) { ...@@ -220,6 +220,7 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -133,6 +133,7 @@ void compare(bool use_mkldnn = false) { ...@@ -133,6 +133,7 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册