From ddd34f8ade1e9077d53b99bdebf4f4cbdb69552e Mon Sep 17 00:00:00 2001 From: MRXLT Date: Fri, 8 Nov 2019 19:45:53 +0800 Subject: [PATCH] add feed var customize for bert server --- demo-client/src/bert_service.cpp | 4 +-- demo-serving/op/bert_service_op.cpp | 37 ++++++++++++++++++++++----- demo-serving/proto/bert_service.proto | 10 +++++--- sdk-cpp/proto/bert_service.proto | 10 +++++--- 4 files changed, 47 insertions(+), 14 deletions(-) diff --git a/demo-client/src/bert_service.cpp b/demo-client/src/bert_service.cpp index 59daad75..1910b184 100644 --- a/demo-client/src/bert_service.cpp +++ b/demo-client/src/bert_service.cpp @@ -93,9 +93,9 @@ int create_req(Request* req, ins->add_input_masks(0.0); } } - ins->set_max_seq_len(max_seq_len); - ins->set_emb_size(emb_size); } + req->set_max_seq_len(max_seq_len); + req->set_emb_size(emb_size); return 0; } diff --git a/demo-serving/op/bert_service_op.cpp b/demo-serving/op/bert_service_op.cpp index 6791d4b9..831ab855 100644 --- a/demo-serving/op/bert_service_op.cpp +++ b/demo-serving/op/bert_service_op.cpp @@ -28,6 +28,21 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance; using baidu::paddle_serving::predictor::bert_service::Request; using baidu::paddle_serving::predictor::bert_service::EmbeddingValues; +std::vector split(const std::string &str, + const std::string &pattern) { + std::vector res; + if (str == "") return res; + std::string strs = str + pattern; + size_t pos = strs.find(pattern); + while (pos != strs.npos) { + std::string temp = strs.substr(0, pos); + res.push_back(temp); + strs = strs.substr(pos + 1, strs.size()); + pos = strs.find(pattern); + } + return res; +} + int BertServiceOp::inference() { timeval op_start; gettimeofday(&op_start, NULL); @@ -43,17 +58,27 @@ int BertServiceOp::inference() { return 0; } - const int64_t MAX_SEQ_LEN = req->instances(0).max_seq_len(); - const int64_t EMB_SIZE = req->instances(0).emb_size(); + const int64_t MAX_SEQ_LEN = req->max_seq_len(); + const int64_t EMB_SIZE = req->emb_size(); paddle::PaddleTensor src_ids; paddle::PaddleTensor pos_ids; paddle::PaddleTensor seg_ids; paddle::PaddleTensor input_masks; - src_ids.name = std::string("src_ids"); - pos_ids.name = std::string("pos_ids"); - seg_ids.name = std::string("sent_ids"); - input_masks.name = std::string("input_mask"); + + if (req->has_feed_var_names()) { + // support paddlehub model + std::vector feed_list = split(req->feed_var_names(), ";"); + src_ids.name = feed_list[0]; + pos_ids.name = feed_list[1]; + seg_ids.name = feed_list[2]; + input_masks.name = feed_list[3]; + } else { + src_ids.name = std::string("src_ids"); + pos_ids.name = std::string("pos_ids"); + seg_ids.name = std::string("sent_ids"); + input_masks.name = std::string("input_mask"); + } src_ids.dtype = paddle::PaddleDType::INT64; src_ids.shape = {batch_size, MAX_SEQ_LEN, 1}; diff --git a/demo-serving/proto/bert_service.proto b/demo-serving/proto/bert_service.proto index 73ee7ed4..8b19c5ef 100644 --- a/demo-serving/proto/bert_service.proto +++ b/demo-serving/proto/bert_service.proto @@ -25,11 +25,15 @@ message BertReqInstance { repeated int64 sentence_type_ids = 2; repeated int64 position_ids = 3; repeated float input_masks = 4; - optional int64 max_seq_len = 5; - optional int64 emb_size = 6; }; -message Request { repeated BertReqInstance instances = 1; }; +message Request { + repeated BertReqInstance instances = 1; + optional int64 max_seq_len = 2; + optional int64 emb_size = 3; + optional string feed_var_names = 4; + optional string fetch_var_names = 5; +}; message EmbeddingValues { repeated float values = 1; }; diff --git a/sdk-cpp/proto/bert_service.proto b/sdk-cpp/proto/bert_service.proto index 693386af..e12e7674 100644 --- a/sdk-cpp/proto/bert_service.proto +++ b/sdk-cpp/proto/bert_service.proto @@ -25,11 +25,15 @@ message BertReqInstance { repeated int64 sentence_type_ids = 2; repeated int64 position_ids = 3; repeated float input_masks = 4; - optional int64 max_seq_len = 5; - optional int64 emb_size = 6; }; -message Request { repeated BertReqInstance instances = 1; }; +message Request { + repeated BertReqInstance instances = 1; + optional int64 max_seq_len = 2; + optional int64 emb_size = 3; + optional string feed_var_names = 4; + optional string fetch_var_names = 5; +}; message EmbeddingValues { repeated float values = 1; }; -- GitLab