提交 f29deb9f 编写于 作者: M MRXLT

add profile for preprocess

上级 376ca0ba
...@@ -35,19 +35,28 @@ def single_func(idx, resource): ...@@ -35,19 +35,28 @@ def single_func(idx, resource):
dataset = [] dataset = []
for line in fin: for line in fin:
dataset.append(line.strip()) dataset.append(line.strip())
profile_flags = False
if os.environ["FLAGS_profile_client"]:
profile_flags = True
if args.request == "rpc": if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
fetch = ["pooled_output"] fetch = ["pooled_output"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]]) client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
feed_batch = []
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[bi]))
start = time.time() start = time.time()
for i in range(1000): for i in range(1000):
if args.batch_size >= 1: if args.batch_size >= 1:
feed_batch = []
b_start = time.time()
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[bi]))
b_end = time.time()
if profile_flags:
print("PROFILE\tpid:{}\tbert+pre_0:{} bert_pre_1:{}".format(
os.getpid(),
int(round(b_start * 1000000)),
int(round(b_end * 1000000))))
result = client.batch_predict( result = client.batch_predict(
feed_batch=feed_batch, fetch=fetch) feed_batch=feed_batch, fetch=fetch)
else: else:
...@@ -62,7 +71,7 @@ def single_func(idx, resource): ...@@ -62,7 +71,7 @@ def single_func(idx, resource):
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = [ endpoint_list = [
"127.0.0.1:9295", "127.0.0.1:9296", "127.0.0.1:9297", "127.0.0.1:9298" "127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295"
] ]
result = multi_thread_runner.run(single_func, args.thread, result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list}) {"endpoint": endpoint_list})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册