diff --git a/demo-serving/op/bert_service_op.cpp b/demo-serving/op/bert_service_op.cpp index e4618dc7de989536397af4135f9c2e5f6e5164a2..f4713782bf54e894efdf391c05231d77accf7019 100644 --- a/demo-serving/op/bert_service_op.cpp +++ b/demo-serving/op/bert_service_op.cpp @@ -31,7 +31,7 @@ using baidu::paddle_serving::predictor::bert_service::Embedding_values; extern int64_t MAX_SEQ_LEN = 128; const bool POOLING = true; const int LAYER_NUM = 12; -const int EMB_SIZE = 768; +extern int EMB_SIZE = 768; int BertServiceOp::inference() { timeval op_start; @@ -49,6 +49,7 @@ int BertServiceOp::inference() { } MAX_SEQ_LEN = req->instances(0).max_seq_len(); + EMB_SIZE = req->instances(0).emb_size(); paddle::PaddleTensor src_ids; paddle::PaddleTensor pos_ids; diff --git a/demo-serving/proto/bert_service.proto b/demo-serving/proto/bert_service.proto index ce1ceeb5ee440c5fd6c8ff2e573d46c30fb4e8ff..3eba501317c2533ad8c02d79d5eca2dc6682a65e 100644 --- a/demo-serving/proto/bert_service.proto +++ b/demo-serving/proto/bert_service.proto @@ -26,6 +26,7 @@ message BertReqInstance { repeated int64 position_ids = 3; repeated float input_masks = 4; required int64 max_seq_len = 5; + required int64 emb_size = 6; }; message Request { repeated BertReqInstance instances = 1; }; @@ -34,7 +35,11 @@ message Embedding_values { repeated float values = 1; }; message BertResInstance { repeated Embedding_values instances = 1; }; -message Response { repeated BertResInstance instances = 1; }; +message Response { + repeated BertResInstance instances = 1; + optional int64 op_time = 2; + optional int64 infer_time = 3; +}; service BertService { rpc inference(Request) returns (Response);