diff --git a/demo-client/src/bert_service.cpp b/demo-client/src/bert_service.cpp index 3c6d75da31b1503677ac8d95cdfae9baee5764f9..b519a432923bac21478061d467b70aac9cede4d7 100644 --- a/demo-client/src/bert_service.cpp +++ b/demo-client/src/bert_service.cpp @@ -32,7 +32,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance; using baidu::paddle_serving::predictor::bert_service::Embedding_values; int batch_size = 1; -int max_seq_len = 512; +int max_seq_len = 82; int layer_num = 12; int emb_size = 768; int thread_num = 1; @@ -55,7 +55,7 @@ std::vector split(const std::string& str, } return res; } - +/* int create_req(Request* req, const std::vector& data_list, int data_index, @@ -93,6 +93,48 @@ int create_req(Request* req, } return 0; } +*/ + + +int create_req(Request* req, + const std::vector& data_list, + int data_index, + int batch_size) { + for (int i = 0; i < batch_size; ++i) { + BertReqInstance* ins = req->add_instances(); + if (!ins) { + LOG(ERROR) << "Failed create req instance"; + return -1; + } + // add data + // avoid out of boundary + int cur_index = data_index + i; + if (cur_index >= data_list.size()) { + cur_index = cur_index % data_list.size(); + } + + std::vector feature_list = split(data_list[cur_index], ":"); + std::vector shape_list = split(feature_list[0]," "); + std::vector token_list = split(feature_list[1], " "); + std::vector pos_list = split(feature_list[2], " "); + std::vector seg_list = split(feature_list[3], " "); + std::vector mask_list = split(feature_list[4], " "); + for (int fi = 0; fi < max_seq_len; fi++) { + if (fi < token_list.size()) { + ins->add_token_ids(std::stoi(token_list[fi])); + ins->add_sentence_type_ids(std::stoll(seg_list[fi])); + ins->add_position_ids(std::stoll(pos_list[fi])); + ins->add_input_masks(std::stof(mask_list[fi])); + } else { + ins->add_token_ids(0); + ins->add_sentence_type_ids(0); + ins->add_position_ids(0); + ins->add_input_masks(0.0); + } + } + } + return 0; +} void print_res(const Request& req, const Response& res,