提交 097d5d82 编写于 作者: X xulongteng

change bert service

上级 024b3063
...@@ -28,7 +28,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance; ...@@ -28,7 +28,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request; using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Embedding_values; using baidu::paddle_serving::predictor::bert_service::Embedding_values;
const uint32_t MAX_SEQ_LEN = 512; const uint32_t MAX_SEQ_LEN = 64;
const bool POOLING = true; const bool POOLING = true;
const int LAYER_NUM = 12; const int LAYER_NUM = 12;
const int EMB_SIZE = 768; const int EMB_SIZE = 768;
...@@ -75,10 +75,10 @@ int BertServiceOp::inference() { ...@@ -75,10 +75,10 @@ int BertServiceOp::inference() {
for (uint32_t i = 0; i < batch_size; i++) { for (uint32_t i = 0; i < batch_size; i++) {
lod_set[0].push_back(i * MAX_SEQ_LEN); lod_set[0].push_back(i * MAX_SEQ_LEN);
} }
src_ids.lod = lod_set; //src_ids.lod = lod_set;
pos_ids.lod = lod_set; //pos_ids.lod = lod_set;
seg_ids.lod = lod_set; //seg_ids.lod = lod_set;
input_masks.lod = lod_set; //input_masks.lod = lod_set;
uint32_t index = 0; uint32_t index = 0;
for (uint32_t i = 0; i < batch_size; i++) { for (uint32_t i = 0; i < batch_size; i++) {
...@@ -92,17 +92,17 @@ int BertServiceOp::inference() { ...@@ -92,17 +92,17 @@ int BertServiceOp::inference() {
memcpy(src_data, memcpy(src_data,
req_instance.token_ids().data(), req_instance.token_ids().data(),
sizeof(int64_t) * req_instance.token_ids_size()); sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(pos_data, memcpy(pos_data,
req_instance.position_ids().data(), req_instance.position_ids().data(),
sizeof(int64_t) * req_instance.position_ids_size()); sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(seg_data, memcpy(seg_data,
req_instance.sentence_type_ids().data(), req_instance.sentence_type_ids().data(),
sizeof(int64_t) * req_instance.sentence_type_ids_size()); sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(input_masks_data, memcpy(input_masks_data,
req_instance.input_masks().data(), req_instance.input_masks().data(),
sizeof(float) * req_instance.input_masks_size()); sizeof(float) * MAX_SEQ_LEN);
index += req_instance.input_masks_size(); index += MAX_SEQ_LEN;
} }
in->push_back(src_ids); in->push_back(src_ids);
...@@ -115,14 +115,14 @@ int BertServiceOp::inference() { ...@@ -115,14 +115,14 @@ int BertServiceOp::inference() {
LOG(ERROR) << "Failed get tls output object"; LOG(ERROR) << "Failed get tls output object";
return -1; return -1;
} }
/*
LOG(INFO) << "batch_size : " << batch_size; LOG(INFO) << "batch_size : " << batch_size;
LOG(INFO) << "MAX_SEQ_LEN : " << req->instances(0).input_masks_size(); LOG(INFO) << "MAX_SEQ_LEN : " << (*in)[0].shape[1];
int64_t* example = (int64_t*)(*in)[2].data.data(); float* example = (float*)(*in)[3].data.data();
for(uint32_t i = 0; i < MAX_SEQ_LEN; i++){ for(uint32_t i = 0; i < MAX_SEQ_LEN; i++){
LOG(INFO) << *(example + i); LOG(INFO) << *(example + i);
} }
*/
if (predictor::InferManager::instance().infer( if (predictor::InferManager::instance().infer(
BERT_MODEL_NAME, in, out, batch_size)) { BERT_MODEL_NAME, in, out, batch_size)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册