From 509d9bdfa5bad2ce00882d634763ac5f9c8b6ce4 Mon Sep 17 00:00:00 2001 From: guru4elephant Date: Fri, 13 Mar 2020 00:08:05 +0800 Subject: [PATCH] refine bert benchmark and batch benchmark and bert client --- python/examples/bert/benchmark.py | 1 - python/examples/bert/benchmark_batch.py | 101 ++++++++------- python/examples/bert/bert_client.py | 165 ++++-------------------- 3 files changed, 79 insertions(+), 188 deletions(-) diff --git a/python/examples/bert/benchmark.py b/python/examples/bert/benchmark.py index 7839f0a1..cefd772b 100644 --- a/python/examples/bert/benchmark.py +++ b/python/examples/bert/benchmark.py @@ -35,7 +35,6 @@ def single_func(idx, resource): fin = open("data-c.txt") if args.request == "rpc": reader = BertReader(vocab_file="vocab.txt", max_seq_len=128) - config_file = './serving_client_conf/serving_client_conf.prototxt' fetch = ["pooled_output"] client = Client() client.load_client_config(args.model) diff --git a/python/examples/bert/benchmark_batch.py b/python/examples/bert/benchmark_batch.py index d8d31013..0993b1a8 100644 --- a/python/examples/bert/benchmark_batch.py +++ b/python/examples/bert/benchmark_batch.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,61 +13,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=doc-string-missing +from __future__ import unicode_literals, absolute_import +import os import sys +import time from paddle_serving_client import Client -from paddle_serving_client.metric import auc from paddle_serving_client.utils import MultiThreadRunner -import time -from bert_client import BertService +from paddle_serving_client.utils import benchmark_args +from batching import pad_batch_data +import tokenization +import requests +import json +from bert_reader import BertReader +args = benchmark_args() -def predict(thr_id, resource, batch_size): - bc = BertService( - model_name="bert_chinese_L-12_H-768_A-12", - max_seq_len=20, - do_lower_case=True) - bc.load_client(resource["conf_file"], resource["server_endpoint"]) - thread_num = resource["thread_num"] - file_list = resource["filelist"] - line_id = 0 - result = [] - label_list = [] - dataset = [] - for fn in file_list: - fin = open(fn) - for line in fin: - if line_id % thread_num == thr_id - 1: - dataset.append(line.strip()) - line_id += 1 - fin.close() +batch_size = 24 - start = time.time() - fetch = ["pooled_output"] - batch = [] - for inst in dataset: - if len(batch) < batch_size: - batch.append([inst]) - else: - fetch_map_batch = bc.run_batch_general(batch, fetch) - batch = [] - result.append(fetch_map_batch) - end = time.time() - return [result, label_list, [end - start]] +def single_func(idx, resource): + fin = open("data-c.txt") + if args.request == "rpc": + reader = BertReader(vocab_file="vocab.txt", max_seq_len=128) + fetch = ["pooled_output"] + client = Client() + client.load_client_config(args.model) + client.connect([resource["endpoint"][idx % 4]]) -if __name__ == '__main__': - conf_file = sys.argv[1] - data_file = sys.argv[2] - thread_num = sys.argv[3] - batch_size = sys.ragv[4] - resource = {} - resource["conf_file"] = conf_file - resource["server_endpoint"] = ["127.0.0.1:9293"] - resource["filelist"] = [data_file] - resource["thread_num"] = int(thread_num) + start = time.time() + idx = 0 + batch_data = [] + for line in fin: + feed_dict = reader.process(line) + batch_data.append(feed_dict) + idx += 1 + if idx % batch_size == 0: + result = client.batch_predict( + feed_batch=batch_data, fetch=fetch) + batch_data = [] + end = time.time() + elif args.request == "http": + header = {"Content-Type": "application/json"} + for line in fin: + dict_data = {"words": line, "fetch": ["pooled_output"]} + r = requests.post( + 'http://{}/bert/prediction'.format(resource["endpoint"][0]), + data=json.dumps(dict_data), + headers=header) + end = time.time() + return [[end - start]] - thread_runner = MultiThreadRunner() - result = thread_runner.run(predict, int(sys.argv[3]), resource, batch_size) - print("total time {} s".format(sum(result[-1]) / len(result[-1]))) +if __name__ == '__main__': + multi_thread_runner = MultiThreadRunner() + endpoint_list = [] + card_num = 4 + for i in range(args.thread): + endpoint_list.append("127.0.0.1:{}".format(9494 + i % card_num)) + print(endpoint_list) + result = multi_thread_runner.run(single_func, args.thread, + {"endpoint": endpoint_list}) + print(result) diff --git a/python/examples/bert/bert_client.py b/python/examples/bert/bert_client.py index 343a6e01..91323bc1 100644 --- a/python/examples/bert/bert_client.py +++ b/python/examples/bert/bert_client.py @@ -1,5 +1,19 @@ # coding:utf-8 # pylint: disable=doc-string-missing +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys import numpy as np @@ -10,146 +24,17 @@ import time from paddlehub.common.logger import logger import socket from paddle_serving_client import Client -from paddle_serving_client.utils import MultiThreadRunner from paddle_serving_client.utils import benchmark_args args = benchmark_args() -_ver = sys.version_info -is_py2 = (_ver[0] == 2) -is_py3 = (_ver[0] == 3) - -if is_py2: - import httplib -if is_py3: - import http.client as httplib - - -class BertService(): - def __init__(self, - max_seq_len=128, - model_name="bert_uncased_L-12_H-768_A-12", - show_ids=False, - do_lower_case=True, - process_id=0, - retry=3): - self.process_id = process_id - self.reader_flag = False - self.batch_size = 0 - self.max_seq_len = max_seq_len - self.model_name = model_name - self.show_ids = show_ids - self.do_lower_case = do_lower_case - self.retry = retry - self.pid = os.getpid() - self.profile = True if ("FLAGS_profile_client" in os.environ and - os.environ["FLAGS_profile_client"]) else False - - module = hub.Module(name=self.model_name) - inputs, outputs, program = module.context( - trainable=True, max_seq_len=self.max_seq_len) - input_ids = inputs["input_ids"] - position_ids = inputs["position_ids"] - segment_ids = inputs["segment_ids"] - input_mask = inputs["input_mask"] - self.reader = hub.reader.ClassifyReader( - vocab_path=module.get_vocab_path(), - dataset=None, - max_seq_len=self.max_seq_len, - do_lower_case=self.do_lower_case) - self.reader_flag = True - - def load_client(self, config_file, server_addr): - self.client = Client() - self.client.load_client_config(config_file) - self.client.connect(server_addr) - - def run_general(self, text, fetch): - self.batch_size = len(text) - data_generator = self.reader.data_generator( - batch_size=self.batch_size, phase='predict', data=text) - result = [] - prepro_start = time.time() - for run_step, batch in enumerate(data_generator(), start=1): - token_list = batch[0][0].reshape(-1).tolist() - pos_list = batch[0][1].reshape(-1).tolist() - sent_list = batch[0][2].reshape(-1).tolist() - mask_list = batch[0][3].reshape(-1).tolist() - for si in range(self.batch_size): - feed = { - "input_ids": token_list, - "position_ids": pos_list, - "segment_ids": sent_list, - "input_mask": mask_list - } - prepro_end = time.time() - if self.profile: - print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format( - self.pid, - int(round(prepro_start * 1000000)), - int(round(prepro_end * 1000000)))) - fetch_map = self.client.predict(feed=feed, fetch=fetch) - - return fetch_map - - def run_batch_general(self, text, fetch): - self.batch_size = len(text) - data_generator = self.reader.data_generator( - batch_size=self.batch_size, phase='predict', data=text) - result = [] - prepro_start = time.time() - for run_step, batch in enumerate(data_generator(), start=1): - token_list = batch[0][0].reshape(-1).tolist() - pos_list = batch[0][1].reshape(-1).tolist() - sent_list = batch[0][2].reshape(-1).tolist() - mask_list = batch[0][3].reshape(-1).tolist() - feed_batch = [] - for si in range(self.batch_size): - feed = { - "input_ids": token_list[si * self.max_seq_len:(si + 1) * - self.max_seq_len], - "position_ids": - pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len], - "segment_ids": sent_list[si * self.max_seq_len:(si + 1) * - self.max_seq_len], - "input_mask": - mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len] - } - feed_batch.append(feed) - prepro_end = time.time() - if self.profile: - print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format( - self.pid, - int(round(prepro_start * 1000000)), - int(round(prepro_end * 1000000)))) - fetch_map_batch = self.client.batch_predict( - feed_batch=feed_batch, fetch=fetch) - return fetch_map_batch - - -def single_func(idx, resource): - bc = BertService( - model_name='bert_chinese_L-12_H-768_A-12', - max_seq_len=20, - show_ids=False, - do_lower_case=True) - config_file = './serving_client_conf/serving_client_conf.prototxt' - fetch = ["pooled_output"] - server_addr = [resource["endpoint"][idx]] - bc.load_client(config_file, server_addr) - batch_size = 1 - start = time.time() - fin = open("data-c.txt") - for line in fin: - result = bc.run_general([[line.strip()]], fetch) - end = time.time() - return [[end - start]] - - -if __name__ == '__main__': - multi_thread_runner = MultiThreadRunner() - result = multi_thread_runner.run(single_func, args.thread, { - "endpoint": [ - "127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496", - "127.0.0.1:9497" - ] - }) +fin = open("data-c.txt") +reader = BertReader(vocab_file="vocab.txt", max_seq_len=128) +fetch = ["pooled_output"] +endpoint_list = ["127.0.0.1:9494"] +client = Client() +client.load_client_config(args.model) +client.connect(endpoint_list) + +for line in fin: + feed_dict = reader.process(line) + result = client.predict(feed=feed_dict, fetch=fetch) -- GitLab