未验证 提交 245b1f05 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15570 from luotao1/bert

fix compiler error, use len20 dataset for bert
......@@ -128,10 +128,10 @@ inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" SERIAL)
# bert
set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert")
download_model_and_data(${BERT_INSTALL_DIR} "bert_model.tar.gz" "bert_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc)
# bert, max_len=20
set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert20")
download_model_and_data(${BERT_INSTALL_DIR} "bert_model.tar.gz" "bert_data_len20.txt.tar.gz")
inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc SERIAL)
# anakin
if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
......
......@@ -18,7 +18,6 @@ namespace paddle {
namespace inference {
using paddle::PaddleTensor;
using paddle::contrib::AnalysisConfig;
template <typename T>
void GetValueFromStream(std::stringstream *ss, T *t) {
......@@ -158,12 +157,10 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
return true;
}
void SetConfig(contrib::AnalysisConfig *config) {
config->SetModel(FLAGS_infer_model);
}
void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); }
void profile(bool use_mkldnn = false) {
contrib::AnalysisConfig config;
AnalysisConfig config;
SetConfig(&config);
if (use_mkldnn) {
......@@ -213,17 +210,14 @@ TEST(Analyzer_bert, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
// Compare Deterministic result
// TODO(luotao): Since each unit-test on CI only have 10 minutes, cancel this to
// decrease the CI time.
// TEST(Analyzer_bert, compare_determine) {
// AnalysisConfig cfg;
// SetConfig(&cfg);
//
// std::vector<std::vector<PaddleTensor>> inputs;
// LoadInputData(&inputs);
// CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config
// *>(&cfg),
// inputs);
// }
TEST(Analyzer_bert, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
inputs);
}
} // namespace inference
} // namespace paddle
......@@ -20,7 +20,6 @@ namespace paddle {
namespace inference {
using namespace framework; // NOLINT
using namespace contrib; // NOLINT
struct DataRecord {
std::vector<std::vector<std::vector<float>>> link_step_data_all;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册