提交 93e99054 编写于 作者: L Liu Yiqun

Add unittest for calling CreateVariables manually.

上级 a9855e4a
...@@ -46,8 +46,8 @@ TEST(inference, image_classification) { ...@@ -46,8 +46,8 @@ TEST(inference, image_classification) {
// Run inference on CPU // Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1, TestInference<paddle::platform::CPUPlace, false>(dirname, cpu_feeds,
FLAGS_repeat); cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims(); LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -57,8 +57,8 @@ TEST(inference, image_classification) { ...@@ -57,8 +57,8 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU // Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---"; LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2, TestInference<paddle::platform::CUDAPlace, false>(dirname, cpu_feeds,
FLAGS_repeat); cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims(); LOG(INFO) << output2.dims();
CheckError<float>(output1, output2); CheckError<float>(output1, output2);
......
...@@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1, ...@@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
} }
template <typename Place> template <typename Place, bool CreateVars = true>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs, const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
...@@ -166,14 +166,16 @@ void TestInference(const std::string& dirname, ...@@ -166,14 +166,16 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program // 6. Run the inference program
{ {
const bool create_vars = false; if (!CreateVars) {
if (!create_vars) { // If users don't want to create and destroy variables every time they
// run, they need to set `create_vars` to false and manually call
// `CreateVariables` before running.
executor.CreateVariables(*inference_program, scope, 0); executor.CreateVariables(*inference_program, scope, 0);
} }
// Ignore the profiling results of the first run // Ignore the profiling results of the first run
executor.Run( executor.Run(*inference_program, scope, feed_targets, fetch_targets,
*inference_program, scope, feed_targets, fetch_targets, create_vars); CreateVars);
// Enable the profiler // Enable the profiler
paddle::platform::EnableProfiler(state); paddle::platform::EnableProfiler(state);
...@@ -184,8 +186,8 @@ void TestInference(const std::string& dirname, ...@@ -184,8 +186,8 @@ void TestInference(const std::string& dirname,
"run_inference", "run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
executor.Run( executor.Run(*inference_program, scope, feed_targets, fetch_targets,
*inference_program, scope, feed_targets, fetch_targets, create_vars); CreateVars);
} }
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册