提交 097d5d82 编写于 作者: X xulongteng

change bert service

上级 024b3063
......@@ -28,7 +28,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Embedding_values;
const uint32_t MAX_SEQ_LEN = 512;
const uint32_t MAX_SEQ_LEN = 64;
const bool POOLING = true;
const int LAYER_NUM = 12;
const int EMB_SIZE = 768;
......@@ -75,10 +75,10 @@ int BertServiceOp::inference() {
for (uint32_t i = 0; i < batch_size; i++) {
lod_set[0].push_back(i * MAX_SEQ_LEN);
}
src_ids.lod = lod_set;
pos_ids.lod = lod_set;
seg_ids.lod = lod_set;
input_masks.lod = lod_set;
//src_ids.lod = lod_set;
//pos_ids.lod = lod_set;
//seg_ids.lod = lod_set;
//input_masks.lod = lod_set;
uint32_t index = 0;
for (uint32_t i = 0; i < batch_size; i++) {
......@@ -92,17 +92,17 @@ int BertServiceOp::inference() {
memcpy(src_data,
req_instance.token_ids().data(),
sizeof(int64_t) * req_instance.token_ids_size());
sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(pos_data,
req_instance.position_ids().data(),
sizeof(int64_t) * req_instance.position_ids_size());
sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(seg_data,
req_instance.sentence_type_ids().data(),
sizeof(int64_t) * req_instance.sentence_type_ids_size());
sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(input_masks_data,
req_instance.input_masks().data(),
sizeof(float) * req_instance.input_masks_size());
index += req_instance.input_masks_size();
sizeof(float) * MAX_SEQ_LEN);
index += MAX_SEQ_LEN;
}
in->push_back(src_ids);
......@@ -115,14 +115,14 @@ int BertServiceOp::inference() {
LOG(ERROR) << "Failed get tls output object";
return -1;
}
/*
LOG(INFO) << "batch_size : " << batch_size;
LOG(INFO) << "MAX_SEQ_LEN : " << req->instances(0).input_masks_size();
int64_t* example = (int64_t*)(*in)[2].data.data();
LOG(INFO) << "MAX_SEQ_LEN : " << (*in)[0].shape[1];
float* example = (float*)(*in)[3].data.data();
for(uint32_t i = 0; i < MAX_SEQ_LEN; i++){
LOG(INFO) << *(example + i);
}
*/
if (predictor::InferManager::instance().infer(
BERT_MODEL_NAME, in, out, batch_size)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册