diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 845fdf511e455509ff3e871084c17163c90c674a..e55b354f1931c219bdacc768bfc2839c6d2ea3cc 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/ir/graph.h" @@ -26,6 +27,8 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h" +DECLARE_bool(use_mkldnn); + namespace paddle { namespace framework { namespace details { @@ -55,6 +58,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Note(zcd): record_skip_memory_opt_vars_pass should be the first pass. AppendPass("record_skip_memory_opt_vars_pass"); +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) { + VLOG(5) << "Add mkldnn_placement_pass"; + AppendPass("mkldnn_placement_pass"); + } else if (!strategy_.mkldnn_enabled_op_types_.empty()) { + LOG(WARNING) + << "mkldnn_enabled_op_types specify the operator type list to " + "use MKLDNN acceleration. It is null in default, means " + "that all the operators supported by MKLDNN will be " + "accelerated. And it should not be set when " + "FLAGS_use_mkldnn=false."; + } +#else + PADDLE_ENFORCE(!FLAGS_use_mkldnn, + "Please compile with MKLDNN first to use MKLDNN"); +#endif if (strategy_.enable_sequential_execution_) { VLOG(5) << "Add sequential_execution_pass"; AppendPass("sequential_execution_pass"); @@ -313,6 +332,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, } else if (pass->Type() == "inplace_pass") { pass->Erase(ir::kUseCuda); pass->Set(ir::kUseCuda, new bool(use_cuda)); + } else if (pass->Type() == "mkldnn_placement_pass") { + pass->Set("mkldnn_enabled_op_types", + new std::unordered_set(mkldnn_enabled_op_types_)); } VLOG(3) << "Start Apply Pass " << pass->Type(); graph = pass->Apply(graph); @@ -351,3 +373,6 @@ USE_PASS(fuse_all_reduce_op_pass); USE_PASS(runtime_context_cache_pass); USE_PASS(expected_kernel_cache_pass); USE_PASS(record_skip_memory_opt_vars_pass); +#ifdef PADDLE_WITH_MKLDNN +USE_PASS(mkldnn_placement_pass); +#endif diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index b1601cfbcd5e9c66f1bbecd1f6fe10bc279cea26..38cc00a185547a1430b4f998430ea3a2d02f8c91 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include "paddle/fluid/framework/ir/pass_builder.h" @@ -109,6 +110,7 @@ struct BuildStrategy { bool cache_runtime_context_{false}; bool cache_expected_kernel_{true}; + std::unordered_set mkldnn_enabled_op_types_; // NOTE: // Before you add new options, think if it's a general strategy that works diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 43322b796b3f63b2d9ce0eb7fb034ac52ad796f9..e0ab615bb5db8a40ef632bf45b603ad25ece37f8 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1498,6 +1498,15 @@ All parameter, weight, gradient are variables in Paddle. "cache_expected_kernel", [](const BuildStrategy &self) { return self.cache_expected_kernel_; }, [](BuildStrategy &self, bool b) { self.cache_expected_kernel_ = b; }) + .def_property( + "mkldnn_enabled_op_types", + [](const BuildStrategy &self) { + return self.mkldnn_enabled_op_types_; + }, + [](BuildStrategy &self, + const std::unordered_set &mkldnn_enabled_op_types) { + self.mkldnn_enabled_op_types_ = mkldnn_enabled_op_types; + }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true);