提交 a369de86 编写于 作者: M MRXLT

refine imagenet benchmark script

上级 9b74a9be
......@@ -2,7 +2,7 @@
示例中采用ResNet50_vd模型执行imagenet 1000分类任务。
### 模型及配置文件获取
### 获取模型配置文件和样例数据
```
sh get_model.sh
```
......
......@@ -18,23 +18,28 @@ from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
import time
import os
args = benchmark_args()
def single_func(idx, resource):
file_list = []
for file_name in os.listdir("./image_data/n01440764"):
file_list.append(file_name)
img_list = []
for i in range(1000):
img_list.append(open("./image_data/n01440764/" + file_list[i]).read())
if args.request == "rpc":
reader = ImageReader()
fetch = ["score"]
client = Client()
client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % 4]])
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
with open("./data/n01440764_10026.JPEG") as f:
img = f.read()
img = reader.process_image(img).reshape(-1)
img = reader.process_image(img_list[i]).reshape(-1)
fetch_map = client.predict(feed={"image": img}, fetch=["score"])
end = time.time()
return [[end - start]]
......@@ -43,10 +48,14 @@ def single_func(idx, resource):
if __name__ == "__main__":
multi_thread_runner = MultiThreadRunner()
endpoint_list = []
card_num = 4
for i in range(args.thread):
endpoint_list.append("127.0.0.1:{}".format(9295 + i % card_num))
endpoint_list = ["127.0.0.1:9393"]
#card_num = 4
#for i in range(args.thread):
# endpoint_list.append("127.0.0.1:{}".format(9295 + i % card_num))
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
print(result)
avg_cost = 0
for i in range(args.thread):
avg_cost += result[0][i]
avg_cost = avg_cost / args.thread
print("average total cost {} s.".format(avg_cost))
......@@ -30,6 +30,12 @@ args = benchmark_args()
def single_func(idx, resource):
file_list = []
for file_name in os.listdir("./image_data/n01440764"):
file_list.append(file_name)
img_list = []
for i in range(1000):
img_list.append(open("./image_data/n01440764/" + file_list[i]).read())
if args.request == "rpc":
reader = ImageReader()
fetch = ["score"]
......@@ -37,13 +43,11 @@ def single_func(idx, resource):
client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
with open("./data/n01440764_10026.JPEG") as f:
raw_img = f.read()
for i in range(1000):
if args.batch_size >= 1:
feed_batch = []
for bi in range(args.batch_size):
img = reader.process_image(raw_img)
img = reader.process_image(img_list[i])
img = img.reshape(-1)
feed_batch.append({"image": img})
result = client.batch_predict(
......
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imagenet-example/conf_and_model.tar.gz
tar -xzvf conf_and_model.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imagenet-example/ResNet50_vd.tar.gz
tar -xzvf ResNet50_vd.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imagenet-example/ResNet101_vd.tar.gz
tar -xzvf ResNet101_vd.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imagenet-example/image_data.tar.gz
tar -xzvf imgae_data.tar.gz
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册