From 09d9d77a8fe658694f5b9075fca6146f3b655ebf Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Mon, 1 Oct 2018 14:28:53 +0200 Subject: [PATCH] Enable MKLDNN in Naive Executor test=develop --- paddle/fluid/framework/naive_executor.cc | 17 +++++++++++++++++ paddle/fluid/framework/naive_executor.h | 4 ++++ .../fluid/inference/api/analysis_predictor.cc | 6 ++++++ .../inference/tests/api/analyzer_vis_tester.cc | 2 -- 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 53d39513f36..ba10687d65c 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -146,5 +146,22 @@ void NaiveExecutor::CleanFeedFetchOps() { ops_.swap(ops); } +void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) { +#ifdef PADDLE_WITH_MKLDNN + VLOG(3) << "use_mkldnn=True"; + for (size_t block_id = 0; block_id < program.Size(); ++block_id) { + auto *block = const_cast(program).MutableBlock(block_id); + for (auto *op : block->AllOps()) { + if (op->HasAttr("use_mkldnn")) { + op->SetAttr("use_mkldnn", true); + } + } + } +#else + LOG(WARNING) + << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 9355e9e36a6..9374f3f4a35 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -14,6 +14,8 @@ #pragma once +#include +#include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -46,6 +48,8 @@ class NaiveExecutor { void CleanFeedFetchOps(); + void EnableMKLDNN(const ProgramDesc& program); + protected: void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a153433d29b..3bc6af5241c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -71,6 +71,11 @@ bool AnalysisPredictor::Init( } else { inference_program_ = program; } + + if (config_._use_mkldnn) { + executor_->EnableMKLDNN(*inference_program_); + } + executor_->Prepare(scope_.get(), *inference_program_, 0, config_.use_feed_fetch_ops); @@ -92,6 +97,7 @@ bool AnalysisPredictor::Run(const std::vector &inputs, LOG(ERROR) << "fail to set feed"; return false; } + // Run the inference program // if share variables, we need not create variables executor_->Run(); diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index a2e86305b85..305b8bfe158 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -61,8 +61,6 @@ void SetConfig(AnalysisConfig *cfg) { cfg->ir_passes.push_back("fc_gru_fuse_pass"); #ifdef PADDLE_WITH_MKLDNN cfg->_use_mkldnn = true; - // disable mkldnn fuse since it should have some bugs - cfg->ir_passes.push_back("conv_relu_mkldnn_fuse_pass"); #endif } -- GitLab