From f29deb9fa10870b447ec6201c2e54770b272bef0 Mon Sep 17 00:00:00 2001 From: MRXLT Date: Thu, 26 Mar 2020 20:37:55 +0800 Subject: [PATCH] add profile for preprocess --- python/examples/bert/benchmark_batch.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/python/examples/bert/benchmark_batch.py b/python/examples/bert/benchmark_batch.py index e0f67714..872799e6 100644 --- a/python/examples/bert/benchmark_batch.py +++ b/python/examples/bert/benchmark_batch.py @@ -35,19 +35,28 @@ def single_func(idx, resource): dataset = [] for line in fin: dataset.append(line.strip()) + profile_flags = False + if os.environ["FLAGS_profile_client"]: + profile_flags = True if args.request == "rpc": reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) fetch = ["pooled_output"] client = Client() client.load_client_config(args.model) client.connect([resource["endpoint"][idx % len(resource["endpoint"])]]) - feed_batch = [] - for bi in range(args.batch_size): - feed_batch.append(reader.process(dataset[bi])) - start = time.time() for i in range(1000): if args.batch_size >= 1: + feed_batch = [] + b_start = time.time() + for bi in range(args.batch_size): + feed_batch.append(reader.process(dataset[bi])) + b_end = time.time() + if profile_flags: + print("PROFILE\tpid:{}\tbert+pre_0:{} bert_pre_1:{}".format( + os.getpid(), + int(round(b_start * 1000000)), + int(round(b_end * 1000000)))) result = client.batch_predict( feed_batch=feed_batch, fetch=fetch) else: @@ -62,7 +71,7 @@ def single_func(idx, resource): if __name__ == '__main__': multi_thread_runner = MultiThreadRunner() endpoint_list = [ - "127.0.0.1:9295", "127.0.0.1:9296", "127.0.0.1:9297", "127.0.0.1:9298" + "127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295" ] result = multi_thread_runner.run(single_func, args.thread, {"endpoint": endpoint_list}) -- GitLab