提交 5104ccd7 编写于 作者: X xulongteng

Merge remote-tracking branch 'refs/remotes/origin/bert' into bert

...@@ -30,7 +30,7 @@ using baidu::paddle_serving::predictor::bert_service::Request; ...@@ -30,7 +30,7 @@ using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Response; using baidu::paddle_serving::predictor::bert_service::Response;
using baidu::paddle_serving::predictor::bert_service::BertResInstance; using baidu::paddle_serving::predictor::bert_service::BertResInstance;
using baidu::paddle_serving::predictor::bert_service::BertReqInstance; using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Embedding_values; using baidu::paddle_serving::predictor::bert_service::EmbeddingValues;
extern int batch_size = 1; extern int batch_size = 1;
extern int max_seq_len = 128; extern int max_seq_len = 128;
...@@ -108,7 +108,7 @@ void print_res(const Request& req, ...@@ -108,7 +108,7 @@ void print_res(const Request& req,
std::ostringstream oss; std::ostringstream oss;
oss << "["; oss << "[";
for (uint32_t bi = 0; bi < res_ins.instances_size(); bi++) { for (uint32_t bi = 0; bi < res_ins.instances_size(); bi++) {
const Embedding_values& emb_ins = res_ins.instances(bi); const EmbeddingValues& emb_ins = res_ins.instances(bi);
oss << "["; oss << "[";
for (uint32_t ei = 0; ei < emb_ins.values_size(); ei++) { for (uint32_t ei = 0; ei < emb_ins.values_size(); ei++) {
oss << emb_ins.values(ei) << " "; oss << emb_ins.values(ei) << " ";
......
...@@ -45,4 +45,5 @@ engines { ...@@ -45,4 +45,5 @@ engines {
runtime_thread_num: 0 runtime_thread_num: 0
batch_infer_size: 0 batch_infer_size: 0
enable_batch_align: 0 enable_batch_align: 0
enable_memory_optimization: true
} }
...@@ -26,12 +26,7 @@ using baidu::paddle_serving::predictor::bert_service::BertResInstance; ...@@ -26,12 +26,7 @@ using baidu::paddle_serving::predictor::bert_service::BertResInstance;
using baidu::paddle_serving::predictor::bert_service::Response; using baidu::paddle_serving::predictor::bert_service::Response;
using baidu::paddle_serving::predictor::bert_service::BertReqInstance; using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request; using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Embedding_values; using baidu::paddle_serving::predictor::bert_service::EmbeddingValues;
extern int64_t MAX_SEQ_LEN = 128;
const bool POOLING = true;
const int LAYER_NUM = 12;
extern int EMB_SIZE = 768;
int BertServiceOp::inference() { int BertServiceOp::inference() {
timeval op_start; timeval op_start;
...@@ -48,8 +43,8 @@ int BertServiceOp::inference() { ...@@ -48,8 +43,8 @@ int BertServiceOp::inference() {
return 0; return 0;
} }
MAX_SEQ_LEN = req->instances(0).max_seq_len(); const int64_t MAX_SEQ_LEN = req->instances(0).max_seq_len();
EMB_SIZE = req->instances(0).emb_size(); const int64_t EMB_SIZE = req->instances(0).emb_size();
paddle::PaddleTensor src_ids; paddle::PaddleTensor src_ids;
paddle::PaddleTensor pos_ids; paddle::PaddleTensor pos_ids;
...@@ -99,7 +94,6 @@ int BertServiceOp::inference() { ...@@ -99,7 +94,6 @@ int BertServiceOp::inference() {
memcpy(src_data, memcpy(src_data,
req_instance.token_ids().data(), req_instance.token_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN); sizeof(int64_t) * MAX_SEQ_LEN);
#if 1
memcpy(pos_data, memcpy(pos_data,
req_instance.position_ids().data(), req_instance.position_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN); sizeof(int64_t) * MAX_SEQ_LEN);
...@@ -109,7 +103,6 @@ int BertServiceOp::inference() { ...@@ -109,7 +103,6 @@ int BertServiceOp::inference() {
memcpy(input_masks_data, memcpy(input_masks_data,
req_instance.input_masks().data(), req_instance.input_masks().data(),
sizeof(float) * MAX_SEQ_LEN); sizeof(float) * MAX_SEQ_LEN);
#endif
index += MAX_SEQ_LEN; index += MAX_SEQ_LEN;
} }
...@@ -151,30 +144,14 @@ int BertServiceOp::inference() { ...@@ -151,30 +144,14 @@ int BertServiceOp::inference() {
uint64_t infer_time = uint64_t infer_time =
(infer_end.tv_sec * 1000 + infer_end.tv_usec / 1000 - (infer_end.tv_sec * 1000 + infer_end.tv_usec / 1000 -
(infer_start.tv_sec * 1000 + infer_start.tv_usec / 1000)); (infer_start.tv_sec * 1000 + infer_start.tv_usec / 1000));
#if 0
LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " seq_len : " << out->at(0).shape[1]
<< " emb_size : " << out->at(0).shape[2];
float *out_data = reinterpret_cast<float *>(out->at(0).data.data());
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < MAX_SEQ_LEN; si++) {
Embedding_values *emb_instance = res_instance->add_instances();
for (uint32_t ei = 0; ei < EMB_SIZE; ei++) {
uint32_t index = bi * MAX_SEQ_LEN * EMB_SIZE + si * EMB_SIZE + ei;
emb_instance->add_values(out_data[index]);
}
}
}
#else
LOG(INFO) << "batch_size : " << out->at(0).shape[0] LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " emb_size : " << out->at(0).shape[1]; << " emb_size : " << out->at(0).shape[1];
float *out_data = reinterpret_cast<float *>(out->at(0).data.data()); float *out_data = reinterpret_cast<float *>(out->at(0).data.data());
for (uint32_t bi = 0; bi < batch_size; bi++) { for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances(); BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < 1; si++) { for (uint32_t si = 0; si < 1; si++) {
Embedding_values *emb_instance = res_instance->add_instances(); EmbeddingValues *emb_instance = res_instance->add_instances();
for (uint32_t ei = 0; ei < EMB_SIZE; ei++) { for (uint32_t ei = 0; ei < EMB_SIZE; ei++) {
uint32_t index = bi * EMB_SIZE + ei; uint32_t index = bi * EMB_SIZE + ei;
emb_instance->add_values(out_data[index]); emb_instance->add_values(out_data[index]);
...@@ -189,7 +166,7 @@ int BertServiceOp::inference() { ...@@ -189,7 +166,7 @@ int BertServiceOp::inference() {
res->set_op_time(op_time); res->set_op_time(op_time);
res->set_infer_time(infer_time); res->set_infer_time(infer_time);
#endif
for (size_t i = 0; i < in->size(); ++i) { for (size_t i = 0; i < in->size(); ++i) {
(*in)[i].shape.clear(); (*in)[i].shape.clear();
} }
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif #endif
#else #else
#include "./paddle_inference_api.h" #include "paddle_inference_api.h" // NOLINT
#endif #endif
#include "demo-serving/bert_service.pb.h" #include "demo-serving/bert_service.pb.h"
......
...@@ -31,9 +31,9 @@ message BertReqInstance { ...@@ -31,9 +31,9 @@ message BertReqInstance {
message Request { repeated BertReqInstance instances = 1; }; message Request { repeated BertReqInstance instances = 1; };
message Embedding_values { repeated float values = 1; }; message EmbeddingValues { repeated float values = 1; };
message BertResInstance { repeated Embedding_values instances = 1; }; message BertResInstance { repeated EmbeddingValues instances = 1; };
message Response { message Response {
repeated BertResInstance instances = 1; repeated BertResInstance instances = 1;
......
...@@ -31,9 +31,9 @@ message BertReqInstance { ...@@ -31,9 +31,9 @@ message BertReqInstance {
message Request { repeated BertReqInstance instances = 1; }; message Request { repeated BertReqInstance instances = 1; };
message Embedding_values { repeated float values = 1; }; message EmbeddingValues { repeated float values = 1; };
message BertResInstance { repeated Embedding_values instances = 1; }; message BertResInstance { repeated EmbeddingValues instances = 1; };
message Response { message Response {
repeated BertResInstance instances = 1; repeated BertResInstance instances = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册