未验证 提交 b8d2a021 编写于 作者: Q Qi Li 提交者: GitHub

fix ut error of test_recognize_digits, test=develop (#27791)

上级 c4b1faa4
......@@ -4,37 +4,26 @@ function(train_test TARGET_NAME)
set(multiValueArgs ARGS)
cmake_parse_arguments(train_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(arg_list "")
if(train_test_ARGS)
foreach(arg ${train_test_ARGS})
list(APPEND arg_list "_${arg}")
endforeach()
else()
list(APPEND arg_list "_")
endif()
foreach(arg ${arg_list})
string(REGEX REPLACE "^_$" "" arg "${arg}")
if (NOT APPLE AND NOT WIN32)
cc_test(test_train_${TARGET_NAME}${arg}
cc_test(test_train_${TARGET_NAME}
SRCS test_train_${TARGET_NAME}.cc
DEPS paddle_fluid_shared
ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.train.model/)
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
else()
cc_test(test_train_${TARGET_NAME}${arg}
SRCS test_train_${TARGET_NAME}.cc
DEPS paddle_fluid_api
ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.train.model/)
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
endif()
set_tests_properties(test_train_${TARGET_NAME}${arg}
set_tests_properties(test_train_${TARGET_NAME}
PROPERTIES FIXTURES_REQUIRED test_${TARGET_NAME}_infer_model)
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_train_${TARGET_NAME}${arg}
set_tests_properties(test_train_${TARGET_NAME}
PROPERTIES TIMEOUT 150)
endif()
endforeach()
endfunction(train_test)
if(WITH_TESTING)
train_test(recognize_digits ARGS mlp conv)
train_test(recognize_digits)
endif()
......@@ -32,16 +32,15 @@ DEFINE_string(dirname, "", "Directory of the train model.");
namespace paddle {
void Train() {
CHECK(!FLAGS_dirname.empty());
void Train(std::string model_dir) {
framework::InitDevices(false);
const auto cpu_place = platform::CPUPlace();
framework::Executor executor(cpu_place);
framework::Scope scope;
auto train_program = inference::Load(
&executor, &scope, FLAGS_dirname + "__model_combined__.main_program",
FLAGS_dirname + "__params_combined__");
&executor, &scope, model_dir + "__model_combined__.main_program",
model_dir + "__params_combined__");
std::string loss_name = "";
for (auto op_desc : train_program->Block(0).AllOps()) {
......@@ -87,6 +86,10 @@ void Train() {
EXPECT_LT(last_loss, first_loss);
}
TEST(train, recognize_digits) { Train(); }
TEST(train, recognize_digits) {
CHECK(!FLAGS_dirname.empty());
Train(FLAGS_dirname + "recognize_digits_mlp.train.model/");
Train(FLAGS_dirname + "recognize_digits_conv.train.model/");
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册