From 661826a70a8bb6e0d6279b23431dfa5d021e03e6 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 16 May 2018 17:35:51 +0800 Subject: [PATCH] enable MKLDNN inference test --- .../test_inference_image_classification.cc | 5 ++++- paddle/fluid/inference/tests/test_helper.h | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index 60c761c5281..987da18116c 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -21,6 +21,7 @@ DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model."); DEFINE_int32(batch_size, 1, "Batch size of input data"); DEFINE_int32(repeat, 1, "Running the inference program repeat times"); DEFINE_bool(skip_cpu, false, "Skip the cpu test"); +DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference"); TEST(inference, image_classification) { if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { @@ -58,8 +59,10 @@ TEST(inference, image_classification) { // Run inference on CPU LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "Batch size is " << FLAGS_batch_size; + LOG(INFO) << "FLAGS_use_mkldnn: " << FLAGS_use_mkldnn; TestInference( - dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined); + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined, + FLAGS_use_mkldnn); LOG(INFO) << output1.dims(); } diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index b02e5c99f00..61334a10dea 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -133,11 +133,24 @@ std::vector> GetFeedTargetShapes( return feed_target_shapes; } +void EnableMKLDNN( + const std::unique_ptr& program) { + for (size_t bid = 0; bid < program->Size(); ++bid) { + auto* block = program->MutableBlock(bid); + for (auto* op : block->AllOps()) { + if (op->HasAttr("use_mkldnn")) { + op->SetAttr("use_mkldnn", true); + } + } + } +} + template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, const std::vector& cpu_fetchs, - const int repeat = 1, const bool is_combined = false) { + const int repeat = 1, const bool is_combined = false, + const bool use_mkldnn = false) { // 1. Define place, executor, scope auto place = Place(); auto executor = paddle::framework::Executor(place); @@ -169,6 +182,9 @@ void TestInference(const std::string& dirname, "init_program", paddle::platform::DeviceContextPool::Instance().Get(place)); inference_program = InitProgram(&executor, scope, dirname, is_combined); + if (use_mkldnn) { + EnableMKLDNN(inference_program); + } } // Disable the profiler and print the timing information paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, -- GitLab