提交 ccf35cf4 编写于 作者: M MRXLT

add batch benchmark script

上级 88fc9f14
......@@ -18,6 +18,7 @@ from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
def predict(thr_id, resource):
client = Client()
client.load_client_config(resource["conf_file"])
......@@ -44,7 +45,7 @@ def predict(thr_id, resource):
fetch = ["acc", "cost", "prediction"]
infer_time_list = []
for inst in dataset:
fetch_map = client.predict(feed=inst, fetch=fetch, debug=True)
fetch_map = client.predict(feed=inst, fetch=fetch, profile=True)
prob.append(fetch_map["prediction"][1])
label_list.append(label[0])
infer_time_list.append(fetch_map["infer_time"])
......@@ -52,6 +53,7 @@ def predict(thr_id, resource):
client.release()
return [prob, label_list, [sum(infer_time_list)], [end - start]]
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
......@@ -64,5 +66,7 @@ if __name__ == '__main__':
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource)
print("{}\t{}".format(sys.argv[3], sum(result[-1]) / len(result[-1])))
print("{}\t{}".format(sys.argv[3], sum(result[2]) / 1000.0 / 1000.0 / len(result[2])))
print("thread num {}\ttotal time {}".format(sys.argv[
3], sum(result[-1]) / len(result[-1])))
print("thread num {}\ttotal time {}".format(sys.argv[
3], sum(result[2]) / 1000.0 / 1000.0 / len(result[2])))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
def predict(thr_id, resource):
client = Client()
client.load_client_config(resource["conf_file"])
client.connect(resource["server_endpoint"])
thread_num = resource["thread_num"]
file_list = resource["filelist"]
line_id = 0
prob = []
label_list = []
dataset = []
for fn in file_list:
fin = open(fn)
for line in fin:
if line_id % thread_num == thr_id - 1:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0])]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
dataset.append(feed)
line_id += 1
fin.close()
start = time.time()
fetch = ["acc", "cost", "prediction"]
infer_time_list = []
counter = 0
feed_list = []
for inst in dataset:
counter += 1
feed_list.append(inst)
if counter == resource["batch_size"]:
fetch_map_batch, infer_time = client.batch_predict(
feed_batch=feed_list, fetch=fetch, profile=True)
#prob.append(fetch_map["prediction"][1])
#label_list.append(label[0])
infer_time_list.append(infer_time)
counter = 0
feed_list = []
if counter != 0:
fetch_map_batch, infer_time = client.batch_predict(
feed_batch=feed_list, fetch=fetch, profile=True)
infer_time_list.append(infer_time)
end = time.time()
client.release()
return [prob, label_list, [sum(infer_time_list)], [end - start]]
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
resource = {}
resource["conf_file"] = conf_file
resource["server_endpoint"] = ["127.0.0.1:9292"]
resource["filelist"] = [data_file]
resource["thread_num"] = int(sys.argv[3])
resource["batch_size"] = int(sys.argv[4])
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource)
print("thread num {}\tbatch size {}\ttotal time {}".format(sys.argv[
3], resource["batch_size"], sum(result[-1]) / len(result[-1])))
print("thread num {}\tbatch size {}\tinfer time {}".format(
sys.argv[3], resource["batch_size"],
sum(result[2]) / 1000.0 / 1000.0 / len(result[2])))
......@@ -22,7 +22,7 @@ import time
def batch_predict(batch_size=4):
client = Client()
client.load_client_config(conf_file)
client.connect(["127.0.0.1:8010"])
client.connect(["127.0.0.1:9292"])
start = time.time()
fetch = ["acc", "cost", "prediction"]
feed_batch = []
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from paddle_serving_server import OpMaker
......@@ -7,14 +21,18 @@ from paddle_serving_server import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(general_response_op)
server = Server()
server.set_vlog_level(3)
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(12)
server.set_num_threads(4)
server.load_model_config(sys.argv[1])
port = int(sys.argv[2])
server.prepare_server(workdir="work_dir1", port=port, device="cpu")
......
......@@ -104,7 +104,8 @@ class Client(object):
predictor_sdk = SDKConfig()
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString())
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
def get_feed_names(self):
return self.feed_names_
......@@ -112,7 +113,7 @@ class Client(object):
def get_fetch_names(self):
return self.fetch_names_
def predict(self, feed={}, fetch=[], debug=False):
def predict(self, feed={}, fetch=[], profile=False):
int_slot = []
float_slot = []
int_feed_names = []
......@@ -142,12 +143,12 @@ class Client(object):
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
if debug:
if profile:
result_map["infer_time"] = result[-1][0]
return result_map
def batch_predict(self, feed_batch=[], fetch=[], debug=False):
def batch_predict(self, feed_batch=[], fetch=[], profile=False):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......@@ -189,7 +190,7 @@ class Client(object):
infer_time = result_batch[-1][0][0]
if debug:
if profile:
return result_map_batch, infer_time
else:
return result_map_batch
......
......@@ -24,13 +24,13 @@ from version import serving_server_version
class OpMaker(object):
def __init__(self):
self.op_dict = {
"general_infer":"GeneralInferOp",
"general_reader":"GeneralReaderOp",
"general_response":"GeneralResponseOp",
"general_text_reader":"GeneralTextReaderOp",
"general_text_response":"GeneralTextResponseOp",
"general_single_kv":"GeneralSingleKVOp",
"general_dist_kv":"GeneralDistKVOp"
"general_infer": "GeneralInferOp",
"general_reader": "GeneralReaderOp",
"general_response": "GeneralResponseOp",
"general_text_reader": "GeneralTextReaderOp",
"general_text_response": "GeneralTextResponseOp",
"general_single_kv": "GeneralSingleKVOp",
"general_dist_kv": "GeneralDistKVOp"
}
# currently, inputs and outputs are not used
......@@ -96,6 +96,9 @@ class Server(object):
def set_port(self, port):
self.port = port
def set_vlog_level(self, vlog_level):
self.vlog_level = vlog_level
def set_reload_interval(self, interval):
self.reload_interval_s = interval
......@@ -252,7 +255,8 @@ class Server(object):
"-resource_file {} " \
"-workflow_path {} " \
"-workflow_file {} " \
"-bthread_concurrency {} ".format(
"-bthread_concurrency {} " \
"-v {} ".format(
self.bin_path,
self.workdir,
self.infer_service_fn,
......@@ -264,5 +268,6 @@ class Server(object):
self.resource_fn,
self.workdir,
self.workflow_fn,
self.num_threads,)
self.num_threads,
self.vlog_level)
os.system(command)
......@@ -24,13 +24,13 @@ from version import serving_server_version
class OpMaker(object):
def __init__(self):
self.op_dict = {
"general_infer":"GeneralInferOp",
"general_reader":"GeneralReaderOp",
"general_response":"GeneralResponseOp",
"general_text_reader":"GeneralTextReaderOp",
"general_text_response":"GeneralTextResponseOp",
"general_single_kv":"GeneralSingleKVOp",
"general_dist_kv":"GeneralDistKVOp"
"general_infer": "GeneralInferOp",
"general_reader": "GeneralReaderOp",
"general_response": "GeneralResponseOp",
"general_text_reader": "GeneralTextReaderOp",
"general_text_response": "GeneralTextResponseOp",
"general_single_kv": "GeneralSingleKVOp",
"general_dist_kv": "GeneralDistKVOp"
}
# currently, inputs and outputs are not used
......@@ -96,6 +96,9 @@ class Server(object):
def set_port(self, port):
self.port = port
def set_vlog_level(self, vlog_level):
slef.vlog_level = vlog_level
def set_reload_interval(self, interval):
self.reload_interval_s = interval
......@@ -105,6 +108,9 @@ class Server(object):
def set_memory_optimize(self, flag=False):
self.memory_optimization = flag
def set_gpuid(self, gpuid=0):
self.gpuid = gpuid
def _prepare_engine(self, model_config_path, device):
if self.model_toolkit_conf == None:
self.model_toolkit_conf = server_sdk.ModelToolkitConf()
......@@ -232,7 +238,10 @@ class Server(object):
"-resource_path {} " \
"-resource_file {} " \
"-workflow_path {} " \
"-workflow_file {} ".format(
"-workflow_file {} " \
"-bthread_concurrency {} " \
"-gpuid {} " \
"-v {} ".format(
self.bin_path,
self.workdir,
self.infer_service_fn,
......@@ -243,5 +252,8 @@ class Server(object):
self.workdir,
self.resource_fn,
self.workdir,
self.workflow_fn)
self.workflow_fn,
self.num_threads,
self.gpuid,
self.vlog_level)
os.system(command)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册