提交 8f2e556e 编写于 作者: Z ZhenWang

support the small dam model. test=develop

上级 20120d9c
...@@ -48,7 +48,9 @@ inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2 ...@@ -48,7 +48,9 @@ inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2
# DAM # DAM
set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam") set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam")
download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz") # For the normal DAM model
# download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz")
download_model_and_data(${DAM_INSTALL_DIR} "small_dam_model.tar.gz" "small_dam_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_dam ${DAM_INSTALL_DIR} analyzer_dam_tester.cc) inference_analysis_api_test(test_analyzer_dam ${DAM_INSTALL_DIR} analyzer_dam_tester.cc)
# chinese_ner # chinese_ner
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
using contrib::AnalysisConfig; using contrib::AnalysisConfig;
#define MAX_TURN_NUM 9 #define MAX_TURN_NUM 1
#define MAX_TURN_LEN 50 #define MAX_TURN_LEN 50
static std::vector<float> result_data; static std::vector<float> result_data;
...@@ -148,8 +148,7 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data, ...@@ -148,8 +148,7 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
} }
void SetConfig(contrib::AnalysisConfig *cfg) { void SetConfig(contrib::AnalysisConfig *cfg) {
cfg->prog_file = FLAGS_infer_model + "/__model__"; cfg->model_dir = FLAGS_infer_model;
cfg->param_file = FLAGS_infer_model + "/param";
cfg->use_gpu = false; cfg->use_gpu = false;
cfg->device = 0; cfg->device = 0;
cfg->specify_input_name = true; cfg->specify_input_name = true;
...@@ -202,8 +201,8 @@ TEST(Analyzer_dam, fuse_statis) { ...@@ -202,8 +201,8 @@ TEST(Analyzer_dam, fuse_statis) {
auto fuse_statis = GetFuseStatis( auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops); static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse")); ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 317); EXPECT_EQ(fuse_statis.at("fc_fuse"), 45);
EXPECT_EQ(num_ops, 2020); EXPECT_EQ(num_ops, 292);
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册