diff --git a/paddle/inference/tests/book/test_inference_recognize_digits.cc b/paddle/inference/tests/book/test_inference_recognize_digits.cc index 0dfaf9a0eec63c2ecdef60626e0077f43b955842..de15167ac3a465544b31b967b03040b2cf205cac 100644 --- a/paddle/inference/tests/book/test_inference_recognize_digits.cc +++ b/paddle/inference/tests/book/test_inference_recognize_digits.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,19 +21,12 @@ limitations under the License. */ DEFINE_string(dirname, "", "Directory of the inference model."); -TEST(recognize_digits, CPU) { - if (FLAGS_dirname.empty()) { - LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; - } - - std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; - std::string dirname = FLAGS_dirname; - - // 0. Initialize all the devices - paddle::framework::InitDevices(); - +template +void TestInference(const std::string& dirname, + const std::vector& cpu_feeds, + std::vector& cpu_fetchs) { // 1. Define place, executor and scope - auto place = paddle::platform::CPUPlace(); + auto place = Place(); auto executor = paddle::framework::Executor(place); auto* scope = new paddle::framework::Scope(); @@ -49,37 +42,77 @@ TEST(recognize_digits, CPU) { // 4. Prepare inputs std::map feed_targets; - paddle::framework::LoDTensor input; - srand(time(0)); - float* input_ptr = - input.mutable_data({1, 28, 28}, paddle::platform::CPUPlace()); - for (int i = 0; i < 784; ++i) { - input_ptr[i] = rand() / (static_cast(RAND_MAX)); + for (size_t i = 0; i < feed_target_names.size(); ++i) { + // Please make sure that cpu_feeds[i] is right for feed_target_names[i] + feed_targets[feed_target_names[i]] = cpu_feeds[i]; } - feed_targets[feed_target_names[0]] = &input; // 5. Define Tensor to get the outputs std::map fetch_targets; - paddle::framework::LoDTensor output; - fetch_targets[fetch_target_names[0]] = &output; + for (size_t i = 0; i < fetch_target_names.size(); ++i) { + fetch_targets[fetch_target_names[i]] = cpu_fetchs[i]; + } // 6. Run the inference program executor.Run(*inference_program, scope, feed_targets, fetch_targets); - // 7. Use the output as your expect. - LOG(INFO) << output.dims(); - std::stringstream ss; - ss << "result:"; - float* output_ptr = output.data(); - for (int j = 0; j < output.numel(); ++j) { - ss << " " << output_ptr[j]; - } - LOG(INFO) << ss.str(); - delete scope; delete engine; } +TEST(inference, recognize_digits) { + if (FLAGS_dirname.empty()) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + } + + LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + + // 0. Initialize all the devices + paddle::framework::InitDevices(); + + paddle::framework::LoDTensor input; + srand(time(0)); + float* input_ptr = + input.mutable_data({1, 28, 28}, paddle::platform::CPUPlace()); + for (int i = 0; i < 784; ++i) { + input_ptr[i] = rand() / (static_cast(RAND_MAX)); + } + std::vector cpu_feeds; + cpu_feeds.push_back(&input); + + paddle::framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); + + // Run inference on CPU + TestInference( + FLAGS_dirname, cpu_feeds, cpu_fetchs1); + LOG(INFO) << output1.dims(); + +#ifdef PADDLE_WITH_CUDA + paddle::framework::LoDTensor output2; + std::vector cpu_fetchs2; + cpu_fetchs2.push_back(&output2); + + // Run inference on CUDA GPU + TestInference( + FLAGS_dirname, cpu_feeds, cpu_fetchs2); + LOG(INFO) << output2.dims(); + + EXPECT_EQ(output1.dims(), output2.dims()); + EXPECT_EQ(output1.numel(), output2.numel()); + + float err = 1E-3; + int count = 0; + for (int64_t i = 0; i < output1.numel(); ++i) { + if (fabs(output1.data()[i] - output2.data()[i]) > err) { + count++; + } + } + EXPECT_EQ(count, 0) << "There are " << count << " different elements."; +#endif +} + int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, false); testing::InitGoogleTest(&argc, argv);