提交 489f4e23 编写于 作者: S shenyuhan

Unified nlp interface.

上级 8a44068e
...@@ -3,7 +3,7 @@ import requests ...@@ -3,7 +3,7 @@ import requests
import json import json
if __name__ == "__main__": if __name__ == "__main__":
file_list = ["cat.jpg", "flower.jpg"] file_list = ["../img/cat.jpg", "../img/flower.jpg"]
files = [("image", (open(item, "rb"))) for item in file_list] files = [("image", (open(item, "rb"))) for item in file_list]
url = "http://127.0.0.1:8866/predict/image/vgg11_imagenet" url = "http://127.0.0.1:8866/predict/image/vgg11_imagenet"
r = requests.post(url=url, files=files) r = requests.post(url=url, files=files)
......
...@@ -6,8 +6,7 @@ if __name__ == "__main__": ...@@ -6,8 +6,7 @@ if __name__ == "__main__":
text_list = ["今天是个好日子", "天气预报说今天要下雨"] text_list = ["今天是个好日子", "天气预报说今天要下雨"]
text = {"text": text_list} text = {"text": text_list}
# 将用户自定义词典文件发送到预测接口即可 # 将用户自定义词典文件发送到预测接口即可
with open("dict.txt", "rb") as fp: file = {"user_dict": open("dict.txt", "rb")}
file = {"user_dict": fp.read()}
url = "http://127.0.0.1:8866/predict/text/lac" url = "http://127.0.0.1:8866/predict/text/lac"
r = requests.post(url=url, files=file, data=text) r = requests.post(url=url, files=file, data=text)
......
...@@ -22,17 +22,6 @@ import os ...@@ -22,17 +22,6 @@ import os
import base64 import base64
import logging import logging
nlp_module_method = {
"lac": "predict_lexical_analysis",
"simnet_bow": "predict_sentiment_analysis",
"lm_lstm": "predict_pretrained_model",
"senta_lstm": "predict_pretrained_model",
"senta_gru": "predict_pretrained_model",
"senta_cnn": "predict_pretrained_model",
"senta_bow": "predict_pretrained_model",
"senta_bilstm": "predict_pretrained_model",
"emotion_detection_textcnn": "predict_pretrained_model"
}
cv_module_method = { cv_module_method = {
"vgg19_imagenet": "predict_classification", "vgg19_imagenet": "predict_classification",
"vgg16_imagenet": "predict_classification", "vgg16_imagenet": "predict_classification",
...@@ -65,63 +54,33 @@ cv_module_method = { ...@@ -65,63 +54,33 @@ cv_module_method = {
} }
def predict_sentiment_analysis(module, input_text, batch_size, extra=None): def predict_nlp(module, input_text, req_id, batch_size, extra=None):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
try:
data = input_text[0]
data.update(input_text[1])
results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
return results
def predict_pretrained_model(module, input_text, batch_size, extra=None):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
try: try:
data = {"text": input_text} data = input_text
results = predict_method( if module.name == "lac" and extra.get("user_dict", []) != []:
data=data, use_gpu=use_gpu, batch_size=batch_size) res = predict_method(
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
return {"result": "Please check data format!"}
return results
def predict_lexical_analysis(module, input_text, batch_size, extra=[]):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
data = {"text": input_text}
try:
if extra == []:
results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
else:
user_dict = extra[0]
results = predict_method(
data=data, data=data,
user_dict=user_dict, user_dict=extra.get("user_dict", [])[0],
use_gpu=use_gpu, use_gpu=use_gpu,
batch_size=batch_size) batch_size=batch_size)
for path in extra: else:
os.remove(path) res = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err: except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
return {"result": "Please check data format!"} return {"results": "Please check data format!"}
return results finally:
user_dict = extra.get("user_dict", [])
for item in user_dict:
if os.path.exists(item):
os.remove(item)
return {"results": res}
def predict_classification(module, input_img, batch_size): def predict_classification(module, input_img, id, batch_size, extra={}):
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
...@@ -133,31 +92,35 @@ def predict_classification(module, input_img, batch_size): ...@@ -133,31 +92,35 @@ def predict_classification(module, input_img, batch_size):
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
return {"result": "Please check data format!"} return {"result": "Please check data format!"}
finally:
for item in input_img["image"]:
if os.path.exists(item):
os.remove(item)
return results return results
def predict_gan(module, input_img, id, batch_size, extra={}): def predict_gan(module, input_img, id, batch_size, extra={}):
# special
output_folder = module.name.split("_")[0] + "_" + "output" output_folder = module.name.split("_")[0] + "_" + "output"
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
try: try:
extra.update({"image": input_img})
input_img = {"image": input_img} input_img = {"image": input_img}
results = predict_method( results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size) data=extra, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err: except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
return {"result": "Please check data format!"} return {"result": "Please check data format!"}
finally:
base64_list = [] base64_list = []
results_pack = [] results_pack = []
input_img = input_img.get("image", []) input_img = input_img.get("image", [])
for index in range(len(input_img)): for index in range(len(input_img)):
# special
item = input_img[index] item = input_img[index]
with open(os.path.join(output_folder, item), "rb") as fp: output_file = results[index].split(" ")[-1]
# special with open(output_file, "rb") as fp:
b_head = "data:image/" + item.split(".")[-1] + ";base64" b_head = "data:image/" + item.split(".")[-1] + ";base64"
b_body = base64.b64encode(fp.read()) b_body = base64.b64encode(fp.read())
b_body = str(b_body).replace("b'", "").replace("'", "") b_body = str(b_body).replace("b'", "").replace("'", "")
...@@ -168,11 +131,11 @@ def predict_gan(module, input_img, id, batch_size, extra={}): ...@@ -168,11 +131,11 @@ def predict_gan(module, input_img, id, batch_size, extra={}):
results[index].update({"base64": b_img}) results[index].update({"base64": b_img})
results_pack.append(results[index]) results_pack.append(results[index])
os.remove(item) os.remove(item)
os.remove(os.path.join(output_folder, item)) os.remove(output_file)
return results_pack return results_pack
def predict_object_detection(module, input_img, id, batch_size): def predict_object_detection(module, input_img, id, batch_size, extra={}):
output_folder = "output" output_folder = "output"
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
...@@ -185,6 +148,7 @@ def predict_object_detection(module, input_img, id, batch_size): ...@@ -185,6 +148,7 @@ def predict_object_detection(module, input_img, id, batch_size):
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
return {"result": "Please check data format!"} return {"result": "Please check data format!"}
finally:
base64_list = [] base64_list = []
results_pack = [] results_pack = []
input_img = input_img.get("image", []) input_img = input_img.get("image", [])
...@@ -205,8 +169,7 @@ def predict_object_detection(module, input_img, id, batch_size): ...@@ -205,8 +169,7 @@ def predict_object_detection(module, input_img, id, batch_size):
return results_pack return results_pack
def predict_semantic_segmentation(module, input_img, id, batch_size): def predict_semantic_segmentation(module, input_img, id, batch_size, extra={}):
# special
output_folder = module.name.split("_")[-1] + "_" + "output" output_folder = module.name.split("_")[-1] + "_" + "output"
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
...@@ -219,6 +182,7 @@ def predict_semantic_segmentation(module, input_img, id, batch_size): ...@@ -219,6 +182,7 @@ def predict_semantic_segmentation(module, input_img, id, batch_size):
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
return {"result": "Please check data format!"} return {"result": "Please check data format!"}
finally:
base64_list = [] base64_list = []
results_pack = [] results_pack = []
input_img = input_img.get("image", []) input_img = input_img.get("image", [])
...@@ -227,7 +191,6 @@ def predict_semantic_segmentation(module, input_img, id, batch_size): ...@@ -227,7 +191,6 @@ def predict_semantic_segmentation(module, input_img, id, batch_size):
item = input_img[index] item = input_img[index]
output_file_path = "" output_file_path = ""
with open(results[index]["processed"], "rb") as fp: with open(results[index]["processed"], "rb") as fp:
# special
b_head = "data:image/png;base64" b_head = "data:image/png;base64"
b_body = base64.b64encode(fp.read()) b_body = base64.b64encode(fp.read())
b_body = str(b_body).replace("b'", "").replace("'", "") b_body = str(b_body).replace("b'", "").replace("'", "")
...@@ -236,8 +199,8 @@ def predict_semantic_segmentation(module, input_img, id, batch_size): ...@@ -236,8 +199,8 @@ def predict_semantic_segmentation(module, input_img, id, batch_size):
output_file_path = results[index]["processed"] output_file_path = results[index]["processed"]
results[index]["origin"] = results[index]["origin"].replace( results[index]["origin"] = results[index]["origin"].replace(
id + "_", "") id + "_", "")
results[index]["processed"] = results[index]["processed"].replace( results[index]["processed"] = results[index][
id + "_", "") "processed"].replace(id + "_", "")
results[index].update({"base64": b_img}) results[index].update({"base64": b_img})
results_pack.append(results[index]) results_pack.append(results[index])
os.remove(item) os.remove(item)
...@@ -274,14 +237,18 @@ def create_app(): ...@@ -274,14 +237,18 @@ def create_app():
module_info.update({"cv_module": [{"Choose...": "Choose..."}]}) module_info.update({"cv_module": [{"Choose...": "Choose..."}]})
for item in cv_module: for item in cv_module:
module_info["cv_module"].append({item: item}) module_info["cv_module"].append({item: item})
module_info.update({"Choose...": [{"请先选择分类": "Choose..."}]})
return {"module_info": module_info} return {"module_info": module_info}
@app_instance.route("/predict/image/<module_name>", methods=["POST"]) @app_instance.route("/predict/image/<module_name>", methods=["POST"])
def predict_image(module_name): def predict_image(module_name):
if request.path.split("/")[-1] not in cv_module:
return {"error": "Module {} is not available.".format(module_name)}
req_id = request.data.get("id") req_id = request.data.get("id")
global use_gpu, batch_size_dict global use_gpu, batch_size_dict
img_base64 = request.form.getlist("image") img_base64 = request.form.getlist("image")
extra_info = {}
for item in list(request.form.keys()):
extra_info.update({item: request.form.getlist(item)})
file_name_list = [] file_name_list = []
if img_base64 != []: if img_base64 != []:
for item in img_base64: for item in img_base64:
...@@ -310,36 +277,34 @@ def create_app(): ...@@ -310,36 +277,34 @@ def create_app():
module_type = module.type.split("/")[-1].replace("-", "_").lower() module_type = module.type.split("/")[-1].replace("-", "_").lower()
predict_func = eval("predict_" + module_type) predict_func = eval("predict_" + module_type)
batch_size = batch_size_dict.get(module_name, 1) batch_size = batch_size_dict.get(module_name, 1)
results = predict_func(module, file_name_list, req_id, batch_size) results = predict_func(module, file_name_list, req_id, batch_size,
extra_info)
r = {"results": str(results)} r = {"results": str(results)}
return r return r
@app_instance.route("/predict/text/<module_name>", methods=["POST"]) @app_instance.route("/predict/text/<module_name>", methods=["POST"])
def predict_text(module_name): def predict_text(module_name):
if request.path.split("/")[-1] not in nlp_module:
return {"error": "Module {} is not available.".format(module_name)}
req_id = request.data.get("id") req_id = request.data.get("id")
global use_gpu inputs = {}
if module_name == "simnet_bow": for item in list(request.form.keys()):
text_1 = request.form.getlist("text_1") inputs.update({item: request.form.getlist(item)})
text_2 = request.form.getlist("text_2") files = {}
data = [{"text_1": text_1}, {"text_2": text_2}] for file_key in list(request.files.keys()):
else: files[file_key] = []
data = request.form.getlist("text") for file in request.files.getlist(file_key):
file = request.files.getlist("user_dict") file_name = req_id + "_" + file.filename
files[file_key].append(file_name)
file.save(file_name)
module = TextModelService.get_module(module_name) module = TextModelService.get_module(module_name)
predict_func_name = nlp_module_method.get(module_name, "") results = predict_nlp(
if predict_func_name != "": module=module,
predict_func = eval(predict_func_name) input_text=inputs,
else: req_id=req_id,
module_type = module.type.split("/")[-1].replace("-", "_").lower() batch_size=batch_size_dict.get(module_name, 1),
predict_func = eval("predict_" + module_type) extra=files)
file_list = [] return results
for item in file:
file_path = req_id + "_" + item.filename
file_list.append(file_path)
item.save(file_path)
batch_size = batch_size_dict.get(module_name, 1)
results = predict_func(module, data, batch_size, file_list)
return {"results": results}
return app_instance return app_instance
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册