提交 93e99054 编写于 作者: L Liu Yiqun

Add unittest for calling CreateVariables manually.

上级 a9855e4a
......@@ -46,8 +46,8 @@ TEST(inference, image_classification) {
// Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1,
FLAGS_repeat);
TestInference<paddle::platform::CPUPlace, false>(dirname, cpu_feeds,
cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
......@@ -57,8 +57,8 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2,
FLAGS_repeat);
TestInference<paddle::platform::CUDAPlace, false>(dirname, cpu_feeds,
cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
......
......@@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}
template <typename Place>
template <typename Place, bool CreateVars = true>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
......@@ -166,14 +166,16 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
const bool create_vars = false;
if (!create_vars) {
if (!CreateVars) {
// If users don't want to create and destroy variables every time they
// run, they need to set `create_vars` to false and manually call
// `CreateVariables` before running.
executor.CreateVariables(*inference_program, scope, 0);
}
// Ignore the profiling results of the first run
executor.Run(
*inference_program, scope, feed_targets, fetch_targets, create_vars);
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
// Enable the profiler
paddle::platform::EnableProfiler(state);
......@@ -184,8 +186,8 @@ void TestInference(const std::string& dirname,
"run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place));
executor.Run(
*inference_program, scope, feed_targets, fetch_targets, create_vars);
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
}
// Disable the profiler and print the timing information
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册