提交 b6f66ca4 编写于 作者: X xulongteng

Merge remote-tracking branch 'refs/remotes/origin/bert' into bert

......@@ -19,11 +19,10 @@
#include <fstream>
#include <string>
#include <thread> //NOLINT
#include "./data_pre.h"
#include "sdk-cpp/bert_service.pb.h"
#include "sdk-cpp/include/common.h"
#include "sdk-cpp/include/predictor_sdk.h"
#include "data_pre.h"
using baidu::paddle_serving::sdk_cpp::Predictor;
using baidu::paddle_serving::sdk_cpp::PredictorApi;
using baidu::paddle_serving::predictor::bert_service::Request;
......@@ -83,7 +82,6 @@ int create_req(Request* req,
}
#else
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
......@@ -102,7 +100,7 @@ int create_req(Request* req,
}
std::vector<std::string> feature_list = split(data_list[cur_index], ":");
std::vector<std::string> shape_list = split(feature_list[0]," ");
std::vector<std::string> shape_list = split(feature_list[0], " ");
std::vector<std::string> token_list = split(feature_list[1], " ");
std::vector<std::string> pos_list = split(feature_list[2], " ");
std::vector<std::string> seg_list = split(feature_list[3], " ");
......@@ -203,6 +201,7 @@ void calc_time(int server_concurrency, int batch_size) {
LOG(INFO) << "Total request : " << (time_list.size());
LOG(INFO) << "Batch size : " << batch_size;
LOG(INFO) << "Max concurrency : " << server_concurrency;
LOG(INFO) << "Max Seq Len : " << max_seq_len;
float total_time = 0;
float max_time = 0;
float min_time = 1000000;
......@@ -240,17 +239,22 @@ void calc_time(int server_concurrency, int batch_size) {
LOG(INFO) << "99.9 percent ms: " << time_list[percent_pos_999];
} else {
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
}
}
int main(int argc, char** argv) {
PredictorApi api;
response_time.resize(thread_num);
int server_concurrency = thread_num;
if (argc > 1) {
thread_num = std::stoi(argv[1]);
batch_size = std::stoi(argv[2]);
max_seq_len = std::stoi(argv[3]);
}
response_time.resize(thread_num);
int server_concurrency = thread_num;
// log set
#ifdef BCLOUD
logging::LoggingSettings settings;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SERVING_DEMO_CLIENT_SRC_DATA_PRE_H_
#define SERVING_DEMO_CLIENT_SRC_DATA_PRE_H_
#include <sys/stat.h>
#include <map>
#include <string>
#include <vector>
std::vector<std::string> split(const std::string& str,
const std::string& pattern) {
std::vector<std::string> res;
if (str == "") return res;
std::string strs = str + pattern;
size_t pos = strs.find(pattern);
while (pos != strs.npos) {
std::string temp = strs.substr(0, pos);
res.push_back(temp);
strs = strs.substr(pos + 1, strs.size());
pos = strs.find(pattern);
}
return res;
}
std::map<std::string, uint64_t> dict;
char* dict_filename = "./data/bert/vocab.txt";
int make_dict() {
std::ifstream dict_file(dict_filename);
if (!dict_file) {
std::cout << "read dict failed" << std::endl;
return -1;
}
std::string line;
uint64_t index = 0;
while (getline(dict_file, line)) {
dict[line] = 0;
index += 1;
}
return 0;
}
class BertData {
public:
int gen_data(std::string line) {
std::vector<std::string> data_list;
data_list = split(line, " ");
tokenization(data_list);
return 0;
}
int tokenization(std::vector<std::string> data_list) {}
private:
std::vector<uint64_t> token_list;
std::vector<uint64_t> seg_list;
std::vector<uint64_t> pos_list;
std::vector<float> input_masks;
};
#endif // SERVING_DEMO_CLIENT_SRC_DATA_PRE_H_
engines {
name: "image_classification_resnet"
type: "FLUID_CPU_ANALYSIS_DIR"
reloadable_meta: "./data/model/paddle/fluid_time_file"
reloadable_type: "timestamp_ne"
model_data_path: "./data/model/paddle/fluid/SE_ResNeXt50_32x4d"
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
enable_memory_optimization: true
static_optimization: false
force_update_static_cache: false
}
engines {
name: "text_classification_bow"
type: "FLUID_CPU_ANALYSIS_DIR"
reloadable_meta: "./data/model/paddle/fluid_time_file"
reloadable_type: "timestamp_ne"
model_data_path: "./data/model/paddle/fluid/text_classification_lstm"
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
}
engines {
name: "ctr_prediction"
type: "FLUID_CPU_ANALYSIS_DIR"
reloadable_meta: "./data/model/paddle/fluid_time_file"
reloadable_type: "timestamp_ne"
model_data_path: "./data/model/paddle/fluid/ctr_prediction"
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
sparse_param_service_type: REMOTE
sparse_param_service_table_name: "dict"
}
engines {
name: "bert"
type: "FLUID_CPU_ANALYSIS_DIR"
reloadable_meta: "./data/model/paddle/fluid_time_file"
reloadable_type: "timestamp_ne"
model_data_path: "./data/model/paddle/fluid/bert"
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
}
services {
name: "BuiltinDenseFormatService"
workflows: "workflow1"
}
services {
name: "BuiltinSparseFormatService"
workflows: "workflow2"
}
services {
name: "BuiltinTestEchoService"
workflows: "workflow3"
}
services {
name: "ImageClassifyService"
workflows: "workflow4"
}
services {
name: "BuiltinFluidService"
workflows: "workflow5"
}
services {
name: "TextClassificationService"
workflows: "workflow6"
}
services {
name: "EchoKVDBService"
workflows: "workflow7"
}
services {
name: "CTRPredictionService"
workflows: "workflow8"
}
services {
name: "BertService"
workflows: "workflow9"
}
......@@ -77,10 +77,10 @@ int BertServiceOp::inference() {
for (uint32_t i = 0; i < batch_size; i++) {
lod_set[0].push_back(i * MAX_SEQ_LEN);
}
//src_ids.lod = lod_set;
//pos_ids.lod = lod_set;
//seg_ids.lod = lod_set;
//input_masks.lod = lod_set;
// src_ids.lod = lod_set;
// pos_ids.lod = lod_set;
// seg_ids.lod = lod_set;
// input_masks.lod = lod_set;
uint32_t index = 0;
for (uint32_t i = 0; i < batch_size; i++) {
......@@ -120,12 +120,11 @@ int BertServiceOp::inference() {
return -1;
}
/*
/*
float* example = (float*)(*in)[3].data.data();
for(uint32_t i = 0; i < MAX_SEQ_LEN; i++){
LOG(INFO) << *(example + i);
*/
*/
if (predictor::InferManager::instance().infer(
BERT_MODEL_NAME, in, out, batch_size)) {
......@@ -138,7 +137,7 @@ int BertServiceOp::inference() {
<< " seq_len : " << out->at(0).shape[1]
<< " emb_size : " << out->at(0).shape[2];
float *out_data = (float*) out->at(0).data.data();
float *out_data = reinterpret_cast<float *>out->at(0).data.data();
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < MAX_SEQ_LEN; si++) {
......@@ -152,13 +151,13 @@ int BertServiceOp::inference() {
#else
LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " emb_size : " << out->at(0).shape[1];
float *out_data = (float*) out->at(0).data.data();
float *out_data = reinterpret_cast<float *> out->at(0).data.data();
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < 1; si++) {
Embedding_values *emb_instance = res_instance->add_instances();
for (uint32_t ei = 0; ei < EMB_SIZE; ei++) {
uint32_t index = bi * MAX_SEQ_LEN * EMB_SIZE + si * EMB_SIZE + ei;
uint32_t index = bi * EMB_SIZE + ei;
emb_instance->add_values(out_data[index]);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册