提交 3a6bf1f0 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: wuzewu

fix serving (#204)

上级 92df1e21
...@@ -20,7 +20,6 @@ import time ...@@ -20,7 +20,6 @@ import time
import os import os
import base64 import base64
import logging import logging
import cv2
import multiprocessing as mp import multiprocessing as mp
from multiprocessing.managers import BaseManager from multiprocessing.managers import BaseManager
import random import random
...@@ -93,7 +92,6 @@ def predict_cv(input_data, module_name, batch_size=1): ...@@ -93,7 +92,6 @@ def predict_cv(input_data, module_name, batch_size=1):
filename_list = [] filename_list = []
for index in range(len(input_data)): for index in range(len(input_data)):
filename_list.append(input_data[index][3]) filename_list.append(input_data[index][3])
cv2.imread(input_data[index][3])
input_images = {"image": filename_list} input_images = {"image": filename_list}
module = ImageModelService.get_module(module_name) module = ImageModelService.get_module(module_name)
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
...@@ -130,31 +128,35 @@ def predict_cv(input_data, module_name, batch_size=1): ...@@ -130,31 +128,35 @@ def predict_cv(input_data, module_name, batch_size=1):
def worker(): def worker():
global batch_size_list, name_list, queue_name_list, cv_module global batch_size_list, name_list, queue_name_list, cv_module
latest_num = random.randrange(0, len(queue_name_list)) latest_num = random.randrange(0, len(queue_name_list))
try:
while True: while True:
time.sleep(0.01) time.sleep(0.01)
for index in range(len(queue_name_list)): for index in range(len(queue_name_list)):
while queues_dict[queue_name_list[latest_num]].empty() is not True: while queues_dict[
input_data = [] queue_name_list[latest_num]].empty() is not True:
lock.acquire() input_data = []
try: lock.acquire()
batch = queues_dict[ try:
queue_name_list[latest_num]].get_attribute("maxsize") batch = queues_dict[
for index2 in range(batch): queue_name_list[latest_num]].get_attribute(
if queues_dict[ "maxsize")
queue_name_list[latest_num]].empty() is True: for index2 in range(batch):
break if queues_dict[queue_name_list[latest_num]].empty(
input_data.append( ) is True:
queues_dict[queue_name_list[latest_num]].get()) break
finally: input_data.append(
lock.release() queues_dict[queue_name_list[latest_num]].get())
if len(input_data) != 0: finally:
choose_module_category(input_data, lock.release()
queue_name_list[latest_num], if len(input_data) != 0:
batch_size_list[latest_num]) choose_module_category(input_data,
else: queue_name_list[latest_num],
pass batch_size_list[latest_num])
latest_num = (latest_num + 1) % len(queue_name_list) else:
pass
latest_num = (latest_num + 1) % len(queue_name_list)
except KeyboardInterrupt:
print("Process %s is end." % (os.getpid()))
def init_pool(l): def init_pool(l):
...@@ -168,7 +170,7 @@ def create_app(): ...@@ -168,7 +170,7 @@ def create_app():
gunicorn_logger = logging.getLogger('gunicorn.error') gunicorn_logger = logging.getLogger('gunicorn.error')
app_instance.logger.handlers = gunicorn_logger.handlers app_instance.logger.handlers = gunicorn_logger.handlers
app_instance.logger.setLevel(gunicorn_logger.level) app_instance.logger.setLevel(gunicorn_logger.level)
global queues_dict global queues_dict, pool
lock = mp.Lock() lock = mp.Lock()
pool = mp.Pool( pool = mp.Pool(
processes=(mp.cpu_count() - 1), processes=(mp.cpu_count() - 1),
...@@ -310,6 +312,9 @@ def run(is_use_gpu=False, configs=None, port=8888): ...@@ -310,6 +312,9 @@ def run(is_use_gpu=False, configs=None, port=8888):
return return
my_app = create_app() my_app = create_app()
my_app.run(host="0.0.0.0", port=port, debug=False) my_app.run(host="0.0.0.0", port=port, debug=False)
pool.close()
pool.join()
print("PaddleHub-Serving has been stopped.")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -52,6 +52,10 @@ setup( ...@@ -52,6 +52,10 @@ setup(
author_email='paddle-dev@baidu.com', author_email='paddle-dev@baidu.com',
install_requires=REQUIRED_PACKAGES, install_requires=REQUIRED_PACKAGES,
packages=find_packages(), packages=find_packages(),
data_files=[('paddlehub/serving/templates', [
'paddlehub/serving/templates/serving_config.json',
'paddlehub/serving/templates/main.html'
])],
# PyPI package information. # PyPI package information.
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
......
...@@ -140,7 +140,7 @@ if __name__ == '__main__': ...@@ -140,7 +140,7 @@ if __name__ == '__main__':
if is_path_valid(args.saved_params_dir) and os.path.exists(best_model_dir): if is_path_valid(args.saved_params_dir) and os.path.exists(best_model_dir):
shutil.copytree(best_model_dir, args.saved_params_dir) shutil.copytree(best_model_dir, args.saved_params_dir)
shutil.rmtree(config.checkpoint_dir) shutil.rmtree(config.checkpoint_dir)
# acc on dev will be used by auto finetune # acc on dev will be used by auto finetune
print("AutoFinetuneEval"+"\t"+str(float(eval_avg_score["acc"]))) print("AutoFinetuneEval"+"\t"+str(float(eval_avg_score["acc"])))
``` ```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册