diff --git a/demo-client/src/bert_service.cpp b/demo-client/src/bert_service.cpp index f272078698e9921c0fad0b65666589d4c8c6a2e6..ac6b661285918384cc18002064bc85dfcdca727b 100644 --- a/demo-client/src/bert_service.cpp +++ b/demo-client/src/bert_service.cpp @@ -30,7 +30,7 @@ using baidu::paddle_serving::predictor::bert_service::Request; using baidu::paddle_serving::predictor::bert_service::Response; using baidu::paddle_serving::predictor::bert_service::BertResInstance; using baidu::paddle_serving::predictor::bert_service::BertReqInstance; -using baidu::paddle_serving::predictor::bert_service::Embedding_values; +using baidu::paddle_serving::predictor::bert_service::EmbeddingValues; extern int batch_size = 1; extern int max_seq_len = 128; @@ -108,7 +108,7 @@ void print_res(const Request& req, std::ostringstream oss; oss << "["; for (uint32_t bi = 0; bi < res_ins.instances_size(); bi++) { - const Embedding_values& emb_ins = res_ins.instances(bi); + const EmbeddingValues& emb_ins = res_ins.instances(bi); oss << "["; for (uint32_t ei = 0; ei < emb_ins.values_size(); ei++) { oss << emb_ins.values(ei) << " "; diff --git a/demo-serving/conf/model_toolkit.prototxt b/demo-serving/conf/model_toolkit.prototxt index c3f618d104461b6aeef13173aaf482e0538f50af..fae40a9dd0b4c19225cedd8bad42aea43bd1c5d7 100644 --- a/demo-serving/conf/model_toolkit.prototxt +++ b/demo-serving/conf/model_toolkit.prototxt @@ -45,4 +45,5 @@ engines { runtime_thread_num: 0 batch_infer_size: 0 enable_batch_align: 0 + enable_memory_optimization: true } diff --git a/demo-serving/op/bert_service_op.cpp b/demo-serving/op/bert_service_op.cpp index f34d113a59fc79a7badaaf75f3b839e3200c89f6..6791d4b9ec46e2a5ece0ab585225aea5f909e75c 100644 --- a/demo-serving/op/bert_service_op.cpp +++ b/demo-serving/op/bert_service_op.cpp @@ -26,12 +26,7 @@ using baidu::paddle_serving::predictor::bert_service::BertResInstance; using baidu::paddle_serving::predictor::bert_service::Response; using baidu::paddle_serving::predictor::bert_service::BertReqInstance; using baidu::paddle_serving::predictor::bert_service::Request; -using baidu::paddle_serving::predictor::bert_service::Embedding_values; - -extern int64_t MAX_SEQ_LEN = 128; -const bool POOLING = true; -const int LAYER_NUM = 12; -extern int EMB_SIZE = 768; +using baidu::paddle_serving::predictor::bert_service::EmbeddingValues; int BertServiceOp::inference() { timeval op_start; @@ -48,8 +43,8 @@ int BertServiceOp::inference() { return 0; } - MAX_SEQ_LEN = req->instances(0).max_seq_len(); - EMB_SIZE = req->instances(0).emb_size(); + const int64_t MAX_SEQ_LEN = req->instances(0).max_seq_len(); + const int64_t EMB_SIZE = req->instances(0).emb_size(); paddle::PaddleTensor src_ids; paddle::PaddleTensor pos_ids; @@ -99,7 +94,6 @@ int BertServiceOp::inference() { memcpy(src_data, req_instance.token_ids().data(), sizeof(int64_t) * MAX_SEQ_LEN); -#if 1 memcpy(pos_data, req_instance.position_ids().data(), sizeof(int64_t) * MAX_SEQ_LEN); @@ -109,7 +103,6 @@ int BertServiceOp::inference() { memcpy(input_masks_data, req_instance.input_masks().data(), sizeof(float) * MAX_SEQ_LEN); -#endif index += MAX_SEQ_LEN; } @@ -151,30 +144,14 @@ int BertServiceOp::inference() { uint64_t infer_time = (infer_end.tv_sec * 1000 + infer_end.tv_usec / 1000 - (infer_start.tv_sec * 1000 + infer_start.tv_usec / 1000)); -#if 0 - LOG(INFO) << "batch_size : " << out->at(0).shape[0] - << " seq_len : " << out->at(0).shape[1] - << " emb_size : " << out->at(0).shape[2]; - - float *out_data = reinterpret_cast(out->at(0).data.data()); - for (uint32_t bi = 0; bi < batch_size; bi++) { - BertResInstance *res_instance = res->add_instances(); - for (uint32_t si = 0; si < MAX_SEQ_LEN; si++) { - Embedding_values *emb_instance = res_instance->add_instances(); - for (uint32_t ei = 0; ei < EMB_SIZE; ei++) { - uint32_t index = bi * MAX_SEQ_LEN * EMB_SIZE + si * EMB_SIZE + ei; - emb_instance->add_values(out_data[index]); - } - } - } -#else + LOG(INFO) << "batch_size : " << out->at(0).shape[0] << " emb_size : " << out->at(0).shape[1]; float *out_data = reinterpret_cast(out->at(0).data.data()); for (uint32_t bi = 0; bi < batch_size; bi++) { BertResInstance *res_instance = res->add_instances(); for (uint32_t si = 0; si < 1; si++) { - Embedding_values *emb_instance = res_instance->add_instances(); + EmbeddingValues *emb_instance = res_instance->add_instances(); for (uint32_t ei = 0; ei < EMB_SIZE; ei++) { uint32_t index = bi * EMB_SIZE + ei; emb_instance->add_values(out_data[index]); @@ -189,7 +166,7 @@ int BertServiceOp::inference() { res->set_op_time(op_time); res->set_infer_time(infer_time); -#endif + for (size_t i = 0; i < in->size(); ++i) { (*in)[i].shape.clear(); } diff --git a/demo-serving/op/bert_service_op.h b/demo-serving/op/bert_service_op.h index ddc38b5612dfa570bc2d2669b8c35099b75515c2..080e33e44a2a9e7c8f84855c8d4e911101c329d0 100644 --- a/demo-serving/op/bert_service_op.h +++ b/demo-serving/op/bert_service_op.h @@ -21,7 +21,7 @@ #include "paddle/fluid/inference/api/paddle_inference_api.h" #endif #else -#include "./paddle_inference_api.h" +#include "paddle_inference_api.h" // NOLINT #endif #include "demo-serving/bert_service.pb.h" diff --git a/demo-serving/proto/bert_service.proto b/demo-serving/proto/bert_service.proto index 4ba3c536d5e31492742a7ae43ee1e3bb13a3db14..73ee7ed4958cc26777097fc16ee2f4dbb6c06879 100644 --- a/demo-serving/proto/bert_service.proto +++ b/demo-serving/proto/bert_service.proto @@ -31,9 +31,9 @@ message BertReqInstance { message Request { repeated BertReqInstance instances = 1; }; -message Embedding_values { repeated float values = 1; }; +message EmbeddingValues { repeated float values = 1; }; -message BertResInstance { repeated Embedding_values instances = 1; }; +message BertResInstance { repeated EmbeddingValues instances = 1; }; message Response { repeated BertResInstance instances = 1; diff --git a/sdk-cpp/proto/bert_service.proto b/sdk-cpp/proto/bert_service.proto index 01168560e1d1fb698e326a966aeb553decbf538c..693386afb188ebfc9631e9102a2a7eb27223851e 100644 --- a/sdk-cpp/proto/bert_service.proto +++ b/sdk-cpp/proto/bert_service.proto @@ -31,9 +31,9 @@ message BertReqInstance { message Request { repeated BertReqInstance instances = 1; }; -message Embedding_values { repeated float values = 1; }; +message EmbeddingValues { repeated float values = 1; }; -message BertResInstance { repeated Embedding_values instances = 1; }; +message BertResInstance { repeated EmbeddingValues instances = 1; }; message Response { repeated BertResInstance instances = 1;