diff --git a/paddlehub/serving/app_single.py b/paddlehub/serving/app_single.py index 7eec330ad488790a6c81c465a12b1eb85a0b3e46..2285fb897b001be0976e02c87b5c561c82cec750 100644 --- a/paddlehub/serving/app_single.py +++ b/paddlehub/serving/app_single.py @@ -12,11 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and from flask import Flask, request, render_template -from paddlehub.serving.model_service.text_model_service import TextModelService -from paddlehub.serving.model_service.image_model_service import ImageModelService +from paddlehub.serving.model_service.model_manage import default_module_manager from paddlehub.common import utils -# from model_service.text_model_service import TextModelService -# from model_service.image_model_service import ImageModelService import time import os import base64 @@ -269,7 +266,7 @@ def create_app(): file_name = req_id + "_" + item.filename item.save(file_name) file_name_list.append(file_name) - module = ImageModelService.get_module(module_name) + module = default_module_manager.get_module(module_name) predict_func_name = cv_module_method.get(module_name, "") if predict_func_name != "": predict_func = eval(predict_func_name) @@ -297,7 +294,7 @@ def create_app(): file_name = req_id + "_" + file.filename files[file_key].append(file_name) file.save(file_name) - module = TextModelService.get_module(module_name) + module = default_module_manager.get_module(module_name) results = predict_nlp( module=module, input_text=inputs, @@ -321,6 +318,7 @@ def config_with_file(configs): elif item["category"] == "NLP": nlp_module.append(item["module"]) batch_size_dict.update({item["module"]: item["batch_size"]}) + default_module_manager.load_module([item["module"]]) def run(is_use_gpu=False, configs=None, port=8866, timeout=60): diff --git a/paddlehub/serving/model_service/model_manage.py b/paddlehub/serving/model_service/model_manage.py new file mode 100644 index 0000000000000000000000000000000000000000..b18c43a4adbcc2dd98e6a273e223bbe68c6e1569 --- /dev/null +++ b/paddlehub/serving/model_service/model_manage.py @@ -0,0 +1,34 @@ +# coding: utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddlehub as hub + + +class ModuleManager(object): + def __init__(self): + self.modules = {} + + def load_module(self, modules=[]): + for name in modules: + self.modules.update({name: hub.Module(name)}) + print("Loading %s successful." % name) + + def get_module(self, name): + if name in self.modules.keys(): + return self.modules[name] + else: + return hub.Module(name) + + +default_module_manager = ModuleManager()