提交 46498409 编写于 作者: M MRXLT

refine imdb demo

上级 de96b47a
......@@ -37,26 +37,39 @@ def single_func(idx, resource):
client.load_client_config(args.model)
client.connect([args.endpoint])
for i in range(1000):
if args.batch_size == 1:
word_ids, label = imdb_dataset.get_words_and_label(line)
fetch_map = client.predict(
feed={"words": word_ids}, fetch=["prediction"])
if args.batch_size >= 1:
feed_batch = []
for bi in range(args.batch_size):
word_ids, label = imdb_dataset.get_words_and_label(dataset[
bi])
feed_batch.append({"words": word_ids})
result = client.predict(feed=feed_batch, fetch=["prediction"])
if result is None:
raise ("predict failed.")
else:
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
for fn in filelist:
fin = open(fn)
for line in fin:
word_ids, label = imdb_dataset.get_words_and_label(line)
r = requests.post(
"http://{}/imdb/prediction".format(args.endpoint),
data={"words": word_ids,
"fetch": ["prediction"]})
if args.batch_size >= 1:
feed_batch = []
for bi in range(args.batch_size):
feed_batch.append({"words": dataset[bi]})
r = requests.post(
"http://{}/imdb/prediction".format(args.endpoint),
json={"feed": feed_batch,
"fetch": ["prediction"]})
if r.status_code != 200:
print('HTTP status code -ne 200')
raise ("predict failed.")
else:
print("unsupport batch size {}".format(args.batch_size))
end = time.time()
return [[end - start]]
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, {})
print(result)
avg_cost = 0
for cost in result[0]:
avg_cost += cost
print("total cost {} s of each thread".format(avg_cost / args.thread))
rm profile_log
for thread_num in 1 2 4 8 16
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --model imdbo_bow_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
for batch_size in 1 2 4 8 16 32 64 128 256 512
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model imdb_bow_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
done
done
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import sys
import time
import requests
from imdb_reader import IMDBDataset
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
args = benchmark_args()
def single_func(idx, resource):
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource("./imdb.vocab")
dataset = []
with open("./test_data/part-0") as fin:
for line in fin:
dataset.append(line.strip())
start = time.time()
if args.request == "rpc":
client = Client()
client.load_client_config(args.model)
client.connect([args.endpoint])
for i in range(1000):
if args.batch_size >= 1:
feed_batch = []
for bi in range(args.batch_size):
word_ids, label = imdb_dataset.get_words_and_label(dataset[
bi])
feed_batch.append({"words": word_ids})
result = client.predict(feed=feed_batch, fetch=["prediction"])
if result is None:
raise ("predict failed.")
else:
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
if args.batch_size >= 1:
feed_batch = []
for bi in range(args.batch_size):
feed_batch.append({"words": dataset[bi]})
r = requests.post(
"http://{}/imdb/prediction".format(args.endpoint),
json={"feed": feed_batch,
"fetch": ["prediction"]})
if r.status_code != 200:
print('HTTP status code -ne 200')
raise ("predict failed.")
else:
print("unsupport batch size {}".format(args.batch_size))
end = time.time()
return [[end - start]]
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, {})
avg_cost = 0
for cost in result[0]:
avg_cost += cost
print("total cost {} s of each thread".format(avg_cost / args.thread))
rm profile_log
for thread_num in 1 2 4 8 16
do
for batch_size in 1 2 4 8 16 32 64 128 256 512
do
$PYTHONROOT/bin/python benchmark_batch.py --thread $thread_num --batch_size $batch_size --model imdb_bow_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
done
done
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import Client
import sys
import subprocess
from multiprocessing import Pool
import time
def batch_predict(batch_size=4):
client = Client()
client.load_client_config(conf_file)
client.connect(["127.0.0.1:9292"])
fetch = ["acc", "cost", "prediction"]
feed_batch = []
for line in sys.stdin:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0])]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
feed_batch.append(feed)
if len(feed_batch) == batch_size:
fetch_batch = client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
for i in range(batch_size):
print("{} {}".format(fetch_batch[i]["prediction"][1],
feed_batch[i]["label"][0]))
feed_batch = []
if len(feed_batch) > 0:
fetch_batch = client.batch_predict(feed_batch=feed_batch, fetch=fetch)
for i in range(len(feed_batch)):
print("{} {}".format(fetch_batch[i]["prediction"][1], feed_batch[i][
"label"][0]))
if __name__ == '__main__':
conf_file = sys.argv[1]
batch_size = int(sys.argv[2])
batch_predict(batch_size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册