提交 eee694a9 编写于 作者: M MRXLT

fix image benchmark

上级 737de755
...@@ -39,8 +39,8 @@ def single_func(idx, resource): ...@@ -39,8 +39,8 @@ def single_func(idx, resource):
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]]) client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time() start = time.time()
for i in range(1000): for i in range(100):
img = reader.process_image(img_list[i]).reshape(-1) img = reader.process_image(img_list[i])
fetch_map = client.predict(feed={"image": img}, fetch=["score"]) fetch_map = client.predict(feed={"image": img}, fetch=["score"])
end = time.time() end = time.time()
return [[end - start]] return [[end - start]]
...@@ -49,7 +49,7 @@ def single_func(idx, resource): ...@@ -49,7 +49,7 @@ def single_func(idx, resource):
if __name__ == "__main__": if __name__ == "__main__":
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9393"] endpoint_list = ["127.0.0.1:9292"]
#card_num = 4 #card_num = 4
#for i in range(args.thread): #for i in range(args.thread):
# endpoint_list.append("127.0.0.1:{}".format(9295 + i % card_num)) # endpoint_list.append("127.0.0.1:{}".format(9295 + i % card_num))
......
...@@ -24,6 +24,7 @@ from paddle_serving_client.utils import MultiThreadRunner ...@@ -24,6 +24,7 @@ from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args from paddle_serving_client.utils import benchmark_args
import requests import requests
import json import json
import base64
from image_reader import ImageReader from image_reader import ImageReader
args = benchmark_args() args = benchmark_args()
...@@ -36,6 +37,10 @@ def single_func(idx, resource): ...@@ -36,6 +37,10 @@ def single_func(idx, resource):
img_list = [] img_list = []
for i in range(1000): for i in range(1000):
img_list.append(open("./image_data/n01440764/" + file_list[i]).read()) img_list.append(open("./image_data/n01440764/" + file_list[i]).read())
profile_flags = False
if "FLAGS_profile_client" in os.environ and os.environ[
"FLAGS_profile_client"]:
profile_flags = True
if args.request == "rpc": if args.request == "rpc":
reader = ImageReader() reader = ImageReader()
fetch = ["score"] fetch = ["score"]
...@@ -46,23 +51,43 @@ def single_func(idx, resource): ...@@ -46,23 +51,43 @@ def single_func(idx, resource):
for i in range(1000): for i in range(1000):
if args.batch_size >= 1: if args.batch_size >= 1:
feed_batch = [] feed_batch = []
i_start = time.time()
for bi in range(args.batch_size): for bi in range(args.batch_size):
img = reader.process_image(img_list[i]) img = reader.process_image(img_list[i])
img = img.reshape(-1)
feed_batch.append({"image": img}) feed_batch.append({"image": img})
i_end = time.time()
if profile_flags:
print("PROFILE\tpid:{}\timage_pre_0:{} image_pre_1:{}".
format(os.getpid(),
int(round(i_start * 1000000)),
int(round(i_end * 1000000))))
result = client.predict(feed=feed_batch, fetch=fetch) result = client.predict(feed=feed_batch, fetch=fetch)
else: else:
print("unsupport batch size {}".format(args.batch_size)) print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http": elif args.request == "http":
raise ("no batch predict for http") py_version = 2
server = "http://" + resource["endpoint"][idx % len(resource[
"endpoint"])] + "/image/prediction"
start = time.time()
for i in range(1000):
if py_version == 2:
image = base64.b64encode(
open("./image_data/n01440764/" + file_list[i]).read())
else:
image = base64.b64encode(open(image_path, "rb").read()).decode(
"utf-8")
req = json.dumps({"feed": [{"image": image}], "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
end = time.time() end = time.time()
return [[end - start]] return [[end - start]]
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9393"] endpoint_list = ["127.0.0.1:9292"]
#endpoint_list = endpoint_list + endpoint_list + endpoint_list #endpoint_list = endpoint_list + endpoint_list + endpoint_list
result = multi_thread_runner.run(single_func, args.thread, result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list}) {"endpoint": endpoint_list})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册