From 6ac2e079b37d0b3fd7165362f1749437aea8df5a Mon Sep 17 00:00:00 2001 From: Liu Yiqun <liuyiqun01@baidu.com> Date: Mon, 29 Jan 2018 09:53:19 +0000 Subject: [PATCH] Enable whole-archive flag in cc_test and use cc_test to rewrite the CMakeLists.txt of inference unittest. --- cmake/generic.cmake | 14 ++++++++++---- paddle/inference/tests/book/CMakeLists.txt | 17 ++++------------- .../book/test_inference_recognize_digits.cc | 15 +++++---------- paddle/testing/paddle_gtest_main.cc | 4 +++- .../fluid/tests/book/test_recognize_digits.py | 6 ++---- 5 files changed, 24 insertions(+), 32 deletions(-) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 585db019d5..18770fe286 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -224,12 +224,18 @@ function(cc_test TARGET_NAME) if(WITH_TESTING) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS DEPS) + set(multiValueArgs SRCS DEPS ARGS) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_executable(${TARGET_NAME} ${cc_test_SRCS}) - target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) + # Support linking flags: --whole-archive (Linux) / -force_load (MacOS) + target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) + if("${cc_test_DEPS}" MATCHES "ARCHIVE_START") + list(REMOVE_ITEM cc_test_DEPS ARCHIVE_START ARCHIVE_END) + endif() add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) - add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + add_test(NAME ${TARGET_NAME} + COMMAND ${TARGET_NAME} ${cc_test_ARGS} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endif() endfunction(cc_test) @@ -457,7 +463,7 @@ endfunction() function(py_test TARGET_NAME) if(WITH_TESTING) - set(options STATIC static SHARED shared) + set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DEPS ARGS) cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt index 78083cc218..d3798fb8fd 100644 --- a/paddle/inference/tests/book/CMakeLists.txt +++ b/paddle/inference/tests/book/CMakeLists.txt @@ -1,16 +1,7 @@ set(PYTHON_TESTS_DIR ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/tests) -add_executable(test_inference_recognize_digits test_inference_recognize_digits.cc) -target_circle_link_libraries( - test_inference_recognize_digits - ARCHIVE_START - paddle_fluid - ARCHIVE_END - gtest - gflags) -add_test( - NAME test_inference_recognize_digits_mlp - COMMAND test_inference_recognize_digits - --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +cc_test(test_inference_recognize_digits_mlp + SRCS test_inference_recognize_digits.cc + DEPS ARCHIVE_START paddle_fluid ARCHIVE_END + ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model) set_tests_properties(test_inference_recognize_digits_mlp PROPERTIES DEPENDS test_recognize_digits_mlp_cpu) diff --git a/paddle/inference/tests/book/test_inference_recognize_digits.cc b/paddle/inference/tests/book/test_inference_recognize_digits.cc index de15167ac3..45fbfe27a7 100644 --- a/paddle/inference/tests/book/test_inference_recognize_digits.cc +++ b/paddle/inference/tests/book/test_inference_recognize_digits.cc @@ -66,9 +66,10 @@ TEST(inference, recognize_digits) { } LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + std::string dirname = FLAGS_dirname; - // 0. Initialize all the devices - paddle::framework::InitDevices(); + // 0. Call `paddle::framework::InitDevices()` initialize all the devices + // In unittests, this is done in paddle/testing/paddle_gtest_main.cc paddle::framework::LoDTensor input; srand(time(0)); @@ -86,7 +87,7 @@ TEST(inference, recognize_digits) { // Run inference on CPU TestInference<paddle::platform::CPUPlace, float>( - FLAGS_dirname, cpu_feeds, cpu_fetchs1); + dirname, cpu_feeds, cpu_fetchs1); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -96,7 +97,7 @@ TEST(inference, recognize_digits) { // Run inference on CUDA GPU TestInference<paddle::platform::CUDAPlace, float>( - FLAGS_dirname, cpu_feeds, cpu_fetchs2); + dirname, cpu_feeds, cpu_fetchs2); LOG(INFO) << output2.dims(); EXPECT_EQ(output1.dims(), output2.dims()); @@ -112,9 +113,3 @@ TEST(inference, recognize_digits) { EXPECT_EQ(count, 0) << "There are " << count << " different elements."; #endif } - -int main(int argc, char** argv) { - google::ParseCommandLineFlags(&argc, &argv, false); - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index a7fb50ee41..a2f21e37e4 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -22,7 +22,9 @@ limitations under the License. */ int main(int argc, char** argv) { std::vector<char*> new_argv; std::string gflags_env; - new_argv.push_back(argv[0]); + for (int i = 0; i < argc; ++i) { + new_argv.push_back(argv[i]); + } #ifdef PADDLE_WITH_CUDA new_argv.push_back( strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory")); diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py index d6e4675a24..b4b6020f58 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py @@ -163,10 +163,8 @@ def infer(args, save_dirname=None): [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - if args.nn_type == 'mlp': - tensor_img = numpy.random.rand(1, 28, 28).astype("float32") - else: - tensor_img = numpy.random.rand(1, 1, 28, 28).astype("float32") + # The input's dimension of conv should be 4-D or 5-D. + tensor_img = numpy.random.rand(1, 1, 28, 28).astype("float32") # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. -- GitLab