diff --git a/demo-client/src/bert_service.cpp b/demo-client/src/bert_service.cpp index 59daad75069e1ed4b8a85006d5b137aaf55de105..1910b18421bf99d4b4f1e197fcbaf9d07cb52976 100644 --- a/demo-client/src/bert_service.cpp +++ b/demo-client/src/bert_service.cpp @@ -93,9 +93,9 @@ int create_req(Request* req, ins->add_input_masks(0.0); } } - ins->set_max_seq_len(max_seq_len); - ins->set_emb_size(emb_size); } + req->set_max_seq_len(max_seq_len); + req->set_emb_size(emb_size); return 0; } diff --git a/demo-serving/op/bert_service_op.cpp b/demo-serving/op/bert_service_op.cpp index 6791d4b9ec46e2a5ece0ab585225aea5f909e75c..831ab855c5004238994528462b3402d820847419 100644 --- a/demo-serving/op/bert_service_op.cpp +++ b/demo-serving/op/bert_service_op.cpp @@ -28,6 +28,21 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance; using baidu::paddle_serving::predictor::bert_service::Request; using baidu::paddle_serving::predictor::bert_service::EmbeddingValues; +std::vector split(const std::string &str, + const std::string &pattern) { + std::vector res; + if (str == "") return res; + std::string strs = str + pattern; + size_t pos = strs.find(pattern); + while (pos != strs.npos) { + std::string temp = strs.substr(0, pos); + res.push_back(temp); + strs = strs.substr(pos + 1, strs.size()); + pos = strs.find(pattern); + } + return res; +} + int BertServiceOp::inference() { timeval op_start; gettimeofday(&op_start, NULL); @@ -43,17 +58,27 @@ int BertServiceOp::inference() { return 0; } - const int64_t MAX_SEQ_LEN = req->instances(0).max_seq_len(); - const int64_t EMB_SIZE = req->instances(0).emb_size(); + const int64_t MAX_SEQ_LEN = req->max_seq_len(); + const int64_t EMB_SIZE = req->emb_size(); paddle::PaddleTensor src_ids; paddle::PaddleTensor pos_ids; paddle::PaddleTensor seg_ids; paddle::PaddleTensor input_masks; - src_ids.name = std::string("src_ids"); - pos_ids.name = std::string("pos_ids"); - seg_ids.name = std::string("sent_ids"); - input_masks.name = std::string("input_mask"); + + if (req->has_feed_var_names()) { + // support paddlehub model + std::vector feed_list = split(req->feed_var_names(), ";"); + src_ids.name = feed_list[0]; + pos_ids.name = feed_list[1]; + seg_ids.name = feed_list[2]; + input_masks.name = feed_list[3]; + } else { + src_ids.name = std::string("src_ids"); + pos_ids.name = std::string("pos_ids"); + seg_ids.name = std::string("sent_ids"); + input_masks.name = std::string("input_mask"); + } src_ids.dtype = paddle::PaddleDType::INT64; src_ids.shape = {batch_size, MAX_SEQ_LEN, 1}; diff --git a/demo-serving/proto/bert_service.proto b/demo-serving/proto/bert_service.proto index 73ee7ed4958cc26777097fc16ee2f4dbb6c06879..8b19c5efc683a6fa4447e00cd4700b771102a0d5 100644 --- a/demo-serving/proto/bert_service.proto +++ b/demo-serving/proto/bert_service.proto @@ -25,11 +25,15 @@ message BertReqInstance { repeated int64 sentence_type_ids = 2; repeated int64 position_ids = 3; repeated float input_masks = 4; - optional int64 max_seq_len = 5; - optional int64 emb_size = 6; }; -message Request { repeated BertReqInstance instances = 1; }; +message Request { + repeated BertReqInstance instances = 1; + optional int64 max_seq_len = 2; + optional int64 emb_size = 3; + optional string feed_var_names = 4; + optional string fetch_var_names = 5; +}; message EmbeddingValues { repeated float values = 1; }; diff --git a/sdk-cpp/proto/bert_service.proto b/sdk-cpp/proto/bert_service.proto index 693386afb188ebfc9631e9102a2a7eb27223851e..e12e7674343bc25379cf9ca65ebde4e7775af028 100644 --- a/sdk-cpp/proto/bert_service.proto +++ b/sdk-cpp/proto/bert_service.proto @@ -25,11 +25,15 @@ message BertReqInstance { repeated int64 sentence_type_ids = 2; repeated int64 position_ids = 3; repeated float input_masks = 4; - optional int64 max_seq_len = 5; - optional int64 emb_size = 6; }; -message Request { repeated BertReqInstance instances = 1; }; +message Request { + repeated BertReqInstance instances = 1; + optional int64 max_seq_len = 2; + optional int64 emb_size = 3; + optional string feed_var_names = 4; + optional string fetch_var_names = 5; +}; message EmbeddingValues { repeated float values = 1; };