提交 5104ccd7 编写于 作者: X xulongteng

Merge remote-tracking branch 'refs/remotes/origin/bert' into bert

......@@ -30,7 +30,7 @@ using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Response;
using baidu::paddle_serving::predictor::bert_service::BertResInstance;
using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Embedding_values;
using baidu::paddle_serving::predictor::bert_service::EmbeddingValues;
extern int batch_size = 1;
extern int max_seq_len = 128;
......@@ -108,7 +108,7 @@ void print_res(const Request& req,
std::ostringstream oss;
oss << "[";
for (uint32_t bi = 0; bi < res_ins.instances_size(); bi++) {
const Embedding_values& emb_ins = res_ins.instances(bi);
const EmbeddingValues& emb_ins = res_ins.instances(bi);
oss << "[";
for (uint32_t ei = 0; ei < emb_ins.values_size(); ei++) {
oss << emb_ins.values(ei) << " ";
......
......@@ -45,4 +45,5 @@ engines {
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
enable_memory_optimization: true
}
......@@ -26,12 +26,7 @@ using baidu::paddle_serving::predictor::bert_service::BertResInstance;
using baidu::paddle_serving::predictor::bert_service::Response;
using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Embedding_values;
extern int64_t MAX_SEQ_LEN = 128;
const bool POOLING = true;
const int LAYER_NUM = 12;
extern int EMB_SIZE = 768;
using baidu::paddle_serving::predictor::bert_service::EmbeddingValues;
int BertServiceOp::inference() {
timeval op_start;
......@@ -48,8 +43,8 @@ int BertServiceOp::inference() {
return 0;
}
MAX_SEQ_LEN = req->instances(0).max_seq_len();
EMB_SIZE = req->instances(0).emb_size();
const int64_t MAX_SEQ_LEN = req->instances(0).max_seq_len();
const int64_t EMB_SIZE = req->instances(0).emb_size();
paddle::PaddleTensor src_ids;
paddle::PaddleTensor pos_ids;
......@@ -99,7 +94,6 @@ int BertServiceOp::inference() {
memcpy(src_data,
req_instance.token_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN);
#if 1
memcpy(pos_data,
req_instance.position_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN);
......@@ -109,7 +103,6 @@ int BertServiceOp::inference() {
memcpy(input_masks_data,
req_instance.input_masks().data(),
sizeof(float) * MAX_SEQ_LEN);
#endif
index += MAX_SEQ_LEN;
}
......@@ -151,30 +144,14 @@ int BertServiceOp::inference() {
uint64_t infer_time =
(infer_end.tv_sec * 1000 + infer_end.tv_usec / 1000 -
(infer_start.tv_sec * 1000 + infer_start.tv_usec / 1000));
#if 0
LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " seq_len : " << out->at(0).shape[1]
<< " emb_size : " << out->at(0).shape[2];
float *out_data = reinterpret_cast<float *>(out->at(0).data.data());
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < MAX_SEQ_LEN; si++) {
Embedding_values *emb_instance = res_instance->add_instances();
for (uint32_t ei = 0; ei < EMB_SIZE; ei++) {
uint32_t index = bi * MAX_SEQ_LEN * EMB_SIZE + si * EMB_SIZE + ei;
emb_instance->add_values(out_data[index]);
}
}
}
#else
LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " emb_size : " << out->at(0).shape[1];
float *out_data = reinterpret_cast<float *>(out->at(0).data.data());
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < 1; si++) {
Embedding_values *emb_instance = res_instance->add_instances();
EmbeddingValues *emb_instance = res_instance->add_instances();
for (uint32_t ei = 0; ei < EMB_SIZE; ei++) {
uint32_t index = bi * EMB_SIZE + ei;
emb_instance->add_values(out_data[index]);
......@@ -189,7 +166,7 @@ int BertServiceOp::inference() {
res->set_op_time(op_time);
res->set_infer_time(infer_time);
#endif
for (size_t i = 0; i < in->size(); ++i) {
(*in)[i].shape.clear();
}
......
......@@ -21,7 +21,7 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif
#else
#include "./paddle_inference_api.h"
#include "paddle_inference_api.h" // NOLINT
#endif
#include "demo-serving/bert_service.pb.h"
......
......@@ -31,9 +31,9 @@ message BertReqInstance {
message Request { repeated BertReqInstance instances = 1; };
message Embedding_values { repeated float values = 1; };
message EmbeddingValues { repeated float values = 1; };
message BertResInstance { repeated Embedding_values instances = 1; };
message BertResInstance { repeated EmbeddingValues instances = 1; };
message Response {
repeated BertResInstance instances = 1;
......
......@@ -31,9 +31,9 @@ message BertReqInstance {
message Request { repeated BertReqInstance instances = 1; };
message Embedding_values { repeated float values = 1; };
message EmbeddingValues { repeated float values = 1; };
message BertResInstance { repeated Embedding_values instances = 1; };
message BertResInstance { repeated EmbeddingValues instances = 1; };
message Response {
repeated BertResInstance instances = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册