未验证 提交 b8d2a021 编写于 作者: Q Qi Li 提交者: GitHub

fix ut error of test_recognize_digits, test=develop (#27791)

上级 c4b1faa4
...@@ -4,37 +4,26 @@ function(train_test TARGET_NAME) ...@@ -4,37 +4,26 @@ function(train_test TARGET_NAME)
set(multiValueArgs ARGS) set(multiValueArgs ARGS)
cmake_parse_arguments(train_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(train_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(arg_list "") if (NOT APPLE AND NOT WIN32)
if(train_test_ARGS) cc_test(test_train_${TARGET_NAME}
foreach(arg ${train_test_ARGS}) SRCS test_train_${TARGET_NAME}.cc
list(APPEND arg_list "_${arg}") DEPS paddle_fluid_shared
endforeach() ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
else() else()
list(APPEND arg_list "_") cc_test(test_train_${TARGET_NAME}${arg}
SRCS test_train_${TARGET_NAME}.cc
DEPS paddle_fluid_api
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
endif()
set_tests_properties(test_train_${TARGET_NAME}
PROPERTIES FIXTURES_REQUIRED test_${TARGET_NAME}_infer_model)
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_train_${TARGET_NAME}
PROPERTIES TIMEOUT 150)
endif() endif()
foreach(arg ${arg_list})
string(REGEX REPLACE "^_$" "" arg "${arg}")
if (NOT APPLE AND NOT WIN32)
cc_test(test_train_${TARGET_NAME}${arg}
SRCS test_train_${TARGET_NAME}.cc
DEPS paddle_fluid_shared
ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.train.model/)
else()
cc_test(test_train_${TARGET_NAME}${arg}
SRCS test_train_${TARGET_NAME}.cc
DEPS paddle_fluid_api
ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.train.model/)
endif()
set_tests_properties(test_train_${TARGET_NAME}${arg}
PROPERTIES FIXTURES_REQUIRED test_${TARGET_NAME}_infer_model)
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_train_${TARGET_NAME}${arg}
PROPERTIES TIMEOUT 150)
endif()
endforeach()
endfunction(train_test) endfunction(train_test)
if(WITH_TESTING) if(WITH_TESTING)
train_test(recognize_digits ARGS mlp conv) train_test(recognize_digits)
endif() endif()
...@@ -32,16 +32,15 @@ DEFINE_string(dirname, "", "Directory of the train model."); ...@@ -32,16 +32,15 @@ DEFINE_string(dirname, "", "Directory of the train model.");
namespace paddle { namespace paddle {
void Train() { void Train(std::string model_dir) {
CHECK(!FLAGS_dirname.empty());
framework::InitDevices(false); framework::InitDevices(false);
const auto cpu_place = platform::CPUPlace(); const auto cpu_place = platform::CPUPlace();
framework::Executor executor(cpu_place); framework::Executor executor(cpu_place);
framework::Scope scope; framework::Scope scope;
auto train_program = inference::Load( auto train_program = inference::Load(
&executor, &scope, FLAGS_dirname + "__model_combined__.main_program", &executor, &scope, model_dir + "__model_combined__.main_program",
FLAGS_dirname + "__params_combined__"); model_dir + "__params_combined__");
std::string loss_name = ""; std::string loss_name = "";
for (auto op_desc : train_program->Block(0).AllOps()) { for (auto op_desc : train_program->Block(0).AllOps()) {
...@@ -87,6 +86,10 @@ void Train() { ...@@ -87,6 +86,10 @@ void Train() {
EXPECT_LT(last_loss, first_loss); EXPECT_LT(last_loss, first_loss);
} }
TEST(train, recognize_digits) { Train(); } TEST(train, recognize_digits) {
CHECK(!FLAGS_dirname.empty());
Train(FLAGS_dirname + "recognize_digits_mlp.train.model/");
Train(FLAGS_dirname + "recognize_digits_conv.train.model/");
}
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册