提交 d1a1d3f4 编写于 作者: M MRXLT

add http request for bert benchmark

上级 e57b8cd2
...@@ -19,6 +19,8 @@ from __future__ import unicode_literals, absolute_import ...@@ -19,6 +19,8 @@ from __future__ import unicode_literals, absolute_import
import os import os
import sys import sys
import time import time
import json
import requests
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args, show_latency from paddle_serving_client.utils import benchmark_args, show_latency
...@@ -72,7 +74,39 @@ def single_func(idx, resource): ...@@ -72,7 +74,39 @@ def single_func(idx, resource):
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http": elif args.request == "http":
raise ("not implemented") reader = ChineseBertReader({"max_seq_len": 128})
fetch = ["pooled_output"]
server = "http://" + resource["endpoint"][idx % len(resource[
"endpoint"])] + "/bert/prediction"
start = time.time()
for i in range(turns):
if args.batch_size >= 1:
l_start = time.time()
feed_batch = []
b_start = time.time()
for bi in range(args.batch_size):
feed_batch.append({"words": dataset[bi]})
req = json.dumps({"feed": feed_batch, "fetch": fetch})
b_end = time.time()
if profile_flags:
sys.stderr.write(
"PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}\n".format(
os.getpid(),
int(round(b_start * 1000000)),
int(round(b_end * 1000000))))
result = requests.post(
server,
data=req,
headers={"Content-Type": "application/json"})
l_end = time.time()
if latency_flags:
latency_list.append(l_end * 1000 - l_start * 1000)
else:
print("unsupport batch size {}".format(args.batch_size))
else:
raise ValueError("not implemented {} request".format(args.request))
end = time.time() end = time.time()
if latency_flags: if latency_flags:
return [[end - start], latency_list] return [[end - start], latency_list]
...@@ -82,9 +116,7 @@ def single_func(idx, resource): ...@@ -82,9 +116,7 @@ def single_func(idx, resource):
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = [ endpoint_list = ["127.0.0.1:9292"]
"127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295"
]
turns = 10 turns = 10
start = time.time() start = time.time()
result = multi_thread_runner.run( result = multi_thread_runner.run(
......
...@@ -73,7 +73,7 @@ def single_func(idx, resource): ...@@ -73,7 +73,7 @@ def single_func(idx, resource):
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http": elif args.request == "http":
py_version = 2 py_version = sys.version_info[0]
server = "http://" + resource["endpoint"][idx % len(resource[ server = "http://" + resource["endpoint"][idx % len(resource[
"endpoint"])] + "/image/prediction" "endpoint"])] + "/image/prediction"
start = time.time() start = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册