提交 e1b83acf 编写于 作者: X xulongteng

add parameter

上级 062353e2
......@@ -22,6 +22,7 @@
#include "sdk-cpp/bert_service.pb.h"
#include "sdk-cpp/include/common.h"
#include "sdk-cpp/include/predictor_sdk.h"
#include "data_pre.h"
using baidu::paddle_serving::sdk_cpp::Predictor;
using baidu::paddle_serving::sdk_cpp::PredictorApi;
......@@ -31,31 +32,17 @@ using baidu::paddle_serving::predictor::bert_service::BertResInstance;
using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Embedding_values;
int batch_size = 1;
int max_seq_len = 82;
int layer_num = 12;
int emb_size = 768;
int thread_num = 1;
extern int batch_size = 1;
extern int max_seq_len = 128;
extern int layer_num = 12;
extern int emb_size = 768;
extern int thread_num = 1;
std::atomic<int> g_concurrency(0);
std::vector<std::vector<int>> response_time;
char* data_filename = "./data/bert/demo_wiki_train";
std::vector<std::string> split(const std::string& str,
const std::string& pattern) {
std::vector<std::string> res;
if (str == "") return res;
std::string strs = str + pattern;
size_t pos = strs.find(pattern);
while (pos != strs.npos) {
std::string temp = strs.substr(0, pos);
res.push_back(temp);
strs = strs.substr(pos + 1, strs.size());
pos = strs.find(pattern);
}
return res;
}
/*
#if 1
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
......@@ -90,10 +77,11 @@ int create_req(Request* req,
ins->add_input_masks(0.0);
}
}
ins->set_max_seq_len(max_seq_len);
}
return 0;
}
*/
#else
int create_req(Request* req,
......@@ -136,6 +124,8 @@ int create_req(Request* req,
return 0;
}
#endif
void print_res(const Request& req,
const Response& res,
std::string route_tag,
......@@ -256,6 +246,11 @@ int main(int argc, char** argv) {
PredictorApi api;
response_time.resize(thread_num);
int server_concurrency = thread_num;
if (argc > 1) {
thread_num = std::stoi(argv[1]);
batch_size = std::stoi(argv[2]);
max_seq_len = std::stoi(argv[3]);
}
// log set
#ifdef BCLOUD
logging::LoggingSettings settings;
......
......@@ -28,7 +28,7 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Embedding_values;
const uint32_t MAX_SEQ_LEN = 64;
extern int64_t MAX_SEQ_LEN = 128;
const bool POOLING = true;
const int LAYER_NUM = 12;
const int EMB_SIZE = 768;
......@@ -45,6 +45,8 @@ int BertServiceOp::inference() {
return 0;
}
MAX_SEQ_LEN = req->instances(0).max_seq_len();
paddle::PaddleTensor src_ids;
paddle::PaddleTensor pos_ids;
paddle::PaddleTensor seg_ids;
......@@ -93,6 +95,7 @@ int BertServiceOp::inference() {
memcpy(src_data,
req_instance.token_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN);
#if 1
memcpy(pos_data,
req_instance.position_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN);
......@@ -102,6 +105,7 @@ int BertServiceOp::inference() {
memcpy(input_masks_data,
req_instance.input_masks().data(),
sizeof(float) * MAX_SEQ_LEN);
#endif
index += MAX_SEQ_LEN;
}
......@@ -116,12 +120,11 @@ int BertServiceOp::inference() {
return -1;
}
LOG(INFO) << "batch_size : " << batch_size;
LOG(INFO) << "MAX_SEQ_LEN : " << (*in)[0].shape[1];
/*
float* example = (float*)(*in)[3].data.data();
for(uint32_t i = 0; i < MAX_SEQ_LEN; i++){
LOG(INFO) << *(example + i);
}
*/
if (predictor::InferManager::instance().infer(
......@@ -130,13 +133,12 @@ int BertServiceOp::inference() {
return -1;
}
// float *out_data = static_cast<float *>(out->at(0).data.data());
LOG(INFO) << "check point";
/*
#if 0
LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " seq_len : " << out->at(0).shape[1]
<< " emb_size : " << out->at(0).shape[2];
float *out_data = (float*) out->at(0).data.data();
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < MAX_SEQ_LEN; si++) {
......@@ -147,7 +149,22 @@ int BertServiceOp::inference() {
}
}
}
#else
LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " emb_size : " << out->at(0).shape[1];
float *out_data = (float*) out->at(0).data.data();
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < 1; si++) {
Embedding_values *emb_instance = res_instance->add_instances();
for (uint32_t ei = 0; ei < EMB_SIZE; ei++) {
uint32_t index = bi * MAX_SEQ_LEN * EMB_SIZE + si * EMB_SIZE + ei;
emb_instance->add_values(out_data[index]);
}
}
}
#endif
for (size_t i = 0; i < in->size(); ++i) {
(*in)[i].shape.clear();
}
......@@ -159,7 +176,6 @@ int BertServiceOp::inference() {
}
out->clear();
butil::return_object<TensorVector>(out);
*/
return 0;
}
......
......@@ -25,6 +25,7 @@ message BertReqInstance {
repeated int64 sentence_type_ids = 2;
repeated int64 position_ids = 3;
repeated float input_masks = 4;
required int64 max_seq_len = 5;
};
message Request { repeated BertReqInstance instances = 1; };
......
......@@ -200,6 +200,7 @@ class FluidGpuAnalysisDirCore : public FluidFamilyCore {
analysis_config.EnableUseGpu(100, FLAGS_gpuid);
analysis_config.SwitchSpecifyInputNames(true);
analysis_config.SetCpuMathLibraryNumThreads(1);
analysis_config.SwitchIrOptim(true);
if (params.enable_memory_optimization()) {
analysis_config.EnableMemoryOptim(params.static_optimization(),
......
......@@ -25,6 +25,7 @@ message BertReqInstance {
repeated int64 sentence_type_ids = 2;
repeated int64 position_ids = 3;
repeated float input_masks = 4;
required int64 max_seq_len = 5;
};
message Request { repeated BertReqInstance instances = 1; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册