diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 6ec225ed983055dad1380cf484fa02526fb3f6cc..c35c1138df90ca3eb4f6496475371a6a975a6942 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -671,9 +671,8 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) { for (size_t bid = 0; bid < program.Size(); ++bid) { auto* block = const_cast(program).MutableBlock(bid); for (auto* op : block->AllOps()) { - if (op->HasAttr("use_mkldnn")) { + if (FoundOneDNNKernel(op) || FoundPhiOneDNNKernel(op)) op->SetAttr("use_mkldnn", true); - } } } #else diff --git a/paddle/fluid/framework/ir/cudnn_placement_pass.cc b/paddle/fluid/framework/ir/cudnn_placement_pass.cc index 420e8ee83adbc2935d84c009cfb88589d02bc29c..6b7293acf94284461266790219c1c3d6725e0ede 100644 --- a/paddle/fluid/framework/ir/cudnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/cudnn_placement_pass.cc @@ -13,6 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/cudnn_placement_pass.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +namespace ir { + +bool CUDNNPlacementPass::IsSupport(const Node* op) const { + std::string attr_name = GetAttrName(); + + if (!(op->Op()->HasAttr(attr_name) || op->Op()->HasProtoAttr(attr_name))) + return false; + + auto& all_kernels = OperatorWithKernel::AllOpKernels(); + auto it = all_kernels.find(op->Op()->Type()); + if (it == all_kernels.end()) { + // All control operators don't have kernel. + return false; + } + for (auto& kernel_pair : it->second) { + if (platform::is_gpu_place(kernel_pair.first.place_) && + (kernel_pair.first.library_type_ == LibraryType::kCUDNN)) { + return true; + } + } + return false; +} + +} // namespace ir +} // namespace framework +} // namespace paddle REGISTER_PASS(cudnn_placement_pass, paddle::framework::ir::CUDNNPlacementPass) .RequirePassAttr("cudnn_enabled_op_types"); diff --git a/paddle/fluid/framework/ir/cudnn_placement_pass.h b/paddle/fluid/framework/ir/cudnn_placement_pass.h index 8d84c2bf707956c4a00454a6dc66efcb42bec816..afd32cfa721e209b9ba9d9f97f9600dbf725f165 100644 --- a/paddle/fluid/framework/ir/cudnn_placement_pass.h +++ b/paddle/fluid/framework/ir/cudnn_placement_pass.h @@ -27,6 +27,9 @@ namespace ir { * Specifies which operators should use cuDNN. */ class CUDNNPlacementPass : public PlacementPassBase { + protected: + bool IsSupport(const Node* op) const override; + private: const std::string GetPlacementName() const override { return "cuDNN"; } diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc index ba9ccb5daa6677f4b435c077d5bae1991830ea54..05e46db50afd308685530adb3ef286745fd27f16 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc @@ -40,6 +40,7 @@ void TestFcRNNFusePass(const std::string& pass_name, "__param_scope__", (pass_name == "fc_gru_fuse_pass" ? fc_gru_test::CreateParamScope() : fc_lstm_test::CreateParamScope())); + RegisterOpKernel({"mul", "elementwise_add"}); graph.reset(mkldnn_placement_pass_->Apply(graph.release())); auto check_num_mkldnn_nodes = [&](const std::unique_ptr& graph) { diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc index 6032f38b0cffd8627c547a08e5f5b657decf89df..83b06102d21ff85821ac4834c1a6605dfc440b75 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc @@ -13,6 +13,69 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +namespace ir { + +inline bool FoundOneDNNKernelWithCorrectDataType( + const framework::ir::Node* op) { + const auto op_type = op->Op()->Type(); + auto& all_kernels = framework::OperatorWithKernel::AllOpKernels(); + auto it = all_kernels.find(op_type); + if (it != all_kernels.end()) { + for (auto& kernel_pair : it->second) { + if (platform::is_cpu_place(kernel_pair.first.place_) && + (kernel_pair.first.library_type_ == + framework::LibraryType::kMKLDNN)) { + if (op->inputs.size() > 0) { + if (op->inputs[0]->IsVar() && + op->inputs[0]->Var()->Name() != "feed" && + kernel_pair.first.data_type_ == + op->inputs[0]->Var()->GetDataType()) + return true; + } else { + return true; + } + } + } + } + return false; +} + +inline bool FoundPhiOneDNNKernelWithCorrectDataType( + const framework::ir::Node* op) { + auto op_type = op->Op()->Type(); + auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( + phi::TransToPhiKernelName(op_type)); + + for (auto& kernel_pair : phi_kernels) { + if (kernel_pair.first.backend() == phi::Backend::ONEDNN) { + if (op->inputs.size() > 0) { + if (op->inputs[0]->IsVar() && op->inputs[0]->Var()->Name() != "feed" && + kernel_pair.first.dtype() == + framework::TransToPhiDataType( + op->inputs[0]->Var()->GetDataType())) + return true; + } else { + return true; + } + } + } + return false; +} + +bool MKLDNNPlacementPass::IsSupport(const Node* op) const { + if (FoundOneDNNKernelWithCorrectDataType(op) || + FoundPhiOneDNNKernelWithCorrectDataType(op)) + return true; + return false; +} + +} // namespace ir +} // namespace framework +} // namespace paddle REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass) .RequirePassAttr("mkldnn_enabled_op_types"); diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h index ca56a8900ca4f7edac7be095a0968555bf628124..5fc1dbd24f18ef9ce2ebb70feb5df36d6c5b7fd2 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h @@ -27,6 +27,9 @@ namespace ir { * Specifies which operators should use MKLDNN. */ class MKLDNNPlacementPass : public PlacementPassBase { + protected: + bool IsSupport(const Node* op) const override; + private: const std::string GetPlacementName() const override { return "MKLDNN"; } diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc index 79b70e39aaf753c3336ab907ce45c242d17faf29..b7697252a67c4a3838a40a4d43cab826a537eeec 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc @@ -14,9 +14,9 @@ #include -#include "paddle/utils/tribool.h" - #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/utils/tribool.h" namespace paddle { namespace framework { @@ -80,6 +80,7 @@ class PlacementPassTest { "l"})) { auto* var = prog.MutableBlock(0)->Var(v); var->SetType(proto::VarType::SELECTED_ROWS); + var->SetDataType(framework::proto::VarType::FP32); if (v == "weights" || v == "bias") { var->SetPersistable(true); } @@ -129,7 +130,7 @@ class PlacementPassTest { void MainTest(std::initializer_list mkldnn_enabled_op_types, unsigned expected_use_mkldnn_true_count) { auto prog = BuildProgramDesc(); - + RegisterOpKernel({"conv2d", "pool2d", "concat", "relu"}); std::unique_ptr graph(new ir::Graph(prog)); auto pass = PassRegistry::Instance().Get("mkldnn_placement_pass"); @@ -162,8 +163,8 @@ class PlacementPassTest { }; TEST(MKLDNNPlacementPass, enable_conv_relu) { - // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool - PlacementPassTest().MainTest({"conv2d", "relu"}, 3); + // 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool + PlacementPassTest().MainTest({"conv2d", "relu"}, 4); } TEST(MKLDNNPlacementPass, enable_relu_pool) { @@ -172,8 +173,9 @@ TEST(MKLDNNPlacementPass, enable_relu_pool) { } TEST(MKLDNNPlacementPass, enable_all) { - // 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool - PlacementPassTest().MainTest({}, 4); + // 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool + + // 1 concat + PlacementPassTest().MainTest({}, 6); } TEST(MKLDNNPlacementPass, placement_name) { diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 589cf4d0d192d9cea54d2535bdfe45d5a022bd67..dc423d9d17dac174d67ce18c93a850a1177e8519 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -933,6 +933,21 @@ static int GetNumOpNodes(const std::unique_ptr& graph, return num_nodes; } +static void RegisterOpKernel(std::vector&& op_types) { + auto& all_kernels = OperatorWithKernel::AllOpKernels(); + + platform::CPUPlace place = platform::CPUPlace(); + OpKernelType mkldnn_kernel_type = OpKernelType(proto::VarType::FP32, + place, + DataLayout::kAnyLayout, + LibraryType::kMKLDNN); + + auto fake_kernel_func = [](const ExecutionContext&) -> void {}; + + for (auto& op_name : op_types) + all_kernels[op_name][mkldnn_kernel_type] = fake_kernel_func; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/placement_pass_base.cc b/paddle/fluid/framework/ir/placement_pass_base.cc index ad062dd735a1d5cd5db99b2ffb3c91f99e20c21c..ccf2bf22ab57bd0e54f10648f978b5c6ad39524f 100644 --- a/paddle/fluid/framework/ir/placement_pass_base.cc +++ b/paddle/fluid/framework/ir/placement_pass_base.cc @@ -32,13 +32,10 @@ void PlacementPassBase::ApplyImpl(ir::Graph* graph) const { for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); - if ((op->HasAttr(attr_name) || op->HasProtoAttr(attr_name)) && - IsSupport(op->Type())) { - if (op_types_list.empty() && IsDefaultOpTypes(op->Type())) { - op->SetAttr(attr_name, true); - } else if (std::find(op_types_list.begin(), - op_types_list.end(), - n->Name()) != op_types_list.end()) { + if (IsSupport(n)) { + if (op_types_list.empty() || + std::find(op_types_list.begin(), op_types_list.end(), n->Name()) != + op_types_list.end()) { op->SetAttr(attr_name, true); } } @@ -46,53 +43,6 @@ void PlacementPassBase::ApplyImpl(ir::Graph* graph) const { } } -bool PlacementPassBase::IsSupport(const std::string& op_type) const { - if (GetAttrName() == "use_cudnn") { - auto& all_kernels = OperatorWithKernel::AllOpKernels(); - auto it = all_kernels.find(op_type); - if (it == all_kernels.end()) { - // All control operators don't have kernel. - return false; - } - for (auto& kernel_pair : it->second) { - if (platform::is_gpu_place(kernel_pair.first.place_) && - (kernel_pair.first.library_type_ == LibraryType::kCUDNN)) { - return true; - } - } - } else if (GetAttrName() == "use_mkldnn") { - // This ops have use_mkldnn attr, but not support for now. - const std::vector op_types = { - "trilinear_interp", "bicubic_interp", "linear_interp"}; - return std::find(op_types.begin(), op_types.end(), op_type) == - op_types.end(); - } - return false; -} - -bool PlacementPassBase::IsDefaultOpTypes(const std::string& op_type) const { - if (GetAttrName() == "use_cudnn") { - return true; - } else if (GetAttrName() == "use_mkldnn") { - // For interpolate ops, there's a little difference between Paddle and - // MKLDNN. - // If run MKLDNN interpolate ops, manual set AnalysisConfig and apply - // the corresponding pass. - const std::vector not_default_op_types = {"bilinear_interp", - "nearest_interp", - "trilinear_interp", - "bicubic_interp", - "linear_interp", - "bilinear_interp_v2", - "linear_interp_v2"}; - bool is_interpolate_op = std::find(not_default_op_types.begin(), - not_default_op_types.end(), - op_type) != not_default_op_types.end(); - return !is_interpolate_op; - } - return false; -} - } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/placement_pass_base.h b/paddle/fluid/framework/ir/placement_pass_base.h index 6927c031dcca38a1e1fd9153963b7646b0dbd32d..5254ca976507bd78723195d51b55dc3438aaa6b1 100644 --- a/paddle/fluid/framework/ir/placement_pass_base.h +++ b/paddle/fluid/framework/ir/placement_pass_base.h @@ -35,10 +35,7 @@ class PlacementPassBase : public Pass { virtual const std::string GetPlacementName() const = 0; virtual const std::string GetAttrName() const = 0; virtual const std::unordered_set GetOpTypesList() const = 0; - - private: - bool IsSupport(const std::string& op_type) const; - bool IsDefaultOpTypes(const std::string& op_type) const; + virtual bool IsSupport(const Node* op) const = 0; #if PADDLE_WITH_TESTING friend class PlacementPassTest; diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 6c6c18a88cb6fa1531960eef1b166e63bf4ee7d7..c3eee6888a7d8ff30ab19cace59c7c725237c50a 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -115,17 +115,15 @@ endif() if(WITH_TESTING) if(NOT APPLE AND NOT WIN32) - if(WITH_GPU) - inference_base_test( - test_api_impl - SRCS - api_impl_tester.cc - DEPS - paddle_inference_shared - ARGS - --word2vec_dirname=${WORD2VEC_MODEL_DIR} - --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR}) - endif() + inference_base_test( + test_api_impl + SRCS + api_impl_tester.cc + DEPS + paddle_inference_shared + ARGS + --word2vec_dirname=${WORD2VEC_MODEL_DIR} + --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR}) elseif(WIN32) inference_base_test( test_api_impl @@ -137,7 +135,6 @@ if(WITH_TESTING) --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${IMG_CLS_RESNET_INSTALL_DIR}) endif() - endif() if(NOT APPLE AND NOT WIN32) diff --git a/paddle/fluid/inference/api/api_impl_tester.cc b/paddle/fluid/inference/api/api_impl_tester.cc index 4993a17bc2b92998088af75d65ebc884915f9e62..67dc193feed09fe34213d42774842e768f670faa 100644 --- a/paddle/fluid/inference/api/api_impl_tester.cc +++ b/paddle/fluid/inference/api/api_impl_tester.cc @@ -331,6 +331,18 @@ TEST(inference_api_native, image_classification_gpu) { // } #endif +#ifdef PADDLE_WITH_MKLDNN +TEST(inference_api_native, image_classification_cpu_onednn) { + FLAGS_use_mkldnn = true; + MainImageClassification(paddle::PaddlePlace::kCPU); +} + +TEST(inference_api_native, word2vec_cpu_onednn) { + FLAGS_use_mkldnn = true; + MainWord2Vec(paddle::PaddlePlace::kCPU); +} +#endif + TEST(PassBuilder, Delete) { AnalysisConfig config; config.DisableGpu(); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 86a8a24e84673ca78682527f0e88b743b8120a10..60eccdbb80c193b32d8f25c3bed0092448f345c2 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -140,4 +140,32 @@ inline std::string FindOutputNameByVarName(framework::OpDesc* op, if (output_name == searched_name) ret = name; return ret; } + +inline bool FoundOneDNNKernel(const framework::OpDesc* op) { + auto op_type = op->Type(); + auto& all_kernels = framework::OperatorWithKernel::AllOpKernels(); + auto it = all_kernels.find(op_type); + if (it != all_kernels.end()) { + for (auto& kernel_pair : it->second) { + if (platform::is_cpu_place(kernel_pair.first.place_) && + (kernel_pair.first.library_type_ == + framework::LibraryType::kMKLDNN)) { + return true; + } + } + } + return false; +} + +inline bool FoundPhiOneDNNKernel(const framework::OpDesc* op) { + auto op_type = op->Type(); + auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( + phi::TransToPhiKernelName(op_type)); + + for (auto& kernel_pair : phi_kernels) + if (kernel_pair.first.backend() == phi::Backend::ONEDNN) return true; + + return false; +} + } // namespace paddle