diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index dd1fbaa5b57acd4180a5583f729580748a0bc2ba..2a446cdff2e6231881d8d4334fc4342cdd4bb512 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -142,8 +142,8 @@ class TextDetector(object): outputs.append(output) outs_dict = {} if self.det_algorithm == "EAST": - outs_dict['f_score'] = outputs[0] - outs_dict['f_geo'] = outputs[1] + outs_dict['f_geo'] = outputs[0] + outs_dict['f_score'] = outputs[1] else: outs_dict['maps'] = outputs[0] dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list]) @@ -153,6 +153,8 @@ class TextDetector(object): return dt_boxes, elapse +from tools.infer.utility import draw_text_det_res + if __name__ == "__main__": args = utility.parse_args() image_file_list = get_image_file_list(args.image_dir) @@ -169,14 +171,9 @@ if __name__ == "__main__": total_time += elapse count += 1 print("Predict time of %s:" % image_file, elapse) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - draw_img = draw_ocr(img, dt_boxes, None, None, False) - draw_img_save = "./inference_results/" - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) - cv2.imwrite( - os.path.join(draw_img_save, os.path.basename(image_file)), - draw_img[:, :, ::-1]) - print("The visualized image saved in {}".format( - os.path.join(draw_img_save, os.path.basename(image_file)))) + img_draw = draw_text_det_res(dt_boxes, image_file, return_img=True) + save_path = os.path.join("./inference_det/", + os.path.basename(image_file)) + print("The visualized image saved in {}".format(save_path)) + print("Avg Time:", total_time / (count - 1)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index caab4eac9b372aed8e69ff0712b6a302933e03d4..0681cfb5510b5a4c3869d4c815ec9a029554578a 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -114,7 +114,6 @@ if __name__ == "__main__": valid_image_file_list.append(image_file) img_list.append(img) rec_res, predict_time = text_recognizer(img_list) - rec_res, predict_time = text_recognizer(img_list) for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Total predict time for %d images:%.3f" % diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 95c2332169d4cbb66c12a4d938cb76555404a670..ff45878709a03c17520788316568beb9095ccebd 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -103,13 +103,12 @@ def create_predictor(args, mode): return predictor, input_tensor, output_tensors -def draw_text_det_res(dt_boxes, img_path): +def draw_text_det_res(dt_boxes, img_path, return_img=True): src_im = cv2.imread(img_path) for box in dt_boxes: box = np.array(box).astype(np.int32).reshape(-1, 2) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - img_name_pure = img_path.split("/")[-1] - cv2.imwrite("./output/%s" % img_name_pure, src_im) + return src_im def resize_img(img, input_size=600): diff --git a/tools/program.py b/tools/program.py index 6c9e9904c296129708b5cbb8eb97c196cb4c4784..b6318c3bbf99c670752de9c855d8bdbdd16bf09b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -191,8 +191,8 @@ def build_export(config, main_prog, startup_prog): func_infor = config['Architecture']['function'] model = create_module(func_infor)(params=config) image, outputs = model(mode='export') - fetches_var = [outputs[name] for name in outputs] - fetches_var_name = [name for name in outputs] + fetches_var = sorted([outputs[name] for name in outputs]) + fetches_var_name = [name for name in fetches_var] feeded_var_names = [image.name] target_vars = fetches_var return feeded_var_names, target_vars, fetches_var_name