diff --git a/CMakeLists.txt b/CMakeLists.txt index 030bd19b3fd2f561a847bbc4613e5d2030812a92..5ca9bc51ccd7f30eb170a2ae5fdd5f55e0e74364 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,13 +117,14 @@ else() endif() set(WITH_MKLML ${WITH_MKL}) -if (WITH_MKL AND AVX2_FOUND) - set(WITH_MKLDNN ON) -else() - message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN") - set(WITH_MKLDNN OFF) +if (NOT DEFINED WITH_MKLDNN) + if (WITH_MKL AND AVX2_FOUND) + set(WITH_MKLDNN ON) + else() + message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN") + set(WITH_MKLDNN OFF) + endif() endif() - ######################################################################################## include(external/mklml) # download mklml package diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index 60c761c5281e2f535aab0200c93fb738addcdb87..987da18116cc6f4902bd66ae317f2470a8bc5057 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -21,6 +21,7 @@ DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model."); DEFINE_int32(batch_size, 1, "Batch size of input data"); DEFINE_int32(repeat, 1, "Running the inference program repeat times"); DEFINE_bool(skip_cpu, false, "Skip the cpu test"); +DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference"); TEST(inference, image_classification) { if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { @@ -58,8 +59,10 @@ TEST(inference, image_classification) { // Run inference on CPU LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "Batch size is " << FLAGS_batch_size; + LOG(INFO) << "FLAGS_use_mkldnn: " << FLAGS_use_mkldnn; TestInference( - dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined); + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined, + FLAGS_use_mkldnn); LOG(INFO) << output1.dims(); } diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index cc1589514aab3b973b4909159748bc4223cdce46..01b8dc0be662da22fe15a79cd9abfe5fa92c9577 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -133,11 +133,24 @@ std::vector> GetFeedTargetShapes( return feed_target_shapes; } +void EnableMKLDNN( + const std::unique_ptr& program) { + for (size_t bid = 0; bid < program->Size(); ++bid) { + auto* block = program->MutableBlock(bid); + for (auto* op : block->AllOps()) { + if (op->HasAttr("use_mkldnn")) { + op->SetAttr("use_mkldnn", true); + } + } + } +} + template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, const std::vector& cpu_fetchs, - const int repeat = 1, const bool is_combined = false) { + const int repeat = 1, const bool is_combined = false, + const bool use_mkldnn = false) { // 1. Define place, executor, scope auto place = Place(); auto executor = paddle::framework::Executor(place); @@ -169,6 +182,9 @@ void TestInference(const std::string& dirname, "init_program", paddle::platform::DeviceContextPool::Instance().Get(place)); inference_program = InitProgram(&executor, scope, dirname, is_combined); + if (use_mkldnn) { + EnableMKLDNN(inference_program); + } } // Disable the profiler and print the timing information paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,