提交 25084c89 编写于 作者: M MRXLT

refine bert demo

上级 9c71c83e
...@@ -12,29 +12,43 @@ python prepare_model.py ...@@ -12,29 +12,43 @@ python prepare_model.py
生成server端配置文件与模型文件,存放在serving_server_model文件夹 生成server端配置文件与模型文件,存放在serving_server_model文件夹
生成client端配置文件,存放在serving_client_conf文件夹 生成client端配置文件,存放在serving_client_conf文件夹
### 启动预测服务 ### 获取词典和样例数据
```
sh get_data.sh
```
脚本将下载中文词典vocab.txt和中文样例数据data-c.txt
### 启动RPC预测服务
执行 执行
``` ```
python bert_server.py serving_server_model 9292 #启动cpu预测服务 python -m paddle_serving_server.serve --model serving_server_model/ --port 9292 #启动cpu预测服务
``` ```
或者 或者
``` ```
python bert_gpu_server.py serving_server_model 9292 0 #在gpu 0上启动gpu预测服务 python -m paddle_serving_server_gpu.serve --model serving_server_model/ --port 9292 --gpu_ids 0 #在gpu 0上启动gpu预测服务
``` ```
### 执行预测 ### 执行预测
执行
``` ```
sh get_data.sh python bert_rpc_client.py --thread 4
``` ```
获取中文样例数据 启动client读取data-c.txt中的数据进行预测,--thread参数控制client的进程数,预测结束后会打印出每个进程的耗时,server端的地址在脚本中修改。
### 启动HTTP预测服务
```
export CUDA_VISIBLE_DEVICES=0,1
```
通过环境变量指定gpu预测服务使用的gpu,示例中指定索引为0和1的两块gpu
```
python bert_web_service.py serving_server_model/ 9292 #启动gpu预测服务
```
### 执行预测
执行
``` ```
head data-c.txt | python bert_client.py curl -H "Content-Type:application/json" -X POST -d '{"words": "hello", "fetch":["pooled_output"]}' http://127.0.0.1:9292/bert/prediction
``` ```
将预测样例数据中的前十条样例,并将向量表示打印到标准输出。
### Benchmark ### Benchmark
......
...@@ -33,27 +33,40 @@ args = benchmark_args() ...@@ -33,27 +33,40 @@ args = benchmark_args()
def single_func(idx, resource): def single_func(idx, resource):
fin = open("data-c.txt") fin = open("data-c.txt")
dataset = []
for line in fin:
dataset.append(line.strip())
if args.request == "rpc": if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
config_file = './serving_client_conf/serving_client_conf.prototxt' config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"] fetch = ["pooled_output"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % 4]]) client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time() start = time.time()
for line in fin: for i in range(1000):
feed_dict = reader.process(line) if args.batch_size == 1:
result = client.predict(feed=feed_dict, fetch=fetch) feed_dict = reader.process(dataset[i])
result = client.predict(feed=feed_dict, fetch=fetch)
elif args.batch_size > 1:
feed_batch = []
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[i]))
result = client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
end = time.time() end = time.time()
elif args.request == "http": elif args.request == "http":
start = time.time() start = time.time()
header = {"Content-Type": "application/json"} header = {"Content-Type": "application/json"}
for line in fin: for i in range(1000):
#dict_data = {"words": "this is for output ", "fetch": ["pooled_output"]} dict_data = {"words": dataset[i], "fetch": ["pooled_output"]}
dict_data = {"words": line, "fetch": ["pooled_output"]}
r = requests.post( r = requests.post(
'http://{}/bert/prediction'.format(resource["endpoint"][0]), 'http://{}/bert/prediction'.format(resource["endpoint"][
idx % len(resource["endpoint"])]),
data=json.dumps(dict_data), data=json.dumps(dict_data),
headers=header) headers=header)
end = time.time() end = time.time()
...@@ -62,10 +75,13 @@ def single_func(idx, resource): ...@@ -62,10 +75,13 @@ def single_func(idx, resource):
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = [ endpoint_list = ["127.0.0.1:9292"]
"127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496", "127.0.0.1:9497"
]
#endpoint_list = endpoint_list + endpoint_list + endpoint_list #endpoint_list = endpoint_list + endpoint_list + endpoint_list
#result = multi_thread_runner.run(single_func, args.thread, {"endpoint":endpoint_list}) result = multi_thread_runner.run(single_func, args.thread,
result = single_func(0, {"endpoint": endpoint_list}) {"endpoint": endpoint_list})
print(result) #result = single_func(0, {"endpoint": endpoint_list})
avg_cost = 0
for i in range(args.thread):
avg_cost += result[0][i]
avg_cost = avg_cost / args.thread
print("average total cost {} s.".format(avg_cost))
rm profile_log rm profile_log
for thread_num in 1 4 8 12 16 20 24 #for thread_num in 1 2 4 8 16
for thread_num in 1 2
do do
$PYTHONROOT/bin/python benchmark.py serving_client_conf/serving_client_conf.prototxt data.txt $thread_num $batch_size > profile 2>&1 #for batch_size in 1 2 4 8 16 32 64 128 256 512
for batch_size in 1 2
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log $PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log tail -n 1 profile >> profile_log
done done
done
...@@ -12,6 +12,8 @@ import socket ...@@ -12,6 +12,8 @@ import socket
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args from paddle_serving_client.utils import benchmark_args
from bert_reader import BertReader
args = benchmark_args() args = benchmark_args()
_ver = sys.version_info _ver = sys.version_info
...@@ -43,7 +45,7 @@ class BertService(): ...@@ -43,7 +45,7 @@ class BertService():
self.pid = os.getpid() self.pid = os.getpid()
self.profile = True if ("FLAGS_profile_client" in os.environ and self.profile = True if ("FLAGS_profile_client" in os.environ and
os.environ["FLAGS_profile_client"]) else False os.environ["FLAGS_profile_client"]) else False
'''
module = hub.Module(name=self.model_name) module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context( inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len) trainable=True, max_seq_len=self.max_seq_len)
...@@ -56,6 +58,8 @@ class BertService(): ...@@ -56,6 +58,8 @@ class BertService():
dataset=None, dataset=None,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
do_lower_case=self.do_lower_case) do_lower_case=self.do_lower_case)
'''
self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
self.reader_flag = True self.reader_flag = True
def load_client(self, config_file, server_addr): def load_client(self, config_file, server_addr):
...@@ -64,11 +68,14 @@ class BertService(): ...@@ -64,11 +68,14 @@ class BertService():
self.client.connect(server_addr) self.client.connect(server_addr)
def run_general(self, text, fetch): def run_general(self, text, fetch):
'''
self.batch_size = len(text) self.batch_size = len(text)
data_generator = self.reader.data_generator( data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text) batch_size=self.batch_size, phase='predict', data=text)
'''
result = [] result = []
prepro_start = time.time() prepro_start = time.time()
'''
for run_step, batch in enumerate(data_generator(), start=1): for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist() token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist() pos_list = batch[0][1].reshape(-1).tolist()
...@@ -82,47 +89,56 @@ class BertService(): ...@@ -82,47 +89,56 @@ class BertService():
"input_mask": mask_list "input_mask": mask_list
} }
prepro_end = time.time() prepro_end = time.time()
if self.profile: '''
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format( feed = self.reader.process(text)
self.pid, if self.profile:
int(round(prepro_start * 1000000)), print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
int(round(prepro_end * 1000000)))) self.pid,
fetch_map = self.client.predict(feed=feed, fetch=fetch) int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map = self.client.predict(feed=feed, fetch=fetch)
return fetch_map return fetch_map
def run_batch_general(self, text, fetch): def run_batch_general(self, text, fetch):
self.batch_size = len(text) self.batch_size = len(text)
'''
data_generator = self.reader.data_generator( data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text) batch_size=self.batch_size, phase='predict', data=text)
'''
result = [] result = []
prepro_start = time.time() prepro_start = time.time()
'''
for run_step, batch in enumerate(data_generator(), start=1): for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist() token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist() pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist() sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist() mask_list = batch[0][3].reshape(-1).tolist()
feed_batch = [] '''
for si in range(self.batch_size): feed_batch = []
feed = { for si in range(self.batch_size):
"input_ids": token_list[si * self.max_seq_len:(si + 1) * '''
self.max_seq_len], feed = {
"position_ids": "input_ids": token_list[si * self.max_seq_len:(si + 1) *
pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len], self.max_seq_len],
"segment_ids": sent_list[si * self.max_seq_len:(si + 1) * "position_ids":
self.max_seq_len], pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
"input_mask": "segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len] self.max_seq_len],
} "input_mask":
feed_batch.append(feed) mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
prepro_end = time.time() }
if self.profile: '''
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format( feed = self.reader.process(text[si])
self.pid, feed_batch.append(feed)
int(round(prepro_start * 1000000)), prepro_end = time.time()
int(round(prepro_end * 1000000)))) if self.profile:
fetch_map_batch = self.client.batch_predict( print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
feed_batch=feed_batch, fetch=fetch) self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map_batch = self.client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
return fetch_map_batch return fetch_map_batch
...@@ -134,22 +150,32 @@ def single_func(idx, resource): ...@@ -134,22 +150,32 @@ def single_func(idx, resource):
do_lower_case=True) do_lower_case=True)
config_file = './serving_client_conf/serving_client_conf.prototxt' config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"] fetch = ["pooled_output"]
server_addr = [resource["endpoint"][idx]] server_addr = [resource["endpoint"][idx % len(resource["endpoint"])]]
bc.load_client(config_file, server_addr) bc.load_client(config_file, server_addr)
batch_size = 1 batch_size = 1
use_batch = False if batch_size == 1 else True
feed_batch = []
start = time.time() start = time.time()
fin = open("data-c.txt") fin = open("data-c.txt")
for line in fin: for line in fin:
result = bc.run_general([[line.strip()]], fetch) if not use_batch:
result = bc.run_general(line.strip(), fetch)
else:
if len(feed_batch) == batch_size:
result = bc.run_batch_general(feed_batch, fetch)
feed_batch = []
else:
feed_batch.append(line.strip())
if use_batch and len(feed_batch) > 0:
result = bc.run_batch_general(feed_batch, fetch)
feed_batch = []
end = time.time() end = time.time()
return [[end - start]] return [[end - start]]
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, { result = multi_thread_runner.run(single_func, args.thread,
"endpoint": [ {"endpoint": ["127.0.0.1:9292"]})
"127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496", print("time cost for each thread {}".format(result))
"127.0.0.1:9497"
]
})
...@@ -34,5 +34,6 @@ bert_service.load_model_config(sys.argv[1]) ...@@ -34,5 +34,6 @@ bert_service.load_model_config(sys.argv[1])
gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"] gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"]
gpus = [int(x) for x in gpu_ids.split(",")] gpus = [int(x) for x in gpu_ids.split(",")]
bert_service.set_gpus(gpus) bert_service.set_gpus(gpus)
bert_service.prepare_server(workdir="workdir", port=9494, device="gpu") bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), device="gpu")
bert_service.run_server() bert_service.run_server()
wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate
wget https://paddle-serving.bj.bcebos.com/bert_example/vocab.txt --no-check-certificate
...@@ -29,9 +29,9 @@ with open(profile_file) as f: ...@@ -29,9 +29,9 @@ with open(profile_file) as f:
for line in f.readlines(): for line in f.readlines():
line = line.strip().split("\t") line = line.strip().split("\t")
if line[0] == "PROFILE": if line[0] == "PROFILE":
prase(line[1]) prase(line[2])
print("thread num {}".format(thread_num)) print("thread num {}".format(thread_num))
for name in time_dict: for name in time_dict:
print("{} cost {} s per thread ".format(name, time_dict[name] / ( print("{} cost {} s in each thread ".format(name, time_dict[name] / (
1000000.0 * float(thread_num)))) 1000000.0 * float(thread_num))))
...@@ -31,6 +31,7 @@ def benchmark_args(): ...@@ -31,6 +31,7 @@ def benchmark_args():
help="endpoint of server") help="endpoint of server")
parser.add_argument( parser.add_argument(
"--request", type=str, default="rpc", help="mode of service") "--request", type=str, default="rpc", help="mode of service")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
return parser.parse_args() return parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册