提交 6ac2e079 编写于 作者: L Liu Yiqun

Enable whole-archive flag in cc_test and use cc_test to rewrite the...

Enable whole-archive flag in cc_test and use cc_test to rewrite the CMakeLists.txt of inference unittest.
上级 eca58a62
......@@ -224,12 +224,18 @@ function(cc_test TARGET_NAME)
if(WITH_TESTING)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS})
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
# Support linking flags: --whole-archive (Linux) / -force_load (MacOS)
target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
if("${cc_test_DEPS}" MATCHES "ARCHIVE_START")
list(REMOVE_ITEM cc_test_DEPS ARCHIVE_START ARCHIVE_END)
endif()
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction(cc_test)
......@@ -457,7 +463,7 @@ endfunction()
function(py_test TARGET_NAME)
if(WITH_TESTING)
set(options STATIC static SHARED shared)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
......
set(PYTHON_TESTS_DIR ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/tests)
add_executable(test_inference_recognize_digits test_inference_recognize_digits.cc)
target_circle_link_libraries(
test_inference_recognize_digits
ARCHIVE_START
paddle_fluid
ARCHIVE_END
gtest
gflags)
add_test(
NAME test_inference_recognize_digits_mlp
COMMAND test_inference_recognize_digits
--dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
cc_test(test_inference_recognize_digits_mlp
SRCS test_inference_recognize_digits.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model)
set_tests_properties(test_inference_recognize_digits_mlp
PROPERTIES DEPENDS test_recognize_digits_mlp_cpu)
......@@ -66,9 +66,10 @@ TEST(inference, recognize_digits) {
}
LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::string dirname = FLAGS_dirname;
// 0. Initialize all the devices
paddle::framework::InitDevices();
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
paddle::framework::LoDTensor input;
srand(time(0));
......@@ -86,7 +87,7 @@ TEST(inference, recognize_digits) {
// Run inference on CPU
TestInference<paddle::platform::CPUPlace, float>(
FLAGS_dirname, cpu_feeds, cpu_fetchs1);
dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
......@@ -96,7 +97,7 @@ TEST(inference, recognize_digits) {
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace, float>(
FLAGS_dirname, cpu_feeds, cpu_fetchs2);
dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.dims();
EXPECT_EQ(output1.dims(), output2.dims());
......@@ -112,9 +113,3 @@ TEST(inference, recognize_digits) {
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
#endif
}
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, false);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -22,7 +22,9 @@ limitations under the License. */
int main(int argc, char** argv) {
std::vector<char*> new_argv;
std::string gflags_env;
new_argv.push_back(argv[0]);
for (int i = 0; i < argc; ++i) {
new_argv.push_back(argv[i]);
}
#ifdef PADDLE_WITH_CUDA
new_argv.push_back(
strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"));
......
......@@ -163,10 +163,8 @@ def infer(args, save_dirname=None):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
if args.nn_type == 'mlp':
tensor_img = numpy.random.rand(1, 28, 28).astype("float32")
else:
tensor_img = numpy.random.rand(1, 1, 28, 28).astype("float32")
# The input's dimension of conv should be 4-D or 5-D.
tensor_img = numpy.random.rand(1, 1, 28, 28).astype("float32")
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册