提交 b65e947d 编写于 作者: X xulongteng

bert client and OP

上级 a8985b1b
......@@ -31,7 +31,7 @@ using baidu::paddle_serving::predictor::bert_service::BertResInstance;
using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Embedding_values;
int batch_size = 1;
int batch_size = 49;
int max_seq_len = 82;
int layer_num = 12;
int emb_size = 768;
......@@ -95,7 +95,54 @@ int create_req(Request* req,
}
*/
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
int batch_size) {
// add data
// avoid out of boundary
int cur_index = data_index;
if (cur_index >= data_list.size()) {
cur_index = cur_index % data_list.size();
}
std::vector<std::string> feature_list = split(data_list[cur_index], ";");
std::vector<std::string> src_field = split(feature_list[0], ":");
std::vector<std::string> src_ids = split(src_field[1], " ");
std::vector<std::string> pos_field = split(feature_list[1], ":");
std::vector<std::string> pos_ids = split(pos_field[1], " ");
std::vector<std::string> sent_field = split(feature_list[2], ":");
std::vector<std::string> sent_ids = split(sent_field[1], " ");
std::vector<std::string> mask_field = split(feature_list[3], ":");
std::vector<std::string> input_mask = split(mask_field[1], " ");
std::vector<int> shape;
std::vector<std::string> shapes = split(src_field[0], " ");
for (auto x: shapes) {
shape.push_back(std::stoi(x));
}
for (int i = 0; i < batch_size && i < shape[0]; ++i) {
BertReqInstance* ins = req->add_instances();
if (!ins) {
LOG(ERROR) << "Failed create req instance";
return -1;
}
for (int fi = 0; fi < max_seq_len; fi++) {
ins->add_token_ids(std::stoi(src_ids[i * max_seq_len + fi]));
ins->add_position_ids(std::stoi(pos_ids[i * max_seq_len + fi]));
ins->add_sentence_type_ids(std::stoi(sent_ids[i * max_seq_len + fi]));
ins->add_input_masks(std::stof(input_mask[i * max_seq_len + fi]));
}
}
return 0;
}
#if 0
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
......@@ -120,11 +167,11 @@ int create_req(Request* req,
std::vector<std::string> seg_list = split(feature_list[3], " ");
std::vector<std::string> mask_list = split(feature_list[4], " ");
for (int fi = 0; fi < max_seq_len; fi++) {
if (fi < token_list.size()) {
ins->add_token_ids(std::stoi(token_list[fi]));
ins->add_sentence_type_ids(std::stoll(seg_list[fi]));
ins->add_position_ids(std::stoll(pos_list[fi]));
ins->add_input_masks(std::stof(mask_list[fi]));
if (fi < std::stoi(shape_list[1])) {
ins->add_token_ids(std::stoi(token_list[fi + (i * max_seq_len)]));
ins->add_sentence_type_ids(std::stoll(seg_list[fi + (i * max_seq_len)]));
ins->add_position_ids(std::stoll(pos_list[fi + (i * max_seq_len)]));
ins->add_input_masks(std::stof(mask_list[fi + (i * max_seq_len)]));
} else {
ins->add_token_ids(0);
ins->add_sentence_type_ids(0);
......@@ -135,6 +182,7 @@ int create_req(Request* req,
}
return 0;
}
#endif
void print_res(const Request& req,
const Response& res,
......@@ -184,10 +232,16 @@ void thread_worker(PredictorApi* api,
}
g_concurrency++;
LOG(INFO) << "Current concurrency " << g_concurrency.load();
#if 0
int data_index = turns * batch_size;
if (create_req(&req, data_list, data_index, batch_size) != 0) {
return;
}
#else
if (create_req(&req, data_list, turns, batch_size) != 0) {
return;
}
#endif
if (predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString();
return;
......
......@@ -17,6 +17,9 @@
#include <string>
#include "predictor/framework/infer.h"
#include "predictor/framework/memory.h"
#if 1
#include <sstream>
#endif
namespace baidu {
namespace paddle_serving {
namespace serving {
......@@ -28,7 +31,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Embedding_values;
const uint32_t MAX_SEQ_LEN = 64;
const uint32_t MAX_SEQ_LEN = 82;
const bool POOLING = true;
const int LAYER_NUM = 12;
const int EMB_SIZE = 768;
......@@ -105,24 +108,51 @@ int BertServiceOp::inference() {
index += MAX_SEQ_LEN;
}
#if 0
int64_t *src_data = static_cast<int64_t *>(src_ids.data.data());
std::ostringstream oss;
oss << "src_ids: ";
for (int i = 0; i < MAX_SEQ_LEN * batch_size; ++i) {
oss << src_data[i] << " ";
}
LOG(INFO) << oss.str();
#endif
in->push_back(src_ids);
in->push_back(pos_ids);
in->push_back(seg_ids);
in->push_back(input_masks);
TensorVector *out = butil::get_object<TensorVector>();
// TensorVector out;
/*
if (!out) {
LOG(ERROR) << "Failed get tls output object";
return -1;
}
*/
LOG(INFO) << "batch_size : " << batch_size;
LOG(INFO) << "MAX_SEQ_LEN : " << (*in)[0].shape[1];
float* example = (float*)(*in)[3].data.data();
for(uint32_t i = 0; i < MAX_SEQ_LEN; i++){
LOG(INFO) << *(example + i);
for (int j = 0; j < 3; j ++) {
LOG(INFO) << "name : " << (*in)[j].name << " shape : " << (*in)[j].shape[0]
<< " " << (*in)[j].shape[1] << " " << (*in)[j].shape[2];
int64_t* example = (int64_t*)(*in)[j].data.data();
std::ostringstream oss;
for(uint32_t i = MAX_SEQ_LEN * (batch_size - 1); i < MAX_SEQ_LEN * batch_size; i++){
oss << *(example + i);
}
LOG(INFO) << "data : " << oss.str();
}
for (int j =3; j < 4; j++) {
LOG(INFO) << "name : " << (*in)[j].name << " shape : " << (*in)[j].shape[0]
<< " " << (*in)[j].shape[1] << " " << (*in)[j].shape[2];
float* example = (float*)(*in)[j].data.data();
std::ostringstream oss;
for(uint32_t i = MAX_SEQ_LEN * (batch_size - 1); i < MAX_SEQ_LEN * batch_size; i++){
oss << *(example + i);
}
LOG(INFO) << "data : " << oss.str();
}
if (predictor::InferManager::instance().infer(
BERT_MODEL_NAME, in, out, batch_size)) {
......@@ -130,6 +160,13 @@ int BertServiceOp::inference() {
return -1;
}
/*
paddle::NativeConfig config;
config.model_dir = "./data/model/paddle/fluid/bert";
auto predictor = CreatePaddlePredictor(config);
predictor->Run(*in, &out);
*/
#if 0
// float *out_data = static_cast<float *>(out->at(0).data.data());
LOG(INFO) << "check point";
/*
......@@ -160,6 +197,42 @@ int BertServiceOp::inference() {
out->clear();
butil::return_object<TensorVector>(out);
*/
#else
float *out_data = static_cast<float *>(out->at(0).data.data());
std::ostringstream oss;
oss << "Shape: [";
for (auto x: out->at(0).shape) {
oss << x << " ";
}
oss << "]";
LOG(INFO) << oss.str();
// Output shape is [batch_size x 3]
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
std::ostringstream oss;
oss << "Sample " << bi << " [";
oss << out_data[bi * 3 + 0] << " "
<< out_data[bi * 3 + 1] << " "
<< out_data[bi * 3 + 2] << "]";
LOG(INFO) << oss.str();
}
for (size_t i = 0; i < in->size(); ++i) {
(*in)[i].shape.clear();
}
in->clear();
butil::return_object<TensorVector>(in);
for (size_t i = 0; i < out->size(); ++i) {
(*out)[i].shape.clear();
}
out->clear();
butil::return_object<TensorVector>(out);
#endif
return 0;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册