提交 38f27a53 编写于 作者: W WenmuZhou

merge upstream

...@@ -32,12 +32,10 @@ class SimpleDataSet(Dataset): ...@@ -32,12 +32,10 @@ class SimpleDataSet(Dataset):
self.delimiter = dataset_config.get('delimiter', '\t') self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
if data_source_num == 1: ratio_list = dataset_config.get("ratio_list", [1.0])
ratio_list = [1.0] if isinstance(ratio_list, (float, int)):
else: ratio_list = [float(ratio_list)] * len(data_source_num)
ratio_list = dataset_config.pop('ratio_list')
assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1."
assert len( assert len(
ratio_list ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
...@@ -45,62 +43,32 @@ class SimpleDataSet(Dataset): ...@@ -45,62 +43,32 @@ class SimpleDataSet(Dataset):
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines_list, data_num_list = self.get_image_info_list( self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
label_file_list) self.data_idx_order_list = list(range(len(self.data_lines)))
self.data_idx_order_list = self.dataset_traversal( if mode.lower() == "train":
data_num_list, ratio_list, batch_size)
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
def get_image_info_list(self, file_list): def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
data_lines_list = [] data_lines = []
data_num_list = [] for idx, file in enumerate(file_list):
for file in file_list:
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
data_lines_list.append(lines) lines = random.sample(lines,
data_num_list.append(len(lines)) round(len(lines) * ratio_list[idx]))
return data_lines_list, data_num_list data_lines.extend(lines)
return data_lines
def dataset_traversal(self, data_num_list, ratio_list, batch_size):
select_num_list = []
dataset_num = len(data_num_list)
for dno in range(dataset_num):
select_num = round(batch_size * ratio_list[dno])
select_num = max(select_num, 1)
select_num_list.append(select_num)
data_idx_order_list = []
cur_index_sets = [0] * dataset_num
while True:
finish_read_num = 0
for dataset_idx in range(dataset_num):
cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]:
finish_read_num += 1
else:
select_num = select_num_list[dataset_idx]
for sno in range(select_num):
cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]:
break
data_idx_order_list.append((dataset_idx, cur_index))
cur_index_sets[dataset_idx] += 1
if finish_read_num == dataset_num:
break
return data_idx_order_list
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
for dno in range(len(self.data_lines_list)): random.shuffle(self.data_lines)
random.shuffle(self.data_lines_list[dno])
return return
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx, file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines_list[dataset_idx][file_idx] data_line = self.data_lines[file_idx]
try: try:
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
......
...@@ -23,7 +23,7 @@ import copy ...@@ -23,7 +23,7 @@ import copy
import numpy as np import numpy as np
import math import math
import time import time
import traceback
import paddle.fluid as fluid import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
...@@ -106,10 +106,10 @@ class TextClassifier(object): ...@@ -106,10 +106,10 @@ class TextClassifier(object):
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch]) self.predictor.run([norm_img_batch])
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
cls_res = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
for rno in range(len(cls_res)): for rno in range(len(cls_result)):
label, score = cls_res[rno] label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score] cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > self.cls_thresh: if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]] = cv2.rotate(
...@@ -133,8 +133,8 @@ def main(args): ...@@ -133,8 +133,8 @@ def main(args):
img_list.append(img) img_list.append(img)
try: try:
img_list, cls_res, predict_time = text_classifier(img_list) img_list, cls_res, predict_time = text_classifier(img_list)
except Exception as e: except:
print(e) logger.info(traceback.format_exc())
logger.info( logger.info(
"ERROR!!!! \n" "ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
...@@ -143,10 +143,10 @@ def main(args): ...@@ -143,10 +143,10 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[
ino])) ino]))
print("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
...@@ -178,11 +178,12 @@ if __name__ == "__main__": ...@@ -178,11 +178,12 @@ if __name__ == "__main__":
if count > 0: if count > 0:
total_time += elapse total_time += elapse
count += 1 count += 1
print("Predict time of {}: {}".format(image_file, elapse)) logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file) src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1] img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save, img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure)) "det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if count > 1: if count > 1:
print("Avg Time:", total_time / (count - 1)) logger.info("Avg Time:", total_time / (count - 1))
...@@ -22,7 +22,7 @@ import cv2 ...@@ -22,7 +22,7 @@ import cv2
import numpy as np import numpy as np
import math import math
import time import time
import traceback
import paddle.fluid as fluid import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
...@@ -135,8 +135,8 @@ def main(args): ...@@ -135,8 +135,8 @@ def main(args):
img_list.append(img) img_list.append(img)
try: try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
except Exception as e: except:
print(e) logger.info(traceback.format_exc())
logger.info( logger.info(
"ERROR!!!! \n" "ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
...@@ -145,9 +145,9 @@ def main(args): ...@@ -145,9 +145,9 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
ino])) ino]))
print("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
......
...@@ -23,17 +23,21 @@ import numpy as np ...@@ -23,17 +23,21 @@ import numpy as np
import time import time
from PIL import Image from PIL import Image
import tools.infer.utility as utility import tools.infer.utility as utility
from tools.infer.utility import draw_ocr
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt
class TextSystem(object): class TextSystem(object):
def __init__(self, args): def __init__(self, args):
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
def get_rotate_crop_image(self, img, points): def get_rotate_crop_image(self, img, points):
''' '''
...@@ -72,12 +76,13 @@ class TextSystem(object): ...@@ -72,12 +76,13 @@ class TextSystem(object):
bbox_num = len(img_crop_list) bbox_num = len(img_crop_list)
for bno in range(bbox_num): for bno in range(bbox_num):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
print(bno, rec_res[bno]) logger.info(bno, rec_res[bno])
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
print("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) logger.info("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
img_crop_list = [] img_crop_list = []
...@@ -88,8 +93,15 @@ class TextSystem(object): ...@@ -88,8 +93,15 @@ class TextSystem(object):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
logger.info("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) logger.info("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
return dt_boxes, rec_res return dt_boxes, rec_res
...@@ -119,7 +131,8 @@ def main(args): ...@@ -119,7 +131,8 @@ def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
tackle_img_num = 0 font_path = args.vis_font_path
drop_score = args.drop_score
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -128,20 +141,16 @@ def main(args): ...@@ -128,20 +141,16 @@ def main(args):
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
tackle_img_num += 1
if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
text_sys = TextSystem(args)
dt_boxes, rec_res = text_sys(img) dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime elapse = time.time() - starttime
print("Predict time of %s: %.3fs" % (image_file, elapse)) logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
drop_score = 0.5
dt_num = len(dt_boxes) dt_num = len(dt_boxes)
for dno in range(dt_num): for dno in range(dt_num):
text, score = rec_res[dno] text, score = rec_res[dno]
if score >= drop_score: if score >= drop_score:
text_str = "%s, %.3f" % (text, score) text_str = "%s, %.3f" % (text, score)
print(text_str) logger.info(text_str)
if is_visualize: if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
...@@ -149,15 +158,20 @@ def main(args): ...@@ -149,15 +158,20 @@ def main(args):
txts = [rec_res[i][0] for i in range(len(rec_res))] txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr( draw_img = draw_ocr_box_txt(
image, boxes, txts, scores, drop_score=drop_score) image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
draw_img_save = "./inference_results/" draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)), os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
print("The visualized image saved in {}".format( logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(draw_img_save, os.path.basename(image_file))))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册