From 93e9905482599fef6ea5cf925429f6786f8d2808 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 9 Apr 2018 03:58:10 +0000 Subject: [PATCH] Add unittest for calling CreateVariables manually. --- .../book/test_inference_image_classification.cc | 8 ++++---- paddle/fluid/inference/tests/test_helper.h | 16 +++++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index a6b6c3f82..ca2077d07 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -46,8 +46,8 @@ TEST(inference, image_classification) { // Run inference on CPU LOG(INFO) << "--- CPU Runs: ---"; - TestInference(dirname, cpu_feeds, cpu_fetchs1, - FLAGS_repeat); + TestInference(dirname, cpu_feeds, + cpu_fetchs1, FLAGS_repeat); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -57,8 +57,8 @@ TEST(inference, image_classification) { // Run inference on CUDA GPU LOG(INFO) << "--- GPU Runs: ---"; - TestInference(dirname, cpu_feeds, cpu_fetchs2, - FLAGS_repeat); + TestInference(dirname, cpu_feeds, + cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); CheckError(output1, output2); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 972dc894b..aae34ceda 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } -template +template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, const std::vector& cpu_fetchs, @@ -166,14 +166,16 @@ void TestInference(const std::string& dirname, // 6. Run the inference program { - const bool create_vars = false; - if (!create_vars) { + if (!CreateVars) { + // If users don't want to create and destroy variables every time they + // run, they need to set `create_vars` to false and manually call + // `CreateVariables` before running. executor.CreateVariables(*inference_program, scope, 0); } // Ignore the profiling results of the first run - executor.Run( - *inference_program, scope, feed_targets, fetch_targets, create_vars); + executor.Run(*inference_program, scope, feed_targets, fetch_targets, + CreateVars); // Enable the profiler paddle::platform::EnableProfiler(state); @@ -184,8 +186,8 @@ void TestInference(const std::string& dirname, "run_inference", paddle::platform::DeviceContextPool::Instance().Get(place)); - executor.Run( - *inference_program, scope, feed_targets, fetch_targets, create_vars); + executor.Run(*inference_program, scope, feed_targets, fetch_targets, + CreateVars); } // Disable the profiler and print the timing information -- GitLab