提交 67b5dbca 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #279 from guru4elephant/refine_bert_example

Refine bert example
...@@ -34,8 +34,7 @@ args = benchmark_args() ...@@ -34,8 +34,7 @@ args = benchmark_args()
def single_func(idx, resource): def single_func(idx, resource):
fin = open("data-c.txt") fin = open("data-c.txt")
if args.request == "rpc": if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) reader = BertReader(vocab_file="vocab.txt", max_seq_len=128)
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"] fetch = ["pooled_output"]
client = Client() client = Client()
client.load_client_config(args.model) client.load_client_config(args.model)
...@@ -50,7 +49,6 @@ def single_func(idx, resource): ...@@ -50,7 +49,6 @@ def single_func(idx, resource):
start = time.time() start = time.time()
header = {"Content-Type": "application/json"} header = {"Content-Type": "application/json"}
for line in fin: for line in fin:
#dict_data = {"words": "this is for output ", "fetch": ["pooled_output"]}
dict_data = {"words": line, "fetch": ["pooled_output"]} dict_data = {"words": line, "fetch": ["pooled_output"]}
r = requests.post( r = requests.post(
'http://{}/bert/prediction'.format(resource["endpoint"][0]), 'http://{}/bert/prediction'.format(resource["endpoint"][0]),
...@@ -62,10 +60,11 @@ def single_func(idx, resource): ...@@ -62,10 +60,11 @@ def single_func(idx, resource):
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = [ endpoint_list = []
"127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496", "127.0.0.1:9497" card_num = 4
] for i in range(args.thread):
#endpoint_list = endpoint_list + endpoint_list + endpoint_list endpoint_list.append("127.0.0.1:{}".format(9494 + i % card_num))
#result = multi_thread_runner.run(single_func, args.thread, {"endpoint":endpoint_list}) print(endpoint_list)
result = single_func(0, {"endpoint": endpoint_list}) result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
print(result) print(result)
# -*- coding: utf-8 -*-
#
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,61 +13,66 @@ ...@@ -11,61 +13,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing
from __future__ import unicode_literals, absolute_import
import os
import sys import sys
import time
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner from paddle_serving_client.utils import MultiThreadRunner
import time from paddle_serving_client.utils import benchmark_args
from bert_client import BertService from batching import pad_batch_data
import tokenization
import requests
import json
from bert_reader import BertReader
args = benchmark_args()
def predict(thr_id, resource, batch_size): batch_size = 24
bc = BertService(
model_name="bert_chinese_L-12_H-768_A-12",
max_seq_len=20,
do_lower_case=True)
bc.load_client(resource["conf_file"], resource["server_endpoint"])
thread_num = resource["thread_num"]
file_list = resource["filelist"]
line_id = 0
result = []
label_list = []
dataset = []
for fn in file_list:
fin = open(fn)
for line in fin:
if line_id % thread_num == thr_id - 1:
dataset.append(line.strip())
line_id += 1
fin.close()
start = time.time()
def single_func(idx, resource):
fin = open("data-c.txt")
if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=128)
fetch = ["pooled_output"] fetch = ["pooled_output"]
batch = [] client = Client()
for inst in dataset: client.load_client_config(args.model)
if len(batch) < batch_size: client.connect([resource["endpoint"][idx % 4]])
batch.append([inst])
else: start = time.time()
fetch_map_batch = bc.run_batch_general(batch, fetch) idx = 0
batch = [] batch_data = []
result.append(fetch_map_batch) for line in fin:
feed_dict = reader.process(line)
batch_data.append(feed_dict)
idx += 1
if idx % batch_size == 0:
result = client.batch_predict(
feed_batch=batch_data, fetch=fetch)
batch_data = []
end = time.time()
elif args.request == "http":
header = {"Content-Type": "application/json"}
for line in fin:
dict_data = {"words": line, "fetch": ["pooled_output"]}
r = requests.post(
'http://{}/bert/prediction'.format(resource["endpoint"][0]),
data=json.dumps(dict_data),
headers=header)
end = time.time() end = time.time()
return [result, label_list, [end - start]] return [[end - start]]
if __name__ == '__main__': if __name__ == '__main__':
conf_file = sys.argv[1] multi_thread_runner = MultiThreadRunner()
data_file = sys.argv[2] endpoint_list = []
thread_num = sys.argv[3] card_num = 4
batch_size = sys.ragv[4] for i in range(args.thread):
resource = {} endpoint_list.append("127.0.0.1:{}".format(9494 + i % card_num))
resource["conf_file"] = conf_file print(endpoint_list)
resource["server_endpoint"] = ["127.0.0.1:9293"] result = multi_thread_runner.run(single_func, args.thread,
resource["filelist"] = [data_file] {"endpoint": endpoint_list})
resource["thread_num"] = int(thread_num) print(result)
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource, batch_size)
print("total time {} s".format(sum(result[-1]) / len(result[-1])))
# coding:utf-8 # coding:utf-8
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import sys import sys
import numpy as np import numpy as np
...@@ -10,146 +24,17 @@ import time ...@@ -10,146 +24,17 @@ import time
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
import socket import socket
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args from paddle_serving_client.utils import benchmark_args
args = benchmark_args() args = benchmark_args()
_ver = sys.version_info fin = open("data-c.txt")
is_py2 = (_ver[0] == 2) reader = BertReader(vocab_file="vocab.txt", max_seq_len=128)
is_py3 = (_ver[0] == 3) fetch = ["pooled_output"]
endpoint_list = ["127.0.0.1:9494"]
if is_py2: client = Client()
import httplib client.load_client_config(args.model)
if is_py3: client.connect(endpoint_list)
import http.client as httplib
for line in fin:
feed_dict = reader.process(line)
class BertService(): result = client.predict(feed=feed_dict, fetch=fetch)
def __init__(self,
max_seq_len=128,
model_name="bert_uncased_L-12_H-768_A-12",
show_ids=False,
do_lower_case=True,
process_id=0,
retry=3):
self.process_id = process_id
self.reader_flag = False
self.batch_size = 0
self.max_seq_len = max_seq_len
self.model_name = model_name
self.show_ids = show_ids
self.do_lower_case = do_lower_case
self.retry = retry
self.pid = os.getpid()
self.profile = True if ("FLAGS_profile_client" in os.environ and
os.environ["FLAGS_profile_client"]) else False
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
self.reader = hub.reader.ClassifyReader(
vocab_path=module.get_vocab_path(),
dataset=None,
max_seq_len=self.max_seq_len,
do_lower_case=self.do_lower_case)
self.reader_flag = True
def load_client(self, config_file, server_addr):
self.client = Client()
self.client.load_client_config(config_file)
self.client.connect(server_addr)
def run_general(self, text, fetch):
self.batch_size = len(text)
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
result = []
prepro_start = time.time()
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
for si in range(self.batch_size):
feed = {
"input_ids": token_list,
"position_ids": pos_list,
"segment_ids": sent_list,
"input_mask": mask_list
}
prepro_end = time.time()
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map = self.client.predict(feed=feed, fetch=fetch)
return fetch_map
def run_batch_general(self, text, fetch):
self.batch_size = len(text)
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
result = []
prepro_start = time.time()
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
feed_batch = []
for si in range(self.batch_size):
feed = {
"input_ids": token_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"position_ids":
pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
"segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"input_mask":
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
}
feed_batch.append(feed)
prepro_end = time.time()
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map_batch = self.client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
return fetch_map_batch
def single_func(idx, resource):
bc = BertService(
model_name='bert_chinese_L-12_H-768_A-12',
max_seq_len=20,
show_ids=False,
do_lower_case=True)
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
server_addr = [resource["endpoint"][idx]]
bc.load_client(config_file, server_addr)
batch_size = 1
start = time.time()
fin = open("data-c.txt")
for line in fin:
result = bc.run_general([[line.strip()]], fetch)
end = time.time()
return [[end - start]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, {
"endpoint": [
"127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496",
"127.0.0.1:9497"
]
})
...@@ -17,8 +17,8 @@ Usage: ...@@ -17,8 +17,8 @@ Usage:
Example: Example:
python -m paddle_serving_server.serve --model ./serving_server_model --port 9292 python -m paddle_serving_server.serve --model ./serving_server_model --port 9292
""" """
import os
import argparse import argparse
import os
from multiprocessing import Pool, Process from multiprocessing import Pool, Process
from paddle_serving_server_gpu import serve_args from paddle_serving_server_gpu import serve_args
...@@ -64,12 +64,14 @@ def start_gpu_card_model(gpuid, args): # pylint: disable=doc-string-missing ...@@ -64,12 +64,14 @@ def start_gpu_card_model(gpuid, args): # pylint: disable=doc-string-missing
def start_multi_card(args): # pylint: disable=doc-string-missing def start_multi_card(args): # pylint: disable=doc-string-missing
gpus = "" gpus = ""
if args.gpu_ids == "": if args.gpu_ids == "":
import os if "CUDA_VISIBLE_DEVICES" in os.environ:
gpus = os.environ["CUDA_VISIBLE_DEVICES"] gpus = os.environ["CUDA_VISIBLE_DEVICES"]
else:
gpus = []
else: else:
gpus = args.gpu_ids.split(",") gpus = args.gpu_ids.split(",")
if len(gpus) <= 0: if len(gpus) <= 0:
start_gpu_card_model(-1) start_gpu_card_model(-1, args)
else: else:
gpu_processes = [] gpu_processes = []
for i, gpu_id in enumerate(gpus): for i, gpu_id in enumerate(gpus):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册