提交 20466055 编写于 作者: W WenmuZhou

add save_dir to args

上级 0bf30fea
......@@ -38,6 +38,8 @@ logger = get_logger()
def parse_args():
parser = utility.init_args()
# params for output
parser.add_argument("--table_output", type=str, default='output/table')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_max_text_length", type=int, default=100)
......@@ -65,9 +67,9 @@ class OCRSystem():
layout_res = self.table_layout(copy.deepcopy(img))
for region in layout_res:
x1, y1, x2, y2 = region['bbox']
roi_img = ori_im[y1:y2, x1:x2,:]
roi_img = ori_im[y1:y2, x1:x2, :]
if region['label'] == 'table':
res = self.table_system(roi_img)
res = self.text_system(roi_img)
else:
res = self.text_system(roi_img)
region['res'] = res
......@@ -77,15 +79,15 @@ class OCRSystem():
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
excel_save_folder = 'output/table'
os.makedirs(excel_save_folder, exist_ok=True)
save_folder = args.table_output
os.makedirs(save_folder, exist_ok=True)
text_sys = OCRSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
imgname = os.path.basename(image_file).split('.')[0]
img_name = os.path.basename(image_file).split('.')[0]
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
if not flag:
img = cv2.imread(image_file)
......@@ -95,11 +97,17 @@ def main(args):
starttime = time.time()
res = text_sys(img)
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['label'] == 'table':
# x1, y1, x2, y2 = region['bbox']
excel_path = os.path.join(excel_save_folder, '{}_{}.xlsx'.format(imgname,region['bbox']))
to_excel(region['res'],excel_path)
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
else:
with open(os.path.join(excel_save_folder, 'res.txt'),'a',encoding='utf8') as f:
for box, rec_res in zip(*region['res']):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
logger.info(res)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册