diff --git a/python/examples/fit_a_line/benchmark.py b/python/examples/fit_a_line/benchmark.py index 5cc094dc17f2b4519e076b34d4368f0f33b8f3eb..b1550b2ff5d40aa10e3f415c82235c06a4508012 100644 --- a/python/examples/fit_a_line/benchmark.py +++ b/python/examples/fit_a_line/benchmark.py @@ -25,19 +25,24 @@ args = benchmark_args() def single_func(idx, resource): + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=1) + total_number = sum(1 for _ in train_reader()) + if args.request == "rpc": client = Client() client.load_client_config(args.model) client.connect([args.endpoint]) - train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.uci_housing.train(), buf_size=500), - batch_size=1) start = time.time() for data in train_reader(): - fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"],batch=True) + #new_data = np.zeros((1, 13)).astype("float32") + #new_data[0] = data[0][0] + #fetch_map = client.predict(feed={"x": new_data}, fetch=["price"], batch=True) + fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"]) end = time.time() - return [[end - start]] + return [[end - start], [total_number]] elif args.request == "http": train_reader = paddle.batch( paddle.reader.shuffle( @@ -49,7 +54,7 @@ def single_func(idx, resource): 'http://{}/uci/prediction'.format(args.endpoint), data={"x": data[0]}) end = time.time() - return [[end - start]] + return [[end - start], [total_number]] multi_thread_runner = MultiThreadRunner()