From 04251a913c1f550847317b75a60366dbf9a7cc05 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Fri, 16 Apr 2021 08:44:13 +0000 Subject: [PATCH] fix benchmark.py --- python/examples/fit_a_line/benchmark.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/examples/fit_a_line/benchmark.py b/python/examples/fit_a_line/benchmark.py index 5cc094dc..b1550b2f 100644 --- a/python/examples/fit_a_line/benchmark.py +++ b/python/examples/fit_a_line/benchmark.py @@ -25,19 +25,24 @@ args = benchmark_args() def single_func(idx, resource): + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=1) + total_number = sum(1 for _ in train_reader()) + if args.request == "rpc": client = Client() client.load_client_config(args.model) client.connect([args.endpoint]) - train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.uci_housing.train(), buf_size=500), - batch_size=1) start = time.time() for data in train_reader(): - fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"],batch=True) + #new_data = np.zeros((1, 13)).astype("float32") + #new_data[0] = data[0][0] + #fetch_map = client.predict(feed={"x": new_data}, fetch=["price"], batch=True) + fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"]) end = time.time() - return [[end - start]] + return [[end - start], [total_number]] elif args.request == "http": train_reader = paddle.batch( paddle.reader.shuffle( @@ -49,7 +54,7 @@ def single_func(idx, resource): 'http://{}/uci/prediction'.format(args.endpoint), data={"x": data[0]}) end = time.time() - return [[end - start]] + return [[end - start], [total_number]] multi_thread_runner = MultiThreadRunner() -- GitLab