提交 3df4b7e8 编写于 作者: M MRXLT

add feed var customize for bert server

上级 6bc6cc47
......@@ -93,9 +93,9 @@ int create_req(Request* req,
ins->add_input_masks(0.0);
}
}
ins->set_max_seq_len(max_seq_len);
ins->set_emb_size(emb_size);
}
req->set_max_seq_len(max_seq_len);
req->set_emb_size(emb_size);
return 0;
}
......
......@@ -28,6 +28,21 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::EmbeddingValues;
std::vector<std::string> split(const std::string &str,
const std::string &pattern) {
std::vector<std::string> res;
if (str == "") return res;
std::string strs = str + pattern;
size_t pos = strs.find(pattern);
while (pos != strs.npos) {
std::string temp = strs.substr(0, pos);
res.push_back(temp);
strs = strs.substr(pos + 1, strs.size());
pos = strs.find(pattern);
}
return res;
}
int BertServiceOp::inference() {
timeval op_start;
gettimeofday(&op_start, NULL);
......@@ -43,17 +58,27 @@ int BertServiceOp::inference() {
return 0;
}
const int64_t MAX_SEQ_LEN = req->instances(0).max_seq_len();
const int64_t EMB_SIZE = req->instances(0).emb_size();
const int64_t MAX_SEQ_LEN = req->max_seq_len();
const int64_t EMB_SIZE = req->emb_size();
paddle::PaddleTensor src_ids;
paddle::PaddleTensor pos_ids;
paddle::PaddleTensor seg_ids;
paddle::PaddleTensor input_masks;
if (req->has_feed_var_names()) {
// support paddlehub model
std::vector<std::string> feed_list = split(req->feed_var_names(), ";");
src_ids.name = feed_list[0];
pos_ids.name = feed_list[1];
seg_ids.name = feed_list[2];
input_masks.name = feed_list[3];
} else {
src_ids.name = std::string("src_ids");
pos_ids.name = std::string("pos_ids");
seg_ids.name = std::string("sent_ids");
input_masks.name = std::string("input_mask");
}
src_ids.dtype = paddle::PaddleDType::INT64;
src_ids.shape = {batch_size, MAX_SEQ_LEN, 1};
......
......@@ -25,11 +25,15 @@ message BertReqInstance {
repeated int64 sentence_type_ids = 2;
repeated int64 position_ids = 3;
repeated float input_masks = 4;
optional int64 max_seq_len = 5;
optional int64 emb_size = 6;
};
message Request { repeated BertReqInstance instances = 1; };
message Request {
repeated BertReqInstance instances = 1;
optional int64 max_seq_len = 2;
optional int64 emb_size = 3;
optional string feed_var_names = 4;
optional string fetch_var_names = 5;
};
message EmbeddingValues { repeated float values = 1; };
......
......@@ -25,11 +25,15 @@ message BertReqInstance {
repeated int64 sentence_type_ids = 2;
repeated int64 position_ids = 3;
repeated float input_masks = 4;
optional int64 max_seq_len = 5;
optional int64 emb_size = 6;
};
message Request { repeated BertReqInstance instances = 1; };
message Request {
repeated BertReqInstance instances = 1;
optional int64 max_seq_len = 2;
optional int64 emb_size = 3;
optional string feed_var_names = 4;
optional string fetch_var_names = 5;
};
message EmbeddingValues { repeated float values = 1; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册