未验证 提交 e807901b 编写于 作者: D dyning 提交者: GitHub

Merge pull request #39 from LDOUBLEV/fixocr

Fixocr
...@@ -142,8 +142,8 @@ class TextDetector(object): ...@@ -142,8 +142,8 @@ class TextDetector(object):
outputs.append(output) outputs.append(output)
outs_dict = {} outs_dict = {}
if self.det_algorithm == "EAST": if self.det_algorithm == "EAST":
outs_dict['f_score'] = outputs[0] outs_dict['f_geo'] = outputs[0]
outs_dict['f_geo'] = outputs[1] outs_dict['f_score'] = outputs[1]
else: else:
outs_dict['maps'] = outputs[0] outs_dict['maps'] = outputs[0]
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list]) dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
...@@ -153,6 +153,8 @@ class TextDetector(object): ...@@ -153,6 +153,8 @@ class TextDetector(object):
return dt_boxes, elapse return dt_boxes, elapse
from tools.infer.utility import draw_text_det_res
if __name__ == "__main__": if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
...@@ -169,14 +171,9 @@ if __name__ == "__main__": ...@@ -169,14 +171,9 @@ if __name__ == "__main__":
total_time += elapse total_time += elapse
count += 1 count += 1
print("Predict time of %s:" % image_file, elapse) print("Predict time of %s:" % image_file, elapse)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_draw = draw_text_det_res(dt_boxes, image_file, return_img=True)
draw_img = draw_ocr(img, dt_boxes, None, None, False) save_path = os.path.join("./inference_det/",
draw_img_save = "./inference_results/" os.path.basename(image_file))
if not os.path.exists(draw_img_save): print("The visualized image saved in {}".format(save_path))
os.makedirs(draw_img_save)
cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1])
print("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
print("Avg Time:", total_time / (count - 1)) print("Avg Time:", total_time / (count - 1))
...@@ -114,7 +114,6 @@ if __name__ == "__main__": ...@@ -114,7 +114,6 @@ if __name__ == "__main__":
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
rec_res, predict_time = text_recognizer(img_list)
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %
......
...@@ -103,13 +103,12 @@ def create_predictor(args, mode): ...@@ -103,13 +103,12 @@ def create_predictor(args, mode):
return predictor, input_tensor, output_tensors return predictor, input_tensor, output_tensors
def draw_text_det_res(dt_boxes, img_path): def draw_text_det_res(dt_boxes, img_path, return_img=True):
src_im = cv2.imread(img_path) src_im = cv2.imread(img_path)
for box in dt_boxes: for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2) box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
img_name_pure = img_path.split("/")[-1] return src_im
cv2.imwrite("./output/%s" % img_name_pure, src_im)
def resize_img(img, input_size=600): def resize_img(img, input_size=600):
......
...@@ -191,8 +191,8 @@ def build_export(config, main_prog, startup_prog): ...@@ -191,8 +191,8 @@ def build_export(config, main_prog, startup_prog):
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var = [outputs[name] for name in outputs] fetches_var = sorted([outputs[name] for name in outputs])
fetches_var_name = [name for name in outputs] fetches_var_name = [name for name in fetches_var]
feeded_var_names = [image.name] feeded_var_names = [image.name]
target_vars = fetches_var target_vars = fetches_var
return feeded_var_names, target_vars, fetches_var_name return feeded_var_names, target_vars, fetches_var_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册