diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index a6b6c3f828f4c6f59fca42e4c3d9580d6c136524..ca2077d07411d2cd6095e0dc2a874af0890145c5 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -46,8 +46,8 @@ TEST(inference, image_classification) { // Run inference on CPU LOG(INFO) << "--- CPU Runs: ---"; - TestInference(dirname, cpu_feeds, cpu_fetchs1, - FLAGS_repeat); + TestInference(dirname, cpu_feeds, + cpu_fetchs1, FLAGS_repeat); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -57,8 +57,8 @@ TEST(inference, image_classification) { // Run inference on CUDA GPU LOG(INFO) << "--- GPU Runs: ---"; - TestInference(dirname, cpu_feeds, cpu_fetchs2, - FLAGS_repeat); + TestInference(dirname, cpu_feeds, + cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); CheckError(output1, output2); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 972dc894b2ebf5daf55b62287ea05235af6acf63..aae34ceda07fea6e881cf61b3755ec45d1d6f2dc 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } -template +template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, const std::vector& cpu_fetchs, @@ -166,14 +166,16 @@ void TestInference(const std::string& dirname, // 6. Run the inference program { - const bool create_vars = false; - if (!create_vars) { + if (!CreateVars) { + // If users don't want to create and destroy variables every time they + // run, they need to set `create_vars` to false and manually call + // `CreateVariables` before running. executor.CreateVariables(*inference_program, scope, 0); } // Ignore the profiling results of the first run - executor.Run( - *inference_program, scope, feed_targets, fetch_targets, create_vars); + executor.Run(*inference_program, scope, feed_targets, fetch_targets, + CreateVars); // Enable the profiler paddle::platform::EnableProfiler(state); @@ -184,8 +186,8 @@ void TestInference(const std::string& dirname, "run_inference", paddle::platform::DeviceContextPool::Instance().Get(place)); - executor.Run( - *inference_program, scope, feed_targets, fetch_targets, create_vars); + executor.Run(*inference_program, scope, feed_targets, fetch_targets, + CreateVars); } // Disable the profiler and print the timing information