提交 8f2e556e 编写于 作者: Z ZhenWang

support the small dam model. test=develop

上级 20120d9c
......@@ -48,7 +48,9 @@ inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2
# DAM
set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam")
download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz")
# For the normal DAM model
# download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz")
download_model_and_data(${DAM_INSTALL_DIR} "small_dam_model.tar.gz" "small_dam_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_dam ${DAM_INSTALL_DIR} analyzer_dam_tester.cc)
# chinese_ner
......
......@@ -17,7 +17,7 @@
namespace paddle {
namespace inference {
using contrib::AnalysisConfig;
#define MAX_TURN_NUM 9
#define MAX_TURN_NUM 1
#define MAX_TURN_LEN 50
static std::vector<float> result_data;
......@@ -148,8 +148,7 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
void SetConfig(contrib::AnalysisConfig *cfg) {
cfg->prog_file = FLAGS_infer_model + "/__model__";
cfg->param_file = FLAGS_infer_model + "/param";
cfg->model_dir = FLAGS_infer_model;
cfg->use_gpu = false;
cfg->device = 0;
cfg->specify_input_name = true;
......@@ -202,8 +201,8 @@ TEST(Analyzer_dam, fuse_statis) {
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 317);
EXPECT_EQ(num_ops, 2020);
EXPECT_EQ(fuse_statis.at("fc_fuse"), 45);
EXPECT_EQ(num_ops, 292);
}
// Compare result of NativeConfig and AnalysisConfig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册