提交 25084c89 编写于 作者: M MRXLT

refine bert demo

上级 9c71c83e
......@@ -12,29 +12,43 @@ python prepare_model.py
生成server端配置文件与模型文件,存放在serving_server_model文件夹
生成client端配置文件,存放在serving_client_conf文件夹
### 启动预测服务
### 获取词典和样例数据
```
sh get_data.sh
```
脚本将下载中文词典vocab.txt和中文样例数据data-c.txt
### 启动RPC预测服务
执行
```
python bert_server.py serving_server_model 9292 #启动cpu预测服务
python -m paddle_serving_server.serve --model serving_server_model/ --port 9292 #启动cpu预测服务
```
或者
```
python bert_gpu_server.py serving_server_model 9292 0 #在gpu 0上启动gpu预测服务
python -m paddle_serving_server_gpu.serve --model serving_server_model/ --port 9292 --gpu_ids 0 #在gpu 0上启动gpu预测服务
```
### 执行预测
执行
```
sh get_data.sh
python bert_rpc_client.py --thread 4
```
获取中文样例数据
启动client读取data-c.txt中的数据进行预测,--thread参数控制client的进程数,预测结束后会打印出每个进程的耗时,server端的地址在脚本中修改。
### 启动HTTP预测服务
```
export CUDA_VISIBLE_DEVICES=0,1
```
通过环境变量指定gpu预测服务使用的gpu,示例中指定索引为0和1的两块gpu
```
python bert_web_service.py serving_server_model/ 9292 #启动gpu预测服务
```
### 执行预测
执行
```
head data-c.txt | python bert_client.py
curl -H "Content-Type:application/json" -X POST -d '{"words": "hello", "fetch":["pooled_output"]}' http://127.0.0.1:9292/bert/prediction
```
将预测样例数据中的前十条样例,并将向量表示打印到标准输出。
### Benchmark
......
......@@ -33,27 +33,40 @@ args = benchmark_args()
def single_func(idx, resource):
fin = open("data-c.txt")
dataset = []
for line in fin:
dataset.append(line.strip())
if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
client = Client()
client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % 4]])
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for line in fin:
feed_dict = reader.process(line)
result = client.predict(feed=feed_dict, fetch=fetch)
for i in range(1000):
if args.batch_size == 1:
feed_dict = reader.process(dataset[i])
result = client.predict(feed=feed_dict, fetch=fetch)
elif args.batch_size > 1:
feed_batch = []
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[i]))
result = client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
end = time.time()
elif args.request == "http":
start = time.time()
header = {"Content-Type": "application/json"}
for line in fin:
#dict_data = {"words": "this is for output ", "fetch": ["pooled_output"]}
dict_data = {"words": line, "fetch": ["pooled_output"]}
for i in range(1000):
dict_data = {"words": dataset[i], "fetch": ["pooled_output"]}
r = requests.post(
'http://{}/bert/prediction'.format(resource["endpoint"][0]),
'http://{}/bert/prediction'.format(resource["endpoint"][
idx % len(resource["endpoint"])]),
data=json.dumps(dict_data),
headers=header)
end = time.time()
......@@ -62,10 +75,13 @@ def single_func(idx, resource):
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = [
"127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496", "127.0.0.1:9497"
]
endpoint_list = ["127.0.0.1:9292"]
#endpoint_list = endpoint_list + endpoint_list + endpoint_list
#result = multi_thread_runner.run(single_func, args.thread, {"endpoint":endpoint_list})
result = single_func(0, {"endpoint": endpoint_list})
print(result)
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
#result = single_func(0, {"endpoint": endpoint_list})
avg_cost = 0
for i in range(args.thread):
avg_cost += result[0][i]
avg_cost = avg_cost / args.thread
print("average total cost {} s.".format(avg_cost))
rm profile_log
for thread_num in 1 4 8 12 16 20 24
#for thread_num in 1 2 4 8 16
for thread_num in 1 2
do
$PYTHONROOT/bin/python benchmark.py serving_client_conf/serving_client_conf.prototxt data.txt $thread_num $batch_size > profile 2>&1
#for batch_size in 1 2 4 8 16 32 64 128 256 512
for batch_size in 1 2
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
done
done
......@@ -12,6 +12,8 @@ import socket
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from bert_reader import BertReader
args = benchmark_args()
_ver = sys.version_info
......@@ -43,7 +45,7 @@ class BertService():
self.pid = os.getpid()
self.profile = True if ("FLAGS_profile_client" in os.environ and
os.environ["FLAGS_profile_client"]) else False
'''
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
......@@ -56,6 +58,8 @@ class BertService():
dataset=None,
max_seq_len=self.max_seq_len,
do_lower_case=self.do_lower_case)
'''
self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
self.reader_flag = True
def load_client(self, config_file, server_addr):
......@@ -64,11 +68,14 @@ class BertService():
self.client.connect(server_addr)
def run_general(self, text, fetch):
'''
self.batch_size = len(text)
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
'''
result = []
prepro_start = time.time()
'''
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
......@@ -82,47 +89,56 @@ class BertService():
"input_mask": mask_list
}
prepro_end = time.time()
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map = self.client.predict(feed=feed, fetch=fetch)
'''
feed = self.reader.process(text)
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map = self.client.predict(feed=feed, fetch=fetch)
return fetch_map
def run_batch_general(self, text, fetch):
self.batch_size = len(text)
'''
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
'''
result = []
prepro_start = time.time()
'''
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
feed_batch = []
for si in range(self.batch_size):
feed = {
"input_ids": token_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"position_ids":
pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
"segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"input_mask":
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
}
feed_batch.append(feed)
prepro_end = time.time()
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map_batch = self.client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
'''
feed_batch = []
for si in range(self.batch_size):
'''
feed = {
"input_ids": token_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"position_ids":
pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
"segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"input_mask":
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
}
'''
feed = self.reader.process(text[si])
feed_batch.append(feed)
prepro_end = time.time()
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map_batch = self.client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
return fetch_map_batch
......@@ -134,22 +150,32 @@ def single_func(idx, resource):
do_lower_case=True)
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
server_addr = [resource["endpoint"][idx]]
server_addr = [resource["endpoint"][idx % len(resource["endpoint"])]]
bc.load_client(config_file, server_addr)
batch_size = 1
use_batch = False if batch_size == 1 else True
feed_batch = []
start = time.time()
fin = open("data-c.txt")
for line in fin:
result = bc.run_general([[line.strip()]], fetch)
if not use_batch:
result = bc.run_general(line.strip(), fetch)
else:
if len(feed_batch) == batch_size:
result = bc.run_batch_general(feed_batch, fetch)
feed_batch = []
else:
feed_batch.append(line.strip())
if use_batch and len(feed_batch) > 0:
result = bc.run_batch_general(feed_batch, fetch)
feed_batch = []
end = time.time()
return [[end - start]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, {
"endpoint": [
"127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496",
"127.0.0.1:9497"
]
})
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": ["127.0.0.1:9292"]})
print("time cost for each thread {}".format(result))
......@@ -34,5 +34,6 @@ bert_service.load_model_config(sys.argv[1])
gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"]
gpus = [int(x) for x in gpu_ids.split(",")]
bert_service.set_gpus(gpus)
bert_service.prepare_server(workdir="workdir", port=9494, device="gpu")
bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), device="gpu")
bert_service.run_server()
wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate
wget https://paddle-serving.bj.bcebos.com/bert_example/vocab.txt --no-check-certificate
......@@ -29,9 +29,9 @@ with open(profile_file) as f:
for line in f.readlines():
line = line.strip().split("\t")
if line[0] == "PROFILE":
prase(line[1])
prase(line[2])
print("thread num {}".format(thread_num))
for name in time_dict:
print("{} cost {} s per thread ".format(name, time_dict[name] / (
print("{} cost {} s in each thread ".format(name, time_dict[name] / (
1000000.0 * float(thread_num))))
......@@ -31,6 +31,7 @@ def benchmark_args():
help="endpoint of server")
parser.add_argument(
"--request", type=str, default="rpc", help="mode of service")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
return parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册