提交 4cdf6dd5 编写于 作者: G guru4elephant

add benchmark scripts for imdb

上级 5fe3587f
......@@ -13,55 +13,45 @@
# limitations under the License.
import sys
import time
import requests
from imdb_reader import IMDBDataset
from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
from paddle_serving_client.utils import benchmark_args
args = benchmark_args()
def predict(thr_id, resource):
client = Client()
client.load_client_config(resource["conf_file"])
client.connect(resource["server_endpoint"])
thread_num = resource["thread_num"]
file_list = resource["filelist"]
line_id = 0
prob = []
label_list = []
dataset = []
for fn in file_list:
fin = open(fn)
for line in fin:
if line_id % thread_num == thr_id - 1:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0])]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
dataset.append(feed)
line_id += 1
fin.close()
def single_func(idx, resource):
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource(args.vocab)
filelist_fn = args.filelist
filelist = []
start = time.time()
fetch = ["acc", "cost", "prediction"]
for inst in dataset:
fetch_map = client.predict(feed=inst, fetch=fetch)
prob.append(fetch_map["prediction"][1])
label_list.append(label[0])
with open(filelist_fn) as fin:
for line in fin:
filelist.append(line.strip())
filelist = filelist[idx::args.thread]
if args.request == "rpc":
client = Client()
client.load_client_config(args.model)
client.connect([args.endpoint])
for fn in filelist:
fin = open(fn)
for line in fin:
word_ids, label = imdb_dataset.get_words_and_label(line)
fetch_map = client.predict(feed={"words": word_ids},
fetch=["prediction"])
elif args.request == "http":
for fn in filelist:
fin = open(fn)
for line in fin:
word_ids, label = imdb_dataset.get_words_and_label(line)
r = requests.post("http://{}/imdb/prediction".format(args.endpoint),
data={"words": word_ids})
end = time.time()
client.release()
return [prob, label_list, [end - start]]
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
resource = {}
resource["conf_file"] = conf_file
resource["server_endpoint"] = ["127.0.0.1:9293"]
resource["filelist"] = [data_file]
resource["thread_num"] = int(sys.argv[3])
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource)
return [[end - start]]
print("total time {} s".format(sum(result[-1]) / len(result[-1])))
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, {})
print(result)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from imdb_reader import IMDBDataset
import sys
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"])
# you can define any english sentence or dataset here
# This example reuses imdb reader in training, you
# can define your own data preprocessing easily.
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource(sys.argv[2])
for line in sys.stdin:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0]) + 1]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
word_ids, label = imdb_dataset.get_words_and_label(line)
feed = {"words": word_ids, "label": label}
fetch = ["acc", "cost", "prediction"]
fetch_map = client.predict(feed=feed, fetch=fetch)
print("{} {}".format(fetch_map["prediction"][1], label[0]))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
import sys
import subprocess
from multiprocessing import Pool
import time
def predict(p_id, p_size, data_list):
client = Client()
client.load_client_config(conf_file)
client.connect(["127.0.0.1:8010"])
result = []
for line in data_list:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0])]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
fetch = ["acc", "cost", "prediction"]
fetch_map = client.predict(feed=feed, fetch=fetch)
#print("{} {}".format(fetch_map["prediction"][1], label[0]))
result.append([fetch_map["prediction"][1], label[0]])
return result
def predict_multi_thread(p_num):
data_list = []
with open(data_file) as f:
for line in f.readlines():
data_list.append(line)
start = time.time()
p = Pool(p_num)
p_size = len(data_list) / p_num
result_list = []
for i in range(p_num):
result_list.append(
p.apply_async(predict,
[i, p_size, data_list[i * p_size:(i + 1) * p_size]]))
p.close()
p.join()
for i in range(p_num):
result = result_list[i].get()
for j in result:
print("{} {}".format(j[0], j[1]))
cost = time.time() - start
print("{} threads cost {}".format(p_num, cost))
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
p_num = int(sys.argv[3])
predict_multi_thread(p_num)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册