提交 b18eeabc 编写于 作者: X xulongteng

debug client

上级 3409a058
...@@ -32,7 +32,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance; ...@@ -32,7 +32,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Embedding_values; using baidu::paddle_serving::predictor::bert_service::Embedding_values;
int batch_size = 1; int batch_size = 1;
int max_seq_len = 512; int max_seq_len = 82;
int layer_num = 12; int layer_num = 12;
int emb_size = 768; int emb_size = 768;
int thread_num = 1; int thread_num = 1;
...@@ -55,7 +55,7 @@ std::vector<std::string> split(const std::string& str, ...@@ -55,7 +55,7 @@ std::vector<std::string> split(const std::string& str,
} }
return res; return res;
} }
/*
int create_req(Request* req, int create_req(Request* req,
const std::vector<std::string>& data_list, const std::vector<std::string>& data_list,
int data_index, int data_index,
...@@ -93,6 +93,48 @@ int create_req(Request* req, ...@@ -93,6 +93,48 @@ int create_req(Request* req,
} }
return 0; return 0;
} }
*/
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
int batch_size) {
for (int i = 0; i < batch_size; ++i) {
BertReqInstance* ins = req->add_instances();
if (!ins) {
LOG(ERROR) << "Failed create req instance";
return -1;
}
// add data
// avoid out of boundary
int cur_index = data_index + i;
if (cur_index >= data_list.size()) {
cur_index = cur_index % data_list.size();
}
std::vector<std::string> feature_list = split(data_list[cur_index], ":");
std::vector<std::string> shape_list = split(feature_list[0]," ");
std::vector<std::string> token_list = split(feature_list[1], " ");
std::vector<std::string> pos_list = split(feature_list[2], " ");
std::vector<std::string> seg_list = split(feature_list[3], " ");
std::vector<std::string> mask_list = split(feature_list[4], " ");
for (int fi = 0; fi < max_seq_len; fi++) {
if (fi < token_list.size()) {
ins->add_token_ids(std::stoi(token_list[fi]));
ins->add_sentence_type_ids(std::stoll(seg_list[fi]));
ins->add_position_ids(std::stoll(pos_list[fi]));
ins->add_input_masks(std::stof(mask_list[fi]));
} else {
ins->add_token_ids(0);
ins->add_sentence_type_ids(0);
ins->add_position_ids(0);
ins->add_input_masks(0.0);
}
}
}
return 0;
}
void print_res(const Request& req, void print_res(const Request& req,
const Response& res, const Response& res,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册