提交 c107c3c4 编写于 作者: M MRXLT 提交者: GitHub

Merge branch 'develop' into 0.3.0-bug-fix

...@@ -19,6 +19,8 @@ from __future__ import unicode_literals, absolute_import ...@@ -19,6 +19,8 @@ from __future__ import unicode_literals, absolute_import
import os import os
import sys import sys
import time import time
import json
import requests
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args, show_latency from paddle_serving_client.utils import benchmark_args, show_latency
...@@ -72,7 +74,39 @@ def single_func(idx, resource): ...@@ -72,7 +74,39 @@ def single_func(idx, resource):
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http": elif args.request == "http":
raise ("not implemented") reader = ChineseBertReader({"max_seq_len": 128})
fetch = ["pooled_output"]
server = "http://" + resource["endpoint"][idx % len(resource[
"endpoint"])] + "/bert/prediction"
start = time.time()
for i in range(turns):
if args.batch_size >= 1:
l_start = time.time()
feed_batch = []
b_start = time.time()
for bi in range(args.batch_size):
feed_batch.append({"words": dataset[bi]})
req = json.dumps({"feed": feed_batch, "fetch": fetch})
b_end = time.time()
if profile_flags:
sys.stderr.write(
"PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}\n".format(
os.getpid(),
int(round(b_start * 1000000)),
int(round(b_end * 1000000))))
result = requests.post(
server,
data=req,
headers={"Content-Type": "application/json"})
l_end = time.time()
if latency_flags:
latency_list.append(l_end * 1000 - l_start * 1000)
else:
print("unsupport batch size {}".format(args.batch_size))
else:
raise ValueError("not implemented {} request".format(args.request))
end = time.time() end = time.time()
if latency_flags: if latency_flags:
return [[end - start], latency_list] return [[end - start], latency_list]
...@@ -82,9 +116,7 @@ def single_func(idx, resource): ...@@ -82,9 +116,7 @@ def single_func(idx, resource):
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = [ endpoint_list = ["127.0.0.1:9292"]
"127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295"
]
turns = 10 turns = 10
start = time.time() start = time.time()
result = multi_thread_runner.run( result = multi_thread_runner.run(
......
...@@ -14,15 +14,7 @@ ...@@ -14,15 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys import sys
import numpy as np
import paddlehub as hub
import ujson
import random
import time
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_client.utils import benchmark_args from paddle_serving_client.utils import benchmark_args
from paddle_serving_app.reader import ChineseBertReader from paddle_serving_app.reader import ChineseBertReader
......
...@@ -73,7 +73,7 @@ def single_func(idx, resource): ...@@ -73,7 +73,7 @@ def single_func(idx, resource):
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http": elif args.request == "http":
py_version = 2 py_version = sys.version_info[0]
server = "http://" + resource["endpoint"][idx % len(resource[ server = "http://" + resource["endpoint"][idx % len(resource[
"endpoint"])] + "/image/prediction" "endpoint"])] + "/image/prediction"
start = time.time() start = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册