From 00e9e0797603b93e3e4d32c7c360e090f0adec8b Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Thu, 28 May 2020 15:46:05 +0800 Subject: [PATCH] fix bug in results visualization --- ppocr/data/det/db_process.py | 12 ++++++++++++ tools/infer/utility.py | 38 ++++++++++-------------------------- 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/ppocr/data/det/db_process.py b/ppocr/data/det/db_process.py index 80762c7c..d347ed44 100644 --- a/ppocr/data/det/db_process.py +++ b/ppocr/data/det/db_process.py @@ -25,6 +25,10 @@ from .make_border_map import MakeBorderMap class DBProcessTrain(object): + """ + DB pre-process for Train mode + """ + def __init__(self, params): self.img_set_dir = params['img_set_dir'] self.image_shape = params['image_shape'] @@ -109,6 +113,10 @@ class DBProcessTrain(object): class DBProcessTest(object): + """ + DB pre-process for Test mode + """ + def __init__(self, params): super(DBProcessTest, self).__init__() self.resize_type = 0 @@ -124,6 +132,10 @@ class DBProcessTest(object): def resize_image_type0(self, im): """ resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) """ max_side_len = self.max_side_len h, w, _ = im.shape diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 492721e0..f55a606a 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -107,7 +107,7 @@ def create_predictor(args, mode): return predictor, input_tensor, output_tensors -def draw_text_det_res(dt_boxes, img_path, return_img=True): +def draw_text_det_res(dt_boxes, img_path): src_im = cv2.imread(img_path) for box in dt_boxes: box = np.array(box).astype(np.int32).reshape(-1, 2) @@ -117,10 +117,10 @@ def draw_text_det_res(dt_boxes, img_path, return_img=True): def resize_img(img, input_size=600): """ + resize img and limit the longest side of the image to input_size """ img = np.array(img) im_shape = img.shape - im_size_min = np.min(im_shape[0:2]) im_size_max = np.max(im_shape[0:2]) im_scale = float(input_size) / float(im_size_max) im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale) @@ -131,7 +131,7 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): """ Visualize the results of OCR detection and recognition args: - image(Image): image from Image.open + image(Image|array): RGB image boxes(list): boxes with shape(N, 4, 2) txts(list): the texts scores(list): txxs corresponding scores @@ -140,31 +140,14 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): return(array): the visualized img """ - from PIL import Image, ImageDraw, ImageFont - - img = image.copy() - draw = ImageDraw.Draw(img) + img = image if scores is None: scores = [1] * len(boxes) for (box, score) in zip(boxes, scores): - if score < drop_score: + if score < drop_score or math.isnan(score): continue - draw.line([(box[0][0], box[0][1]), (box[1][0], box[1][1])], fill='red') - draw.line([(box[1][0], box[1][1]), (box[2][0], box[2][1])], fill='red') - draw.line([(box[2][0], box[2][1]), (box[3][0], box[3][1])], fill='red') - draw.line([(box[3][0], box[3][1]), (box[0][0], box[0][1])], fill='red') - draw.line( - [(box[0][0] - 1, box[0][1] + 1), (box[1][0] - 1, box[1][1] + 1)], - fill='red') - draw.line( - [(box[1][0] - 1, box[1][1] + 1), (box[2][0] - 1, box[2][1] + 1)], - fill='red') - draw.line( - [(box[2][0] - 1, box[2][1] + 1), (box[3][0] - 1, box[3][1] + 1)], - fill='red') - draw.line( - [(box[3][0] - 1, box[3][1] + 1), (box[0][0] - 1, box[0][1] + 1)], - fill='red') + box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) + img = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 3) if draw_txt: img = np.array(resize_img(img, input_size=600)) @@ -233,7 +216,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): count, index = 0, 0 for idx, txt in enumerate(texts): index += 1 - if scores[idx] < threshold: + if scores[idx] < threshold or math.isnan(scores[idx]): index -= 1 continue first_line = True @@ -256,11 +239,11 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): if first_line: new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx]) else: - new_txt = " " + txt + " " + '%.3f' % (scores[idx]) + new_txt = " " + txt + " " + '%.3f' % (scores[idx]) draw_txt.text((0, gap * (count + 1)), new_txt, txt_color, font=font) count += 1 # whether add new blank img or not - if count >= img_h // gap - 1 and idx + 1 < len(texts): + if count > img_h // gap - 1 and idx + 1 < len(texts): txt_img_list.append(np.array(blank_img)) blank_img, draw_txt = create_blank_img() count = 0 @@ -270,7 +253,6 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): blank_img = np.array(txt_img_list[0]) else: blank_img = np.concatenate(txt_img_list, axis=1) - # cv2.imwrite("./draw_txt.jpg", np.array(blank_img)) return np.array(blank_img) -- GitLab