未验证 提交 381424bd 编写于 作者: B Bin Long 提交者: GitHub

Merge pull request #374 from ShenYuhan/fix_mask_module_2

fix mask-serving for no-output.
......@@ -65,10 +65,14 @@ def base64s_to_cvmats(base64s):
return base64s
def handle_mask_results(results):
def handle_mask_results(results, data_len):
result = []
if len(results) <= 0:
return results
if len(results) <= 0 and data_len != 0:
return [{
"data": "No face.",
"id": i,
"path": ""
} for i in range(1, data_len + 1)]
_id = results[0]["id"]
_item = {
"data": [],
......@@ -87,6 +91,15 @@ def handle_mask_results(results):
"id": item.get("id", _id)
}
result.append(_item)
for index in range(1, data_len + 1):
if index > len(result):
result.append({"data": "No face.", "id": index, "path": ""})
elif result[index - 1]["id"] != index:
result.insert(index - 1, {
"data": "No face.",
"id": index,
"path": ""
})
return result
......
......@@ -18,7 +18,7 @@ import time
import os
import base64
import logging
import shutil
import glob
cv_module_method = {
"vgg19_imagenet": "predict_classification",
......@@ -140,6 +140,8 @@ def predict_mask(module, input_img, id, batch_size, extra=None, r_img=True):
global use_gpu
method_name = module.desc.attr.map.data['default_signature'].s
predict_method = getattr(module, method_name)
data_len = len(input_img) if input_img is not None else 0
results = []
try:
data = {}
if input_img is not None:
......@@ -153,7 +155,7 @@ def predict_mask(module, input_img, id, batch_size, extra=None, r_img=True):
visualization=r_img,
use_gpu=use_gpu,
batch_size=batch_size)
results = utils.handle_mask_results(results)
results = utils.handle_mask_results(results, data_len)
except Exception as err:
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(curr, " - ", err)
......@@ -166,24 +168,36 @@ def predict_mask(module, input_img, id, batch_size, extra=None, r_img=True):
for index in range(len(results)):
results[index]["path"] = ""
results_pack = results
str_id = id + "*"
files_deleted = glob.glob(str_id)
for path in files_deleted:
if os.path.exists(path):
os.remove(path)
else:
input_img = input_img.get("image", [])
for index in range(len(input_img)):
item = input_img[index]
with open(os.path.join(output_folder, item), "rb") as fp:
b_head = "data:image/" + item.split(".")[-1] + ";base64"
b_body = base64.b64encode(fp.read())
b_body = str(b_body).replace("b'", "").replace("'", "")
b_img = b_head + "," + b_body
base64_list.append(b_img)
results[index]["path"] = results[index]["path"].replace(
id + "_", "") if results[index]["path"] != "" \
else ""
results[index].update({"base64": b_img})
file_path = os.path.join(output_folder, item)
if not os.path.exists(file_path):
results_pack.append(results[index])
os.remove(item)
os.remove(os.path.join(output_folder, item))
os.remove(item)
else:
with open(file_path, "rb") as fp:
b_head = "data:image/" + item.split(
".")[-1] + ";base64"
b_body = base64.b64encode(fp.read())
b_body = str(b_body).replace("b'", "").replace(
"'", "")
b_img = b_head + "," + b_body
base64_list.append(b_img)
results[index]["path"] = results[index]["path"].replace(
id + "_", "") if results[index]["path"] != "" \
else ""
results[index].update({"base64": b_img})
results_pack.append(results[index])
os.remove(item)
os.remove(os.path.join(output_folder, item))
else:
results_pack = results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册