\n GT')
+
+ for i, filename in enumerate(sorted(images_dir)):
+ if filename.endswith("txt"): continue
+ print(filename)
+
+ base = "{}".format(filename)
+ if True:
+ html.write("
\n")
+ html.write(f'
{filename}\n GT')
+ html.write('
GT 310\n
' % (base))
+ html.write("
\n")
+
+ html.write('\n')
+ html.write('
\n')
+ html.write('\n\n')
+ print("ok")
+
+
+def crop_seal_from_img(label_file, data_dir, save_dir, save_gt_path):
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ datas = open(label_file, 'r').readlines()
+ all_gts = []
+ count = 0
+ for idx, line in enumerate(datas):
+ img_path, label = line.strip().split('\t')
+ img_path = os.path.join(data_dir, img_path)
+
+ label = json.loads(label)
+ src_im = cv2.imread(img_path)
+ if src_im is None:
+ continue
+
+ for c, anno in enumerate(label):
+ seal_poly = anno['seal_box']
+ txt_boxes = anno['polys']
+ txts = anno['texts']
+ ignore_tags = anno['ignore_tags']
+
+ box = poly2box(seal_poly)
+ img_crop = src_im[box[0][1]:box[2][1], box[0][0]:box[2][0], :]
+
+ save_path = os.path.join(save_dir, f"{idx}_{c}.jpg")
+ cv2.imwrite(save_path, np.array(img_crop))
+
+ img_gt = []
+ for i in range(len(txts)):
+ txt_boxes_crop = np.array(txt_boxes[i])
+ txt_boxes_crop[:, 1] -= box[0, 1]
+ txt_boxes_crop[:, 0] -= box[0, 0]
+ img_gt.append({'transcription': txts[i], "points": txt_boxes_crop.tolist(), "ignore_tag": ignore_tags[i]})
+
+ if len(img_gt) >= 1:
+ count += 1
+ save_gt = f"{os.path.basename(save_path)}\t{json.dumps(img_gt)}\n"
+
+ all_gts.append(save_gt)
+
+ print(f"The num of all image: {len(all_gts)}, and the number of useful image: {count}")
+ if not os.path.exists(os.path.dirname(save_gt_path)):
+ os.makedirs(os.path.dirname(save_gt_path))
+
+ with open(save_gt_path, "w") as f:
+ f.writelines(all_gts)
+ f.close()
+ print("Done")
+
+
+
+if __name__ == "__main__":
+
+ # 数据处理
+ gen_extract_label("./seal_labeled_datas", "./seal_labeled_datas/Label.txt", "./seal_ppocr_gt/seal_det_img.txt", "./seal_ppocr_gt/seal_ppocr_img.txt")
+ vis_seal_ppocr("./seal_labeled_datas", "./seal_ppocr_gt/seal_ppocr_img.txt", "./seal_ppocr_gt/seal_ppocr_vis/")
+ draw_html("./seal_ppocr_gt/seal_ppocr_vis/", "./vis_seal_ppocr.html")
+ seal_ppocr_img_label = "./seal_ppocr_gt/seal_ppocr_img.txt"
+ crop_seal_from_img(seal_ppocr_img_label, "./seal_labeled_datas/", "./seal_img_crop", "./seal_img_crop/label.txt")
+
+```
+
+处理完成后,生成的文件如下:
+```
+├── seal_img_crop/
+│ ├── 0_0.jpg
+│ ├── ...
+│ └── label.txt
+├── seal_ppocr_gt/
+│ ├── seal_det_img.txt
+│ ├── seal_ppocr_img.txt
+│ └── seal_ppocr_vis/
+│ ├── test1.png
+│ ├── ...
+└── vis_seal_ppocr.html
+
+```
+其中`seal_img_crop/label.txt`文件为印章识别标签文件,其内容格式为:
+```
+0_0.jpg [{"transcription": "\u7535\u5b50\u56de\u5355", "points": [[29, 73], [96, 73], [96, 90], [29, 90]], "ignore_tag": false}, {"transcription": "\u4e91\u5357\u7701\u519c\u6751\u4fe1\u7528\u793e", "points": [[9, 58], [26, 63], [30, 49], [38, 35], [47, 29], [64, 26], [81, 32], [90, 45], [94, 63], [118, 57], [110, 35], [95, 17], [67, 0], [38, 7], [21, 23], [10, 43]], "ignore_tag": false}, {"transcription": "\u4e13\u7528\u7ae0", "points": [[29, 87], [95, 87], [95, 106], [29, 106]], "ignore_tag": false}]
+```
+可以直接用于PaddleOCR的PGNet算法的训练。
+
+`seal_ppocr_gt/seal_det_img.txt`为印章检测标签文件,其内容格式为:
+```
+img/test1.png [{"polys": [[408, 232], [537, 232], [537, 352], [408, 352]], "cls": 1}]
+```
+为了使用PaddleDetection工具完成印章检测模型的训练,需要将`seal_det_img.txt`转换为COCO或者VOC的数据标注格式。
+
+可以直接使用下述代码将印章检测标注转换成VOC格式。
+
+
+```
+import numpy as np
+import json
+import cv2
+import os
+from shapely.geometry import Polygon
+
+seal_train_gt = "./seal_ppocr_gt/seal_det_img.txt"
+# 注:仅用于示例,实际使用中需要分别转换训练集和测试集的标签
+seal_valid_gt = "./seal_ppocr_gt/seal_det_img.txt"
+
+def gen_main_train_txt(mode='train'):
+ if mode == "train":
+ file_path = seal_train_gt
+ if mode in ['valid', 'test']:
+ file_path = seal_valid_gt
+
+ save_path = f"./seal_VOC/ImageSets/Main/{mode}.txt"
+ save_train_path = f"./seal_VOC/{mode}.txt"
+ if not os.path.exists(os.path.dirname(save_path)):
+ os.makedirs(os.path.dirname(save_path))
+
+ datas = open(file_path, 'r').readlines()
+ img_names = []
+ train_names = []
+ for line in datas:
+ img_name = line.strip().split('\t')[0]
+ img_name = os.path.basename(img_name)
+ (i_name, extension) = os.path.splitext(img_name)
+ t_name = 'JPEGImages/'+str(img_name)+' '+'Annotations/'+str(i_name)+'.xml\n'
+ train_names.append(t_name)
+ img_names.append(i_name + "\n")
+
+ with open(save_train_path, "w") as f:
+ f.writelines(train_names)
+ f.close()
+
+ with open(save_path, "w") as f:
+ f.writelines(img_names)
+ f.close()
+
+ print(f"{mode} save done")
+
+
+def gen_xml_label(mode='train'):
+ if mode == "train":
+ file_path = seal_train_gt
+ if mode in ['valid', 'test']:
+ file_path = seal_valid_gt
+
+ datas = open(file_path, 'r').readlines()
+ img_names = []
+ train_names = []
+ anno_path = "./seal_VOC/Annotations"
+ img_path = "./seal_VOC/JPEGImages"
+
+ if not os.path.exists(anno_path):
+ os.makedirs(anno_path)
+ if not os.path.exists(img_path):
+ os.makedirs(img_path)
+
+ for idx, line in enumerate(datas):
+ img_name, label = line.strip().split('\t')
+ img = cv2.imread(os.path.join("./seal_labeled_datas", img_name))
+ cv2.imwrite(os.path.join(img_path, os.path.basename(img_name)), img)
+ height, width, c = img.shape
+ img_name = os.path.basename(img_name)
+ (i_name, extension) = os.path.splitext(img_name)
+ label = json.loads(label)
+
+ xml_file = open(("./seal_VOC/Annotations" + '/' + i_name + '.xml'), 'w')
+ xml_file.write('\n')
+ xml_file.write(' seal_VOC\n')
+ xml_file.write(' ' + str(img_name) + '\n')
+ xml_file.write(' ' + 'Annotations/' + str(img_name) + '\n')
+ xml_file.write(' \n')
+ xml_file.write(' ' + str(width) + '\n')
+ xml_file.write(' ' + str(height) + '\n')
+ xml_file.write(' 3\n')
+ xml_file.write(' \n')
+ xml_file.write(' 0\n')
+
+ for anno in label:
+ poly = anno['polys']
+ if anno['cls'] == 1:
+ gt_cls = 'redseal'
+ xmin = np.min(np.array(poly)[:, 0])
+ ymin = np.min(np.array(poly)[:, 1])
+ xmax = np.max(np.array(poly)[:, 0])
+ ymax = np.max(np.array(poly)[:, 1])
+ xmin,ymin,xmax,ymax= int(xmin),int(ymin),int(xmax),int(ymax)
+ xml_file.write(' \n')
+ xml_file.write('')
+ xml_file.close()
+ print(f'{mode} xml save done!')
+
+
+gen_main_train_txt()
+gen_main_train_txt('valid')
+gen_xml_label('train')
+gen_xml_label('valid')
+
+```
+
+数据处理完成后,转换为VOC格式的印章检测数据存储在~/data/seal_VOC目录下,目录组织结构为:
+
+```
+├── Annotations/
+├── ImageSets/
+│ └── Main/
+│ ├── train.txt
+│ └── valid.txt
+├── JPEGImages/
+├── train.txt
+└── valid.txt
+└── label_list.txt
+```
+
+Annotations下为数据的标签,JPEGImages目录下为图像文件,label_list.txt为标注检测框类别标签文件。
+
+在接下来一节中,将介绍如何使用PaddleDetection工具库完成印章检测模型的训练。
+
+# 4. 印章检测实践
+
+在实际应用中,印章多是出现在合同,发票,公告等场景中,印章文字识别的任务需要排除图像中背景文字的影响,因此需要先检测出图像中的印章区域。
+
+
+借助PaddleDetection目标检测库可以很容易的实现印章检测任务,使用PaddleDetection训练印章检测任务流程如下:
+
+- 选择算法
+- 修改数据集配置路径
+- 启动训练
+
+
+**算法选择**
+
+PaddleDetection中有许多检测算法可以选择,考虑到每条数据中印章区域较为清晰,且考虑到性能需求。在本项目中,我们采用mobilenetv3为backbone的ppyolo算法完成印章检测任务,对应的配置文件是:configs/ppyolo/ppyolo_mbv3_large.yml
+
+
+
+**修改配置文件**
+
+配置文件中的默认数据路径是COCO,
+需要修改为印章检测的数据路径,主要修改如下:
+在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容:
+```
+metric: VOC
+map_type: 11point
+num_classes: 2
+
+TrainDataset:
+ !VOCDataSet
+ dataset_dir: dataset/seal_VOC
+ anno_path: train.txt
+ label_list: label_list.txt
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
+
+EvalDataset:
+ !VOCDataSet
+ dataset_dir: dataset/seal_VOC
+ anno_path: test.txt
+ label_list: label_list.txt
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
+
+TestDataset:
+ !ImageFolder
+ anno_path: dataset/seal_VOC/label_list.txt
+```
+
+配置文件中设置的数据路径在PaddleDetection/dataset目录下,我们可以将处理后的印章检测训练数据移动到PaddleDetection/dataset目录下或者创建一个软连接。
+
+```
+!ln -s seal_VOC ./PaddleDetection/dataset/
+```
+
+另外图象中印章数量比较少,可以调整NMS后处理的检测框数量,即keep_top_k,nms_top_k 从100,1000,调整为10,100。在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容完成后处理参数的调整
+```
+BBoxPostProcess:
+ decode:
+ name: YOLOBox
+ conf_thresh: 0.005
+ downsample_ratio: 32
+ clip_bbox: true
+ scale_x_y: 1.05
+ nms:
+ name: MultiClassNMS
+ keep_top_k: 10 # 修改前100
+ nms_threshold: 0.45
+ nms_top_k: 100 # 修改前1000
+ score_threshold: 0.005
+```
+
+
+修改完成后,需要在PaddleDetection中增加印章数据的处理代码,即在PaddleDetection/ppdet/data/source/目录下创建seal.py文件,文件中填充如下代码:
+```
+import os
+import numpy as np
+from ppdet.core.workspace import register, serializable
+from .dataset import DetDataset
+import cv2
+import json
+
+from ppdet.utils.logger import setup_logger
+logger = setup_logger(__name__)
+
+
+@register
+@serializable
+class SealDataSet(DetDataset):
+ """
+ Load dataset with COCO format.
+
+ Args:
+ dataset_dir (str): root directory for dataset.
+ image_dir (str): directory for images.
+ anno_path (str): coco annotation file path.
+ data_fields (list): key name of data dictionary, at least have 'image'.
+ sample_num (int): number of samples to load, -1 means all.
+ load_crowd (bool): whether to load crowded ground-truth.
+ False as default
+ allow_empty (bool): whether to load empty entry. False as default
+ empty_ratio (float): the ratio of empty record number to total
+ record's, if empty_ratio is out of [0. ,1.), do not sample the
+ records and use all the empty entries. 1. as default
+ """
+
+ def __init__(self,
+ dataset_dir=None,
+ image_dir=None,
+ anno_path=None,
+ data_fields=['image'],
+ sample_num=-1,
+ load_crowd=False,
+ allow_empty=False,
+ empty_ratio=1.):
+ super(SealDataSet, self).__init__(dataset_dir, image_dir, anno_path,
+ data_fields, sample_num)
+ self.load_image_only = False
+ self.load_semantic = False
+ self.load_crowd = load_crowd
+ self.allow_empty = allow_empty
+ self.empty_ratio = empty_ratio
+
+ def _sample_empty(self, records, num):
+ # if empty_ratio is out of [0. ,1.), do not sample the records
+ if self.empty_ratio < 0. or self.empty_ratio >= 1.:
+ return records
+ import random
+ sample_num = min(
+ int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
+ records = random.sample(records, sample_num)
+ return records
+
+ def parse_dataset(self):
+ anno_path = os.path.join(self.dataset_dir, self.anno_path)
+ image_dir = os.path.join(self.dataset_dir, self.image_dir)
+
+ records = []
+ empty_records = []
+ ct = 0
+
+ assert anno_path.endswith('.txt'), \
+ 'invalid seal_gt file: ' + anno_path
+
+ all_datas = open(anno_path, 'r').readlines()
+
+ for idx, line in enumerate(all_datas):
+ im_path, label = line.strip().split('\t')
+ img_path = os.path.join(image_dir, im_path)
+ label = json.loads(label)
+ im_h, im_w, im_c = cv2.imread(img_path).shape
+
+ coco_rec = {
+ 'im_file': img_path,
+ 'im_id': np.array([idx]),
+ 'h': im_h,
+ 'w': im_w,
+ } if 'image' in self.data_fields else {}
+
+ if not self.load_image_only:
+ bboxes = []
+ for anno in label:
+ poly = anno['polys']
+ # poly to box
+ x1 = np.min(np.array(poly)[:, 0])
+ y1 = np.min(np.array(poly)[:, 1])
+ x2 = np.max(np.array(poly)[:, 0])
+ y2 = np.max(np.array(poly)[:, 1])
+ eps = 1e-5
+ if x2 - x1 > eps and y2 - y1 > eps:
+ clean_box = [
+ round(float(x), 3) for x in [x1, y1, x2, y2]
+ ]
+ anno = {'clean_box': clean_box, 'gt_cls':int(anno['cls'])}
+ bboxes.append(anno)
+ else:
+ logger.info("invalid box")
+
+ num_bbox = len(bboxes)
+ if num_bbox <= 0:
+ continue
+
+ gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
+ gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
+ is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
+ # gt_poly = [None] * num_bbox
+
+ for i, box in enumerate(bboxes):
+ gt_class[i][0] = box['gt_cls']
+ gt_bbox[i, :] = box['clean_box']
+ is_crowd[i][0] = 0
+
+ gt_rec = {
+ 'is_crowd': is_crowd,
+ 'gt_class': gt_class,
+ 'gt_bbox': gt_bbox,
+ # 'gt_poly': gt_poly,
+ }
+
+ for k, v in gt_rec.items():
+ if k in self.data_fields:
+ coco_rec[k] = v
+
+ records.append(coco_rec)
+ ct += 1
+ if self.sample_num > 0 and ct >= self.sample_num:
+ break
+ self.roidbs = records
+```
+
+**启动训练**
+
+启动单卡训练的命令为:
+```
+!python3 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval
+
+# 分布式训练命令为:
+!python3 -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval
+```
+
+训练完成后,日志中会打印模型的精度:
+
+```
+[07/05 11:42:09] ppdet.engine INFO: Eval iter: 0
+[07/05 11:42:14] ppdet.metrics.metrics INFO: Accumulating evaluatation results...
+[07/05 11:42:14] ppdet.metrics.metrics INFO: mAP(0.50, 11point) = 99.31%
+[07/05 11:42:14] ppdet.engine INFO: Total sample number: 112, averge FPS: 26.45840794253432
+[07/05 11:42:14] ppdet.engine INFO: Best test bbox ap is 0.996.
+```
+
+
+我们可以使用训练好的模型观察预测结果:
+```
+!python3 tools/infer.py -c configs/ppyolo/ppyolo_mbv3_large.yml -o weights=./output/ppyolo_mbv3_large/model_final.pdparams --img_dir=./test.jpg
+```
+预测结果如下:
+
+![](https://ai-studio-static-online.cdn.bcebos.com/0f650c032b0f4d56bd639713924768cc820635e9977845008d233f465291a29e)
+
+# 5. 印章文字识别实践
+
+在使用ppyolo检测到印章区域后,接下来借助PaddleOCR里的文字识别能力,完成印章中文字的识别。
+
+PaddleOCR中的OCR算法包含文字检测算法,文字识别算法以及OCR端对端算法。
+
+文字检测算法负责检测到图像中的文字,再由文字识别模型识别出检测到的文字,进而实现OCR的任务。文字检测+文字识别串联完成OCR任务的架构称为两阶段的OCR算法。相对应的端对端的OCR方法可以用一个算法同时完成文字检测和识别的任务。
+
+
+| 文字检测 | 文字识别 | 端对端算法 |
+| -------- | -------- | -------- |
+| DB\DB++\EAST\SAST\PSENet | SVTR\CRNN\NRTN\Abinet\SAR\... | PGNet |
+
+
+本节中将分别介绍端对端的文字检测识别算法以及两阶段的文字检测识别算法在印章检测识别任务上的实践。
+
+
+## 5.1 端对端印章文字识别实践
+
+本节介绍使用PaddleOCR里的PGNet算法完成印章文字识别。
+
+PGNet属于端对端的文字检测识别算法,在PaddleOCR中的配置文件为:
+[PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/e2e/e2e_r50_vd_pg.yml)
+
+使用PGNet完成文字检测识别任务的步骤为:
+- 修改配置文件
+- 启动训练
+
+PGNet默认配置文件的数据路径为totaltext数据集路径,本次训练中,需要修改为上一节数据处理后得到的标签文件和数据目录:
+
+训练数据配置修改后如下:
+```
+Train:
+ dataset:
+ name: PGDataSet
+ data_dir: ./train_data/seal_ppocr
+ label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt]
+ ratio_list: [1.0]
+```
+测试数据集配置修改后如下:
+```
+Eval:
+ dataset:
+ name: PGDataSet
+ data_dir: ./train_data/seal_ppocr_test
+ label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt]
+```
+
+启动训练的命令为:
+```
+!python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml
+```
+模型训练完成后,可以得到最终的精度为47.4%。数据量较少,以及数据质量较差会影响模型的训练精度,如果有更多的数据参与训练,精度将进一步提升。
+
+如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
+
+## 5.2 两阶段印章文字识别实践
+
+上一节介绍了使用PGNet实现印章识别任务的训练流程。本小节将介绍使用PaddleOCR里的文字检测和文字识别算法分别完成印章文字的检测和识别。
+
+### 5.2.1 印章文字检测
+
+PaddleOCR中包含丰富的文字检测算法,包含DB,DB++,EAST,SAST,PSENet等等。其中DB,DB++,PSENet均支持弯曲文字检测,本项目中,使用DB++作为印章弯曲文字检测算法。
+
+PaddleOCR中发布的db++文字检测算法模型是英文文本检测模型,因此需要重新训练模型。
+
+
+修改[DB++配置文件](DB++的默认配置文件位于[configs/det/det_r50_db++_icdar15.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml)
+中的数据路径:
+
+
+```
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/seal_ppocr
+ label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt]
+ ratio_list: [1.0]
+```
+测试数据集配置修改后如下:
+```
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/seal_ppocr_test
+ label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt]
+```
+
+
+启动训练:
+```
+!python3 tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.epoch_num=100
+```
+
+考虑到数据较少,通过Global.epoch_num设置仅训练100个epoch。
+模型训练完成后,在测试集上预测的可视化效果如下:
+
+![](https://ai-studio-static-online.cdn.bcebos.com/498119182f0a414ab86ae2de752fa31c9ddc3a74a76847049cc57884602cb269)
+
+
+如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
+
+
+### 5.2.2 印章文字识别
+
+上一节中完成了印章文字的检测模型训练,本节介绍印章文字识别模型的训练。识别模型采用SVTR算法,SVTR算法是IJCAI收录的文字识别算法,SVTR模型具备超轻量高精度的特点。
+
+在启动训练之前,需要准备印章文字识别需要的数据集,需要使用如下代码,将印章中的文字区域剪切出来构建训练集。
+
+```
+import cv2
+import numpy as np
+
+def get_rotate_crop_image(img, points):
+ '''
+ img_height, img_width = img.shape[0:2]
+ left = int(np.min(points[:, 0]))
+ right = int(np.max(points[:, 0]))
+ top = int(np.min(points[:, 1]))
+ bottom = int(np.max(points[:, 1]))
+ img_crop = img[top:bottom, left:right, :].copy()
+ points[:, 0] = points[:, 0] - left
+ points[:, 1] = points[:, 1] - top
+ '''
+ assert len(points) == 4, "shape of points must be 4*2"
+ img_crop_width = int(
+ max(
+ np.linalg.norm(points[0] - points[1]),
+ np.linalg.norm(points[2] - points[3])))
+ img_crop_height = int(
+ max(
+ np.linalg.norm(points[0] - points[3]),
+ np.linalg.norm(points[1] - points[2])))
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height]])
+ M = cv2.getPerspectiveTransform(points, pts_std)
+ dst_img = cv2.warpPerspective(
+ img,
+ M, (img_crop_width, img_crop_height),
+ borderMode=cv2.BORDER_REPLICATE,
+ flags=cv2.INTER_CUBIC)
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
+ dst_img = np.rot90(dst_img)
+ return dst_img
+
+
+def run(data_dir, label_file, save_dir):
+ datas = open(label_file, 'r').readlines()
+ for idx, line in enumerate(datas):
+ img_path, label = line.strip().split('\t')
+ img_path = os.path.join(data_dir, img_path)
+
+ label = json.loads(label)
+ src_im = cv2.imread(img_path)
+ if src_im is None:
+ continue
+
+ for anno in label:
+ seal_box = anno['seal_box']
+ txt_boxes = anno['polys']
+ crop_im = get_rotate_crop_image(src_im, text_boxes)
+
+ save_path = os.path.join(save_dir, f'{idx}.png')
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ # print(src_im.shape)
+ cv2.imwrite(save_path, crop_im)
+
+```
+
+
+数据处理完成后,即可配置训练的配置文件。SVTR配置文件选择[configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml)
+修改SVTR配置文件中的训练数据部分如下:
+
+```
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/seal_ppocr_crop/
+ label_file_list:
+ - ./train_data/seal_ppocr_crop/train_list.txt
+```
+
+修改预测部分配置文件:
+```
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/seal_ppocr_crop/
+ label_file_list:
+ - ./train_data/seal_ppocr_crop_test/train_list.txt
+```
+
+启动训练:
+
+```
+!python3 tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
+
+```
+
+训练完成后可以发现测试集指标达到了61%。
+由于数据较少,训练时会发现在训练集上的acc指标远大于测试集上的acc指标,即出现过拟合现象。通过补充数据和一些数据增强可以缓解这个问题。
+
+
+
+如需获取已训练模型,请扫下图二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
+
+
+![](https://ai-studio-static-online.cdn.bcebos.com/ea32877b717643289dc2121a2e573526d99d0f9eecc64ad4bd8dcf121cb5abde)
diff --git a/configs/rec/rec_r31_robustscanner.yml b/configs/rec/rec_r31_robustscanner.yml
index 40d39aee3c42c18085ace035944dba057b923245..54b69d456ef67c88289504ed5cf8719588ea0803 100644
--- a/configs/rec/rec_r31_robustscanner.yml
+++ b/configs/rec/rec_r31_robustscanner.yml
@@ -12,7 +12,7 @@ Global:
checkpoints:
save_inference_dir:
use_visualdl: False
- infer_img: ./inference/rec_inference
+ infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: &max_text_length 40
diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml
index aea71388f703376120af4d0caf2fa8ccd4d92cce..91d3e104188c04f4372a6e574b7fc07608234a2e 100644
--- a/configs/rec/rec_r32_gaspin_bilstm_att.yml
+++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml
@@ -12,7 +12,7 @@ Global:
checkpoints:
save_inference_dir:
use_visualdl: False
- infer_img: doc/imgs_words/ch/word_1.jpg
+ infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ./ppocr/utils/dict/spin_dict.txt
max_text_length: 25
diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml
index 9355a236e15b60db18e8715c2702701fd5d36c71..9d286f4153eaab44bf0d259bbad4a0b3b8ada568 100755
--- a/configs/table/table_mv3.yml
+++ b/configs/table/table_mv3.yml
@@ -43,7 +43,6 @@ Architecture:
Head:
name: TableAttentionHead
hidden_size: 256
- loc_type: 2
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 4351fdbfcb501945d6061bc1fcc3585bd76eab7d..e9bcc275d1a7157628188c337a312a49408d207b 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -101,11 +101,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
-|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
-|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
+|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
-
## 2. 端到端算法
diff --git a/doc/doc_ch/algorithm_rec_robustscanner.md b/doc/doc_ch/algorithm_rec_robustscanner.md
index 869f9a7c00b617de87ab3c96326e18e536bc18a8..a1ab3baf038f573c50fcc0b67fd8297459777cf8 100644
--- a/doc/doc_ch/algorithm_rec_robustscanner.md
+++ b/doc/doc_ch/algorithm_rec_robustscanner.md
@@ -26,7 +26,7 @@ Zhang
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
-|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
+|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。
diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md
index c996992d2fa6297e6086ffae4bc36ad3e880873d..908a85a417c4070b95630b37b0830e08aae3ff4f 100644
--- a/doc/doc_ch/algorithm_rec_spin.md
+++ b/doc/doc_ch/algorithm_rec_spin.md
@@ -26,7 +26,7 @@ SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
-|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar)|
diff --git a/doc/doc_ch/inference_args.md b/doc/doc_ch/inference_args.md
index 36efc6fbf7a6ec62bc700964dc13261fecdb9bd5..24e7223e397c94fe65b0f26d993fc507b323ed16 100644
--- a/doc/doc_ch/inference_args.md
+++ b/doc/doc_ch/inference_args.md
@@ -7,6 +7,7 @@
| 参数名称 | 类型 | 默认值 | 含义 |
| :--: | :--: | :--: | :--: |
| image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 |
+| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 |
| drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 |
| use_pdserving | bool | False | 是否使用Paddle Serving进行预测 |
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index f7ef7ad4b8fc162a6ed7b275f3beea6955b86452..90449e1729fcff898f27641d3f777c8f002f6a97 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -98,8 +98,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
-|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
-|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
+|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
diff --git a/doc/doc_en/algorithm_rec_robustscanner_en.md b/doc/doc_en/algorithm_rec_robustscanner_en.md
index a324a6d547a9e448566276234c750ad4497abf9c..99372d51300e6571827db7181a4f8537db4f1493 100644
--- a/doc/doc_en/algorithm_rec_robustscanner_en.md
+++ b/doc/doc_en/algorithm_rec_robustscanner_en.md
@@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
-|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
+|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper.
diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md
index 43ab30ce7d96cbb64ddf87156fee3012d666b2bf..03f8d8f69986fc5eb14cfdf294fc25fafb06e269 100644
--- a/doc/doc_en/algorithm_rec_spin_en.md
+++ b/doc/doc_en/algorithm_rec_spin_en.md
@@ -25,7 +25,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
-|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
diff --git a/doc/doc_en/inference_args_en.md b/doc/doc_en/inference_args_en.md
index f2c99fc8297d47f27a219bf7d8e7f2ea518257f0..b28cd8436da62dcd10f96f17751db9384ebcaa8d 100644
--- a/doc/doc_en/inference_args_en.md
+++ b/doc/doc_en/inference_args_en.md
@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par
| parameters | type | default | implication |
| :--: | :--: | :--: | :--: |
| image_dir | str | None, must be specified explicitly | Image or folder path |
+| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization |
| drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results |
| use_pdserving | bool | False | Whether to use Paddle Serving for prediction |
diff --git a/paddleocr.py b/paddleocr.py
index fa732fc110dc7873f8d89b2ca2a21817a1e6d20d..d34b8f78a56a8d8d5455c18e7e1cf1e75df8f3f9 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -480,10 +480,11 @@ class PaddleOCR(predict_system.TextSystem):
params.rec_image_shape = "3, 48, 320"
else:
params.rec_image_shape = "3, 32, 320"
- # download model
- maybe_download(params.det_model_dir, det_url)
- maybe_download(params.rec_model_dir, rec_url)
- maybe_download(params.cls_model_dir, cls_url)
+ # download model if using paddle infer
+ if not params.use_onnx:
+ maybe_download(params.det_model_dir, det_url)
+ maybe_download(params.rec_model_dir, rec_url)
+ maybe_download(params.cls_model_dir, cls_url)
if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
index d3c86e22b02e08c18d8d5cb193f2ffb8b07ad785..50910c5b73aa2a41f329d7222fc8c632509b4c91 100644
--- a/ppocr/modeling/heads/table_att_head.py
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import math
import paddle
import paddle.nn as nn
from paddle import ParamAttr
@@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer):
def __init__(self,
in_channels,
hidden_size,
- loc_type,
in_max_len=488,
max_text_length=800,
out_channels=30,
@@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer):
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.out_channels)
- self.loc_type = loc_type
self.in_max_len = in_max_len
- if self.loc_type == 1:
- self.loc_generator = nn.Linear(hidden_size, 4)
+ if self.in_max_len == 640:
+ self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
+ elif self.in_max_len == 800:
+ self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else:
- if self.in_max_len == 640:
- self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
- elif self.in_max_len == 800:
- self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
- else:
- self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
- self.loc_generator = nn.Linear(self.input_size + hidden_size,
- loc_reg_num)
+ self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
+ self.loc_generator = nn.Linear(self.input_size + hidden_size,
+ loc_reg_num)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
@@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer):
# if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1]
- if len(fea.shape) == 3:
- pass
- else:
- last_shape = int(np.prod(fea.shape[2:])) # gry added
- fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
- fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
+ last_shape = int(np.prod(fea.shape[2:])) # gry added
+ fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
+ fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size))
- output_hiddens = []
+ output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size))
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_text_length + 1):
@@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer):
structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
- output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output_hiddens[:, i, :] = outputs
+ # output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
if self.loc_type == 1:
@@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer):
outputs = None
alpha = None
max_text_length = paddle.to_tensor(self.max_text_length)
- i = 0
- while i < max_text_length + 1:
+ for i in range(max_text_length + 1):
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
- output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output_hiddens[:, i, :] = outputs
+ # output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
- i += 1
- output = paddle.concat(output_hiddens, axis=1)
+ output = output_hiddens
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs)
- if self.loc_type == 1:
- loc_preds = self.loc_generator(output)
- loc_preds = F.sigmoid(loc_preds)
- else:
- loc_fea = fea.transpose([0, 2, 1])
- loc_fea = self.loc_fea_trans(loc_fea)
- loc_fea = loc_fea.transpose([0, 2, 1])
- loc_concat = paddle.concat([output, loc_fea], axis=2)
- loc_preds = self.loc_generator(loc_concat)
- loc_preds = F.sigmoid(loc_preds)
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md
index 08635516ba8301e6f98f175e5eba8c0a97b1708e..1d082f3878c56e42d175d13c75e1fe17916e7781 100644
--- a/ppstructure/table/README.md
+++ b/ppstructure/table/README.md
@@ -114,7 +114,7 @@ python3 table/eval_table.py \
--det_model_dir=path/to/det_model_dir \
--rec_model_dir=path/to/rec_model_dir \
--table_model_dir=path/to/table_model_dir \
- --image_dir=../doc/table/1.png \
+ --image_dir=docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
@@ -145,6 +145,7 @@ python3 table/eval_table.py \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
--det_limit_type=min \
+ --rec_image_shape=3,32,320 \
--gt_path=path/to/gt.txt
```
diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md
index 1ef126261d9ce832cd1919a1b3991f341add998c..feccb70adfe20fa8c1cd06f33a10ee6fa043e69e 100644
--- a/ppstructure/table/README_ch.md
+++ b/ppstructure/table/README_ch.md
@@ -118,7 +118,7 @@ python3 table/eval_table.py \
--det_model_dir=path/to/det_model_dir \
--rec_model_dir=path/to/rec_model_dir \
--table_model_dir=path/to/table_model_dir \
- --image_dir=../doc/table/1.png \
+ --image_dir=docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
@@ -149,6 +149,7 @@ python3 table/eval_table.py \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
--det_limit_type=min \
+ --rec_image_shape=3,32,320 \
--gt_path=path/to/gt.txt
```
diff --git a/requirements.txt b/requirements.txt
index 43cd8c1b082768ebad44a5cf58fc31980ebfe891..7a018b50952a876b4839eabbd72fac09d2bbd73b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -15,3 +15,4 @@ premailer
openpyxl
attrdict
Polygon3
+PyMuPDF==1.18.7
diff --git a/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..068c4c6b1d2655b9dcda1120425de7d52d0d543d
--- /dev/null
+++ b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -0,0 +1,17 @@
+===========================paddle2onnx_params===========================
+model_name:en_table_structure
+python:python3.7
+2onnx: paddle2onnx
+--det_model_dir:./inference/en_ppocr_mobile_v2.0_table_structure_infer/
+--model_filename:inference.pdmodel
+--params_filename:inference.pdiparams
+--det_save_file:./inference/en_ppocr_mobile_v2.0_table_structure_infer/model.onnx
+--rec_model_dir:
+--rec_save_file:
+--opset_version:10
+--enable_onnx_checker:True
+inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt
+--use_gpu:True|False
+--det_model_dir:
+--rec_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
\ No newline at end of file
diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..96b43ceda84376f4d27e134d245229922d667e7e
--- /dev/null
+++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:layoutxlm_ser
+python:python3.7
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
+Architecture.Backbone.checkpoints:null
+train_model_name:latest
+train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Architecture.Backbone.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output
+--use_gpu:False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--ser_model_dir:
+--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,224,224]}]
diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..835395784022f4fcefbfff084dcef8a7bc2a146d
--- /dev/null
+++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:layoutxlm_ser
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:amp
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
+Architecture.Backbone.checkpoints:null
+train_model_name:latest
+train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Architecture.Backbone.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--ser_model_dir:
+--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,224,224]}]
diff --git a/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..45e4e9e858914dd8596cef10625df8160afe45fb
--- /dev/null
+++ b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -0,0 +1,17 @@
+===========================paddle2onnx_params===========================
+model_name:slanet
+python:python3.7
+2onnx: paddle2onnx
+--det_model_dir:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/
+--model_filename:inference.pdmodel
+--params_filename:inference.pdiparams
+--det_save_file:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/model.onnx
+--rec_model_dir:
+--rec_save_file:
+--opset_version:10
+--enable_onnx_checker:True
+inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict_ch.txt
+--use_gpu:True|False
+--det_model_dir:
+--rec_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
\ No newline at end of file
diff --git a/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c9d8d654ad286264e8511c788e278cdbcd52ec9
--- /dev/null
+++ b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:slanet
+python:python3.7
+gpu_list:192.168.0.1,192.168.0.2;0,1
+Global.use_gpu:True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
+Global.pretrained_model:./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy
+train_model_name:latest
+train_infer_img_dir:./ppstructure/docs/table/table.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/slanet/SLANet.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:./inference/en_ppstructure_mobile_v2.0_SLANet_train
+infer_export:null
+infer_quant:False
+inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
+--use_gpu:False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--table_model_dir:
+--image_dir:./ppstructure/docs/table/table.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,488,488]}]
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index ecb1e36bb1bb83c6ee2dcf1cb243e6ee60de5dd8..688deac0f379b50865fe6739529f9301ebcd919b 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -700,10 +700,18 @@ if [ ${MODE} = "cpp_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [[ ${model_name} =~ "en_table_structure" ]];then
- wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
- cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
+
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
+ if [ ${model_name} == "en_table_structure" ]; then
+ wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
+ tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+ elif [ ${model_name} == "en_table_structure_PACT" ]; then
+ wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_slim_infer.tar --no-check-certificate
+ tar xf en_ppocr_mobile_v2.0_table_structure_slim_infer.tar
+ fi
+ cd ../
elif [[ ${model_name} =~ "slanet" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
@@ -791,6 +799,12 @@ if [ ${MODE} = "paddle2onnx_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../
+ elif [[ ${model_name} =~ "slanet" ]];then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate
+ cd ./inference/ && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar && cd ../
+ elif [[ ${model_name} =~ "en_table_structure" ]];then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && cd ../
fi
# wget data
diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh
index 04bfb590f7c6e64cf136d3feef8594994cb86877..f035e6bb645a1e7927844232c2bff72f0480e38e 100644
--- a/test_tipc/test_paddle2onnx.sh
+++ b/test_tipc/test_paddle2onnx.sh
@@ -105,6 +105,19 @@ function func_paddle2onnx(){
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
+ elif [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
+ # trans det
+ set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
+ set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
+ set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
+ set_save_model=$(func_set_params "--save_file" "${det_save_file_value}")
+ set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}")
+ set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}")
+ trans_det_log="${LOG_PATH}/trans_model_det.log"
+ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=True > ${trans_det_log} 2>&1 "
+ eval $trans_model_cmd
+ last_status=${PIPESTATUS[0]}
+ status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
fi
# python inference
@@ -117,7 +130,7 @@ function func_paddle2onnx(){
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
- elif [[ ${model_name} =~ "det" ]]; then
+ elif [[ ${model_name} =~ "det" ]] || [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
elif [[ ${model_name} =~ "rec" ]]; then
@@ -136,7 +149,7 @@ function func_paddle2onnx(){
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
- elif [[ ${model_name} =~ "det" ]]; then
+ elif [[ ${model_name} =~ "det" ]]|| [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
elif [[ ${model_name} =~ "rec" ]]; then
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 00fa2e9b7fafd949c59a0eebd43f2f88ae717320..52c225d2b3913cf8c0dc88abcc07f7ccfd3cc914 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -282,44 +282,67 @@ if __name__ == "__main__":
args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir)
text_detector = TextDetector(args)
- count = 0
total_time = 0
- draw_img_save = "./inference_results"
+ draw_img_save_dir = args.draw_img_save_dir
+ os.makedirs(draw_img_save_dir, exist_ok=True)
if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(2):
res = text_detector(img)
- if not os.path.exists(draw_img_save):
- os.makedirs(draw_img_save)
save_results = []
- for image_file in image_file_list:
- img, flag, _ = check_and_read(image_file)
- if not flag:
+ for idx, image_file in enumerate(image_file_list):
+ img, flag_gif, flag_pdf = check_and_read(image_file)
+ if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
- if img is None:
- logger.info("error in loading image:{}".format(image_file))
- continue
- st = time.time()
- dt_boxes, _ = text_detector(img)
- elapse = time.time() - st
- if count > 0:
+ if not flag_pdf:
+ if img is None:
+ logger.debug("error in loading image:{}".format(image_file))
+ continue
+ imgs = [img]
+ else:
+ page_num = args.page_num
+ if page_num > len(img) or page_num == 0:
+ page_num = len(img)
+ imgs = img[:page_num]
+ for index, img in enumerate(imgs):
+ st = time.time()
+ dt_boxes, _ = text_detector(img)
+ elapse = time.time() - st
total_time += elapse
- count += 1
- save_pred = os.path.basename(image_file) + "\t" + str(
- json.dumps([x.tolist() for x in dt_boxes])) + "\n"
- save_results.append(save_pred)
- logger.info(save_pred)
- logger.info("The predict time of {}: {}".format(image_file, elapse))
- src_im = utility.draw_text_det_res(dt_boxes, image_file)
- img_name_pure = os.path.split(image_file)[-1]
- img_path = os.path.join(draw_img_save,
- "det_res_{}".format(img_name_pure))
- cv2.imwrite(img_path, src_im)
- logger.info("The visualized image saved in {}".format(img_path))
+ if len(imgs) > 1:
+ save_pred = os.path.basename(image_file) + '_' + str(
+ index) + "\t" + str(
+ json.dumps([x.tolist() for x in dt_boxes])) + "\n"
+ else:
+ save_pred = os.path.basename(image_file) + "\t" + str(
+ json.dumps([x.tolist() for x in dt_boxes])) + "\n"
+ save_results.append(save_pred)
+ logger.info(save_pred)
+ if len(imgs) > 1:
+ logger.info("{}_{} The predict time of {}: {}".format(
+ idx, index, image_file, elapse))
+ else:
+ logger.info("{} The predict time of {}: {}".format(
+ idx, image_file, elapse))
+
+ src_im = utility.draw_text_det_res(dt_boxes, img)
+
+ if flag_gif:
+ save_file = image_file[:-3] + "png"
+ elif flag_pdf:
+ save_file = image_file.replace('.pdf',
+ '_' + str(index) + '.png')
+ else:
+ save_file = image_file
+ img_path = os.path.join(
+ draw_img_save_dir,
+ "det_res_{}".format(os.path.basename(save_file)))
+ cv2.imwrite(img_path, src_im)
+ logger.info("The visualized image saved in {}".format(img_path))
- with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
+ with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
f.writelines(save_results)
f.close()
if args.benchmark:
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index e0f2c41fa2aba23491efee920afbd76db1ec84e0..affd0d1bcd1283be02ead3cd61c01c375b49bdf9 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -159,50 +159,75 @@ def main(args):
count = 0
for idx, image_file in enumerate(image_file_list):
- img, flag, _ = check_and_read(image_file)
- if not flag:
+ img, flag_gif, flag_pdf = check_and_read(image_file)
+ if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
- if img is None:
- logger.debug("error in loading image:{}".format(image_file))
- continue
- starttime = time.time()
- dt_boxes, rec_res, time_dict = text_sys(img)
- elapse = time.time() - starttime
- total_time += elapse
-
- logger.debug(
- str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
- for text, score in rec_res:
- logger.debug("{}, {:.3f}".format(text, score))
-
- res = [{
- "transcription": rec_res[idx][0],
- "points": np.array(dt_boxes[idx]).astype(np.int32).tolist(),
- } for idx in range(len(dt_boxes))]
- save_pred = os.path.basename(image_file) + "\t" + json.dumps(
- res, ensure_ascii=False) + "\n"
- save_results.append(save_pred)
-
- if is_visualize:
- image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
- boxes = dt_boxes
- txts = [rec_res[i][0] for i in range(len(rec_res))]
- scores = [rec_res[i][1] for i in range(len(rec_res))]
-
- draw_img = draw_ocr_box_txt(
- image,
- boxes,
- txts,
- scores,
- drop_score=drop_score,
- font_path=font_path)
- if flag:
- image_file = image_file[:-3] + "png"
- cv2.imwrite(
- os.path.join(draw_img_save_dir, os.path.basename(image_file)),
- draw_img[:, :, ::-1])
- logger.debug("The visualized image saved in {}".format(
- os.path.join(draw_img_save_dir, os.path.basename(image_file))))
+ if not flag_pdf:
+ if img is None:
+ logger.debug("error in loading image:{}".format(image_file))
+ continue
+ imgs = [img]
+ else:
+ page_num = args.page_num
+ if page_num > len(img) or page_num == 0:
+ page_num = len(img)
+ imgs = img[:page_num]
+ for index, img in enumerate(imgs):
+ starttime = time.time()
+ dt_boxes, rec_res, time_dict = text_sys(img)
+ elapse = time.time() - starttime
+ total_time += elapse
+ if len(imgs) > 1:
+ logger.debug(
+ str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
+ % (image_file, elapse))
+ else:
+ logger.debug(
+ str(idx) + " Predict time of %s: %.3fs" % (image_file,
+ elapse))
+ for text, score in rec_res:
+ logger.debug("{}, {:.3f}".format(text, score))
+
+ res = [{
+ "transcription": rec_res[i][0],
+ "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
+ } for i in range(len(dt_boxes))]
+ if len(imgs) > 1:
+ save_pred = os.path.basename(image_file) + '_' + str(
+ index) + "\t" + json.dumps(
+ res, ensure_ascii=False) + "\n"
+ else:
+ save_pred = os.path.basename(image_file) + "\t" + json.dumps(
+ res, ensure_ascii=False) + "\n"
+ save_results.append(save_pred)
+
+ if is_visualize:
+ image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ boxes = dt_boxes
+ txts = [rec_res[i][0] for i in range(len(rec_res))]
+ scores = [rec_res[i][1] for i in range(len(rec_res))]
+
+ draw_img = draw_ocr_box_txt(
+ image,
+ boxes,
+ txts,
+ scores,
+ drop_score=drop_score,
+ font_path=font_path)
+ if flag_gif:
+ save_file = image_file[:-3] + "png"
+ elif flag_pdf:
+ save_file = image_file.replace('.pdf',
+ '_' + str(index) + '.png')
+ else:
+ save_file = image_file
+ cv2.imwrite(
+ os.path.join(draw_img_save_dir,
+ os.path.basename(save_file)),
+ draw_img[:, :, ::-1])
+ logger.debug("The visualized image saved in {}".format(
+ os.path.join(draw_img_save_dir, os.path.basename(
+ save_file))))
logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark:
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index b9c9490bdb99f3bee67cb9460a9975b93b0d6366..e555dbec1b314510aaaf6b31f1b35bf60fefa98e 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -45,6 +45,7 @@ def init_args():
# params for text detector
parser.add_argument("--image_dir", type=str)
+ parser.add_argument("--page_num", type=int, default=0)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
@@ -337,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path):
return src_im
-def draw_text_det_res(dt_boxes, img_path):
- src_im = cv2.imread(img_path)
+def draw_text_det_res(dt_boxes, img):
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
- cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
- return src_im
+ cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
+ return img
def resize_img(img, input_size=600):