未验证 提交 68ec0a6f 编写于 作者: T Tao Luo 提交者: GitHub

make parallel_executor support FLAGS_use_mkldnn (#17341)

* make parallel_executor support FLAGS_use_mkldnn

test=develop

* add warning when set mkldnn_enabled_op_types_ in non-mkldnn env

test=develop
上级 08635993
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <memory> #include <memory>
#include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -26,6 +27,8 @@ limitations under the License. */ ...@@ -26,6 +27,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -55,6 +58,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -55,6 +58,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass. // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
AppendPass("record_skip_memory_opt_vars_pass"); AppendPass("record_skip_memory_opt_vars_pass");
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
VLOG(5) << "Add mkldnn_placement_pass";
AppendPass("mkldnn_placement_pass");
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
LOG(WARNING)
<< "mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
if (strategy_.enable_sequential_execution_) { if (strategy_.enable_sequential_execution_) {
VLOG(5) << "Add sequential_execution_pass"; VLOG(5) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass"); AppendPass("sequential_execution_pass");
...@@ -313,6 +332,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -313,6 +332,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
} else if (pass->Type() == "inplace_pass") { } else if (pass->Type() == "inplace_pass") {
pass->Erase(ir::kUseCuda); pass->Erase(ir::kUseCuda);
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda)); pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
} else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
} }
VLOG(3) << "Start Apply Pass " << pass->Type(); VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
...@@ -351,3 +373,6 @@ USE_PASS(fuse_all_reduce_op_pass); ...@@ -351,3 +373,6 @@ USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass); USE_PASS(expected_kernel_cache_pass);
USE_PASS(record_skip_memory_opt_vars_pass); USE_PASS(record_skip_memory_opt_vars_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
...@@ -109,6 +110,7 @@ struct BuildStrategy { ...@@ -109,6 +110,7 @@ struct BuildStrategy {
bool cache_runtime_context_{false}; bool cache_runtime_context_{false};
bool cache_expected_kernel_{true}; bool cache_expected_kernel_{true};
std::unordered_set<std::string> mkldnn_enabled_op_types_;
// NOTE: // NOTE:
// Before you add new options, think if it's a general strategy that works // Before you add new options, think if it's a general strategy that works
......
...@@ -1498,6 +1498,15 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1498,6 +1498,15 @@ All parameter, weight, gradient are variables in Paddle.
"cache_expected_kernel", "cache_expected_kernel",
[](const BuildStrategy &self) { return self.cache_expected_kernel_; }, [](const BuildStrategy &self) { return self.cache_expected_kernel_; },
[](BuildStrategy &self, bool b) { self.cache_expected_kernel_ = b; }) [](BuildStrategy &self, bool b) { self.cache_expected_kernel_ = b; })
.def_property(
"mkldnn_enabled_op_types",
[](const BuildStrategy &self) {
return self.mkldnn_enabled_op_types_;
},
[](BuildStrategy &self,
const std::unordered_set<std::string> &mkldnn_enabled_op_types) {
self.mkldnn_enabled_op_types_ = mkldnn_enabled_op_types;
})
.def("_finalize_strategy_and_create_passes", .def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> { [](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true); return self.CreatePassesFromStrategy(true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册