diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 7dc88d9dd052c59aaa59b7802ee5a38ea9d89bc6..b54693d94804b5fced743dae6ebf6072aef6e402 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -48,7 +48,9 @@ inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2 # DAM set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam") -download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz") +# For the normal DAM model +# download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz") +download_model_and_data(${DAM_INSTALL_DIR} "small_dam_model.tar.gz" "small_dam_data.txt.tar.gz") inference_analysis_api_test(test_analyzer_dam ${DAM_INSTALL_DIR} analyzer_dam_tester.cc) # chinese_ner diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index b369cba5c8b3f8aadd1123d6b7345fad6e47bd0f..b5a68538ab2703d229199a910c3f4b1ffcd8b9c1 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -17,7 +17,7 @@ namespace paddle { namespace inference { using contrib::AnalysisConfig; -#define MAX_TURN_NUM 9 +#define MAX_TURN_NUM 1 #define MAX_TURN_LEN 50 static std::vector result_data; @@ -148,8 +148,7 @@ void PrepareInputs(std::vector *input_slots, DataRecord *data, } void SetConfig(contrib::AnalysisConfig *cfg) { - cfg->prog_file = FLAGS_infer_model + "/__model__"; - cfg->param_file = FLAGS_infer_model + "/param"; + cfg->model_dir = FLAGS_infer_model; cfg->use_gpu = false; cfg->device = 0; cfg->specify_input_name = true; @@ -202,8 +201,8 @@ TEST(Analyzer_dam, fuse_statis) { auto fuse_statis = GetFuseStatis( static_cast(predictor.get()), &num_ops); ASSERT_TRUE(fuse_statis.count("fc_fuse")); - EXPECT_EQ(fuse_statis.at("fc_fuse"), 317); - EXPECT_EQ(num_ops, 2020); + EXPECT_EQ(fuse_statis.at("fc_fuse"), 45); + EXPECT_EQ(num_ops, 292); } // Compare result of NativeConfig and AnalysisConfig