未验证 提交 01141454 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #103 from LDOUBLEV/fixocr

fix bug in results visualization
...@@ -25,6 +25,10 @@ from .make_border_map import MakeBorderMap ...@@ -25,6 +25,10 @@ from .make_border_map import MakeBorderMap
class DBProcessTrain(object): class DBProcessTrain(object):
"""
DB pre-process for Train mode
"""
def __init__(self, params): def __init__(self, params):
self.img_set_dir = params['img_set_dir'] self.img_set_dir = params['img_set_dir']
self.image_shape = params['image_shape'] self.image_shape = params['image_shape']
...@@ -109,6 +113,10 @@ class DBProcessTrain(object): ...@@ -109,6 +113,10 @@ class DBProcessTrain(object):
class DBProcessTest(object): class DBProcessTest(object):
"""
DB pre-process for Test mode
"""
def __init__(self, params): def __init__(self, params):
super(DBProcessTest, self).__init__() super(DBProcessTest, self).__init__()
self.resize_type = 0 self.resize_type = 0
...@@ -124,6 +132,10 @@ class DBProcessTest(object): ...@@ -124,6 +132,10 @@ class DBProcessTest(object):
def resize_image_type0(self, im): def resize_image_type0(self, im):
""" """
resize image to a size multiple of 32 which is required by the network resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
""" """
max_side_len = self.max_side_len max_side_len = self.max_side_len
h, w, _ = im.shape h, w, _ = im.shape
......
...@@ -107,7 +107,7 @@ def create_predictor(args, mode): ...@@ -107,7 +107,7 @@ def create_predictor(args, mode):
return predictor, input_tensor, output_tensors return predictor, input_tensor, output_tensors
def draw_text_det_res(dt_boxes, img_path, return_img=True): def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path) src_im = cv2.imread(img_path)
for box in dt_boxes: for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2) box = np.array(box).astype(np.int32).reshape(-1, 2)
...@@ -117,10 +117,10 @@ def draw_text_det_res(dt_boxes, img_path, return_img=True): ...@@ -117,10 +117,10 @@ def draw_text_det_res(dt_boxes, img_path, return_img=True):
def resize_img(img, input_size=600): def resize_img(img, input_size=600):
""" """
resize img and limit the longest side of the image to input_size
""" """
img = np.array(img) img = np.array(img)
im_shape = img.shape im_shape = img.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2]) im_size_max = np.max(im_shape[0:2])
im_scale = float(input_size) / float(im_size_max) im_scale = float(input_size) / float(im_size_max)
im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale) im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
...@@ -131,7 +131,7 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): ...@@ -131,7 +131,7 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5):
""" """
Visualize the results of OCR detection and recognition Visualize the results of OCR detection and recognition
args: args:
image(Image): image from Image.open image(Image|array): RGB image
boxes(list): boxes with shape(N, 4, 2) boxes(list): boxes with shape(N, 4, 2)
txts(list): the texts txts(list): the texts
scores(list): txxs corresponding scores scores(list): txxs corresponding scores
...@@ -140,31 +140,14 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): ...@@ -140,31 +140,14 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5):
return(array): return(array):
the visualized img the visualized img
""" """
from PIL import Image, ImageDraw, ImageFont img = image
img = image.copy()
draw = ImageDraw.Draw(img)
if scores is None: if scores is None:
scores = [1] * len(boxes) scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores): for (box, score) in zip(boxes, scores):
if score < drop_score: if score < drop_score or math.isnan(score):
continue continue
draw.line([(box[0][0], box[0][1]), (box[1][0], box[1][1])], fill='red') box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
draw.line([(box[1][0], box[1][1]), (box[2][0], box[2][1])], fill='red') img = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 3)
draw.line([(box[2][0], box[2][1]), (box[3][0], box[3][1])], fill='red')
draw.line([(box[3][0], box[3][1]), (box[0][0], box[0][1])], fill='red')
draw.line(
[(box[0][0] - 1, box[0][1] + 1), (box[1][0] - 1, box[1][1] + 1)],
fill='red')
draw.line(
[(box[1][0] - 1, box[1][1] + 1), (box[2][0] - 1, box[2][1] + 1)],
fill='red')
draw.line(
[(box[2][0] - 1, box[2][1] + 1), (box[3][0] - 1, box[3][1] + 1)],
fill='red')
draw.line(
[(box[3][0] - 1, box[3][1] + 1), (box[0][0] - 1, box[0][1] + 1)],
fill='red')
if draw_txt: if draw_txt:
img = np.array(resize_img(img, input_size=600)) img = np.array(resize_img(img, input_size=600))
...@@ -233,7 +216,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): ...@@ -233,7 +216,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
count, index = 0, 0 count, index = 0, 0
for idx, txt in enumerate(texts): for idx, txt in enumerate(texts):
index += 1 index += 1
if scores[idx] < threshold: if scores[idx] < threshold or math.isnan(scores[idx]):
index -= 1 index -= 1
continue continue
first_line = True first_line = True
...@@ -256,11 +239,11 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): ...@@ -256,11 +239,11 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
if first_line: if first_line:
new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx]) new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
else: else:
new_txt = " " + txt + " " + '%.3f' % (scores[idx]) new_txt = " " + txt + " " + '%.3f' % (scores[idx])
draw_txt.text((0, gap * (count + 1)), new_txt, txt_color, font=font) draw_txt.text((0, gap * (count + 1)), new_txt, txt_color, font=font)
count += 1 count += 1
# whether add new blank img or not # whether add new blank img or not
if count >= img_h // gap - 1 and idx + 1 < len(texts): if count > img_h // gap - 1 and idx + 1 < len(texts):
txt_img_list.append(np.array(blank_img)) txt_img_list.append(np.array(blank_img))
blank_img, draw_txt = create_blank_img() blank_img, draw_txt = create_blank_img()
count = 0 count = 0
...@@ -270,7 +253,6 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): ...@@ -270,7 +253,6 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
blank_img = np.array(txt_img_list[0]) blank_img = np.array(txt_img_list[0])
else: else:
blank_img = np.concatenate(txt_img_list, axis=1) blank_img = np.concatenate(txt_img_list, axis=1)
# cv2.imwrite("./draw_txt.jpg", np.array(blank_img))
return np.array(blank_img) return np.array(blank_img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册