diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 423c39813f05af0d6aaade184914e6777c9b8a83..fa2e19bc4c3b80f58e5d6c63eef4e7592b0cbfea 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -115,6 +115,11 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR}) endif() inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc SERIAL) +# bert +set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert") +download_model_and_data(${BERT_INSTALL_DIR} "bert_model.tar.gz" "bert_data.txt.tar.gz") +inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc) + # resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" SERIAL) diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..709d51388d9d98b5642cdb716f29f2e0a622651e --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -0,0 +1,217 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +DEFINE_int32(repeat, 1, "repeat"); + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; +using paddle::contrib::AnalysisConfig; + +template +void GetValueFromStream(std::stringstream *ss, T *t) { + (*ss) >> (*t); +} + +template <> +void GetValueFromStream(std::stringstream *ss, std::string *t) { + *t = ss->str(); +} + +// Split string to vector +template +void Split(const std::string &line, char sep, std::vector *v) { + std::stringstream ss; + T t; + for (auto c : line) { + if (c != sep) { + ss << c; + } else { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } + } + + if (!ss.str().empty()) { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } +} + +template +constexpr paddle::PaddleDType GetPaddleDType(); + +template <> +constexpr paddle::PaddleDType GetPaddleDType() { + return paddle::PaddleDType::INT64; +} + +template <> +constexpr paddle::PaddleDType GetPaddleDType() { + return paddle::PaddleDType::FLOAT32; +} + +// Parse tensor from string +template +bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { + std::vector data; + Split(field, ':', &data); + if (data.size() < 2) return false; + + std::string shape_str = data[0]; + + std::vector shape; + Split(shape_str, ' ', &shape); + + std::string mat_str = data[1]; + + std::vector mat; + Split(mat_str, ' ', &mat); + + tensor->shape = shape; + auto size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * + sizeof(T); + tensor->data.Resize(size); + std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); + tensor->dtype = GetPaddleDType(); + + return true; +} + +// Parse input tensors from string +bool ParseLine(const std::string &line, + std::vector *tensors) { + std::vector fields; + Split(line, ';', &fields); + + if (fields.size() < 5) return false; + + tensors->clear(); + tensors->reserve(5); + + int i = 0; + // src_id + paddle::PaddleTensor src_id; + ParseTensor(fields[i++], &src_id); + tensors->push_back(src_id); + + // pos_id + paddle::PaddleTensor pos_id; + ParseTensor(fields[i++], &pos_id); + tensors->push_back(pos_id); + + // segment_id + paddle::PaddleTensor segment_id; + ParseTensor(fields[i++], &segment_id); + tensors->push_back(segment_id); + + // self_attention_bias + paddle::PaddleTensor self_attention_bias; + ParseTensor(fields[i++], &self_attention_bias); + tensors->push_back(self_attention_bias); + + // next_segment_index + paddle::PaddleTensor next_segment_index; + ParseTensor(fields[i++], &next_segment_index); + tensors->push_back(next_segment_index); + + return true; +} + +// Print outputs to log +void PrintOutputs(const std::vector &outputs) { + LOG(INFO) << "example_id\tcontradiction\tentailment\tneutral"; + + for (size_t i = 0; i < outputs.front().data.length(); i += 3) { + LOG(INFO) << (i / 3) << "\t" + << static_cast(outputs.front().data.data())[i] << "\t" + << static_cast(outputs.front().data.data())[i + 1] + << "\t" + << static_cast(outputs.front().data.data())[i + 2]; + } +} + +bool LoadInputData(std::vector> *inputs) { + if (FLAGS_infer_data.empty()) { + LOG(ERROR) << "please set input data path"; + return false; + } + + std::ifstream fin(FLAGS_infer_data); + std::string line; + + int lineno = 0; + while (std::getline(fin, line)) { + std::vector feed_data; + if (!ParseLine(line, &feed_data)) { + LOG(ERROR) << "Parse line[" << lineno << "] error!"; + } else { + inputs->push_back(std::move(feed_data)); + } + } + + return true; +} + +void SetConfig(contrib::AnalysisConfig *config) { + config->SetModel(FLAGS_infer_model); +} + +void profile(bool use_mkldnn = false) { + contrib::AnalysisConfig config; + SetConfig(&config); + + if (use_mkldnn) { + config.EnableMKLDNN(); + } + + std::vector outputs; + std::vector> inputs; + LoadInputData(&inputs); + TestPrediction(reinterpret_cast(&config), + inputs, &outputs, FLAGS_num_threads); +} + +void compare(bool use_mkldnn = false) { + AnalysisConfig config; + SetConfig(&config); + + std::vector> inputs; + LoadInputData(&inputs); + CompareNativeAndAnalysis( + reinterpret_cast(&config), inputs); +} + +TEST(Analyzer_bert, profile) { profile(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_bert, profile_mkldnn) { profile(true); } +#endif +} // namespace inference +} // namespace paddle