未验证 提交 6de0a18d 编写于 作者: Y Yan Chunwei 提交者: GitHub

Refine/text classification support data (#13256)

上级 11b22883
...@@ -100,12 +100,17 @@ inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc ...@@ -100,12 +100,17 @@ inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc
set(TEXT_CLASSIFICATION_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/text-classification-Senta.tar.gz") set(TEXT_CLASSIFICATION_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/text-classification-Senta.tar.gz")
set(TEXT_CLASSIFICATION_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/text_classification_data.txt.tar.gz")
set(TEXT_CLASSIFICATION_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/text_classification" CACHE PATH "Text Classification model and data root." FORCE) set(TEXT_CLASSIFICATION_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/text_classification" CACHE PATH "Text Classification model and data root." FORCE)
if (NOT EXISTS ${TEXT_CLASSIFICATION_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE) if (NOT EXISTS ${TEXT_CLASSIFICATION_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_MODEL_URL} "text-classification-Senta.tar.gz") inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_MODEL_URL} "text-classification-Senta.tar.gz")
inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_DATA_URL} "text_classification_data.txt.tar.gz")
endif() endif()
inference_analysis_test(test_text_classification SRCS analyzer_text_classification_tester.cc inference_analysis_test(test_text_classification SRCS analyzer_text_classification_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api analysis_predictor EXTRA_DEPS paddle_inference_api paddle_fluid_api analysis_predictor
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta) ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta
--infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt
--topn=1 # Just run top 1 batch.
)
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <glog/logging.h> // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files. #include <glog/logging.h> // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <fstream>
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/api/timer.h" #include "paddle/fluid/inference/api/timer.h"
...@@ -26,6 +28,7 @@ DEFINE_string(infer_model, "", "Directory of the inference model."); ...@@ -26,6 +28,7 @@ DEFINE_string(infer_model, "", "Directory of the inference model.");
DEFINE_string(infer_data, "", "Path of the dataset."); DEFINE_string(infer_data, "", "Path of the dataset.");
DEFINE_int32(batch_size, 1, "batch size."); DEFINE_int32(batch_size, 1, "batch size.");
DEFINE_int32(repeat, 1, "How many times to repeat run."); DEFINE_int32(repeat, 1, "How many times to repeat run.");
DEFINE_int32(topn, -1, "Run top n batches of data to save time");
namespace paddle { namespace paddle {
...@@ -45,41 +48,67 @@ void PrintTime(const double latency, const int bs, const int repeat) { ...@@ -45,41 +48,67 @@ void PrintTime(const double latency, const int bs, const int repeat) {
LOG(INFO) << "====================================="; LOG(INFO) << "=====================================";
} }
void Main(int batch_size) { struct DataReader {
// Three sequence inputs. DataReader(const std::string &path) : file(new std::ifstream(path)) {}
std::vector<PaddleTensor> input_slots(1);
// one batch starts bool NextBatch(PaddleTensor *tensor, int batch_size) {
// data -- PADDLE_ENFORCE_EQ(batch_size, 1);
int64_t data0[] = {0, 1, 2}; std::string line;
for (auto &input : input_slots) { tensor->lod.clear();
input.data.Reset(data0, sizeof(data0)); tensor->lod.emplace_back(std::vector<size_t>({0}));
input.shape = std::vector<int>({3, 1}); std::vector<int64_t> data;
// dtype --
input.dtype = PaddleDType::INT64; for (int i = 0; i < batch_size; i++) {
// LoD -- if (!std::getline(*file, line)) return false;
input.lod = std::vector<std::vector<size_t>>({{0, 3}}); inference::split_to_int64(line, ' ', &data);
}
tensor->lod.front().push_back(data.size());
tensor->data.Resize(data.size() * sizeof(int64_t));
memcpy(tensor->data.data(), data.data(), data.size() * sizeof(int64_t));
tensor->shape.clear();
tensor->shape.push_back(data.size());
tensor->shape.push_back(1);
return true;
} }
std::unique_ptr<std::ifstream> file;
};
void Main(int batch_size) {
// shape -- // shape --
// Create Predictor -- // Create Predictor --
AnalysisConfig config; AnalysisConfig config;
config.model_dir = FLAGS_infer_model; config.model_dir = FLAGS_infer_model;
config.use_gpu = false; config.use_gpu = false;
config.enable_ir_optim = true; config.enable_ir_optim = true;
config.ir_passes.push_back("fc_lstm_fuse_pass");
auto predictor = auto predictor =
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>( CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config); config);
std::vector<PaddleTensor> input_slots(1);
// one batch starts
// data --
auto &input = input_slots[0];
input.dtype = PaddleDType::INT64;
inference::Timer timer; inference::Timer timer;
double sum = 0; double sum = 0;
std::vector<PaddleTensor> output_slots; std::vector<PaddleTensor> output_slots;
for (int i = 0; i < FLAGS_repeat; i++) {
timer.tic(); int num_batches = 0;
CHECK(predictor->Run(input_slots, &output_slots)); for (int t = 0; t < FLAGS_repeat; t++) {
sum += timer.toc(); DataReader reader(FLAGS_infer_data);
while (reader.NextBatch(&input, FLAGS_batch_size)) {
if (FLAGS_topn > 0 && num_batches > FLAGS_topn) break;
timer.tic();
CHECK(predictor->Run(input_slots, &output_slots));
sum += timer.toc();
++num_batches;
}
} }
PrintTime(sum, batch_size, FLAGS_repeat);
PrintTime(sum, batch_size, num_batches);
// Get output // Get output
LOG(INFO) << "get outputs " << output_slots.size(); LOG(INFO) << "get outputs " << output_slots.size();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册