diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index 583c18933f41facfd9fb8c7b84c664eb2e12a385..37e5dc596ea5ff50150199d4e705527af28b2822 100755 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -11,7 +11,7 @@ Global: test_batch_size_per_card: 16 image_shape: [3, 640, 640] reader_yml: ./configs/det/det_db_icdar15_reader.yml - pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/ + pretrain_weights: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/ checkpoints: save_res_path: ./output/det_db/predicts_db.txt save_inference_dir: diff --git a/ppocr/data/det/dataset_traversal.py b/ppocr/data/det/dataset_traversal.py index 0feedeeb65da67ec887425344ddbe0a59b528719..d2cbeaaf47fe90200a733943b27eb45c2f387b56 100755 --- a/ppocr/data/det/dataset_traversal.py +++ b/ppocr/data/det/dataset_traversal.py @@ -89,13 +89,13 @@ class EvalTestReader(object): def batch_iter_reader(): batch_outs = [] - for img_path, img_name in img_list: + for img_path in img_list: img = cv2.imread(img_path) if img is None: logger.info("load image error:" + img_path) continue outs = process_function(img) - outs.append(img_name) + outs.append(img_path) batch_outs.append(outs) if len(batch_outs) == batch_size: yield batch_outs diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 85796fc561c15d53e2f43f0dd38281632613d92e..dd1fbaa5b57acd4180a5583f729580748a0bc2ba 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -20,11 +20,14 @@ from ppocr.data.det.east_process import EASTProcessTest from ppocr.data.det.db_process import DBProcessTest from ppocr.postprocess.db_postprocess import DBPostProcess from ppocr.postprocess.east_postprocess import EASTPostPocess +from ppocr.utils.utility import get_image_file_list +from tools.infer.utility import draw_ocr import copy import numpy as np import math import time import sys +import os class TextDetector(object): @@ -152,7 +155,7 @@ class TextDetector(object): if __name__ == "__main__": args = utility.parse_args() - image_file_list = utility.get_image_file_list(args.image_dir) + image_file_list = get_image_file_list(args.image_dir) text_detector = TextDetector(args) count = 0 total_time = 0 @@ -166,5 +169,14 @@ if __name__ == "__main__": total_time += elapse count += 1 print("Predict time of %s:" % image_file, elapse) - utility.draw_text_det_res(dt_boxes, image_file) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + draw_img = draw_ocr(img, dt_boxes, None, None, False) + draw_img_save = "./inference_results/" + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + cv2.imwrite( + os.path.join(draw_img_save, os.path.basename(image_file)), + draw_img[:, :, ::-1]) + print("The visualized image saved in {}".format( + os.path.join(draw_img_save, os.path.basename(image_file)))) print("Avg Time:", total_time / (count - 1)) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 8f2bd754201287f2eb652052253d26ebda412121..95c2332169d4cbb66c12a4d938cb76555404a670 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -127,10 +127,10 @@ def resize_img(img, input_size=600): def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): from PIL import Image, ImageDraw, ImageFont - w, h = image.size img = image.copy() draw = ImageDraw.Draw(img) - + if scores is None: + scores = [1] * len(boxes) for (box, score) in zip(boxes, scores): if score < drop_score: continue diff --git a/tools/infer_det.py b/tools/infer_det.py index 8d591a654d4b461ab07a81e78cbd046b753f0e96..9da617d1fa0039db39efbaaa913f545956524c94 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -40,7 +40,7 @@ set_paddle_flags( ) from paddle import fluid -from ppocr.utils.utility import create_module +from ppocr.utils.utility import create_module, get_image_file_list import program from ppocr.utils.save_load import init_model from ppocr.data.reader_main import reader_main @@ -50,20 +50,18 @@ from ppocr.utils.utility import initial_logger logger = initial_logger() -def draw_det_res(dt_boxes, config, img_name, ino): +def draw_det_res(dt_boxes, config, img, img_name): if len(dt_boxes) > 0: - img_set_path = config['TestReader']['img_set_dir'] - img_path = img_set_path + img_name import cv2 - src_im = cv2.imread(img_path) + src_im = img for box in dt_boxes: box = box.astype(np.int32).reshape((-1, 1, 2)) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - save_det_path = os.path.basename(config['Global'][ + save_det_path = os.path.dirname(config['Global'][ 'save_res_path']) + "/det_results/" if not os.path.exists(save_det_path): os.makedirs(save_det_path) - save_path = os.path.join(save_det_path, "det_{}.jpg".format(img_name)) + save_path = os.path.join(save_det_path, os.path.basename(img_name)) cv2.imwrite(save_path, src_im) logger.info("The detected Image saved in {}".format(save_path)) @@ -103,8 +101,12 @@ def main(): raise Exception("{} not exists!".format(checkpoints)) save_res_path = config['Global']['save_res_path'] + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) with open(save_res_path, "wb") as fout: + test_reader = reader_main(config=config, mode='test') + # image_file_list = get_image_file_list(args.image_dir) tackling_num = 0 for data in test_reader(): img_num = len(data) @@ -128,7 +130,13 @@ def main(): postprocess_params.update(global_params) postprocess = create_module(postprocess_params['function'])\ (params=postprocess_params) - dt_boxes_list = postprocess({"maps": outs[0]}, ratio_list) + if config['Global']['algorithm'] == 'EAST': + dic = {'f_score': outs[0], 'f_geo': outs[1]} + elif config['Global']['algorithm'] == 'DB': + dic = {'maps': outs[0]} + else: + raise Exception("only support algorithm: ['EAST', 'BD']") + dt_boxes_list = postprocess(dic, ratio_list) for ino in range(img_num): dt_boxes = dt_boxes_list[ino] img_name = img_name_list[ino] @@ -139,7 +147,8 @@ def main(): dt_boxes_json.append(tmp_json) otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n" fout.write(otstr.encode()) - draw_det_res(dt_boxes, config, img_name, ino) + src_img = cv2.imread(img_name) + draw_det_res(dt_boxes, config, src_img, img_name) logger.info("success!")