提交 3fc08321 编写于 作者: M MRXLT

add bert demo

上级 29d36668
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
from test_bert_client import BertService
def predict(thr_id, resource):
bc = BertService(
model_name="bert_chinese_L-12_H-768_A-12",
max_seq_len=20,
do_lower_case=True)
bc.load_client(resource["conf_file"], resource["server_endpoint"])
thread_num = resource["thread_num"]
file_list = resource["filelist"]
line_id = 0
result = []
label_list = []
dataset = []
for fn in file_list:
fin = open(fn)
for line in fin:
if line_id % thread_num == thr_id - 1:
dataset.append(line.strip())
line_id += 1
fin.close()
start = time.time()
fetch = ["pooled_output"]
for inst in dataset:
fetch_map = bc.run_general([[inst]], fetch)
result.append(fetch_map["pooled_output"])
end = time.time()
return [result, label_list, [end - start]]
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
thread_num = sys.argv[3]
resource = {}
resource["conf_file"] = conf_file
resource["server_endpoint"] = ["127.0.0.1:9293"]
resource["filelist"] = [data_file]
resource["thread_num"] = int(thread_num)
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource)
print("total time {} s".format(sum(result[-1]) / len(result[-1])))
rm profile_log
for thread_num in 1 4 8 12 16 20 24
do
$PYTHONROOT/bin/python benchmark.py serving_client_conf/serving_client_conf.prototxt data.txt $thread_num $batch_size > profile 2>&1
$PYTHONROOT/bin/python ../imdb/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
done
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
from test_bert_client import BertService
def predict(thr_id, resource, batch_size):
bc = BertService(
model_name="bert_chinese_L-12_H-768_A-12",
max_seq_len=20,
do_lower_case=True)
bc.load_client(resource["conf_file"], resource["server_endpoint"])
thread_num = resource["thread_num"]
file_list = resource["filelist"]
line_id = 0
result = []
label_list = []
dataset = []
for fn in file_list:
fin = open(fn)
for line in fin:
if line_id % thread_num == thr_id - 1:
dataset.append(line.strip())
line_id += 1
fin.close()
start = time.time()
fetch = ["pooled_output"]
batch = []
for inst in dataset:
if len(batch) < batch_size:
batch.append([inst])
else:
fetch_map_batch = bc.run_batch_general(batch, fetch)
batch = []
result.append(fetch_map_batch)
end = time.time()
return [result, label_list, [end - start]]
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
thread_num = sys.argv[3]
batch_size = sys.ragv[4]
resource = {}
resource["conf_file"] = conf_file
resource["server_endpoint"] = ["127.0.0.1:9293"]
resource["filelist"] = [data_file]
resource["thread_num"] = int(thread_num)
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource, batch_size)
print("total time {} s".format(sum(result[-1]) / len(result[-1])))
rm profile_log
thread_num=1
for batch_size in 1 4 8 16 32 64 128 256
do
$PYTHONROOT/bin/python benchmark_batch.py serving_client_conf/serving_client_conf.prototxt data.txt $thread_num $batch_size > profile 2>&1
$PYTHONROOT/bin/python ../imdb/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
done
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddlehub as hub
import paddle.fluid as fluid
import paddle_serving_client.io as serving_io
model_name = "bert_chinese_L-12_H-768_A-12"
module = hub.Module(model_name)
inputs, outputs, program = module.context(trainable=True, max_seq_len=20)
place = fluid.core_avx.CPUPlace()
exe = fluid.Executor(place)
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
pooled_output = outputs["pooled_output"]
sequence_output = outputs["sequence_output"]
feed_var_names = [
input_ids.name, position_ids.name, segment_ids.name, input_mask.name
]
target_vars = [pooled_output, sequence_output]
serving_io.save_model("serving_server_model", "serving_client_conf", {
"input_ids": input_ids,
"position_ids": position_ids,
"segment_ids": segment_ids,
"input_mask": input_mask,
}, {"pooled_output": pooled_output,
"sequence_output": sequence_output}, program)
# coding:utf-8
import sys
import numpy as np
import paddlehub as hub
import ujson
import random
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client
_ver = sys.version_info
is_py2 = (_ver[0] == 2)
is_py3 = (_ver[0] == 3)
if is_py2:
import httplib
if is_py3:
import http.client as httplib
class BertService():
def __init__(self,
profile=False,
max_seq_len=128,
model_name="bert_uncased_L-12_H-768_A-12",
show_ids=False,
do_lower_case=True,
process_id=0,
retry=3,
load_balance='round_robin'):
self.process_id = process_id
self.reader_flag = False
self.batch_size = 0
self.max_seq_len = max_seq_len
self.profile = profile
self.model_name = model_name
self.show_ids = show_ids
self.do_lower_case = do_lower_case
self.con_list = []
self.con_index = 0
self.load_balance = load_balance
self.server_list = []
self.serving_list = []
self.feed_var_names = ''
self.retry = retry
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name
self.reader = hub.reader.ClassifyReader(
vocab_path=module.get_vocab_path(),
dataset=None,
max_seq_len=self.max_seq_len,
do_lower_case=self.do_lower_case)
self.reader_flag = True
def load_client(self, config_file, server_addr):
self.client = Client()
self.client.load_client_config(config_file)
self.client.connect(server_addr)
def run_general(self, text, fetch):
self.batch_size = len(text)
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
result = []
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
for si in range(self.batch_size):
feed = {
"input_ids": token_list,
"position_ids": pos_list,
"segment_ids": sent_list,
"input_mask": mask_list
}
fetch_map = self.client.predict(feed=feed, fetch=fetch)
return fetch_map
def run_batch_general(self, text, fetch):
self.batch_size = len(text)
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
result = []
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
feed_batch = []
for si in range(self.batch_size):
feed = {
"input_ids": token_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"position_ids":
pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
"segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"input_mask":
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
}
feed_batch.append(feed)
fetch_map_batch = self.client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
return fetch_map_batch
def test():
bc = BertService(
model_name='bert_uncased_L-12_H-768_A-12',
max_seq_len=20,
show_ids=False,
do_lower_case=True)
server_addr = ["127.0.0.1:9293"]
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
bc.load_client(config_file, server_addr)
batch_size = 4
batch = []
for line in sys.stdin:
if len(batch) < batch_size:
batch.append([line.strip()])
else:
result = bc.run_batch_general(batch, fetch)
batch = []
for r in result:
for e in r["pooled_output"]:
print(e)
if __name__ == '__main__':
test()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from paddle_serving_server_gpu import OpMaker
from paddle_serving_server_gpu import OpSeqMaker
from paddle_serving_server_gpu import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(general_response_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(8)
server.set_memory_optimize(True)
server.set_gpuid(1)
server.load_model_config(sys.argv[1])
port = int(sys.argv[2])
server.prepare_server(workdir="work_dir1", port=port, device="gpu")
server.run_server()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from paddle_serving_server import OpMaker
from paddle_serving_server import OpSeqMaker
from paddle_serving_server import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(general_response_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.set_local_bin(
"~/github/Serving/build_server/core/general-server/serving")
server.load_model_config(sys.argv[1])
port = int(sys.argv[2])
server.prepare_server(workdir="work_dir1", port=port, device="cpu")
server.run_server()
......@@ -43,15 +43,13 @@ def predict(thr_id, resource):
start = time.time()
fetch = ["acc", "cost", "prediction"]
infer_time_list = []
for inst in dataset:
fetch_map = client.predict(feed=inst, fetch=fetch, profile=True)
fetch_map = client.predict(feed=inst, fetch=fetch)
prob.append(fetch_map["prediction"][1])
label_list.append(label[0])
infer_time_list.append(fetch_map["infer_time"])
end = time.time()
client.release()
return [prob, label_list, [sum(infer_time_list)], [end - start]]
return [prob, label_list, [end - start]]
if __name__ == '__main__':
......@@ -59,14 +57,11 @@ if __name__ == '__main__':
data_file = sys.argv[2]
resource = {}
resource["conf_file"] = conf_file
resource["server_endpoint"] = ["127.0.0.1:9292"]
resource["server_endpoint"] = ["127.0.0.1:9293"]
resource["filelist"] = [data_file]
resource["thread_num"] = int(sys.argv[3])
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource)
print("thread num {}\ttotal time {}".format(sys.argv[
3], sum(result[-1]) / len(result[-1])))
print("thread num {}\ttotal time {}".format(sys.argv[
3], sum(result[2]) / 1000.0 / 1000.0 / len(result[2])))
print("total time {} s".format(sum(result[-1]) / len(result[-1])))
#coding=utf-8
import sys
import collections
profile_file = sys.argv[1]
thread_num = sys.argv[2]
time_dict = collections.OrderedDict()
def prase(line):
profile_list = line.split(" ")
num = len(profile_list)
for idx in range(num / 2):
profile_0_list = profile_list[idx * 2].split(":")
profile_1_list = profile_list[idx * 2 + 1].split(":")
if len(profile_0_list[0].split("_")) == 2:
name = profile_0_list[0].split("_")[0]
else:
name = profile_0_list[0].split("_")[0] + "_" + profile_0_list[
0].split("_")[1]
cost = long(profile_1_list[1]) - long(profile_0_list[1])
if name not in time_dict:
time_dict[name] = cost
else:
time_dict[name] += cost
with open(profile_file) as f:
for line in f.readlines():
line = line.strip().split("\t")
if line[0] == "PROFILE":
prase(line[1])
print("thread num {}".format(thread_num))
for name in time_dict:
print("{} cost {} s per thread ".format(name, time_dict[name] / (
1000000.0 * float(thread_num))))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册