提交 04251a91 编写于 作者: H HexToString

fix benchmark.py

上级 e204f230
...@@ -25,19 +25,24 @@ args = benchmark_args() ...@@ -25,19 +25,24 @@ args = benchmark_args()
def single_func(idx, resource): def single_func(idx, resource):
if args.request == "rpc":
client = Client()
client.load_client_config(args.model)
client.connect([args.endpoint])
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500), paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=1) batch_size=1)
total_number = sum(1 for _ in train_reader())
if args.request == "rpc":
client = Client()
client.load_client_config(args.model)
client.connect([args.endpoint])
start = time.time() start = time.time()
for data in train_reader(): for data in train_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"],batch=True) #new_data = np.zeros((1, 13)).astype("float32")
#new_data[0] = data[0][0]
#fetch_map = client.predict(feed={"x": new_data}, fetch=["price"], batch=True)
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
end = time.time() end = time.time()
return [[end - start]] return [[end - start], [total_number]]
elif args.request == "http": elif args.request == "http":
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
...@@ -49,7 +54,7 @@ def single_func(idx, resource): ...@@ -49,7 +54,7 @@ def single_func(idx, resource):
'http://{}/uci/prediction'.format(args.endpoint), 'http://{}/uci/prediction'.format(args.endpoint),
data={"x": data[0]}) data={"x": data[0]})
end = time.time() end = time.time()
return [[end - start]] return [[end - start], [total_number]]
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册