提交 61fe956c 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: wuzewu

1.add config batch_size; 2.delete req_id_ for every file (#213)

* 1.add config batch_size; 2.delete req_id_ for every file
上级 66fb66c7
...@@ -65,14 +65,15 @@ cv_module_method = { ...@@ -65,14 +65,15 @@ cv_module_method = {
} }
def predict_sentiment_analysis(module, input_text, extra=None): def predict_sentiment_analysis(module, input_text, batch_size, extra=None):
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
try: try:
data = input_text[0] data = input_text[0]
data.update(input_text[1]) data.update(input_text[1])
results = predict_method(data=data, use_gpu=use_gpu) results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err: except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
...@@ -80,13 +81,14 @@ def predict_sentiment_analysis(module, input_text, extra=None): ...@@ -80,13 +81,14 @@ def predict_sentiment_analysis(module, input_text, extra=None):
return results return results
def predict_pretrained_model(module, input_text, extra=None): def predict_pretrained_model(module, input_text, batch_size, extra=None):
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
try: try:
data = {"text": input_text} data = {"text": input_text}
results = predict_method(data=data, use_gpu=use_gpu) results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err: except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
...@@ -94,18 +96,22 @@ def predict_pretrained_model(module, input_text, extra=None): ...@@ -94,18 +96,22 @@ def predict_pretrained_model(module, input_text, extra=None):
return results return results
def predict_lexical_analysis(module, input_text, extra=[]): def predict_lexical_analysis(module, input_text, batch_size, extra=[]):
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
data = {"text": input_text} data = {"text": input_text}
try: try:
if extra == []: if extra == []:
results = predict_method(data=data, use_gpu=use_gpu) results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
else: else:
user_dict = extra[0] user_dict = extra[0]
results = predict_method( results = predict_method(
data=data, user_dict=user_dict, use_gpu=use_gpu) data=data,
user_dict=user_dict,
use_gpu=use_gpu,
batch_size=batch_size)
for path in extra: for path in extra:
os.remove(path) os.remove(path)
except Exception as err: except Exception as err:
...@@ -115,13 +121,14 @@ def predict_lexical_analysis(module, input_text, extra=[]): ...@@ -115,13 +121,14 @@ def predict_lexical_analysis(module, input_text, extra=[]):
return results return results
def predict_classification(module, input_img): def predict_classification(module, input_img, batch_size):
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
try: try:
input_img = {"image": input_img} input_img = {"image": input_img}
results = predict_method(data=input_img, use_gpu=use_gpu) results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err: except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
...@@ -129,7 +136,7 @@ def predict_classification(module, input_img): ...@@ -129,7 +136,7 @@ def predict_classification(module, input_img):
return results return results
def predict_gan(module, input_img, extra={}): def predict_gan(module, input_img, id, batch_size, extra={}):
# special # special
output_folder = module.name.split("_")[0] + "_" + "output" output_folder = module.name.split("_")[0] + "_" + "output"
global use_gpu global use_gpu
...@@ -137,7 +144,8 @@ def predict_gan(module, input_img, extra={}): ...@@ -137,7 +144,8 @@ def predict_gan(module, input_img, extra={}):
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
try: try:
input_img = {"image": input_img} input_img = {"image": input_img}
results = predict_method(data=input_img, use_gpu=use_gpu) results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err: except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
...@@ -155,6 +163,7 @@ def predict_gan(module, input_img, extra={}): ...@@ -155,6 +163,7 @@ def predict_gan(module, input_img, extra={}):
b_body = str(b_body).replace("b'", "").replace("'", "") b_body = str(b_body).replace("b'", "").replace("'", "")
b_img = b_head + "," + b_body b_img = b_head + "," + b_body
base64_list.append(b_img) base64_list.append(b_img)
results[index] = results[index].replace(id + "_", "")
results[index] = {"path": results[index]} results[index] = {"path": results[index]}
results[index].update({"base64": b_img}) results[index].update({"base64": b_img})
results_pack.append(results[index]) results_pack.append(results[index])
...@@ -163,14 +172,15 @@ def predict_gan(module, input_img, extra={}): ...@@ -163,14 +172,15 @@ def predict_gan(module, input_img, extra={}):
return results_pack return results_pack
def predict_object_detection(module, input_img): def predict_object_detection(module, input_img, id, batch_size):
output_folder = "output" output_folder = "output"
global use_gpu global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
try: try:
input_img = {"image": input_img} input_img = {"image": input_img}
results = predict_method(data=input_img, use_gpu=use_gpu) results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err: except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
...@@ -186,6 +196,8 @@ def predict_object_detection(module, input_img): ...@@ -186,6 +196,8 @@ def predict_object_detection(module, input_img):
b_body = str(b_body).replace("b'", "").replace("'", "") b_body = str(b_body).replace("b'", "").replace("'", "")
b_img = b_head + "," + b_body b_img = b_head + "," + b_body
base64_list.append(b_img) base64_list.append(b_img)
results[index]["path"] = results[index]["path"].replace(
id + "_", "")
results[index].update({"base64": b_img}) results[index].update({"base64": b_img})
results_pack.append(results[index]) results_pack.append(results[index])
os.remove(item) os.remove(item)
...@@ -193,7 +205,7 @@ def predict_object_detection(module, input_img): ...@@ -193,7 +205,7 @@ def predict_object_detection(module, input_img):
return results_pack return results_pack
def predict_semantic_segmentation(module, input_img): def predict_semantic_segmentation(module, input_img, id, batch_size):
# special # special
output_folder = module.name.split("_")[-1] + "_" + "output" output_folder = module.name.split("_")[-1] + "_" + "output"
global use_gpu global use_gpu
...@@ -201,7 +213,8 @@ def predict_semantic_segmentation(module, input_img): ...@@ -201,7 +213,8 @@ def predict_semantic_segmentation(module, input_img):
predict_method = getattr(module, method_name) predict_method = getattr(module, method_name)
try: try:
input_img = {"image": input_img} input_img = {"image": input_img}
results = predict_method(data=input_img, use_gpu=use_gpu) results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err: except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err) print(curr, " - ", err)
...@@ -219,6 +232,10 @@ def predict_semantic_segmentation(module, input_img): ...@@ -219,6 +232,10 @@ def predict_semantic_segmentation(module, input_img):
b_body = str(b_body).replace("b'", "").replace("'", "") b_body = str(b_body).replace("b'", "").replace("'", "")
b_img = b_head + "," + b_body b_img = b_head + "," + b_body
base64_list.append(b_img) base64_list.append(b_img)
results[index]["origin"] = results[index]["origin"].replace(
id + "_", "")
results[index]["processed"] = results[index]["processed"].replace(
id + "_", "")
results[index].update({"base64": b_img}) results[index].update({"base64": b_img})
results_pack.append(results[index]) results_pack.append(results[index])
os.remove(item) os.remove(item)
...@@ -260,7 +277,7 @@ def create_app(): ...@@ -260,7 +277,7 @@ def create_app():
@app_instance.route("/predict/image/<module_name>", methods=["POST"]) @app_instance.route("/predict/image/<module_name>", methods=["POST"])
def predict_image(module_name): def predict_image(module_name):
req_id = request.data.get("id") req_id = request.data.get("id")
global use_gpu global use_gpu, batch_size_dict
img_base64 = request.form.getlist("image") img_base64 = request.form.getlist("image")
file_name_list = [] file_name_list = []
if img_base64 != []: if img_base64 != []:
...@@ -289,7 +306,8 @@ def create_app(): ...@@ -289,7 +306,8 @@ def create_app():
else: else:
module_type = module.type.split("/")[-1].replace("-", "_").lower() module_type = module.type.split("/")[-1].replace("-", "_").lower()
predict_func = eval("predict_" + module_type) predict_func = eval("predict_" + module_type)
results = predict_func(module, file_name_list) batch_size = batch_size_dict.get(module_name, 1)
results = predict_func(module, file_name_list, req_id, batch_size)
r = {"results": str(results)} r = {"results": str(results)}
return r return r
...@@ -316,22 +334,25 @@ def create_app(): ...@@ -316,22 +334,25 @@ def create_app():
file_path = req_id + "_" + item.filename file_path = req_id + "_" + item.filename
file_list.append(file_path) file_list.append(file_path)
item.save(file_path) item.save(file_path)
results = predict_func(module, data, file_list) batch_size = batch_size_dict.get(module_name, 1)
results = predict_func(module, data, batch_size, file_list)
return {"results": results} return {"results": results}
return app_instance return app_instance
def config_with_file(configs): def config_with_file(configs):
global nlp_module, cv_module global nlp_module, cv_module, batch_size_dict
nlp_module = [] nlp_module = []
cv_module = [] cv_module = []
batch_size_dict = {}
for item in configs: for item in configs:
print(item) print(item)
if item["category"] == "CV": if item["category"] == "CV":
cv_module.append(item["module"]) cv_module.append(item["module"])
elif item["category"] == "NLP": elif item["category"] == "NLP":
nlp_module.append(item["module"]) nlp_module.append(item["module"])
batch_size_dict.update({item["module"]: item["batch_size"]})
def run(is_use_gpu=False, configs=None, port=8866, timeout=60): def run(is_use_gpu=False, configs=None, port=8866, timeout=60):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册