From b65e947d404c33f4be997e5ae9321e4f0c84c19d Mon Sep 17 00:00:00 2001 From: xulongteng Date: Wed, 18 Sep 2019 18:58:39 +0800 Subject: [PATCH] bert client and OP --- demo-client/src/bert_service.cpp | 68 ++++++++++++++++++++--- demo-serving/op/bert_service_op.cpp | 85 +++++++++++++++++++++++++++-- 2 files changed, 140 insertions(+), 13 deletions(-) diff --git a/demo-client/src/bert_service.cpp b/demo-client/src/bert_service.cpp index b519a432..c0b89deb 100644 --- a/demo-client/src/bert_service.cpp +++ b/demo-client/src/bert_service.cpp @@ -31,7 +31,7 @@ using baidu::paddle_serving::predictor::bert_service::BertResInstance; using baidu::paddle_serving::predictor::bert_service::BertReqInstance; using baidu::paddle_serving::predictor::bert_service::Embedding_values; -int batch_size = 1; +int batch_size = 49; int max_seq_len = 82; int layer_num = 12; int emb_size = 768; @@ -95,7 +95,54 @@ int create_req(Request* req, } */ +int create_req(Request* req, + const std::vector& data_list, + int data_index, + int batch_size) { + // add data + // avoid out of boundary + int cur_index = data_index; + if (cur_index >= data_list.size()) { + cur_index = cur_index % data_list.size(); + } + + std::vector feature_list = split(data_list[cur_index], ";"); + + std::vector src_field = split(feature_list[0], ":"); + std::vector src_ids = split(src_field[1], " "); + + std::vector pos_field = split(feature_list[1], ":"); + std::vector pos_ids = split(pos_field[1], " "); + + std::vector sent_field = split(feature_list[2], ":"); + std::vector sent_ids = split(sent_field[1], " "); + + std::vector mask_field = split(feature_list[3], ":"); + std::vector input_mask = split(mask_field[1], " "); + + std::vector shape; + std::vector shapes = split(src_field[0], " "); + for (auto x: shapes) { + shape.push_back(std::stoi(x)); + } + + for (int i = 0; i < batch_size && i < shape[0]; ++i) { + BertReqInstance* ins = req->add_instances(); + if (!ins) { + LOG(ERROR) << "Failed create req instance"; + return -1; + } + for (int fi = 0; fi < max_seq_len; fi++) { + ins->add_token_ids(std::stoi(src_ids[i * max_seq_len + fi])); + ins->add_position_ids(std::stoi(pos_ids[i * max_seq_len + fi])); + ins->add_sentence_type_ids(std::stoi(sent_ids[i * max_seq_len + fi])); + ins->add_input_masks(std::stof(input_mask[i * max_seq_len + fi])); + } + } + return 0; +} +#if 0 int create_req(Request* req, const std::vector& data_list, int data_index, @@ -120,11 +167,11 @@ int create_req(Request* req, std::vector seg_list = split(feature_list[3], " "); std::vector mask_list = split(feature_list[4], " "); for (int fi = 0; fi < max_seq_len; fi++) { - if (fi < token_list.size()) { - ins->add_token_ids(std::stoi(token_list[fi])); - ins->add_sentence_type_ids(std::stoll(seg_list[fi])); - ins->add_position_ids(std::stoll(pos_list[fi])); - ins->add_input_masks(std::stof(mask_list[fi])); + if (fi < std::stoi(shape_list[1])) { + ins->add_token_ids(std::stoi(token_list[fi + (i * max_seq_len)])); + ins->add_sentence_type_ids(std::stoll(seg_list[fi + (i * max_seq_len)])); + ins->add_position_ids(std::stoll(pos_list[fi + (i * max_seq_len)])); + ins->add_input_masks(std::stof(mask_list[fi + (i * max_seq_len)])); } else { ins->add_token_ids(0); ins->add_sentence_type_ids(0); @@ -135,6 +182,7 @@ int create_req(Request* req, } return 0; } +#endif void print_res(const Request& req, const Response& res, @@ -184,11 +232,17 @@ void thread_worker(PredictorApi* api, } g_concurrency++; LOG(INFO) << "Current concurrency " << g_concurrency.load(); +#if 0 int data_index = turns * batch_size; if (create_req(&req, data_list, data_index, batch_size) != 0) { return; } - if (predictor->inference(&req, &res) != 0) { +#else + if (create_req(&req, data_list, turns, batch_size) != 0) { + return; + } +#endif + if (predictor->inference(&req, &res) != 0) { LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString(); return; } diff --git a/demo-serving/op/bert_service_op.cpp b/demo-serving/op/bert_service_op.cpp index 5e0ba8ed..3ca21e64 100644 --- a/demo-serving/op/bert_service_op.cpp +++ b/demo-serving/op/bert_service_op.cpp @@ -17,6 +17,9 @@ #include #include "predictor/framework/infer.h" #include "predictor/framework/memory.h" +#if 1 +#include +#endif namespace baidu { namespace paddle_serving { namespace serving { @@ -28,7 +31,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance; using baidu::paddle_serving::predictor::bert_service::Request; using baidu::paddle_serving::predictor::bert_service::Embedding_values; -const uint32_t MAX_SEQ_LEN = 64; +const uint32_t MAX_SEQ_LEN = 82; const bool POOLING = true; const int LAYER_NUM = 12; const int EMB_SIZE = 768; @@ -105,24 +108,51 @@ int BertServiceOp::inference() { index += MAX_SEQ_LEN; } +#if 0 + int64_t *src_data = static_cast(src_ids.data.data()); + std::ostringstream oss; + oss << "src_ids: "; + for (int i = 0; i < MAX_SEQ_LEN * batch_size; ++i) { + oss << src_data[i] << " "; + } + LOG(INFO) << oss.str(); + +#endif in->push_back(src_ids); in->push_back(pos_ids); in->push_back(seg_ids); in->push_back(input_masks); TensorVector *out = butil::get_object(); +// TensorVector out; +/* if (!out) { LOG(ERROR) << "Failed get tls output object"; return -1; } - +*/ LOG(INFO) << "batch_size : " << batch_size; - LOG(INFO) << "MAX_SEQ_LEN : " << (*in)[0].shape[1]; - float* example = (float*)(*in)[3].data.data(); - for(uint32_t i = 0; i < MAX_SEQ_LEN; i++){ - LOG(INFO) << *(example + i); + for (int j = 0; j < 3; j ++) { + LOG(INFO) << "name : " << (*in)[j].name << " shape : " << (*in)[j].shape[0] + << " " << (*in)[j].shape[1] << " " << (*in)[j].shape[2]; + int64_t* example = (int64_t*)(*in)[j].data.data(); + std::ostringstream oss; + for(uint32_t i = MAX_SEQ_LEN * (batch_size - 1); i < MAX_SEQ_LEN * batch_size; i++){ + oss << *(example + i); + } + LOG(INFO) << "data : " << oss.str(); } + for (int j =3; j < 4; j++) { + LOG(INFO) << "name : " << (*in)[j].name << " shape : " << (*in)[j].shape[0] + << " " << (*in)[j].shape[1] << " " << (*in)[j].shape[2]; + float* example = (float*)(*in)[j].data.data(); + std::ostringstream oss; + for(uint32_t i = MAX_SEQ_LEN * (batch_size - 1); i < MAX_SEQ_LEN * batch_size; i++){ + oss << *(example + i); + } + LOG(INFO) << "data : " << oss.str(); + } if (predictor::InferManager::instance().infer( BERT_MODEL_NAME, in, out, batch_size)) { @@ -130,6 +160,13 @@ int BertServiceOp::inference() { return -1; } +/* + paddle::NativeConfig config; + config.model_dir = "./data/model/paddle/fluid/bert"; + auto predictor = CreatePaddlePredictor(config); + predictor->Run(*in, &out); +*/ +#if 0 // float *out_data = static_cast(out->at(0).data.data()); LOG(INFO) << "check point"; /* @@ -160,6 +197,42 @@ int BertServiceOp::inference() { out->clear(); butil::return_object(out); */ + +#else + float *out_data = static_cast(out->at(0).data.data()); + std::ostringstream oss; + oss << "Shape: ["; + + for (auto x: out->at(0).shape) { + oss << x << " "; + } + oss << "]"; + + LOG(INFO) << oss.str(); + + // Output shape is [batch_size x 3] + for (uint32_t bi = 0; bi < batch_size; bi++) { + BertResInstance *res_instance = res->add_instances(); + std::ostringstream oss; + oss << "Sample " << bi << " ["; + oss << out_data[bi * 3 + 0] << " " + << out_data[bi * 3 + 1] << " " + << out_data[bi * 3 + 2] << "]"; + LOG(INFO) << oss.str(); + } + + for (size_t i = 0; i < in->size(); ++i) { + (*in)[i].shape.clear(); + } + in->clear(); + butil::return_object(in); + + for (size_t i = 0; i < out->size(); ++i) { + (*out)[i].shape.clear(); + } + out->clear(); + butil::return_object(out); +#endif return 0; } -- GitLab