提交 a4a8245b 编写于 作者: S shenyuhan

add single_app()

上级 60c527d8
...@@ -23,39 +23,7 @@ import base64 ...@@ -23,39 +23,7 @@ import base64
import logging import logging
def get_img_output(module, base64_head, results): def predict_sentiment_analysis(module, input_text, extra=None):
if module.type.startswith("CV"):
if "semantic-segmentation" in module.type:
output_file = results[0].get("processed", None)
if output_file is not None and os.path.exists(output_file):
with open(output_file, "rb") as fp:
output_img_base64 = base64.b64encode(fp.read())
os.remove(output_file)
results = {
"desc":
"Here is result.",
"output_img":
base64_head + "," + str(output_img_base64).replace(
"b'", "").replace("'", "")
}
return {"result": results}
elif "object-detection" in module.type:
output_file = os.path.join("./output", results[0]["path"])
if output_file is not None and os.path.exists(output_file):
with open(output_file, "rb") as fp:
output_img_base64 = base64.b64encode(fp.read())
os.remove(output_file)
results = {
"desc":
str(results[0]["data"]),
"output_img":
base64_head + "," + str(output_img_base64).replace(
"b'", "").replace("'", "")
}
return {"result": results}
def predict_sentiment_analysis(module, input_text):
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
...@@ -68,7 +36,7 @@ def predict_sentiment_analysis(module, input_text): ...@@ -68,7 +36,7 @@ def predict_sentiment_analysis(module, input_text):
return results return results
def predict_pretrained_model(module, input_text): def predict_pretrained_model(module, input_text, extra=None):
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
...@@ -80,16 +48,18 @@ def predict_pretrained_model(module, input_text): ...@@ -80,16 +48,18 @@ def predict_pretrained_model(module, input_text):
return results return results
def predict_lexical_analysis(module, input_text, extra=None): def predict_lexical_analysis(module, input_text, extra=[]):
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
data = {"text": input_text}
try: try:
if extra is None: if extra is []:
data = {"text": input_text} results = predict_method(data=data, use_gpu=use_gpu)
else: else:
data = {"text": input_text} user_dict = extra[0]
results = predict_method(data=data, use_gpu=use_gpu) results = predict_method(
data=data, user_dict=user_dict, use_gpu=use_gpu)
except Exception as err: except Exception as err:
return {"result": "Please check data format!"} return {"result": "Please check data format!"}
return results return results
...@@ -231,10 +201,8 @@ def create_app(): ...@@ -231,10 +201,8 @@ def create_app():
@app_instance.route("/predict/image/<module_name>", methods=["POST"]) @app_instance.route("/predict/image/<module_name>", methods=["POST"])
def predict_image(module_name): def predict_image(module_name):
# 稍后保存的文件名用id+源文件名的形式以避免冲突
req_id = request.data.get("id") req_id = request.data.get("id")
global use_gpu global use_gpu
# 这里是一个base64的列表
img_base64 = request.form.getlist("input_img") img_base64 = request.form.getlist("input_img")
file_name_list = [] file_name_list = []
if img_base64 != "": if img_base64 != "":
...@@ -256,11 +224,7 @@ def create_app(): ...@@ -256,11 +224,7 @@ def create_app():
file_name = req_id + "_" + item.filename file_name = req_id + "_" + item.filename
item.save(file_name) item.save(file_name)
file_name_list.append(file_name) file_name_list.append(file_name)
# 到这里就把所有原始文件和文件名列表都保存了
# 文件名列表可用于预测
# 获取模型
module = ImageModelService.get_module(module_name) module = ImageModelService.get_module(module_name)
# 根据模型种类寻找具体预测方法,即根据名字定函数
module_type = module.type.split("/")[-1].replace("-", "_").lower() module_type = module.type.split("/")[-1].replace("-", "_").lower()
predict_func = eval("predict_" + module_type) predict_func = eval("predict_" + module_type)
results = predict_func(module, file_name_list) results = predict_func(module, file_name_list)
...@@ -269,13 +233,19 @@ def create_app(): ...@@ -269,13 +233,19 @@ def create_app():
@app_instance.route("/predict/text/<module_name>", methods=["POST"]) @app_instance.route("/predict/text/<module_name>", methods=["POST"])
def predict_text(module_name): def predict_text(module_name):
req_id = request.data.get("id")
global use_gpu global use_gpu
# 应该是一个列表
data = request.form.getlist("input_text") data = request.form.getlist("input_text")
file = request.files.getlist("user_dict")
module = TextModelService.get_module(module_name) module = TextModelService.get_module(module_name)
module_type = module.type.split("/")[-1].replace("-", "_").lower() module_type = module.type.split("/")[-1].replace("-", "_").lower()
predict_func = eval("predict_" + module_type) predict_func = eval("predict_" + module_type)
results = predict_func(module, data) file_list = []
for item in file:
file_path = req_id + "_" + item.filename
file_list.append(file_path)
item.save(file_path)
results = predict_func(module, data, file_list)
return {"results": results} return {"results": results}
return app_instance return app_instance
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册