diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 53d39513f3686cea59e2d56ff62eec9869f3b2de..ba10687d65cfbbac89cfc76879c8b202ebd03229 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -146,5 +146,22 @@ void NaiveExecutor::CleanFeedFetchOps() { ops_.swap(ops); } +void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) { +#ifdef PADDLE_WITH_MKLDNN + VLOG(3) << "use_mkldnn=True"; + for (size_t block_id = 0; block_id < program.Size(); ++block_id) { + auto *block = const_cast(program).MutableBlock(block_id); + for (auto *op : block->AllOps()) { + if (op->HasAttr("use_mkldnn")) { + op->SetAttr("use_mkldnn", true); + } + } + } +#else + LOG(WARNING) + << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 9355e9e36a6358aa91553dca35aaf1b658516a0a..9374f3f4a35cc0f90e5b2d6e8b397784b8eae123 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -14,6 +14,8 @@ #pragma once +#include +#include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -46,6 +48,8 @@ class NaiveExecutor { void CleanFeedFetchOps(); + void EnableMKLDNN(const ProgramDesc& program); + protected: void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a153433d29b6fef7abdbf7b7b446bad40c1d71e6..3bc6af5241c41bd805699121d614d431d46d863f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -71,6 +71,11 @@ bool AnalysisPredictor::Init( } else { inference_program_ = program; } + + if (config_._use_mkldnn) { + executor_->EnableMKLDNN(*inference_program_); + } + executor_->Prepare(scope_.get(), *inference_program_, 0, config_.use_feed_fetch_ops); @@ -92,6 +97,7 @@ bool AnalysisPredictor::Run(const std::vector &inputs, LOG(ERROR) << "fail to set feed"; return false; } + // Run the inference program // if share variables, we need not create variables executor_->Run(); diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index a2e86305b85dd893f578e97e0105fec828916fb4..305b8bfe158150d5dfd8bdaee2c0a89afe264de4 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -61,8 +61,6 @@ void SetConfig(AnalysisConfig *cfg) { cfg->ir_passes.push_back("fc_gru_fuse_pass"); #ifdef PADDLE_WITH_MKLDNN cfg->_use_mkldnn = true; - // disable mkldnn fuse since it should have some bugs - cfg->ir_passes.push_back("conv_relu_mkldnn_fuse_pass"); #endif }