未验证 提交 494c756d 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #99 from MRXLT/bert

Add dockerfile for compile & change bert proto for paddlehub & cube cli support multi thread
FROM centos:centos6.10
RUN export http_proxy=http://172.19.56.199:3128 \
&& export https_proxy=http://172.19.56.199:3128 \
&& yum -y install wget \
&& wget http://people.centos.org/tru/devtools-2/devtools-2.repo -O /etc/yum.repos.d/devtoolset-2.repo \
&& yum -y install devtoolset-2-gcc devtoolset-2-gcc-c++ devtoolset-2-binutils \
&& source /opt/rh/devtoolset-2/enable \
&& echo "source /opt/rh/devtoolset-2/enable" >> /etc/profile \
&& yum -y install git openssl-devel curl-devel bzip2-devel \
&& wget https://cmake.org/files/v3.5/cmake-3.5.2.tar.gz \
&& tar xvf cmake-3.5.2.tar.gz \
&& cd cmake-3.5.2 \
&& ./bootstrap --prefix=/usr \
&& make \
&& make install \
&& cd .. \
&& rm -r cmake-3.5.2* \
&& wget https://dl.google.com/go/go1.12.12.linux-amd64.tar.gz \
&& tar -xzvf go1.12.12.linux-amd64.tar.gz \
&& mv go /usr/local/go \
&& rm go1.12.12.linux-amd64.tar.gz \
&& echo "export GOROOT=/usr/local/go" >> /root/.bashrc \
&& echo "export GOPATH=$HOME/go" >> /root/.bashrc \
&& echo "export PATH=$PATH:/usr/local/go/bin" >> /root/.bashrc
FROM paddlepaddle/paddle_manylinux_devel:cuda9.0_cudnn7
RUN yum -y install git openssl-devel curl-devel bzip2-devel \
&& wget https://dl.google.com/go/go1.12.12.linux-amd64.tar.gz \
&& tar -xzvf go1.12.12.linux-amd64.tar.gz \
&& rm -rf /usr/local/go \
&& mv go /usr/local/go \
&& rm go1.12.12.linux-amd64.tar.gz \
&& echo "GOROOT=/usr/local/go" >> /root/.bashrc \
&& echo "GOPATH=$HOME/go" >> /root/.bashrc \
&& echo "PATH=$PATH:$GOROOT/bin" >> /root/.bashrc
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <atomic> #include <atomic>
#include <fstream>
#include <thread> //NOLINT
#include "cube/cube-api/include/cube_api.h" #include "cube/cube-api/include/cube_api.h"
#define TIME_FLAG(flag) \ #define TIME_FLAG(flag) \
struct timeval flag; \ struct timeval flag; \
gettimeofday(&(flag), NULL); gettimeofday(&(flag), NULL);
...@@ -28,6 +28,11 @@ DEFINE_uint64(batch, 500, "batch size"); ...@@ -28,6 +28,11 @@ DEFINE_uint64(batch, 500, "batch size");
DEFINE_int32(timeout, 200, "timeout in ms"); DEFINE_int32(timeout, 200, "timeout in ms");
DEFINE_int32(retry, 3, "retry times"); DEFINE_int32(retry, 3, "retry times");
DEFINE_bool(print_output, false, "print output flag"); DEFINE_bool(print_output, false, "print output flag");
DEFINE_int32(thread_num, 1, "thread num");
std::atomic<int> g_concurrency(0);
std::vector<uint64_t> time_list;
std::vector<uint64_t> request_list;
namespace { namespace {
inline uint64_t time_diff(const struct timeval& start_time, inline uint64_t time_diff(const struct timeval& start_time,
...@@ -53,7 +58,7 @@ std::string string_to_hex(const std::string& input) { ...@@ -53,7 +58,7 @@ std::string string_to_hex(const std::string& input) {
return output; return output;
} }
int run(int argc, char** argv) { int run(int argc, char** argv, int thread_id) {
google::ParseCommandLineFlags(&argc, &argv, true); google::ParseCommandLineFlags(&argc, &argv, true);
CubeAPI* cube = CubeAPI::instance(); CubeAPI* cube = CubeAPI::instance();
...@@ -62,13 +67,13 @@ int run(int argc, char** argv) { ...@@ -62,13 +67,13 @@ int run(int argc, char** argv) {
LOG(ERROR) << "init cube api failed err=" << ret; LOG(ERROR) << "init cube api failed err=" << ret;
return ret; return ret;
} }
/*
FILE* key_file = fopen(FLAGS_keys.c_str(), "r"); FILE* key_file = fopen(FLAGS_keys.c_str(), "r");
if (key_file == NULL) { if (key_file == NULL) {
LOG(ERROR) << "open key file [" << FLAGS_keys << "] failed"; LOG(ERROR) << "open key file [" << FLAGS_keys << "] failed";
return -1; return -1;
} }
*/
std::atomic<uint64_t> seek_counter(0); std::atomic<uint64_t> seek_counter(0);
std::atomic<uint64_t> seek_cost_total(0); std::atomic<uint64_t> seek_cost_total(0);
uint64_t seek_cost_max = 0; uint64_t seek_cost_max = 0;
...@@ -78,14 +83,32 @@ int run(int argc, char** argv) { ...@@ -78,14 +83,32 @@ int run(int argc, char** argv) {
std::vector<uint64_t> keys; std::vector<uint64_t> keys;
std::vector<CubeValue> values; std::vector<CubeValue> values;
while (fgets(buffer, 1024, key_file)) { std::string line;
uint64_t key = strtoul(buffer, NULL, 10); std::vector<int64_t> key_list;
keys.push_back(key); std::ifstream key_file(FLAGS_keys.c_str());
while (getline(key_file, line)) {
key_list.push_back(std::stoll(line));
}
uint64_t file_size = key_list.size();
uint64_t index = 0;
uint64_t request = 0;
while (g_concurrency.load() >= FLAGS_thread_num) {
}
g_concurrency++;
while (index < file_size) {
// uint64_t key = strtoul(buffer, NULL, 10);
keys.push_back(key_list[index]);
index += 1;
int ret = 0; int ret = 0;
if (keys.size() > FLAGS_batch) { if (keys.size() > FLAGS_batch) {
TIME_FLAG(seek_start); TIME_FLAG(seek_start);
ret = cube->seek(FLAGS_dict, keys, &values); ret = cube->seek(FLAGS_dict, keys, &values);
TIME_FLAG(seek_end); TIME_FLAG(seek_end);
request += 1;
if (ret != 0) { if (ret != 0) {
LOG(WARNING) << "cube seek failed"; LOG(WARNING) << "cube seek failed";
} else if (FLAGS_print_output) { } else if (FLAGS_print_output) {
...@@ -110,37 +133,40 @@ int run(int argc, char** argv) { ...@@ -110,37 +133,40 @@ int run(int argc, char** argv) {
values.clear(); values.clear();
} }
} }
/*
if (keys.size() > 0) {
int ret = 0;
values.resize(keys.size());
TIME_FLAG(seek_start);
ret = cube->seek(FLAGS_dict, keys, &values);
TIME_FLAG(seek_end);
if (ret != 0) {
LOG(WARNING) << "cube seek failed";
} else if (FLAGS_print_output) {
for (size_t i = 0; i < keys.size(); ++i) {
fprintf(stdout,
"key:%lu value:%s\n",
keys[i],
string_to_hex(values[i].buff).c_str());
}
}
if (keys.size() > 0) { ++seek_counter;
int ret = 0; uint64_t seek_cost = time_diff(seek_start, seek_end);
values.resize(keys.size()); seek_cost_total += seek_cost;
TIME_FLAG(seek_start); if (seek_cost > seek_cost_max) {
ret = cube->seek(FLAGS_dict, keys, &values); seek_cost_max = seek_cost;
TIME_FLAG(seek_end); }
if (ret != 0) { if (seek_cost < seek_cost_min) {
LOG(WARNING) << "cube seek failed"; seek_cost_min = seek_cost;
} else if (FLAGS_print_output) {
for (size_t i = 0; i < keys.size(); ++i) {
fprintf(stdout,
"key:%lu value:%s\n",
keys[i],
string_to_hex(values[i].buff).c_str());
} }
} }
*/
g_concurrency--;
++seek_counter; // fclose(key_file);
uint64_t seek_cost = time_diff(seek_start, seek_end);
seek_cost_total += seek_cost;
if (seek_cost > seek_cost_max) {
seek_cost_max = seek_cost;
}
if (seek_cost < seek_cost_min) {
seek_cost_min = seek_cost;
}
}
fclose(key_file);
ret = cube->destroy(); // ret = cube->destroy();
if (ret != 0) { if (ret != 0) {
LOG(WARNING) << "destroy cube api failed err=" << ret; LOG(WARNING) << "destroy cube api failed err=" << ret;
} }
...@@ -150,10 +176,50 @@ int run(int argc, char** argv) { ...@@ -150,10 +176,50 @@ int run(int argc, char** argv) {
LOG(INFO) << "seek cost max = " << seek_cost_max; LOG(INFO) << "seek cost max = " << seek_cost_max;
LOG(INFO) << "seek cost min = " << seek_cost_min; LOG(INFO) << "seek cost min = " << seek_cost_min;
time_list[thread_id] = seek_cost_avg;
request_list[thread_id] = request;
return 0; return 0;
} }
int run_m(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
int thread_num = FLAGS_thread_num;
request_list.resize(thread_num);
time_list.resize(thread_num);
std::vector<std::thread*> thread_pool;
for (int i = 0; i < thread_num; i++) {
thread_pool.push_back(new std::thread(run, argc, argv, i));
}
for (int i = 0; i < thread_num; i++) {
thread_pool[i]->join();
delete thread_pool[i];
}
uint64_t sum_time = 0;
uint64_t max_time = 0;
uint64_t min_time = 1000000;
uint64_t request_num = 0;
for (int i = 0; i < thread_num; i++) {
sum_time += time_list[i];
if (time_list[i] > max_time) {
max_time = time_list[i];
}
if (time_list[i] < min_time) {
min_time = time_list[i];
}
request_num += request_list[i];
}
uint64_t mean_time = sum_time / thread_num;
LOG(INFO) << thread_num << " thread seek cost"
<< " avg = " << std::to_string(mean_time)
<< " max = " << std::to_string(max_time)
<< " min = " << std::to_string(min_time);
LOG(INFO) << " total_request = " << std::to_string(request_num)
<< " speed = " << std::to_string(1000000 * thread_num / mean_time)
<< " query per second";
}
} // namespace mcube } // namespace mcube
} // namespace rec } // namespace rec
int main(int argc, char** argv) { return ::rec::mcube::run(argc, argv); } int main(int argc, char** argv) { return ::rec::mcube::run_m(argc, argv); }
...@@ -93,9 +93,9 @@ int create_req(Request* req, ...@@ -93,9 +93,9 @@ int create_req(Request* req,
ins->add_input_masks(0.0); ins->add_input_masks(0.0);
} }
} }
ins->set_max_seq_len(max_seq_len);
ins->set_emb_size(emb_size);
} }
req->set_max_seq_len(max_seq_len);
req->set_emb_size(emb_size);
return 0; return 0;
} }
...@@ -118,8 +118,8 @@ void print_res(const Request& req, ...@@ -118,8 +118,8 @@ void print_res(const Request& req,
oss << "]\n"; oss << "]\n";
LOG(INFO) << "Receive : " << oss.str(); LOG(INFO) << "Receive : " << oss.str();
} }
LOG(INFO) << "Succ call predictor[ctr_prediction_service], the tag is: " LOG(INFO) << "Succ call predictor[bert_service], the tag is: " << route_tag
<< route_tag << ", elapse_ms: " << elapse_ms; << ", elapse_ms: " << elapse_ms;
} }
void thread_worker(PredictorApi* api, void thread_worker(PredictorApi* api,
......
...@@ -28,6 +28,21 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance; ...@@ -28,6 +28,21 @@ using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request; using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::EmbeddingValues; using baidu::paddle_serving::predictor::bert_service::EmbeddingValues;
std::vector<std::string> split(const std::string &str,
const std::string &pattern) {
std::vector<std::string> res;
if (str == "") return res;
std::string strs = str + pattern;
size_t pos = strs.find(pattern);
while (pos != strs.npos) {
std::string temp = strs.substr(0, pos);
res.push_back(temp);
strs = strs.substr(pos + 1, strs.size());
pos = strs.find(pattern);
}
return res;
}
int BertServiceOp::inference() { int BertServiceOp::inference() {
timeval op_start; timeval op_start;
gettimeofday(&op_start, NULL); gettimeofday(&op_start, NULL);
...@@ -43,17 +58,27 @@ int BertServiceOp::inference() { ...@@ -43,17 +58,27 @@ int BertServiceOp::inference() {
return 0; return 0;
} }
const int64_t MAX_SEQ_LEN = req->instances(0).max_seq_len(); const int64_t MAX_SEQ_LEN = req->max_seq_len();
const int64_t EMB_SIZE = req->instances(0).emb_size(); const int64_t EMB_SIZE = req->emb_size();
paddle::PaddleTensor src_ids; paddle::PaddleTensor src_ids;
paddle::PaddleTensor pos_ids; paddle::PaddleTensor pos_ids;
paddle::PaddleTensor seg_ids; paddle::PaddleTensor seg_ids;
paddle::PaddleTensor input_masks; paddle::PaddleTensor input_masks;
src_ids.name = std::string("src_ids");
pos_ids.name = std::string("pos_ids"); if (req->has_feed_var_names()) {
seg_ids.name = std::string("sent_ids"); // support paddlehub model
input_masks.name = std::string("input_mask"); std::vector<std::string> feed_list = split(req->feed_var_names(), ";");
src_ids.name = feed_list[0];
pos_ids.name = feed_list[1];
seg_ids.name = feed_list[2];
input_masks.name = feed_list[3];
} else {
src_ids.name = std::string("src_ids");
pos_ids.name = std::string("pos_ids");
seg_ids.name = std::string("sent_ids");
input_masks.name = std::string("input_mask");
}
src_ids.dtype = paddle::PaddleDType::INT64; src_ids.dtype = paddle::PaddleDType::INT64;
src_ids.shape = {batch_size, MAX_SEQ_LEN, 1}; src_ids.shape = {batch_size, MAX_SEQ_LEN, 1};
......
...@@ -25,11 +25,15 @@ message BertReqInstance { ...@@ -25,11 +25,15 @@ message BertReqInstance {
repeated int64 sentence_type_ids = 2; repeated int64 sentence_type_ids = 2;
repeated int64 position_ids = 3; repeated int64 position_ids = 3;
repeated float input_masks = 4; repeated float input_masks = 4;
optional int64 max_seq_len = 5;
optional int64 emb_size = 6;
}; };
message Request { repeated BertReqInstance instances = 1; }; message Request {
repeated BertReqInstance instances = 1;
optional int64 max_seq_len = 2;
optional int64 emb_size = 3;
optional string feed_var_names = 4;
optional string fetch_var_names = 5;
};
message EmbeddingValues { repeated float values = 1; }; message EmbeddingValues { repeated float values = 1; };
......
# 使用Docker编译Paddle Serving
## Docker编译环境要求
+ 开发机上已安装Docker
+ 编译GPU版本需要安装nvidia-docker
[CPU版本Dockerfile](../Dockerfile)
[GPU版本Dockerfile](../Dockerfile.gpu)
## 使用方法
### 构建Docker镜像
建立新目录,复制Dockerfile内容到该目录下Dockerfile文件
执行
```bash
docker build -t serving_compile:cpu .
```
或者
```bash
docker build -t serving_compile:cuda9 .
```
## 进入Docker
CPU版本请执行
```bash
docker run -it serving_compile:cpu bash
```
GPU版本请执行
```bash
docker run -it --runtime=nvidia -it serving_compile:cuda9 bash
```
## 预编译文件可执行环境列表
| Docker预编译版本可运行环境 |
| -------------------------- |
| Centos6 |
| Centos7 |
| Ubuntu16.04 |
| Ubuntu 18.04 |
| GPU Docker预编译版本支持的CUDA版本 |
| ---------------------------------- |
| cuda8_cudnn7 |
| cuda9_cudnn7 |
| cuda10_cudnn7 |
**备注:**
+ 若执行预编译版本出现找不到libcrypto.so.10、libssl.so.10的情况,可以将Docker环境中的/usr/lib64/libssl.so.10与/usr/lib64/libcrypto.so.10复制到可执行文件所在目录。
+ CPU预编译版本仅可在CPU机器上执行,GPU预编译版本仅可在GPU机器上执行
...@@ -16,6 +16,8 @@ openssl & openssl-devel ...@@ -16,6 +16,8 @@ openssl & openssl-devel
## 编译 ## 编译
推荐使用Docker编译Paddle Serving, [Docker编译使用说明](./DOCKER.md)
```shell ```shell
$ git clone https://github.com/PaddlePaddle/serving.git $ git clone https://github.com/PaddlePaddle/serving.git
$ cd serving $ cd serving
......
...@@ -25,11 +25,15 @@ message BertReqInstance { ...@@ -25,11 +25,15 @@ message BertReqInstance {
repeated int64 sentence_type_ids = 2; repeated int64 sentence_type_ids = 2;
repeated int64 position_ids = 3; repeated int64 position_ids = 3;
repeated float input_masks = 4; repeated float input_masks = 4;
optional int64 max_seq_len = 5;
optional int64 emb_size = 6;
}; };
message Request { repeated BertReqInstance instances = 1; }; message Request {
repeated BertReqInstance instances = 1;
optional int64 max_seq_len = 2;
optional int64 emb_size = 3;
optional string feed_var_names = 4;
optional string fetch_var_names = 5;
};
message EmbeddingValues { repeated float values = 1; }; message EmbeddingValues { repeated float values = 1; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册