From 8f2e556e65ceaf4e530bcbe0055b33b236c73e17 Mon Sep 17 00:00:00 2001 From: ZhenWang Date: Thu, 29 Nov 2018 19:33:47 +0800 Subject: [PATCH] support the small dam model. test=develop --- paddle/fluid/inference/tests/api/CMakeLists.txt | 4 +++- paddle/fluid/inference/tests/api/analyzer_dam_tester.cc | 9 ++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 7dc88d9dd..b54693d94 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -48,7 +48,9 @@ inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2 # DAM set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam") -download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz") +# For the normal DAM model +# download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz") +download_model_and_data(${DAM_INSTALL_DIR} "small_dam_model.tar.gz" "small_dam_data.txt.tar.gz") inference_analysis_api_test(test_analyzer_dam ${DAM_INSTALL_DIR} analyzer_dam_tester.cc) # chinese_ner diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index b369cba5c..b5a68538a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -17,7 +17,7 @@ namespace paddle { namespace inference { using contrib::AnalysisConfig; -#define MAX_TURN_NUM 9 +#define MAX_TURN_NUM 1 #define MAX_TURN_LEN 50 static std::vector result_data; @@ -148,8 +148,7 @@ void PrepareInputs(std::vector *input_slots, DataRecord *data, } void SetConfig(contrib::AnalysisConfig *cfg) { - cfg->prog_file = FLAGS_infer_model + "/__model__"; - cfg->param_file = FLAGS_infer_model + "/param"; + cfg->model_dir = FLAGS_infer_model; cfg->use_gpu = false; cfg->device = 0; cfg->specify_input_name = true; @@ -202,8 +201,8 @@ TEST(Analyzer_dam, fuse_statis) { auto fuse_statis = GetFuseStatis( static_cast(predictor.get()), &num_ops); ASSERT_TRUE(fuse_statis.count("fc_fuse")); - EXPECT_EQ(fuse_statis.at("fc_fuse"), 317); - EXPECT_EQ(num_ops, 2020); + EXPECT_EQ(fuse_statis.at("fc_fuse"), 45); + EXPECT_EQ(num_ops, 292); } // Compare result of NativeConfig and AnalysisConfig -- GitLab