提交 7c4d218d 编写于 作者: X xulongteng

Merge branch 'bert' of https://github.com/MRXLT/Serving into bert

......@@ -31,7 +31,7 @@ set(DYNAMIC_LIB
-lcurl
-lssl
-lcrypto
# ${CURL_LIB}
${CURL_LIB}
)
target_link_libraries(cube-builder ${DYNAMIC_LIB})
......
此差异已折叠。
......@@ -17,9 +17,10 @@
#include <unistd.h>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <string>
#include <thread> //NOLINT
#include "./data_pre.h"
#include <vector>
#include "sdk-cpp/bert_service.pb.h"
#include "sdk-cpp/include/common.h"
#include "sdk-cpp/include/predictor_sdk.h"
......@@ -41,9 +42,23 @@ extern int max_turn = 1000;
std::atomic<int> g_concurrency(0);
std::vector<std::vector<int>> response_time;
std::vector<std::vector<int>> infer_time;
char* data_filename = "./data/bert/demo_wiki_train";
char* data_filename = "./data/bert/demo_wiki_data";
std::vector<std::string> split(const std::string& str,
const std::string& pattern) {
std::vector<std::string> res;
if (str == "") return res;
std::string strs = str + pattern;
size_t pos = strs.find(pattern);
while (pos != strs.npos) {
std::string temp = strs.substr(0, pos);
res.push_back(temp);
strs = strs.substr(pos + 1, strs.size());
pos = strs.find(pattern);
}
return res;
}
#if 1
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
......@@ -82,49 +97,6 @@ int create_req(Request* req,
}
return 0;
}
#else
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
int batch_size) {
for (int i = 0; i < batch_size; ++i) {
BertReqInstance* ins = req->add_instances();
if (!ins) {
LOG(ERROR) << "Failed create req instance";
return -1;
}
// add data
// avoid out of boundary
int cur_index = data_index + i;
if (cur_index >= data_list.size()) {
cur_index = cur_index % data_list.size();
}
std::vector<std::string> feature_list = split(data_list[cur_index], ":");
std::vector<std::string> shape_list = split(feature_list[0], " ");
std::vector<std::string> token_list = split(feature_list[1], " ");
std::vector<std::string> pos_list = split(feature_list[2], " ");
std::vector<std::string> seg_list = split(feature_list[3], " ");
std::vector<std::string> mask_list = split(feature_list[4], " ");
for (int fi = 0; fi < max_seq_len; fi++) {
if (fi < token_list.size()) {
ins->add_token_ids(std::stoi(token_list[fi]));
ins->add_sentence_type_ids(std::stoll(seg_list[fi]));
ins->add_position_ids(std::stoll(pos_list[fi]));
ins->add_input_masks(std::stof(mask_list[fi]));
} else {
ins->add_token_ids(0);
ins->add_sentence_type_ids(0);
ins->add_position_ids(0);
ins->add_input_masks(0.0);
}
}
}
return 0;
}
#endif
void print_res(const Request& req,
const Response& res,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SERVING_DEMO_CLIENT_SRC_DATA_PRE_H_
#define SERVING_DEMO_CLIENT_SRC_DATA_PRE_H_
#include <sys/stat.h>
#include <iostream>
#include <map>
#include <string>
#include <vector>
std::vector<std::string> split(const std::string& str,
const std::string& pattern) {
std::vector<std::string> res;
if (str == "") return res;
std::string strs = str + pattern;
size_t pos = strs.find(pattern);
while (pos != strs.npos) {
std::string temp = strs.substr(0, pos);
res.push_back(temp);
strs = strs.substr(pos + 1, strs.size());
pos = strs.find(pattern);
}
return res;
}
std::map<std::string, uint64_t> dict;
char* dict_filename = "./data/bert/vocab.txt";
int make_dict() {
std::ifstream dict_file(dict_filename);
if (!dict_file) {
std::cout << "read dict failed" << std::endl;
return -1;
}
std::string line;
uint64_t index = 0;
while (getline(dict_file, line)) {
dict[line] = 0;
index += 1;
}
return 0;
}
class BertData {
public:
int gen_data(std::string line) {
std::vector<std::string> data_list;
data_list = split(line, " ");
tokenization(data_list);
return 0;
}
int tokenization(std::vector<std::string> data_list) {}
private:
std::vector<uint64_t> token_list;
std::vector<uint64_t> seg_list;
std::vector<uint64_t> pos_list;
std::vector<float> input_masks;
};
#endif // SERVING_DEMO_CLIENT_SRC_DATA_PRE_H_
......@@ -24,6 +24,18 @@ if (NOT EXISTS
${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid)
endif()
if (NOT EXISTS
${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/bert_cased_L-12_H-768_A-12)
execute_process(COMMAND wget --no-check-certificate
https://paddle-serving.bj.bcebos.com/data/bert/bert_cased_L-12_H-768_A-12.tar.gz
--output-document
${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/bert_cased_L-12_H-768_A-12.tar.gz)
execute_process(COMMAND ${CMAKE_COMMAND} -E tar xzf
"${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/bert_cased_L-12_H-768_A-12.tar.gz"
WORKING_DIRECTORY
${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid)
endif()
include_directories(SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../kvdb/include)
include(op/CMakeLists.txt)
......
......@@ -41,7 +41,7 @@ engines {
type: "FLUID_CPU_ANALYSIS_DIR"
reloadable_meta: "./data/model/paddle/fluid_time_file"
reloadable_type: "timestamp_ne"
model_data_path: "./data/model/paddle/fluid/bert"
model_data_path: "./data/model/paddle/fluid/bert_cased_L-12_H-768_A-12"
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
......
services {
name: "BuiltinDenseFormatService"
workflows: "workflow1"
}
services {
name: "BuiltinSparseFormatService"
workflows: "workflow2"
}
services {
name: "BuiltinTestEchoService"
workflows: "workflow3"
}
services {
name: "ImageClassifyService"
workflows: "workflow4"
}
services {
name: "BuiltinFluidService"
workflows: "workflow5"
}
services {
name: "TextClassificationService"
workflows: "workflow6"
}
services {
name: "EchoKVDBService"
workflows: "workflow7"
}
services {
name: "CTRPredictionService"
workflows: "workflow8"
}
services {
name: "BertService"
workflows: "workflow9"
......
......@@ -17,7 +17,9 @@
#include <pthread.h>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "configure/include/configure_parser.h"
#include "configure/inferencer_configure.pb.h"
......@@ -199,7 +201,6 @@ class FluidGpuAnalysisDirCore : public FluidFamilyCore {
analysis_config.EnableUseGpu(100, FLAGS_gpuid);
analysis_config.SwitchSpecifyInputNames(true);
analysis_config.SetCpuMathLibraryNumThreads(1);
analysis_config.SwitchIrOptim(true);
if (params.enable_memory_optimization()) {
analysis_config.EnableMemoryOptim();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册