提交 509d9bdf 编写于 作者: G guru4elephant

refine bert benchmark and batch benchmark and bert client

上级 a5b0d4f4
......@@ -35,7 +35,6 @@ def single_func(idx, resource):
fin = open("data-c.txt")
if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=128)
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
client = Client()
client.load_client_config(args.model)
......
# -*- coding: utf-8 -*-
#
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -11,61 +13,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from __future__ import unicode_literals, absolute_import
import os
import sys
import time
from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
from bert_client import BertService
from paddle_serving_client.utils import benchmark_args
from batching import pad_batch_data
import tokenization
import requests
import json
from bert_reader import BertReader
args = benchmark_args()
def predict(thr_id, resource, batch_size):
bc = BertService(
model_name="bert_chinese_L-12_H-768_A-12",
max_seq_len=20,
do_lower_case=True)
bc.load_client(resource["conf_file"], resource["server_endpoint"])
thread_num = resource["thread_num"]
file_list = resource["filelist"]
line_id = 0
result = []
label_list = []
dataset = []
for fn in file_list:
fin = open(fn)
for line in fin:
if line_id % thread_num == thr_id - 1:
dataset.append(line.strip())
line_id += 1
fin.close()
batch_size = 24
start = time.time()
def single_func(idx, resource):
fin = open("data-c.txt")
if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=128)
fetch = ["pooled_output"]
batch = []
for inst in dataset:
if len(batch) < batch_size:
batch.append([inst])
else:
fetch_map_batch = bc.run_batch_general(batch, fetch)
batch = []
result.append(fetch_map_batch)
client = Client()
client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % 4]])
start = time.time()
idx = 0
batch_data = []
for line in fin:
feed_dict = reader.process(line)
batch_data.append(feed_dict)
idx += 1
if idx % batch_size == 0:
result = client.batch_predict(
feed_batch=batch_data, fetch=fetch)
batch_data = []
end = time.time()
elif args.request == "http":
header = {"Content-Type": "application/json"}
for line in fin:
dict_data = {"words": line, "fetch": ["pooled_output"]}
r = requests.post(
'http://{}/bert/prediction'.format(resource["endpoint"][0]),
data=json.dumps(dict_data),
headers=header)
end = time.time()
return [result, label_list, [end - start]]
return [[end - start]]
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
thread_num = sys.argv[3]
batch_size = sys.ragv[4]
resource = {}
resource["conf_file"] = conf_file
resource["server_endpoint"] = ["127.0.0.1:9293"]
resource["filelist"] = [data_file]
resource["thread_num"] = int(thread_num)
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource, batch_size)
print("total time {} s".format(sum(result[-1]) / len(result[-1])))
multi_thread_runner = MultiThreadRunner()
endpoint_list = []
card_num = 4
for i in range(args.thread):
endpoint_list.append("127.0.0.1:{}".format(9494 + i % card_num))
print(endpoint_list)
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
print(result)
# coding:utf-8
# pylint: disable=doc-string-missing
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
......@@ -10,146 +24,17 @@ import time
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
args = benchmark_args()
_ver = sys.version_info
is_py2 = (_ver[0] == 2)
is_py3 = (_ver[0] == 3)
if is_py2:
import httplib
if is_py3:
import http.client as httplib
class BertService():
def __init__(self,
max_seq_len=128,
model_name="bert_uncased_L-12_H-768_A-12",
show_ids=False,
do_lower_case=True,
process_id=0,
retry=3):
self.process_id = process_id
self.reader_flag = False
self.batch_size = 0
self.max_seq_len = max_seq_len
self.model_name = model_name
self.show_ids = show_ids
self.do_lower_case = do_lower_case
self.retry = retry
self.pid = os.getpid()
self.profile = True if ("FLAGS_profile_client" in os.environ and
os.environ["FLAGS_profile_client"]) else False
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
self.reader = hub.reader.ClassifyReader(
vocab_path=module.get_vocab_path(),
dataset=None,
max_seq_len=self.max_seq_len,
do_lower_case=self.do_lower_case)
self.reader_flag = True
def load_client(self, config_file, server_addr):
self.client = Client()
self.client.load_client_config(config_file)
self.client.connect(server_addr)
def run_general(self, text, fetch):
self.batch_size = len(text)
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
result = []
prepro_start = time.time()
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
for si in range(self.batch_size):
feed = {
"input_ids": token_list,
"position_ids": pos_list,
"segment_ids": sent_list,
"input_mask": mask_list
}
prepro_end = time.time()
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map = self.client.predict(feed=feed, fetch=fetch)
return fetch_map
def run_batch_general(self, text, fetch):
self.batch_size = len(text)
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
result = []
prepro_start = time.time()
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
feed_batch = []
for si in range(self.batch_size):
feed = {
"input_ids": token_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"position_ids":
pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
"segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"input_mask":
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
}
feed_batch.append(feed)
prepro_end = time.time()
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map_batch = self.client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
return fetch_map_batch
def single_func(idx, resource):
bc = BertService(
model_name='bert_chinese_L-12_H-768_A-12',
max_seq_len=20,
show_ids=False,
do_lower_case=True)
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
server_addr = [resource["endpoint"][idx]]
bc.load_client(config_file, server_addr)
batch_size = 1
start = time.time()
fin = open("data-c.txt")
for line in fin:
result = bc.run_general([[line.strip()]], fetch)
end = time.time()
return [[end - start]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, {
"endpoint": [
"127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496",
"127.0.0.1:9497"
]
})
fin = open("data-c.txt")
reader = BertReader(vocab_file="vocab.txt", max_seq_len=128)
fetch = ["pooled_output"]
endpoint_list = ["127.0.0.1:9494"]
client = Client()
client.load_client_config(args.model)
client.connect(endpoint_list)
for line in fin:
feed_dict = reader.process(line)
result = client.predict(feed=feed_dict, fetch=fetch)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册