From 250206d1cfeafa74c353ed167b6b5852f8ccec3e Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 23 Jan 2018 10:44:28 +0000 Subject: [PATCH] Change the example of inference to a unittest. --- paddle/inference/CMakeLists.txt | 17 ++-------- paddle/inference/tests/book/CMakeLists.txt | 13 +++++++ .../test_inference_recognize_digits_mlp.cc} | 34 +++++++++---------- 3 files changed, 31 insertions(+), 33 deletions(-) create mode 100644 paddle/inference/tests/book/CMakeLists.txt rename paddle/inference/{example.cc => tests/book/test_inference_recognize_digits_mlp.cc} (72%) diff --git a/paddle/inference/CMakeLists.txt b/paddle/inference/CMakeLists.txt index ae4d3fd2f5..fedf9e4cb8 100644 --- a/paddle/inference/CMakeLists.txt +++ b/paddle/inference/CMakeLists.txt @@ -24,19 +24,6 @@ if(NOT WITH_C_API AND WITH_FLUID) install(TARGETS paddle_fluid_shared DESTINATION lib) endif() -add_executable(example example.cc) -if(APPLE) - set(OPTIONAL_LINK_FLAGS) - if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") - set(OPTIONAL_LINK_FLAGS "-undefined dynamic_lookup") - endif() - target_link_libraries(example - -Wl,-force_load paddle_fluid - ${OPTIONAL_LINK_FLAGS} - ${PTOOLS_LIB}) -else() - target_link_libraries(example - -Wl,--start-group -Wl,--whole-archive paddle_fluid - -Wl,--no-whole-archive -Wl,--end-group - ${PTOOLS_LIB}) +if(WITH_TESTING) + add_subdirectory(tests/book) endif() diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt new file mode 100644 index 0000000000..31e6796fdb --- /dev/null +++ b/paddle/inference/tests/book/CMakeLists.txt @@ -0,0 +1,13 @@ +add_executable(test_inference_recognize_digits_mlp test_inference_recognize_digits_mlp.cc) +target_circle_link_libraries( + test_inference_recognize_digits_mlp + ARCHIVE_START + paddle_fluid + ARCHIVE_END + gtest + gflags) +add_test( + NAME test_inference_recognize_digits_mlp + COMMAND test_inference_recognize_digits_mlp + --dirname=${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/tests/book/recognize_digits_mlp.inference.model + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/paddle/inference/example.cc b/paddle/inference/tests/book/test_inference_recognize_digits_mlp.cc similarity index 72% rename from paddle/inference/example.cc rename to paddle/inference/tests/book/test_inference_recognize_digits_mlp.cc index 0c18b45624..e96af21344 100644 --- a/paddle/inference/example.cc +++ b/paddle/inference/tests/book/test_inference_recognize_digits_mlp.cc @@ -12,20 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include -#include +#include #include "gflags/gflags.h" #include "paddle/inference/inference.h" DEFINE_string(dirname, "", "Directory of the inference model."); -int main(int argc, char** argv) { - google::ParseCommandLineFlags(&argc, &argv, true); +TEST(inference, recognize_digits_mlp) { if (FLAGS_dirname.empty()) { - // Example: - // ./example --dirname=recognize_digits_mlp.inference.model - std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl; - exit(1); + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; } std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; @@ -48,20 +45,21 @@ int main(int argc, char** argv) { engine->Execute(feeds, fetchs); for (size_t i = 0; i < fetchs.size(); ++i) { - auto dims_i = fetchs[i].dims(); - std::cout << "dims_i:"; - for (int j = 0; j < dims_i.size(); ++j) { - std::cout << " " << dims_i[j]; - } - std::cout << std::endl; - std::cout << "result:"; + LOG(INFO) << fetchs[i].dims(); + std::stringstream ss; + ss << "result:"; float* output_ptr = fetchs[i].data(); - for (int j = 0; j < paddle::framework::product(dims_i); ++j) { - std::cout << " " << output_ptr[j]; + for (int j = 0; j < fetchs[i].numel(); ++j) { + ss << " " << output_ptr[j]; } - std::cout << std::endl; + LOG(INFO) << ss.str(); } delete engine; - return 0; +} + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, false); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } -- GitLab