提交 61fe956c 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: wuzewu

1.add config batch_size; 2.delete req_id_ for every file (#213)

* 1.add config batch_size; 2.delete req_id_ for every file
上级 66fb66c7
......@@ -65,14 +65,15 @@ cv_module_method = {
}
def predict_sentiment_analysis(module, input_text, extra=None):
def predict_sentiment_analysis(module, input_text, batch_size, extra=None):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
try:
data = input_text[0]
data.update(input_text[1])
results = predict_method(data=data, use_gpu=use_gpu)
results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
......@@ -80,13 +81,14 @@ def predict_sentiment_analysis(module, input_text, extra=None):
return results
def predict_pretrained_model(module, input_text, extra=None):
def predict_pretrained_model(module, input_text, batch_size, extra=None):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
try:
data = {"text": input_text}
results = predict_method(data=data, use_gpu=use_gpu)
results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
......@@ -94,18 +96,22 @@ def predict_pretrained_model(module, input_text, extra=None):
return results
def predict_lexical_analysis(module, input_text, extra=[]):
def predict_lexical_analysis(module, input_text, batch_size, extra=[]):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
data = {"text": input_text}
try:
if extra == []:
results = predict_method(data=data, use_gpu=use_gpu)
results = predict_method(
data=data, use_gpu=use_gpu, batch_size=batch_size)
else:
user_dict = extra[0]
results = predict_method(
data=data, user_dict=user_dict, use_gpu=use_gpu)
data=data,
user_dict=user_dict,
use_gpu=use_gpu,
batch_size=batch_size)
for path in extra:
os.remove(path)
except Exception as err:
......@@ -115,13 +121,14 @@ def predict_lexical_analysis(module, input_text, extra=[]):
return results
def predict_classification(module, input_img):
def predict_classification(module, input_img, batch_size):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
try:
input_img = {"image": input_img}
results = predict_method(data=input_img, use_gpu=use_gpu)
results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
......@@ -129,7 +136,7 @@ def predict_classification(module, input_img):
return results
def predict_gan(module, input_img, extra={}):
def predict_gan(module, input_img, id, batch_size, extra={}):
# special
output_folder = module.name.split("_")[0] + "_" + "output"
global use_gpu
......@@ -137,7 +144,8 @@ def predict_gan(module, input_img, extra={}):
predict_method = getattr(module, method_name)
try:
input_img = {"image": input_img}
results = predict_method(data=input_img, use_gpu=use_gpu)
results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
......@@ -155,6 +163,7 @@ def predict_gan(module, input_img, extra={}):
b_body = str(b_body).replace("b'", "").replace("'", "")
b_img = b_head + "," + b_body
base64_list.append(b_img)
results[index] = results[index].replace(id + "_", "")
results[index] = {"path": results[index]}
results[index].update({"base64": b_img})
results_pack.append(results[index])
......@@ -163,14 +172,15 @@ def predict_gan(module, input_img, extra={}):
return results_pack
def predict_object_detection(module, input_img):
def predict_object_detection(module, input_img, id, batch_size):
output_folder = "output"
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
try:
input_img = {"image": input_img}
results = predict_method(data=input_img, use_gpu=use_gpu)
results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
......@@ -186,6 +196,8 @@ def predict_object_detection(module, input_img):
b_body = str(b_body).replace("b'", "").replace("'", "")
b_img = b_head + "," + b_body
base64_list.append(b_img)
results[index]["path"] = results[index]["path"].replace(
id + "_", "")
results[index].update({"base64": b_img})
results_pack.append(results[index])
os.remove(item)
......@@ -193,7 +205,7 @@ def predict_object_detection(module, input_img):
return results_pack
def predict_semantic_segmentation(module, input_img):
def predict_semantic_segmentation(module, input_img, id, batch_size):
# special
output_folder = module.name.split("_")[-1] + "_" + "output"
global use_gpu
......@@ -201,7 +213,8 @@ def predict_semantic_segmentation(module, input_img):
predict_method = getattr(module, method_name)
try:
input_img = {"image": input_img}
results = predict_method(data=input_img, use_gpu=use_gpu)
results = predict_method(
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
......@@ -219,6 +232,10 @@ def predict_semantic_segmentation(module, input_img):
b_body = str(b_body).replace("b'", "").replace("'", "")
b_img = b_head + "," + b_body
base64_list.append(b_img)
results[index]["origin"] = results[index]["origin"].replace(
id + "_", "")
results[index]["processed"] = results[index]["processed"].replace(
id + "_", "")
results[index].update({"base64": b_img})
results_pack.append(results[index])
os.remove(item)
......@@ -260,7 +277,7 @@ def create_app():
@app_instance.route("/predict/image/<module_name>", methods=["POST"])
def predict_image(module_name):
req_id = request.data.get("id")
global use_gpu
global use_gpu, batch_size_dict
img_base64 = request.form.getlist("image")
file_name_list = []
if img_base64 != []:
......@@ -289,7 +306,8 @@ def create_app():
else:
module_type = module.type.split("/")[-1].replace("-", "_").lower()
predict_func = eval("predict_" + module_type)
results = predict_func(module, file_name_list)
batch_size = batch_size_dict.get(module_name, 1)
results = predict_func(module, file_name_list, req_id, batch_size)
r = {"results": str(results)}
return r
......@@ -316,22 +334,25 @@ def create_app():
file_path = req_id + "_" + item.filename
file_list.append(file_path)
item.save(file_path)
results = predict_func(module, data, file_list)
batch_size = batch_size_dict.get(module_name, 1)
results = predict_func(module, data, batch_size, file_list)
return {"results": results}
return app_instance
def config_with_file(configs):
global nlp_module, cv_module
global nlp_module, cv_module, batch_size_dict
nlp_module = []
cv_module = []
batch_size_dict = {}
for item in configs:
print(item)
if item["category"] == "CV":
cv_module.append(item["module"])
elif item["category"] == "NLP":
nlp_module.append(item["module"])
batch_size_dict.update({item["module"]: item["batch_size"]})
def run(is_use_gpu=False, configs=None, port=8866, timeout=60):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册