提交 54230dca 编写于 作者: W wangjiawei04

fix bert

上级 22c3047c
...@@ -30,6 +30,7 @@ class BertService(WebService): ...@@ -30,6 +30,7 @@ class BertService(WebService):
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
feed_res = [] feed_res = []
is_batch = True is_batch = True
print(feed)
for ins in feed: for ins in feed:
feed_dict = self.reader.process(ins["words"].encode("utf-8")) feed_dict = self.reader.process(ins["words"].encode("utf-8"))
for key in feed_dict.keys(): for key in feed_dict.keys():
......
import sys
import os
import yaml
import requests
import time
import json
import ast
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout):
with open(filein, "r") as fin:
res = yaml.load(fin)
del_list = []
for key in res["DAG"].keys():
if "call" in key:
del_list.append(key)
for key in del_list:
del res["DAG"][key]
with open(fileout, "w") as fout:
yaml.dump(res, fout, default_flow_style=False)
def run_http(idx, batch_size):
"""
{"feed":[{"words": "hello"}], "fetch":["pooled_output"]}
"""
print("start thread ({})".format(idx))
url = "http://127.0.0.1:9292/bert/prediction"
start = time.time()
with open("data-c.txt", 'r') as fin:
start = time.time()
lines = fin.readlines()
start_idx = 0
while start_idx < len(lines):
end_idx = min(len(lines), start_idx + batch_size)
feed = {}
feed_lst = [{"words": lines[i]} for i in range(start_idx, end_idx)]
data = {"feed": feed_lst, "fetch": ["pooled_output"]}
r = requests.post(url=url, data=json.dumps(data), headers={"Content-Type": "application/json"})
start_idx += batch_size
end = time.time()
if end - start > 40:
break
end = time.time()
return [[end - start]]
def multithread_http(thread, batch_size):
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(run_http , thread, batch_size)
if __name__ == "__main__":
if sys.argv[1] == "run":
thread = int(sys.argv[2])
batch_size = int(sys.argv[3])
multithread_http(thread, batch_size)
if sys.argv[1] == "dump":
filein = sys.argv[2]
fileout = sys.argv[3]
parse_benchmark(filein, fileout)
modelname="bert"
# HTTP
ps -ef | grep web_service | awk '{print $2}' | xargs kill -9
sleep 3
rm -rf profile_log_$modelname
for thread_num in 1 8 16
do
for batch_size in 1 10 100
do
python3.7 bert_web_service.py bert_seq128_model/ 9292 &
sleep 3
echo "----Bert thread num: $thread_num batch size: $batch_size mode:http ----" >>profile_log_$modelname
nvidia-smi --id=2 --query-compute-apps=used_memory --format=csv -lms 100 > gpu_use.log 2>&1 &
nvidia-smi --id=2 --query-gpu=utilization.gpu --format=csv -lms 100 > gpu_utilization.log 2>&1 &
echo "import psutil\ncpu_utilization=psutil.cpu_percent(1,False)\nprint('CPU_UTILIZATION:', cpu_utilization)\n" > cpu_utilization.py
python3.7 new_benchmark.py run $thread_num $batch_size
python3.7 cpu_utilization.py >>profile_log_$modelname
ps -ef | grep web_service | awk '{print $2}' | xargs kill -9
python3.7 new_benchmark.py dump benchmark.log benchmark.tmp
mv benchmark.tmp benchmark.log
awk 'BEGIN {max = 0} {if(NR>1){if ($modelname > max) max=$modelname}} END {print "MAX_GPU_MEMORY:", max}' gpu_use.log >> profile_log_$modelname
awk 'BEGIN {max = 0} {if(NR>1){if ($modelname > max) max=$modelname}} END {print "GPU_UTILIZATION:", max}' gpu_utilization.log >> profile_log_$modelname
cat benchmark.log >> profile_log_$modelname
done
done
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册