提交 b18eeabc 编写于 作者: X xulongteng

debug client

上级 3409a058
......@@ -32,7 +32,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Embedding_values;
int batch_size = 1;
int max_seq_len = 512;
int max_seq_len = 82;
int layer_num = 12;
int emb_size = 768;
int thread_num = 1;
......@@ -55,7 +55,7 @@ std::vector<std::string> split(const std::string& str,
}
return res;
}
/*
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
......@@ -93,6 +93,48 @@ int create_req(Request* req,
}
return 0;
}
*/
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
int batch_size) {
for (int i = 0; i < batch_size; ++i) {
BertReqInstance* ins = req->add_instances();
if (!ins) {
LOG(ERROR) << "Failed create req instance";
return -1;
}
// add data
// avoid out of boundary
int cur_index = data_index + i;
if (cur_index >= data_list.size()) {
cur_index = cur_index % data_list.size();
}
std::vector<std::string> feature_list = split(data_list[cur_index], ":");
std::vector<std::string> shape_list = split(feature_list[0]," ");
std::vector<std::string> token_list = split(feature_list[1], " ");
std::vector<std::string> pos_list = split(feature_list[2], " ");
std::vector<std::string> seg_list = split(feature_list[3], " ");
std::vector<std::string> mask_list = split(feature_list[4], " ");
for (int fi = 0; fi < max_seq_len; fi++) {
if (fi < token_list.size()) {
ins->add_token_ids(std::stoi(token_list[fi]));
ins->add_sentence_type_ids(std::stoll(seg_list[fi]));
ins->add_position_ids(std::stoll(pos_list[fi]));
ins->add_input_masks(std::stof(mask_list[fi]));
} else {
ins->add_token_ids(0);
ins->add_sentence_type_ids(0);
ins->add_position_ids(0);
ins->add_input_masks(0.0);
}
}
}
return 0;
}
void print_res(const Request& req,
const Response& res,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册