提交 489f4e23 编写于 作者: S shenyuhan

Unified nlp interface.

上级 8a44068e
......@@ -3,7 +3,7 @@ import requests
import json
if __name__ == "__main__":
file_list = ["cat.jpg", "flower.jpg"]
file_list = ["../img/cat.jpg", "../img/flower.jpg"]
files = [("image", (open(item, "rb"))) for item in file_list]
url = "http://127.0.0.1:8866/predict/image/vgg11_imagenet"
r = requests.post(url=url, files=files)
......
......@@ -6,8 +6,7 @@ if __name__ == "__main__":
text_list = ["今天是个好日子", "天气预报说今天要下雨"]
text = {"text": text_list}
# 将用户自定义词典文件发送到预测接口即可
with open("dict.txt", "rb") as fp:
file = {"user_dict": fp.read()}
file = {"user_dict": open("dict.txt", "rb")}
url = "http://127.0.0.1:8866/predict/text/lac"
r = requests.post(url=url, files=file, data=text)
......
......@@ -22,17 +22,6 @@ import os
import base64
import logging
nlp_module_method = {
"lac": "predict_lexical_analysis",
"simnet_bow": "predict_sentiment_analysis",
"lm_lstm": "predict_pretrained_model",
"senta_lstm": "predict_pretrained_model",
"senta_gru": "predict_pretrained_model",
"senta_cnn": "predict_pretrained_model",
"senta_bow": "predict_pretrained_model",
"senta_bilstm": "predict_pretrained_model",
"emotion_detection_textcnn": "predict_pretrained_model"
}
cv_module_method = {
"vgg19_imagenet": "predict_classification",
"vgg16_imagenet": "predict_classification",
......@@ -65,63 +54,33 @@ cv_module_method = {
}
def predict_sentiment_analysis(module, input_text, batch_size, extra=None):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
try:
data = input_text[0]
data.update(input_text[1])
results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
return results
def predict_pretrained_model(module, input_text, batch_size, extra=None):
global use_gpu
def predict_nlp(module, input_text, req_id, batch_size, extra=None):
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
try:
data = {"text": input_text}
results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
return results
def predict_lexical_analysis(module, input_text, batch_size, extra=[]):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
data = {"text": input_text}
try:
if extra == []:
results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
else:
user_dict = extra[0]
results = predict_method(
data = input_text
if module.name == "lac" and extra.get("user_dict", []) != []:
res = predict_method(
data=data,
user_dict=user_dict,
user_dict=extra.get("user_dict", [])[0],
use_gpu=use_gpu,
batch_size=batch_size)
for path in extra:
os.remove(path)
else:
res = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
return results
return {"results": "Please check data format!"}
finally:
user_dict = extra.get("user_dict", [])
for item in user_dict:
if os.path.exists(item):
os.remove(item)
return {"results": res}
def predict_classification(module, input_img, batch_size):
def predict_classification(module, input_img, id, batch_size, extra={}):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
......@@ -133,31 +92,35 @@ def predict_classification(module, input_img, batch_size):
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
finally:
for item in input_img["image"]:
if os.path.exists(item):
os.remove(item)
return results
def predict_gan(module, input_img, id, batch_size, extra={}):
# special
output_folder = module.name.split("_")[0] + "_" + "output"
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
try:
extra.update({"image": input_img})
input_img = {"image": input_img}
results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
data=extra, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
finally:
base64_list = []
results_pack = []
input_img = input_img.get("image", [])
for index in range(len(input_img)):
# special
item = input_img[index]
with open(os.path.join(output_folder, item), "rb") as fp:
# special
output_file = results[index].split(" ")[-1]
with open(output_file, "rb") as fp:
b_head = "data:image/" + item.split(".")[-1] + ";base64"
b_body = base64.b64encode(fp.read())
b_body = str(b_body).replace("b'", "").replace("'", "")
......@@ -168,11 +131,11 @@ def predict_gan(module, input_img, id, batch_size, extra={}):
results[index].update({"base64": b_img})
results_pack.append(results[index])
os.remove(item)
os.remove(os.path.join(output_folder, item))
os.remove(output_file)
return results_pack
def predict_object_detection(module, input_img, id, batch_size):
def predict_object_detection(module, input_img, id, batch_size, extra={}):
output_folder = "output"
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
......@@ -185,6 +148,7 @@ def predict_object_detection(module, input_img, id, batch_size):
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
finally:
base64_list = []
results_pack = []
input_img = input_img.get("image", [])
......@@ -205,8 +169,7 @@ def predict_object_detection(module, input_img, id, batch_size):
return results_pack
def predict_semantic_segmentation(module, input_img, id, batch_size):
# special
def predict_semantic_segmentation(module, input_img, id, batch_size, extra={}):
output_folder = module.name.split("_")[-1] + "_" + "output"
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
......@@ -219,6 +182,7 @@ def predict_semantic_segmentation(module, input_img, id, batch_size):
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
finally:
base64_list = []
results_pack = []
input_img = input_img.get("image", [])
......@@ -227,7 +191,6 @@ def predict_semantic_segmentation(module, input_img, id, batch_size):
item = input_img[index]
output_file_path = ""
with open(results[index]["processed"], "rb") as fp:
# special
b_head = "data:image/png;base64"
b_body = base64.b64encode(fp.read())
b_body = str(b_body).replace("b'", "").replace("'", "")
......@@ -236,8 +199,8 @@ def predict_semantic_segmentation(module, input_img, id, batch_size):
output_file_path = results[index]["processed"]
results[index]["origin"] = results[index]["origin"].replace(
id + "_", "")
results[index]["processed"] = results[index]["processed"].replace(
id + "_", "")
results[index]["processed"] = results[index][
"processed"].replace(id + "_", "")
results[index].update({"base64": b_img})
results_pack.append(results[index])
os.remove(item)
......@@ -274,14 +237,18 @@ def create_app():
module_info.update({"cv_module": [{"Choose...": "Choose..."}]})
for item in cv_module:
module_info["cv_module"].append({item: item})
module_info.update({"Choose...": [{"请先选择分类": "Choose..."}]})
return {"module_info": module_info}
@app_instance.route("/predict/image/<module_name>", methods=["POST"])
def predict_image(module_name):
if request.path.split("/")[-1] not in cv_module:
return {"error": "Module {} is not available.".format(module_name)}
req_id = request.data.get("id")
global use_gpu, batch_size_dict
img_base64 = request.form.getlist("image")
extra_info = {}
for item in list(request.form.keys()):
extra_info.update({item: request.form.getlist(item)})
file_name_list = []
if img_base64 != []:
for item in img_base64:
......@@ -310,36 +277,34 @@ def create_app():
module_type = module.type.split("/")[-1].replace("-", "_").lower()
predict_func = eval("predict_" + module_type)
batch_size = batch_size_dict.get(module_name, 1)
results = predict_func(module, file_name_list, req_id, batch_size)
results = predict_func(module, file_name_list, req_id, batch_size,
extra_info)
r = {"results": str(results)}
return r
@app_instance.route("/predict/text/<module_name>", methods=["POST"])
def predict_text(module_name):
if request.path.split("/")[-1] not in nlp_module:
return {"error": "Module {} is not available.".format(module_name)}
req_id = request.data.get("id")
global use_gpu
if module_name == "simnet_bow":
text_1 = request.form.getlist("text_1")
text_2 = request.form.getlist("text_2")
data = [{"text_1": text_1}, {"text_2": text_2}]
else:
data = request.form.getlist("text")
file = request.files.getlist("user_dict")
inputs = {}
for item in list(request.form.keys()):
inputs.update({item: request.form.getlist(item)})
files = {}
for file_key in list(request.files.keys()):
files[file_key] = []
for file in request.files.getlist(file_key):
file_name = req_id + "_" + file.filename
files[file_key].append(file_name)
file.save(file_name)
module = TextModelService.get_module(module_name)
predict_func_name = nlp_module_method.get(module_name, "")
if predict_func_name != "":
predict_func = eval(predict_func_name)
else:
module_type = module.type.split("/")[-1].replace("-", "_").lower()
predict_func = eval("predict_" + module_type)
file_list = []
for item in file:
file_path = req_id + "_" + item.filename
file_list.append(file_path)
item.save(file_path)
batch_size = batch_size_dict.get(module_name, 1)
results = predict_func(module, data, batch_size, file_list)
return {"results": results}
results = predict_nlp(
module=module,
input_text=inputs,
req_id=req_id,
batch_size=batch_size_dict.get(module_name, 1),
extra=files)
return results
return app_instance
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册