From b8d2a021f0fa6e750dac4857ba7603a7b4fd440e Mon Sep 17 00:00:00 2001 From: Qi Li Date: Sat, 10 Oct 2020 12:17:06 +0800 Subject: [PATCH] fix ut error of test_recognize_digits, test=develop (#27791) --- paddle/fluid/train/CMakeLists.txt | 43 +++++++------------ .../train/test_train_recognize_digits.cc | 13 +++--- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/train/CMakeLists.txt b/paddle/fluid/train/CMakeLists.txt index d587081fba..ad4bc20f9f 100644 --- a/paddle/fluid/train/CMakeLists.txt +++ b/paddle/fluid/train/CMakeLists.txt @@ -4,37 +4,26 @@ function(train_test TARGET_NAME) set(multiValueArgs ARGS) cmake_parse_arguments(train_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - set(arg_list "") - if(train_test_ARGS) - foreach(arg ${train_test_ARGS}) - list(APPEND arg_list "_${arg}") - endforeach() + if (NOT APPLE AND NOT WIN32) + cc_test(test_train_${TARGET_NAME} + SRCS test_train_${TARGET_NAME}.cc + DEPS paddle_fluid_shared + ARGS --dirname=${PYTHON_TESTS_DIR}/book/) else() - list(APPEND arg_list "_") + cc_test(test_train_${TARGET_NAME}${arg} + SRCS test_train_${TARGET_NAME}.cc + DEPS paddle_fluid_api + ARGS --dirname=${PYTHON_TESTS_DIR}/book/) + endif() + set_tests_properties(test_train_${TARGET_NAME} + PROPERTIES FIXTURES_REQUIRED test_${TARGET_NAME}_infer_model) + if(NOT WIN32 AND NOT APPLE) + set_tests_properties(test_train_${TARGET_NAME} + PROPERTIES TIMEOUT 150) endif() - foreach(arg ${arg_list}) - string(REGEX REPLACE "^_$" "" arg "${arg}") - if (NOT APPLE AND NOT WIN32) - cc_test(test_train_${TARGET_NAME}${arg} - SRCS test_train_${TARGET_NAME}.cc - DEPS paddle_fluid_shared - ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.train.model/) - else() - cc_test(test_train_${TARGET_NAME}${arg} - SRCS test_train_${TARGET_NAME}.cc - DEPS paddle_fluid_api - ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.train.model/) - endif() - set_tests_properties(test_train_${TARGET_NAME}${arg} - PROPERTIES FIXTURES_REQUIRED test_${TARGET_NAME}_infer_model) - if(NOT WIN32 AND NOT APPLE) - set_tests_properties(test_train_${TARGET_NAME}${arg} - PROPERTIES TIMEOUT 150) - endif() - endforeach() endfunction(train_test) if(WITH_TESTING) - train_test(recognize_digits ARGS mlp conv) + train_test(recognize_digits) endif() diff --git a/paddle/fluid/train/test_train_recognize_digits.cc b/paddle/fluid/train/test_train_recognize_digits.cc index e7b698e1a3..fb993439bb 100644 --- a/paddle/fluid/train/test_train_recognize_digits.cc +++ b/paddle/fluid/train/test_train_recognize_digits.cc @@ -32,16 +32,15 @@ DEFINE_string(dirname, "", "Directory of the train model."); namespace paddle { -void Train() { - CHECK(!FLAGS_dirname.empty()); +void Train(std::string model_dir) { framework::InitDevices(false); const auto cpu_place = platform::CPUPlace(); framework::Executor executor(cpu_place); framework::Scope scope; auto train_program = inference::Load( - &executor, &scope, FLAGS_dirname + "__model_combined__.main_program", - FLAGS_dirname + "__params_combined__"); + &executor, &scope, model_dir + "__model_combined__.main_program", + model_dir + "__params_combined__"); std::string loss_name = ""; for (auto op_desc : train_program->Block(0).AllOps()) { @@ -87,6 +86,10 @@ void Train() { EXPECT_LT(last_loss, first_loss); } -TEST(train, recognize_digits) { Train(); } +TEST(train, recognize_digits) { + CHECK(!FLAGS_dirname.empty()); + Train(FLAGS_dirname + "recognize_digits_mlp.train.model/"); + Train(FLAGS_dirname + "recognize_digits_conv.train.model/"); +} } // namespace paddle -- GitLab