提交 2e75c394 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #225 from MRXLT/general-server-bert-v1

updated readme && benchmark
## 语义理解预测服务
示例中采用BERT模型进行语义理解预测,将文本表示为向量的形式,可以用来做进一步的分析和预测。
### 获取模型
示例中采用[Paddlehub](https://github.com/PaddlePaddle/PaddleHub)中的[BERT中文模型](https://www.paddlepaddle.org.cn/hubdetail?name=bert_chinese_L-12_H-768_A-12&en_category=SemanticModel)
执行
```
python prepare_model.py
```
生成server端配置文件与模型文件,存放在serving_server_model文件夹
生成client端配置文件,存放在serving_client_conf文件夹
### 启动预测服务
执行
```
python bert_server.py serving_server_model 9292 #启动cpu预测服务
```
或者
```
python bert_gpu_server.py serving_server_model 9292 0 #在gpu 0上启动gpu预测服务
```
### 执行预测
执行
```
sh get_data.sh
```
获取中文样例数据
执行
```
head data-c.txt | python bert_client.py
```
将预测样例数据中的前十条样例,并将向量表示打印到标准输出。
### Benchmark
模型:bert_chinese_L-12_H-768_A-12
设备:GPU V100 * 1
环境:CUDA 9.2,cudnn 7.1.4
测试中将样例数据中的1W个样本复制为10W个样本,每个client线程发送线程数分之一个样本,batch size为1,max_seq_len为20,时间单位为秒.
在client线程数为4时,预测速度可以达到432样本每秒。
由于单张GPU内部只能串行计算,client线程增多只能减少GPU的空闲时间,因此在线程数达到4之后,线程数增多对预测速度没有提升。
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
| ------------------ | ------ | ------------ | ----- | ------ | ---- | ------- | ------ |
| 1 | 3.05 | 290.54 | 0.37 | 239.15 | 6.43 | 0.71 | 365.63 |
| 4 | 0.85 | 213.66 | 0.091 | 200.39 | 1.62 | 0.2 | 231.45 |
| 8 | 0.42 | 223.12 | 0.043 | 110.99 | 0.8 | 0.098 | 232.05 |
| 12 | 0.32 | 225.26 | 0.029 | 73.87 | 0.53 | 0.078 | 231.45 |
| 16 | 0.23 | 227.26 | 0.022 | 55.61 | 0.4 | 0.056 | 231.9 |
总耗时变化规律如下:
![bert benchmark](../../../doc/bert-benchmark-batch-size-1.png)
......@@ -17,7 +17,7 @@ from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
from test_bert_client import BertService
from bert_client import BertService
def predict(thr_id, resource):
......@@ -55,7 +55,7 @@ if __name__ == '__main__':
thread_num = sys.argv[3]
resource = {}
resource["conf_file"] = conf_file
resource["server_endpoint"] = ["127.0.0.1:9293"]
resource["server_endpoint"] = ["127.0.0.1:9292"]
resource["filelist"] = [data_file]
resource["thread_num"] = int(thread_num)
......
......@@ -17,7 +17,7 @@ from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
from test_bert_client import BertService
from bert_client import BertService
def predict(thr_id, resource, batch_size):
......
# coding:utf-8
import os
import sys
import numpy as np
import paddlehub as hub
import ujson
import random
import time
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client
......@@ -20,29 +22,22 @@ if is_py3:
class BertService():
def __init__(self,
profile=False,
max_seq_len=128,
model_name="bert_uncased_L-12_H-768_A-12",
show_ids=False,
do_lower_case=True,
process_id=0,
retry=3,
load_balance='round_robin'):
retry=3):
self.process_id = process_id
self.reader_flag = False
self.batch_size = 0
self.max_seq_len = max_seq_len
self.profile = profile
self.model_name = model_name
self.show_ids = show_ids
self.do_lower_case = do_lower_case
self.con_list = []
self.con_index = 0
self.load_balance = load_balance
self.server_list = []
self.serving_list = []
self.feed_var_names = ''
self.retry = retry
self.profile = True if ("FLAGS_profile_client" in os.environ and
os.environ["FLAGS_profile_client"]) else False
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
......@@ -51,7 +46,6 @@ class BertService():
position_ids = inputs["position_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name
self.reader = hub.reader.ClassifyReader(
vocab_path=module.get_vocab_path(),
dataset=None,
......@@ -69,6 +63,7 @@ class BertService():
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
result = []
prepro_start = time.time()
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
......@@ -81,6 +76,11 @@ class BertService():
"segment_ids": sent_list,
"input_mask": mask_list
}
prepro_end = time.time()
if self.profile:
print("PROFILE\tbert_pre_0:{} bert_pre_1:{}".format(
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map = self.client.predict(feed=feed, fetch=fetch)
return fetch_map
......@@ -90,6 +90,7 @@ class BertService():
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
result = []
prepro_start = time.time()
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
......@@ -108,6 +109,11 @@ class BertService():
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
}
feed_batch.append(feed)
prepro_end = time.time()
if self.profile:
print("PROFILE\tbert_pre_0:{} bert_pre_1:{}".format(
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map_batch = self.client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
return fetch_map_batch
......@@ -116,11 +122,11 @@ class BertService():
def test():
bc = BertService(
model_name='bert_uncased_L-12_H-768_A-12',
model_name='bert_chinese_L-12_H-768_A-12',
max_seq_len=20,
show_ids=False,
do_lower_case=True)
server_addr = ["127.0.0.1:9293"]
server_addr = ["127.0.0.1:9292"]
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
bc.load_client(config_file, server_addr)
......@@ -133,8 +139,7 @@ def test():
result = bc.run_batch_general(batch, fetch)
batch = []
for r in result:
for e in r["pooled_output"]:
print(e)
print(r)
if __name__ == '__main__':
......
......@@ -36,5 +36,7 @@ server.set_gpuid(1)
server.load_model_config(sys.argv[1])
port = int(sys.argv[2])
gpuid = sys.argv[3]
server.set_gpuid(gpuid)
server.prepare_server(workdir="work_dir1", port=port, device="gpu")
server.run_server()
......@@ -21,46 +21,10 @@ cat test.data | python test_client_batch.py inference.conf 4 > result
模型 :IMDB-CNN
测试中,client共发送2500条测试样本,图中数据为单个线程的耗时,时间单位为秒
server thread num :4
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
| 1 | 0.99 | 27.39 | 0.085 | 19.92 | 0.046 | 0.032 | 29.84 |
| 4 | 0.22 | 7.66 | 0.021 | 4.93 | 0.011 | 0.0082 | 8.28 |
| 8 | 0.1 | 6.66 | 0.01 | 2.42 | 0.0038 | 0.0046 | 6.95 |
| 12 | 0.074 | 6.87 | 0.0069 | 1.61 | 0.0059 | 0.0032 | 7.07 |
| 16 | 0.056 | 7.01 | 0.0053 | 1.23 | 0.0029 | 0.0026 | 7.17 |
| 20 | 0.045 | 7.02 | 0.0042 | 0.97 | 0.0023 | 0.002 | 7.15 |
| 24 | 0.039 | 7.012 | 0.0034 | 0.8 | 0.0019 | 0.0016 | 7.12 |
server thread num : 8
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
| 1 | 1.02 | 28.9 | 0.096 | 20.64 | 0.047 | 0.036 | 31.51 |
| 4 | 0.22 | 7.83 | 0.021 | 5.08 | 0.012 | 0.01 | 8.45 |
| 8 | 0.11 | 4.44 | 0.01 | 2.5 | 0.0059 | 0.0051 | 4.73 |
| 12 | 0.074 | 4.11 | 0.0069 | 1.65 | 0.0039 | 0.0029 | 4.31 |
| 16 | 0.057 | 4.2 | 0.0052 | 1.24 | 0.0029 | 0.0024 | 4.35 |
| 20 | 0.046 | 4.05 | 0.0043 | 1.01 | 0.0024 | 0.0021 | 4.18 |
| 24 | 0.038 | 4.02 | 0.0034 | 0.81 | 0.0019 | 0.0015 | 4.13 |
server thread num : 12
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
| 1 | 1.02 | 29.47 | 0.098 | 20.95 | 0.048 | 0.038 | 31.96 |
| 4 | 0.21 | 7.36 | 0.022 | 5.01 | 0.011 | 0.0081 | 7.95 |
| 8 | 0.11 | 4.52 | 0.011 | 2.58 | 0.0061 | 0.0051 | 4.83 |
| 12 | 0.072 | 3.25 | 0.0076 | 1.72 | 0.0042 | 0.0038 | 3.45 |
| 16 | 0.059 | 3.93 | 0.0055 | 1.26 | 0.0029 | 0.0023 | 4.1 |
| 20 | 0.047 | 3.79 | 0.0044 | 1.01 | 0.0024 | 0.0021 | 3.92 |
| 24 | 0.041 | 3.76 | 0.0036 | 0.83 | 0.0019 | 0.0017 | 3.87 |
server thread num : 16
测试中,client共发送25000条测试样本,图中数据为单个线程的耗时,时间单位为秒。可以看出,client端多线程的预测速度相比单线程有明显提升,在16线程时预测速度是单线程的8.7倍。
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
| 1 | 1.09 | 28.79 | 0.094 | 20.59 | 0.047 | 0.034 | 31.41 |
......@@ -71,26 +35,6 @@ server thread num : 16
| 20 | 0.049 | 3.77 | 0.0047 | 1.03 | 0.0025 | 0.0022 | 3.91 |
| 24 | 0.041 | 3.86 | 0.0039 | 0.85 | 0.002 | 0.0017 | 3.98 |
server thread num : 20
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
| 1 | 1.03 | 28.42 | 0.085 | 20.47 | 0.046 | 0.037 | 30.98 |
| 4 | 0.22 | 7.94 | 0.022 | 5.33 | 0.012 | 0.011 | 8.53 |
| 8 | 0.11 | 4.54 | 0.01 | 2.58 | 0.006 | 0.0046 | 4.84 |
| 12 | 0.079 | 4.54 | 0.0076 | 1.78 | 0.0042 | 0.0039 | 4.76 |
| 16 | 0.059 | 3.41 | 0.0057 | 1.33 | 0.0032 | 0.0027 | 3.58 |
| 20 | 0.051 | 4.33 | 0.0047 | 1.06 | 0.0025 | 0.0023 | 4.48 |
| 24 | 0.043 | 4.51 | 0.004 | 0.88 | 0.0021 | 0.0018 | 4.63 |
server thread num :24
预测总耗时变化规律如下:
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
| ------------------ | ------ | ------------ | ------ | ---- | ------ | ------- | ----- |
| 1 | 0.93 | 29.28 | 0.099 | 20.5 | 0.048 | 0.028 | 31.61 |
| 4 | 0.22 | 7.72 | 0.023 | 4.98 | 0.011 | 0.0095 | 8.33 |
| 8 | 0.11 | 4.77 | 0.012 | 2.65 | 0.0062 | 0.0049 | 5.09 |
| 12 | 0.081 | 4.22 | 0.0078 | 1.77 | 0.0042 | 0.0033 | 4.44 |
| 16 | 0.062 | 4.21 | 0.0061 | 1.34 | 0.0032 | 0.0026 | 4.39 |
| 20 | 0.5 | 3.58 | 0.005 | 1.07 | 0.0026 | 0.0023 | 3.72 |
| 24 | 0.043 | 4.27 | 0.0042 | 0.89 | 0.0022 | 0.0018 | 4.4 |
![total cost](../../../doc/imdb-benchmark-server-16.png)
## Timeline工具使用
serving框架中内置了预测服务中各阶段时间打点的功能,通过环境变量来控制是否开启。
```
export FLAGS_profile_client=1 #开启client端各阶段时间打点
export FLAGS_profile_server=1 #开启server端各阶段时间打点
```
开启该功能后,client端在预测的过程中会将对应的日志信息打印到标准输出。
为了更直观地展现各阶段的耗时,提供脚本对日志文件做进一步的分析处理。
使用时先将client的输出保存到文件,以profile为例。
```
python show_profile.py profile ${thread_num}
```
脚本将计算各阶段的耗时,并除以线程数做平均,打印到标准输出。
```
python timeline_trace.py profile trace
```
脚本将日志中的时间打点信息转换成json格式保存到trace文件,trace文件可以通过chrome浏览器的tracing功能进行可视化。
具体操作:打开chrome浏览器,在地址栏输入chrome://tracing/,跳转至tracing页面,点击load按钮,打开保存的trace文件,即可将预测服务的各阶段时间信息可视化。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册