未验证 提交 16309001 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #248 from guru4elephant/refine_imdb_benchmark

Refine imdb benchmark
......@@ -13,55 +13,45 @@
# limitations under the License.
import sys
import time
import requests
from imdb_reader import IMDBDataset
from paddle_serving_client import Client
from paddle_serving_client.metric import auc
from paddle_serving_client.utils import MultiThreadRunner
import time
from paddle_serving_client.utils import benchmark_args
args = benchmark_args()
def predict(thr_id, resource):
client = Client()
client.load_client_config(resource["conf_file"])
client.connect(resource["server_endpoint"])
thread_num = resource["thread_num"]
file_list = resource["filelist"]
line_id = 0
prob = []
label_list = []
dataset = []
for fn in file_list:
fin = open(fn)
for line in fin:
if line_id % thread_num == thr_id - 1:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0])]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
dataset.append(feed)
line_id += 1
fin.close()
def single_func(idx, resource):
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource(args.vocab)
filelist_fn = args.filelist
filelist = []
start = time.time()
fetch = ["acc", "cost", "prediction"]
for inst in dataset:
fetch_map = client.predict(feed=inst, fetch=fetch)
prob.append(fetch_map["prediction"][1])
label_list.append(label[0])
with open(filelist_fn) as fin:
for line in fin:
filelist.append(line.strip())
filelist = filelist[idx::args.thread]
if args.request == "rpc":
client = Client()
client.load_client_config(args.model)
client.connect([args.endpoint])
for fn in filelist:
fin = open(fn)
for line in fin:
word_ids, label = imdb_dataset.get_words_and_label(line)
fetch_map = client.predict(feed={"words": word_ids},
fetch=["prediction"])
elif args.request == "http":
for fn in filelist:
fin = open(fn)
for line in fin:
word_ids, label = imdb_dataset.get_words_and_label(line)
r = requests.post("http://{}/imdb/prediction".format(args.endpoint),
data={"words": word_ids})
end = time.time()
client.release()
return [prob, label_list, [end - start]]
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
resource = {}
resource["conf_file"] = conf_file
resource["server_endpoint"] = ["127.0.0.1:9293"]
resource["filelist"] = [data_file]
resource["thread_num"] = int(sys.argv[3])
thread_runner = MultiThreadRunner()
result = thread_runner.run(predict, int(sys.argv[3]), resource)
return [[end - start]]
print("total time {} s".format(sum(result[-1]) / len(result[-1])))
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, {})
print(result)
wget --no-check-certificate https://fleet.bj.bcebos.com/text_classification_data.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imdb-demo/imdb_model.tar.gz
tar -zxvf text_classification_data.tar.gz
#wget --no-check-certificate https://paddle-serving.bj.bcebos.com/imdb-demo%2Fimdb.tar.gz
#tar -xzf imdb-demo%2Fimdb.tar.gz
tar -zxvf imdb_model.tar.gz
......@@ -30,6 +30,14 @@ class IMDBDataset(dg.MultiSlotDataGenerator):
self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0])
def get_words_only(self, line):
sent = line.lower().replace("<br />", " ").strip()
words = [x for x in self._pattern.split(sent) if x and x != " "]
feas = [
self._vocab[x] if x in self._vocab else self._unk_id for x in words
]
return feas
def get_words_and_label(self, line):
send = '|'.join(line.split('|')[:-1]).lower().replace("<br />",
" ").strip()
......
wget https://paddle-serving.bj.bcebos.com/imdb-demo%2Fimdb_service.tar.gz
wget https://paddle-serving.bj.bcebos.com/imdb-demo/imdb_service.tar.gz
tar -xzf imdb_service.tar.gz
wget --no-check-certificate https://fleet.bj.bcebos.com/text_classification_data.tar.gz
tar -zxvf text_classification_data.tar.gz
......
......@@ -49,8 +49,9 @@ if __name__ == "__main__":
dataset.set_batch_size(128)
dataset.set_filelist(filelist)
dataset.set_thread(10)
from nets import bow_net
avg_cost, acc, prediction = bow_net(data, label, dict_dim)
from nets import lstm_net
model_name = "imdb_lstm"
avg_cost, acc, prediction = lstm_net(data, label, dict_dim)
optimizer = fluid.optimizer.SGD(learning_rate=0.01)
optimizer.minimize(avg_cost)
......@@ -65,6 +66,7 @@ if __name__ == "__main__":
program=fluid.default_main_program(), dataset=dataset, debug=False)
logger.info("TRAIN --> pass: {}".format(i))
if i == 5:
serving_io.save_model("imdb_model", "imdb_client_conf",
serving_io.save_model("{}_model".format(model_name),
"{}_client_conf".format(model_name),
{"words": data}, {"prediction": prediction},
fluid.default_main_program())
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from imdb_reader import IMDBDataset
import sys
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"])
# you can define any english sentence or dataset here
# This example reuses imdb reader in training, you
# can define your own data preprocessing easily.
imdb_dataset = IMDBDataset()
imdb_dataset.load_resource(sys.argv[2])
for line in sys.stdin:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0]) + 1]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
word_ids, label = imdb_dataset.get_words_and_label(line)
feed = {"words": word_ids, "label": label}
fetch = ["acc", "cost", "prediction"]
fetch_map = client.predict(feed=feed, fetch=fetch)
print("{} {}".format(fetch_map["prediction"][1], label[0]))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
import sys
import subprocess
from multiprocessing import Pool
import time
def predict(p_id, p_size, data_list):
client = Client()
client.load_client_config(conf_file)
client.connect(["127.0.0.1:8010"])
result = []
for line in data_list:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0])]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
fetch = ["acc", "cost", "prediction"]
fetch_map = client.predict(feed=feed, fetch=fetch)
#print("{} {}".format(fetch_map["prediction"][1], label[0]))
result.append([fetch_map["prediction"][1], label[0]])
return result
def predict_multi_thread(p_num):
data_list = []
with open(data_file) as f:
for line in f.readlines():
data_list.append(line)
start = time.time()
p = Pool(p_num)
p_size = len(data_list) / p_num
result_list = []
for i in range(p_num):
result_list.append(
p.apply_async(predict,
[i, p_size, data_list[i * p_size:(i + 1) * p_size]]))
p.close()
p.join()
for i in range(p_num):
result = result_list[i].get()
for j in result:
print("{} {}".format(j[0], j[1]))
cost = time.time() - start
print("{} threads cost {}".format(p_num, cost))
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
p_num = int(sys.argv[3])
predict_multi_thread(p_num)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from paddle_serving_server_gpu import OpMaker
from paddle_serving_server_gpu import OpSeqMaker
from paddle_serving_server_gpu import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(12)
server.load_model_config(sys.argv[1])
port = int(sys.argv[2])
server.prepare_server(workdir="work_dir1", port=port, device="gpu")
server.run_server()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from paddle_serving_server import OpMaker
from paddle_serving_server import OpSeqMaker
from paddle_serving_server import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(general_response_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.load_model_config(sys.argv[1])
port = int(sys.argv[2])
server.prepare_server(workdir="work_dir1", port=port, device="cpu")
server.run_server()
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!flask/bin/python
from paddle_serving_server.web_service import WebService
from imdb_reader import IMDBDataset
import sys
......@@ -27,7 +26,7 @@ class IMDBService(WebService):
if "words" not in feed:
exit(-1)
res_feed = {}
res_feed["words"] = self.dataset.get_words_and_label(feed["words"])[0]
res_feed["words"] = self.dataset.get_words_only(feed["words"])[0]
return res_feed, fetch
imdb_service = IMDBService(name="imdb")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册