diff --git a/paddlehub/serving/app_single.py b/paddlehub/serving/app_single.py index 21195348baa0318619c4fe704d2631d353e799bf..eaa407921c0f7a940da2f93b5702d959b1a5405e 100644 --- a/paddlehub/serving/app_single.py +++ b/paddlehub/serving/app_single.py @@ -23,39 +23,7 @@ import base64 import logging -def get_img_output(module, base64_head, results): - if module.type.startswith("CV"): - if "semantic-segmentation" in module.type: - output_file = results[0].get("processed", None) - if output_file is not None and os.path.exists(output_file): - with open(output_file, "rb") as fp: - output_img_base64 = base64.b64encode(fp.read()) - os.remove(output_file) - results = { - "desc": - "Here is result.", - "output_img": - base64_head + "," + str(output_img_base64).replace( - "b'", "").replace("'", "") - } - return {"result": results} - elif "object-detection" in module.type: - output_file = os.path.join("./output", results[0]["path"]) - if output_file is not None and os.path.exists(output_file): - with open(output_file, "rb") as fp: - output_img_base64 = base64.b64encode(fp.read()) - os.remove(output_file) - results = { - "desc": - str(results[0]["data"]), - "output_img": - base64_head + "," + str(output_img_base64).replace( - "b'", "").replace("'", "") - } - return {"result": results} - - -def predict_sentiment_analysis(module, input_text): +def predict_sentiment_analysis(module, input_text, extra=None): global use_gpu method_name = module.desc.attr.map.data['default_signature'].s predict_method = getattr(module, method_name) @@ -68,7 +36,7 @@ def predict_sentiment_analysis(module, input_text): return results -def predict_pretrained_model(module, input_text): +def predict_pretrained_model(module, input_text, extra=None): global use_gpu method_name = module.desc.attr.map.data['default_signature'].s predict_method = getattr(module, method_name) @@ -80,16 +48,18 @@ def predict_pretrained_model(module, input_text): return results -def predict_lexical_analysis(module, input_text, extra=None): +def predict_lexical_analysis(module, input_text, extra=[]): global use_gpu method_name = module.desc.attr.map.data['default_signature'].s predict_method = getattr(module, method_name) + data = {"text": input_text} try: - if extra is None: - data = {"text": input_text} + if extra is []: + results = predict_method(data=data, use_gpu=use_gpu) else: - data = {"text": input_text} - results = predict_method(data=data, use_gpu=use_gpu) + user_dict = extra[0] + results = predict_method( + data=data, user_dict=user_dict, use_gpu=use_gpu) except Exception as err: return {"result": "Please check data format!"} return results @@ -231,10 +201,8 @@ def create_app(): @app_instance.route("/predict/image/", methods=["POST"]) def predict_image(module_name): - # 稍后保存的文件名用id+源文件名的形式以避免冲突 req_id = request.data.get("id") global use_gpu - # 这里是一个base64的列表 img_base64 = request.form.getlist("input_img") file_name_list = [] if img_base64 != "": @@ -256,11 +224,7 @@ def create_app(): file_name = req_id + "_" + item.filename item.save(file_name) file_name_list.append(file_name) - # 到这里就把所有原始文件和文件名列表都保存了 - # 文件名列表可用于预测 - # 获取模型 module = ImageModelService.get_module(module_name) - # 根据模型种类寻找具体预测方法,即根据名字定函数 module_type = module.type.split("/")[-1].replace("-", "_").lower() predict_func = eval("predict_" + module_type) results = predict_func(module, file_name_list) @@ -269,13 +233,19 @@ def create_app(): @app_instance.route("/predict/text/", methods=["POST"]) def predict_text(module_name): + req_id = request.data.get("id") global use_gpu - # 应该是一个列表 data = request.form.getlist("input_text") + file = request.files.getlist("user_dict") module = TextModelService.get_module(module_name) module_type = module.type.split("/")[-1].replace("-", "_").lower() predict_func = eval("predict_" + module_type) - results = predict_func(module, data) + file_list = [] + for item in file: + file_path = req_id + "_" + item.filename + file_list.append(file_path) + item.save(file_path) + results = predict_func(module, data, file_list) return {"results": results} return app_instance