未验证 提交 99ee41d8 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #1299 from WenmuZhou/fix_predict_system

add predict_cls to predict_system
...@@ -23,7 +23,7 @@ import copy ...@@ -23,7 +23,7 @@ import copy
import numpy as np import numpy as np
import math import math
import time import time
import traceback
import paddle.fluid as fluid import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
...@@ -106,10 +106,10 @@ class TextClassifier(object): ...@@ -106,10 +106,10 @@ class TextClassifier(object):
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch]) self.predictor.run([norm_img_batch])
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
cls_res = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
for rno in range(len(cls_res)): for rno in range(len(cls_result)):
label, score = cls_res[rno] label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score] cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > self.cls_thresh: if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]] = cv2.rotate(
...@@ -133,8 +133,8 @@ def main(args): ...@@ -133,8 +133,8 @@ def main(args):
img_list.append(img) img_list.append(img)
try: try:
img_list, cls_res, predict_time = text_classifier(img_list) img_list, cls_res, predict_time = text_classifier(img_list)
except Exception as e: except:
print(e) logger.info(traceback.format_exc())
logger.info( logger.info(
"ERROR!!!! \n" "ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
...@@ -143,10 +143,10 @@ def main(args): ...@@ -143,10 +143,10 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[
ino])) ino]))
print("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
...@@ -178,11 +178,12 @@ if __name__ == "__main__": ...@@ -178,11 +178,12 @@ if __name__ == "__main__":
if count > 0: if count > 0:
total_time += elapse total_time += elapse
count += 1 count += 1
print("Predict time of {}: {}".format(image_file, elapse)) logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file) src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1] img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save, img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure)) "det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if count > 1: if count > 1:
print("Avg Time:", total_time / (count - 1)) logger.info("Avg Time:", total_time / (count - 1))
...@@ -22,7 +22,7 @@ import cv2 ...@@ -22,7 +22,7 @@ import cv2
import numpy as np import numpy as np
import math import math
import time import time
import traceback
import paddle.fluid as fluid import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
...@@ -135,8 +135,8 @@ def main(args): ...@@ -135,8 +135,8 @@ def main(args):
img_list.append(img) img_list.append(img)
try: try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
except Exception as e: except:
print(e) logger.info(traceback.format_exc())
logger.info( logger.info(
"ERROR!!!! \n" "ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
...@@ -145,9 +145,9 @@ def main(args): ...@@ -145,9 +145,9 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
ino])) ino]))
print("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
......
...@@ -23,17 +23,21 @@ import numpy as np ...@@ -23,17 +23,21 @@ import numpy as np
import time import time
from PIL import Image from PIL import Image
import tools.infer.utility as utility import tools.infer.utility as utility
from tools.infer.utility import draw_ocr
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt
class TextSystem(object): class TextSystem(object):
def __init__(self, args): def __init__(self, args):
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
def get_rotate_crop_image(self, img, points): def get_rotate_crop_image(self, img, points):
''' '''
...@@ -72,12 +76,12 @@ class TextSystem(object): ...@@ -72,12 +76,12 @@ class TextSystem(object):
bbox_num = len(img_crop_list) bbox_num = len(img_crop_list)
for bno in range(bbox_num): for bno in range(bbox_num):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
print(bno, rec_res[bno]) logger.info(bno, rec_res[bno])
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
print("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
img_crop_list = [] img_crop_list = []
...@@ -88,8 +92,14 @@ class TextSystem(object): ...@@ -88,8 +92,14 @@ class TextSystem(object):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
logger.info("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
return dt_boxes, rec_res return dt_boxes, rec_res
...@@ -119,7 +129,8 @@ def main(args): ...@@ -119,7 +129,8 @@ def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
tackle_img_num = 0 font_path = args.vis_font_path
drop_score = args.drop_score
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -128,20 +139,16 @@ def main(args): ...@@ -128,20 +139,16 @@ def main(args):
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
tackle_img_num += 1
if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
text_sys = TextSystem(args)
dt_boxes, rec_res = text_sys(img) dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime elapse = time.time() - starttime
print("Predict time of %s: %.3fs" % (image_file, elapse)) logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
drop_score = 0.5
dt_num = len(dt_boxes) dt_num = len(dt_boxes)
for dno in range(dt_num): for dno in range(dt_num):
text, score = rec_res[dno] text, score = rec_res[dno]
if score >= drop_score: if score >= drop_score:
text_str = "%s, %.3f" % (text, score) text_str = "%s, %.3f" % (text, score)
print(text_str) logger.info(text_str)
if is_visualize: if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
...@@ -149,15 +156,20 @@ def main(args): ...@@ -149,15 +156,20 @@ def main(args):
txts = [rec_res[i][0] for i in range(len(rec_res))] txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr( draw_img = draw_ocr_box_txt(
image, boxes, txts, scores, drop_score=drop_score) image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
draw_img_save = "./inference_results/" draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)), os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
print("The visualized image saved in {}".format( logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(draw_img_save, os.path.basename(image_file))))
......
...@@ -71,6 +71,7 @@ def parse_args(): ...@@ -71,6 +71,7 @@ def parse_args():
parser.add_argument("--use_space_char", type=str2bool, default=True) parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument( parser.add_argument(
"--vis_font_path", type=str, default="./doc/simfang.ttf") "--vis_font_path", type=str, default="./doc/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=False)
...@@ -202,7 +203,12 @@ def draw_ocr(image, ...@@ -202,7 +203,12 @@ def draw_ocr(image,
return image return image
def draw_ocr_box_txt(image, boxes, txts): def draw_ocr_box_txt(image,
boxes,
txts,
scores=None,
drop_score=0.5,
font_path="./doc/simfang.ttf"):
h, w = image.height, image.width h, w = image.height, image.width
img_left = image.copy() img_left = image.copy()
img_right = Image.new('RGB', (w, h), (255, 255, 255)) img_right = Image.new('RGB', (w, h), (255, 255, 255))
...@@ -212,7 +218,9 @@ def draw_ocr_box_txt(image, boxes, txts): ...@@ -212,7 +218,9 @@ def draw_ocr_box_txt(image, boxes, txts):
random.seed(0) random.seed(0)
draw_left = ImageDraw.Draw(img_left) draw_left = ImageDraw.Draw(img_left)
draw_right = ImageDraw.Draw(img_right) draw_right = ImageDraw.Draw(img_right)
for (box, txt) in zip(boxes, txts): for idx, (box, txt) in enumerate(zip(boxes, txts)):
if scores is not None and scores[idx] < drop_score:
continue
color = (random.randint(0, 255), random.randint(0, 255), color = (random.randint(0, 255), random.randint(0, 255),
random.randint(0, 255)) random.randint(0, 255))
draw_left.polygon(box, fill=color) draw_left.polygon(box, fill=color)
...@@ -222,14 +230,13 @@ def draw_ocr_box_txt(image, boxes, txts): ...@@ -222,14 +230,13 @@ def draw_ocr_box_txt(image, boxes, txts):
box[2][1], box[3][0], box[3][1] box[2][1], box[3][0], box[3][1]
], ],
outline=color) outline=color)
box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][ box_height = math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][
1])**2) 1]) ** 2)
box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][ box_width = math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][
1])**2) 1]) ** 2)
if box_height > 2 * box_width: if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10) font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype( font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
"./doc/simfang.ttf", font_size, encoding="utf-8")
cur_y = box[0][1] cur_y = box[0][1]
for c in txt: for c in txt:
char_size = font.getsize(c) char_size = font.getsize(c)
...@@ -238,8 +245,7 @@ def draw_ocr_box_txt(image, boxes, txts): ...@@ -238,8 +245,7 @@ def draw_ocr_box_txt(image, boxes, txts):
cur_y += char_size[1] cur_y += char_size[1]
else: else:
font_size = max(int(box_height * 0.8), 10) font_size = max(int(box_height * 0.8), 10)
font = ImageFont.truetype( font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
"./doc/simfang.ttf", font_size, encoding="utf-8")
draw_right.text( draw_right.text(
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5) img_left = Image.blend(image, img_left, 0.5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册