From 6ac2e079b37d0b3fd7165362f1749437aea8df5a Mon Sep 17 00:00:00 2001
From: Liu Yiqun <liuyiqun01@baidu.com>
Date: Mon, 29 Jan 2018 09:53:19 +0000
Subject: [PATCH] Enable whole-archive flag in cc_test and use cc_test to
 rewrite the CMakeLists.txt of inference unittest.

---
 cmake/generic.cmake                             | 14 ++++++++++----
 paddle/inference/tests/book/CMakeLists.txt      | 17 ++++-------------
 .../book/test_inference_recognize_digits.cc     | 15 +++++----------
 paddle/testing/paddle_gtest_main.cc             |  4 +++-
 .../fluid/tests/book/test_recognize_digits.py   |  6 ++----
 5 files changed, 24 insertions(+), 32 deletions(-)

diff --git a/cmake/generic.cmake b/cmake/generic.cmake
index 585db019d5..18770fe286 100644
--- a/cmake/generic.cmake
+++ b/cmake/generic.cmake
@@ -224,12 +224,18 @@ function(cc_test TARGET_NAME)
   if(WITH_TESTING)
     set(options "")
     set(oneValueArgs "")
-    set(multiValueArgs SRCS DEPS)
+    set(multiValueArgs SRCS DEPS ARGS)
     cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
     add_executable(${TARGET_NAME} ${cc_test_SRCS})
-    target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
+    # Support linking flags: --whole-archive (Linux) / -force_load (MacOS)
+    target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
+    if("${cc_test_DEPS}" MATCHES "ARCHIVE_START")
+      list(REMOVE_ITEM cc_test_DEPS ARCHIVE_START ARCHIVE_END)
+    endif()
     add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
-    add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+    add_test(NAME ${TARGET_NAME}
+             COMMAND ${TARGET_NAME} ${cc_test_ARGS}
+             WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
   endif()
 endfunction(cc_test)
 
@@ -457,7 +463,7 @@ endfunction()
 
 function(py_test TARGET_NAME)
   if(WITH_TESTING)
-    set(options STATIC static SHARED shared)
+    set(options "")
     set(oneValueArgs "")
     set(multiValueArgs SRCS DEPS ARGS)
     cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt
index 78083cc218..d3798fb8fd 100644
--- a/paddle/inference/tests/book/CMakeLists.txt
+++ b/paddle/inference/tests/book/CMakeLists.txt
@@ -1,16 +1,7 @@
 set(PYTHON_TESTS_DIR ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/tests)
-add_executable(test_inference_recognize_digits test_inference_recognize_digits.cc)
-target_circle_link_libraries(
-    test_inference_recognize_digits
-    ARCHIVE_START
-    paddle_fluid
-    ARCHIVE_END
-    gtest
-    gflags)
-add_test(
-    NAME test_inference_recognize_digits_mlp
-    COMMAND test_inference_recognize_digits
-            --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model
-    WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+cc_test(test_inference_recognize_digits_mlp
+    SRCS test_inference_recognize_digits.cc
+    DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
+    ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model)
 set_tests_properties(test_inference_recognize_digits_mlp
     PROPERTIES DEPENDS test_recognize_digits_mlp_cpu)
diff --git a/paddle/inference/tests/book/test_inference_recognize_digits.cc b/paddle/inference/tests/book/test_inference_recognize_digits.cc
index de15167ac3..45fbfe27a7 100644
--- a/paddle/inference/tests/book/test_inference_recognize_digits.cc
+++ b/paddle/inference/tests/book/test_inference_recognize_digits.cc
@@ -66,9 +66,10 @@ TEST(inference, recognize_digits) {
   }
 
   LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
+  std::string dirname = FLAGS_dirname;
 
-  // 0. Initialize all the devices
-  paddle::framework::InitDevices();
+  // 0. Call `paddle::framework::InitDevices()` initialize all the devices
+  // In unittests, this is done in paddle/testing/paddle_gtest_main.cc
 
   paddle::framework::LoDTensor input;
   srand(time(0));
@@ -86,7 +87,7 @@ TEST(inference, recognize_digits) {
 
   // Run inference on CPU
   TestInference<paddle::platform::CPUPlace, float>(
-      FLAGS_dirname, cpu_feeds, cpu_fetchs1);
+      dirname, cpu_feeds, cpu_fetchs1);
   LOG(INFO) << output1.dims();
 
 #ifdef PADDLE_WITH_CUDA
@@ -96,7 +97,7 @@ TEST(inference, recognize_digits) {
 
   // Run inference on CUDA GPU
   TestInference<paddle::platform::CUDAPlace, float>(
-      FLAGS_dirname, cpu_feeds, cpu_fetchs2);
+      dirname, cpu_feeds, cpu_fetchs2);
   LOG(INFO) << output2.dims();
 
   EXPECT_EQ(output1.dims(), output2.dims());
@@ -112,9 +113,3 @@ TEST(inference, recognize_digits) {
   EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
 #endif
 }
-
-int main(int argc, char** argv) {
-  google::ParseCommandLineFlags(&argc, &argv, false);
-  testing::InitGoogleTest(&argc, argv);
-  return RUN_ALL_TESTS();
-}
diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc
index a7fb50ee41..a2f21e37e4 100644
--- a/paddle/testing/paddle_gtest_main.cc
+++ b/paddle/testing/paddle_gtest_main.cc
@@ -22,7 +22,9 @@ limitations under the License. */
 int main(int argc, char** argv) {
   std::vector<char*> new_argv;
   std::string gflags_env;
-  new_argv.push_back(argv[0]);
+  for (int i = 0; i < argc; ++i) {
+    new_argv.push_back(argv[i]);
+  }
 #ifdef PADDLE_WITH_CUDA
   new_argv.push_back(
       strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"));
diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py
index d6e4675a24..b4b6020f58 100644
--- a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py
+++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py
@@ -163,10 +163,8 @@ def infer(args, save_dirname=None):
     [inference_program, feed_target_names,
      fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
 
-    if args.nn_type == 'mlp':
-        tensor_img = numpy.random.rand(1, 28, 28).astype("float32")
-    else:
-        tensor_img = numpy.random.rand(1, 1, 28, 28).astype("float32")
+    # The input's dimension of conv should be 4-D or 5-D.
+    tensor_img = numpy.random.rand(1, 1, 28, 28).astype("float32")
 
     # Construct feed as a dictionary of {feed_target_name: feed_target_data}
     # and results will contain a list of data corresponding to fetch_targets.
-- 
GitLab