\n GT')
+ for i, filename in enumerate(sorted(images_dir)):
+ if filename.endswith("txt"): continue
+ print(filename)
+ base = "{}".format(filename)
+ if True:
+ html.write("
+ html.write(f'
{filename}\n GT')
+ html.write('
GT 310\n
' % (base))
+ html.write("
+ html.write('\n')
+ html.write('
+ html.write('\n\n')
+ print("ok")
+def crop_seal_from_img(label_file, data_dir, save_dir, save_gt_path):
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ datas = open(label_file, 'r').readlines()
+ all_gts = []
+ count = 0
+ for idx, line in enumerate(datas):
+ img_path, label = line.strip().split('\t')
+ img_path = os.path.join(data_dir, img_path)
+ label = json.loads(label)
+ src_im = cv2.imread(img_path)
+ if src_im is None:
+ continue
+ for c, anno in enumerate(label):
+ seal_poly = anno['seal_box']
+ txt_boxes = anno['polys']
+ txts = anno['texts']
+ ignore_tags = anno['ignore_tags']
+ box = poly2box(seal_poly)
+ img_crop = src_im[box[0][1]:box[2][1], box[0][0]:box[2][0], :]
+ save_path = os.path.join(save_dir, f"{idx}_{c}.jpg")
+ cv2.imwrite(save_path, np.array(img_crop))
+ img_gt = []
+ for i in range(len(txts)):
+ txt_boxes_crop = np.array(txt_boxes[i])
+ txt_boxes_crop[:, 1] -= box[0, 1]
+ txt_boxes_crop[:, 0] -= box[0, 0]
+ img_gt.append({'transcription': txts[i], "points": txt_boxes_crop.tolist(), "ignore_tag": ignore_tags[i]})
+ if len(img_gt) >= 1:
+ count += 1
+ save_gt = f"{os.path.basename(save_path)}\t{json.dumps(img_gt)}\n"
+ all_gts.append(save_gt)
+ print(f"The num of all image: {len(all_gts)}, and the number of useful image: {count}")
+ if not os.path.exists(os.path.dirname(save_gt_path)):
+ os.makedirs(os.path.dirname(save_gt_path))
+ with open(save_gt_path, "w") as f:
+ f.writelines(all_gts)
+ f.close()
+ print("Done")
+if __name__ == "__main__":
+ # 数据处理
+ gen_extract_label("./seal_labeled_datas", "./seal_labeled_datas/Label.txt", "./seal_ppocr_gt/seal_det_img.txt", "./seal_ppocr_gt/seal_ppocr_img.txt")
+ vis_seal_ppocr("./seal_labeled_datas", "./seal_ppocr_gt/seal_ppocr_img.txt", "./seal_ppocr_gt/seal_ppocr_vis/")
+ draw_html("./seal_ppocr_gt/seal_ppocr_vis/", "./vis_seal_ppocr.html")
+ seal_ppocr_img_label = "./seal_ppocr_gt/seal_ppocr_img.txt"
+ crop_seal_from_img(seal_ppocr_img_label, "./seal_labeled_datas/", "./seal_img_crop", "./seal_img_crop/label.txt")
+├── seal_img_crop/
+│ ├── 0_0.jpg
+│ ├── ...
+│ └── label.txt
+├── seal_ppocr_gt/
+│ ├── seal_det_img.txt
+│ ├── seal_ppocr_img.txt
+│ └── seal_ppocr_vis/
+│ ├── test1.png
+│ ├── ...
+└── vis_seal_ppocr.html
+0_0.jpg [{"transcription": "\u7535\u5b50\u56de\u5355", "points": [[29, 73], [96, 73], [96, 90], [29, 90]], "ignore_tag": false}, {"transcription": "\u4e91\u5357\u7701\u519c\u6751\u4fe1\u7528\u793e", "points": [[9, 58], [26, 63], [30, 49], [38, 35], [47, 29], [64, 26], [81, 32], [90, 45], [94, 63], [118, 57], [110, 35], [95, 17], [67, 0], [38, 7], [21, 23], [10, 43]], "ignore_tag": false}, {"transcription": "\u4e13\u7528\u7ae0", "points": [[29, 87], [95, 87], [95, 106], [29, 106]], "ignore_tag": false}]
+img/test1.png [{"polys": [[408, 232], [537, 232], [537, 352], [408, 352]], "cls": 1}]
+import numpy as np
+import json
+import cv2
+import os
+from shapely.geometry import Polygon
+seal_train_gt = "./seal_ppocr_gt/seal_det_img.txt"
+# 注:仅用于示例,实际使用中需要分别转换训练集和测试集的标签
+seal_valid_gt = "./seal_ppocr_gt/seal_det_img.txt"
+def gen_main_train_txt(mode='train'):
+ if mode == "train":
+ file_path = seal_train_gt
+ if mode in ['valid', 'test']:
+ file_path = seal_valid_gt
+ save_path = f"./seal_VOC/ImageSets/Main/{mode}.txt"
+ save_train_path = f"./seal_VOC/{mode}.txt"
+ if not os.path.exists(os.path.dirname(save_path)):
+ os.makedirs(os.path.dirname(save_path))
+ datas = open(file_path, 'r').readlines()
+ img_names = []
+ train_names = []
+ for line in datas:
+ img_name = line.strip().split('\t')[0]
+ img_name = os.path.basename(img_name)
+ (i_name, extension) = os.path.splitext(img_name)
+ t_name = 'JPEGImages/'+str(img_name)+' '+'Annotations/'+str(i_name)+'.xml\n'
+ train_names.append(t_name)
+ img_names.append(i_name + "\n")
+ with open(save_train_path, "w") as f:
+ f.writelines(train_names)
+ f.close()
+ with open(save_path, "w") as f:
+ f.writelines(img_names)
+ f.close()
+ print(f"{mode} save done")
+def gen_xml_label(mode='train'):
+ if mode == "train":
+ file_path = seal_train_gt
+ if mode in ['valid', 'test']:
+ file_path = seal_valid_gt
+ datas = open(file_path, 'r').readlines()
+ img_names = []
+ train_names = []
+ anno_path = "./seal_VOC/Annotations"
+ img_path = "./seal_VOC/JPEGImages"
+ if not os.path.exists(anno_path):
+ os.makedirs(anno_path)
+ if not os.path.exists(img_path):
+ os.makedirs(img_path)
+ for idx, line in enumerate(datas):
+ img_name, label = line.strip().split('\t')
+ img = cv2.imread(os.path.join("./seal_labeled_datas", img_name))
+ cv2.imwrite(os.path.join(img_path, os.path.basename(img_name)), img)
+ height, width, c = img.shape
+ img_name = os.path.basename(img_name)
+ (i_name, extension) = os.path.splitext(img_name)
+ label = json.loads(label)
+ xml_file = open(("./seal_VOC/Annotations" + '/' + i_name + '.xml'), 'w')
+ xml_file.write('\n')
+ xml_file.write(' seal_VOC\n')
+ xml_file.write(' ' + str(img_name) + '\n')
+ xml_file.write(' ' + 'Annotations/' + str(img_name) + '\n')
+ xml_file.write(' \n')
+ xml_file.write(' ' + str(width) + '\n')
+ xml_file.write(' ' + str(height) + '\n')
+ xml_file.write(' 3\n')
+ xml_file.write(' \n')
+ xml_file.write(' 0\n')
+ for anno in label:
+ poly = anno['polys']
+ if anno['cls'] == 1:
+ gt_cls = 'redseal'
+ xmin = np.min(np.array(poly)[:, 0])
+ ymin = np.min(np.array(poly)[:, 1])
+ xmax = np.max(np.array(poly)[:, 0])
+ ymax = np.max(np.array(poly)[:, 1])
+ xmin,ymin,xmax,ymax= int(xmin),int(ymin),int(xmax),int(ymax)
+ xml_file.write(' \n')
+ xml_file.write('')
+ xml_file.close()
+ print(f'{mode} xml save done!')
+├── Annotations/
+├── ImageSets/
+│ └── Main/
+│ ├── train.txt
+│ └── valid.txt
+├── JPEGImages/
+├── train.txt
+└── valid.txt
+└── label_list.txt
+# 4. 印章检测实践
+- 选择算法
+- 修改数据集配置路径
+- 启动训练
+metric: VOC
+map_type: 11point
+num_classes: 2
+ !VOCDataSet
+ dataset_dir: dataset/seal_VOC
+ anno_path: train.txt
+ label_list: label_list.txt
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
+ !VOCDataSet
+ dataset_dir: dataset/seal_VOC
+ anno_path: test.txt
+ label_list: label_list.txt
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
+ !ImageFolder
+ anno_path: dataset/seal_VOC/label_list.txt
+!ln -s seal_VOC ./PaddleDetection/dataset/
+另外图象中印章数量比较少,可以调整NMS后处理的检测框数量,即keep_top_k,nms_top_k 从100,1000,调整为10,100。在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容完成后处理参数的调整
+ decode:
+ name: YOLOBox
+ conf_thresh: 0.005
+ downsample_ratio: 32
+ clip_bbox: true
+ scale_x_y: 1.05
+ nms:
+ name: MultiClassNMS
+ keep_top_k: 10 # 修改前100
+ nms_threshold: 0.45
+ nms_top_k: 100 # 修改前1000
+ score_threshold: 0.005
+import os
+import numpy as np
+from ppdet.core.workspace import register, serializable
+from .dataset import DetDataset
+import cv2
+import json
+from ppdet.utils.logger import setup_logger
+logger = setup_logger(__name__)
+class SealDataSet(DetDataset):
+ """
+ Load dataset with COCO format.
+ Args:
+ dataset_dir (str): root directory for dataset.
+ image_dir (str): directory for images.
+ anno_path (str): coco annotation file path.
+ data_fields (list): key name of data dictionary, at least have 'image'.
+ sample_num (int): number of samples to load, -1 means all.
+ load_crowd (bool): whether to load crowded ground-truth.
+ False as default
+ allow_empty (bool): whether to load empty entry. False as default
+ empty_ratio (float): the ratio of empty record number to total
+ record's, if empty_ratio is out of [0. ,1.), do not sample the
+ records and use all the empty entries. 1. as default
+ """
+ def __init__(self,
+ dataset_dir=None,
+ image_dir=None,
+ anno_path=None,
+ data_fields=['image'],
+ sample_num=-1,
+ load_crowd=False,
+ allow_empty=False,
+ empty_ratio=1.):
+ super(SealDataSet, self).__init__(dataset_dir, image_dir, anno_path,
+ data_fields, sample_num)
+ self.load_image_only = False
+ self.load_semantic = False
+ self.load_crowd = load_crowd
+ self.allow_empty = allow_empty
+ self.empty_ratio = empty_ratio
+ def _sample_empty(self, records, num):
+ # if empty_ratio is out of [0. ,1.), do not sample the records
+ if self.empty_ratio < 0. or self.empty_ratio >= 1.:
+ return records
+ import random
+ sample_num = min(
+ int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
+ records = random.sample(records, sample_num)
+ return records
+ def parse_dataset(self):
+ anno_path = os.path.join(self.dataset_dir, self.anno_path)
+ image_dir = os.path.join(self.dataset_dir, self.image_dir)
+ records = []
+ empty_records = []
+ ct = 0
+ assert anno_path.endswith('.txt'), \
+ 'invalid seal_gt file: ' + anno_path
+ all_datas = open(anno_path, 'r').readlines()
+ for idx, line in enumerate(all_datas):
+ im_path, label = line.strip().split('\t')
+ img_path = os.path.join(image_dir, im_path)
+ label = json.loads(label)
+ im_h, im_w, im_c = cv2.imread(img_path).shape
+ coco_rec = {
+ 'im_file': img_path,
+ 'im_id': np.array([idx]),
+ 'h': im_h,
+ 'w': im_w,
+ } if 'image' in self.data_fields else {}
+ if not self.load_image_only:
+ bboxes = []
+ for anno in label:
+ poly = anno['polys']
+ # poly to box
+ x1 = np.min(np.array(poly)[:, 0])
+ y1 = np.min(np.array(poly)[:, 1])
+ x2 = np.max(np.array(poly)[:, 0])
+ y2 = np.max(np.array(poly)[:, 1])
+ eps = 1e-5
+ if x2 - x1 > eps and y2 - y1 > eps:
+ clean_box = [
+ round(float(x), 3) for x in [x1, y1, x2, y2]
+ ]
+ anno = {'clean_box': clean_box, 'gt_cls':int(anno['cls'])}
+ bboxes.append(anno)
+ else:
+ logger.info("invalid box")
+ num_bbox = len(bboxes)
+ if num_bbox <= 0:
+ continue
+ gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
+ gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
+ is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
+ # gt_poly = [None] * num_bbox
+ for i, box in enumerate(bboxes):
+ gt_class[i][0] = box['gt_cls']
+ gt_bbox[i, :] = box['clean_box']
+ is_crowd[i][0] = 0
+ gt_rec = {
+ 'is_crowd': is_crowd,
+ 'gt_class': gt_class,
+ 'gt_bbox': gt_bbox,
+ # 'gt_poly': gt_poly,
+ }
+ for k, v in gt_rec.items():
+ if k in self.data_fields:
+ coco_rec[k] = v
+ records.append(coco_rec)
+ ct += 1
+ if self.sample_num > 0 and ct >= self.sample_num:
+ break
+ self.roidbs = records
+!python3 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval
+# 分布式训练命令为:
+!python3 -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval
+[07/05 11:42:09] ppdet.engine INFO: Eval iter: 0
+[07/05 11:42:14] ppdet.metrics.metrics INFO: Accumulating evaluatation results...
+[07/05 11:42:14] ppdet.metrics.metrics INFO: mAP(0.50, 11point) = 99.31%
+[07/05 11:42:14] ppdet.engine INFO: Total sample number: 112, averge FPS: 26.45840794253432
+[07/05 11:42:14] ppdet.engine INFO: Best test bbox ap is 0.996.
+!python3 tools/infer.py -c configs/ppyolo/ppyolo_mbv3_large.yml -o weights=./output/ppyolo_mbv3_large/model_final.pdparams --img_dir=./test.jpg
+# 5. 印章文字识别实践
+| 文字检测 | 文字识别 | 端对端算法 |
+| -------- | -------- | -------- |
+## 5.1 端对端印章文字识别实践
+- 修改配置文件
+- 启动训练
+ dataset:
+ name: PGDataSet
+ data_dir: ./train_data/seal_ppocr
+ label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt]
+ ratio_list: [1.0]
+ dataset:
+ name: PGDataSet
+ data_dir: ./train_data/seal_ppocr_test
+ label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt]
+!python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml
+## 5.2 两阶段印章文字识别实践
+### 5.2.1 印章文字检测
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/seal_ppocr
+ label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt]
+ ratio_list: [1.0]
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/seal_ppocr_test
+ label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt]
+!python3 tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.epoch_num=100
+### 5.2.2 印章文字识别
+import cv2
+import numpy as np
+def get_rotate_crop_image(img, points):
+ '''
+ img_height, img_width = img.shape[0:2]
+ left = int(np.min(points[:, 0]))
+ right = int(np.max(points[:, 0]))
+ top = int(np.min(points[:, 1]))
+ bottom = int(np.max(points[:, 1]))
+ img_crop = img[top:bottom, left:right, :].copy()
+ points[:, 0] = points[:, 0] - left
+ points[:, 1] = points[:, 1] - top
+ '''
+ assert len(points) == 4, "shape of points must be 4*2"
+ img_crop_width = int(
+ max(
+ np.linalg.norm(points[0] - points[1]),
+ np.linalg.norm(points[2] - points[3])))
+ img_crop_height = int(
+ max(
+ np.linalg.norm(points[0] - points[3]),
+ np.linalg.norm(points[1] - points[2])))
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height]])
+ M = cv2.getPerspectiveTransform(points, pts_std)
+ dst_img = cv2.warpPerspective(
+ img,
+ M, (img_crop_width, img_crop_height),
+ borderMode=cv2.BORDER_REPLICATE,
+ flags=cv2.INTER_CUBIC)
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
+ dst_img = np.rot90(dst_img)
+ return dst_img
+def run(data_dir, label_file, save_dir):
+ datas = open(label_file, 'r').readlines()
+ for idx, line in enumerate(datas):
+ img_path, label = line.strip().split('\t')
+ img_path = os.path.join(data_dir, img_path)
+ label = json.loads(label)
+ src_im = cv2.imread(img_path)
+ if src_im is None:
+ continue
+ for anno in label:
+ seal_box = anno['seal_box']
+ txt_boxes = anno['polys']
+ crop_im = get_rotate_crop_image(src_im, text_boxes)
+ save_path = os.path.join(save_dir, f'{idx}.png')
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ # print(src_im.shape)
+ cv2.imwrite(save_path, crop_im)
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/seal_ppocr_crop/
+ label_file_list:
+ - ./train_data/seal_ppocr_crop/train_list.txt
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/seal_ppocr_crop/
+ label_file_list:
+ - ./train_data/seal_ppocr_crop_test/train_list.txt
+!python3 tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
diff --git a/configs/rec/rec_r31_robustscanner.yml b/configs/rec/rec_r31_robustscanner.yml
index 40d39aee3c42c18085ace035944dba057b923245..54b69d456ef67c88289504ed5cf8719588ea0803 100644
--- a/configs/rec/rec_r31_robustscanner.yml
+++ b/configs/rec/rec_r31_robustscanner.yml
@@ -12,7 +12,7 @@ Global:
use_visualdl: False
- infer_img: ./inference/rec_inference
+ infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: &max_text_length 40
diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml
index aea71388f703376120af4d0caf2fa8ccd4d92cce..91d3e104188c04f4372a6e574b7fc07608234a2e 100644
--- a/configs/rec/rec_r32_gaspin_bilstm_att.yml
+++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml
@@ -12,7 +12,7 @@ Global:
use_visualdl: False
- infer_img: doc/imgs_words/ch/word_1.jpg
+ infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ./ppocr/utils/dict/spin_dict.txt
max_text_length: 25
diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml
index 9355a236e15b60db18e8715c2702701fd5d36c71..9d286f4153eaab44bf0d259bbad4a0b3b8ada568 100755
--- a/configs/table/table_mv3.yml
+++ b/configs/table/table_mv3.yml
@@ -43,7 +43,6 @@ Architecture:
name: TableAttentionHead
hidden_size: 256
- loc_type: 2
max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 4351fdbfcb501945d6061bc1fcc3585bd76eab7d..e9bcc275d1a7157628188c337a312a49408d207b 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -101,11 +101,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
-|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
-|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
+|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
## 2. 端到端算法
diff --git a/doc/doc_ch/algorithm_rec_robustscanner.md b/doc/doc_ch/algorithm_rec_robustscanner.md
index 869f9a7c00b617de87ab3c96326e18e536bc18a8..a1ab3baf038f573c50fcc0b67fd8297459777cf8 100644
--- a/doc/doc_ch/algorithm_rec_robustscanner.md
+++ b/doc/doc_ch/algorithm_rec_robustscanner.md
@@ -26,7 +26,7 @@ Zhang
| --- | --- | --- | --- | --- |
-|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md
index c996992d2fa6297e6086ffae4bc36ad3e880873d..908a85a417c4070b95630b37b0830e08aae3ff4f 100644
--- a/doc/doc_ch/algorithm_rec_spin.md
+++ b/doc/doc_ch/algorithm_rec_spin.md
@@ -26,7 +26,7 @@ SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识
| --- | --- | --- | --- | --- |
-|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
diff --git a/doc/doc_ch/inference_args.md b/doc/doc_ch/inference_args.md
index 36efc6fbf7a6ec62bc700964dc13261fecdb9bd5..24e7223e397c94fe65b0f26d993fc507b323ed16 100644
--- a/doc/doc_ch/inference_args.md
+++ b/doc/doc_ch/inference_args.md
@@ -7,6 +7,7 @@
| 参数名称 | 类型 | 默认值 | 含义 |
| :--: | :--: | :--: | :--: |
| image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 |
+| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 |
| drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 |
| use_pdserving | bool | False | 是否使用Paddle Serving进行预测 |
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index f7ef7ad4b8fc162a6ed7b275f3beea6955b86452..90449e1729fcff898f27641d3f777c8f002f6a97 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -98,8 +98,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
-|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
-|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
+|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
diff --git a/doc/doc_en/algorithm_rec_robustscanner_en.md b/doc/doc_en/algorithm_rec_robustscanner_en.md
index a324a6d547a9e448566276234c750ad4497abf9c..99372d51300e6571827db7181a4f8537db4f1493 100644
--- a/doc/doc_en/algorithm_rec_robustscanner_en.md
+++ b/doc/doc_en/algorithm_rec_robustscanner_en.md
@@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
-|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
+|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper.
diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md
index 43ab30ce7d96cbb64ddf87156fee3012d666b2bf..03f8d8f69986fc5eb14cfdf294fc25fafb06e269 100644
--- a/doc/doc_en/algorithm_rec_spin_en.md
+++ b/doc/doc_en/algorithm_rec_spin_en.md
@@ -25,7 +25,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
-|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
diff --git a/doc/doc_en/inference_args_en.md b/doc/doc_en/inference_args_en.md
index f2c99fc8297d47f27a219bf7d8e7f2ea518257f0..b28cd8436da62dcd10f96f17751db9384ebcaa8d 100644
--- a/doc/doc_en/inference_args_en.md
+++ b/doc/doc_en/inference_args_en.md
@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par
| parameters | type | default | implication |
| :--: | :--: | :--: | :--: |
| image_dir | str | None, must be specified explicitly | Image or folder path |
+| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization |
| drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results |
| use_pdserving | bool | False | Whether to use Paddle Serving for prediction |
diff --git a/paddleocr.py b/paddleocr.py
index fa732fc110dc7873f8d89b2ca2a21817a1e6d20d..d34b8f78a56a8d8d5455c18e7e1cf1e75df8f3f9 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -480,10 +480,11 @@ class PaddleOCR(predict_system.TextSystem):
params.rec_image_shape = "3, 48, 320"
params.rec_image_shape = "3, 32, 320"
- # download model
- maybe_download(params.det_model_dir, det_url)
- maybe_download(params.rec_model_dir, rec_url)
- maybe_download(params.cls_model_dir, cls_url)
+ # download model if using paddle infer
+ if not params.use_onnx:
+ maybe_download(params.det_model_dir, det_url)
+ maybe_download(params.rec_model_dir, rec_url)
+ maybe_download(params.cls_model_dir, cls_url)
if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
index d3c86e22b02e08c18d8d5cb193f2ffb8b07ad785..50910c5b73aa2a41f329d7222fc8c632509b4c91 100644
--- a/ppocr/modeling/heads/table_att_head.py
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import math
import paddle
import paddle.nn as nn
from paddle import ParamAttr
@@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer):
def __init__(self,
- loc_type,
@@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer):
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.out_channels)
- self.loc_type = loc_type
self.in_max_len = in_max_len
- if self.loc_type == 1:
- self.loc_generator = nn.Linear(hidden_size, 4)
+ if self.in_max_len == 640:
+ self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
+ elif self.in_max_len == 800:
+ self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
- if self.in_max_len == 640:
- self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
- elif self.in_max_len == 800:
- self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
- else:
- self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
- self.loc_generator = nn.Linear(self.input_size + hidden_size,
- loc_reg_num)
+ self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
+ self.loc_generator = nn.Linear(self.input_size + hidden_size,
+ loc_reg_num)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
@@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer):
# if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1]
- if len(fea.shape) == 3:
- pass
- else:
- last_shape = int(np.prod(fea.shape[2:])) # gry added
- fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
- fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
+ last_shape = int(np.prod(fea.shape[2:])) # gry added
+ fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
+ fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size))
- output_hiddens = []
+ output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size))
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_text_length + 1):
@@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer):
structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
- output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output_hiddens[:, i, :] = outputs
+ # output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
if self.loc_type == 1:
@@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer):
outputs = None
alpha = None
max_text_length = paddle.to_tensor(self.max_text_length)
- i = 0
- while i < max_text_length + 1:
+ for i in range(max_text_length + 1):
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
- output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output_hiddens[:, i, :] = outputs
+ # output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
- i += 1
- output = paddle.concat(output_hiddens, axis=1)
+ output = output_hiddens
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs)
- if self.loc_type == 1:
- loc_preds = self.loc_generator(output)
- loc_preds = F.sigmoid(loc_preds)
- else:
- loc_fea = fea.transpose([0, 2, 1])
- loc_fea = self.loc_fea_trans(loc_fea)
- loc_fea = loc_fea.transpose([0, 2, 1])
- loc_concat = paddle.concat([output, loc_fea], axis=2)
- loc_preds = self.loc_generator(loc_concat)
- loc_preds = F.sigmoid(loc_preds)
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md
index 08635516ba8301e6f98f175e5eba8c0a97b1708e..1d082f3878c56e42d175d13c75e1fe17916e7781 100644
--- a/ppstructure/table/README.md
+++ b/ppstructure/table/README.md
@@ -114,7 +114,7 @@ python3 table/eval_table.py \
--det_model_dir=path/to/det_model_dir \
--rec_model_dir=path/to/rec_model_dir \
--table_model_dir=path/to/table_model_dir \
- --image_dir=../doc/table/1.png \
+ --image_dir=docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
@@ -145,6 +145,7 @@ python3 table/eval_table.py \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
--det_limit_type=min \
+ --rec_image_shape=3,32,320 \
diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md
index 1ef126261d9ce832cd1919a1b3991f341add998c..feccb70adfe20fa8c1cd06f33a10ee6fa043e69e 100644
--- a/ppstructure/table/README_ch.md
+++ b/ppstructure/table/README_ch.md
@@ -118,7 +118,7 @@ python3 table/eval_table.py \
--det_model_dir=path/to/det_model_dir \
--rec_model_dir=path/to/rec_model_dir \
--table_model_dir=path/to/table_model_dir \
- --image_dir=../doc/table/1.png \
+ --image_dir=docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
@@ -149,6 +149,7 @@ python3 table/eval_table.py \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \
--det_limit_type=min \
+ --rec_image_shape=3,32,320 \
diff --git a/requirements.txt b/requirements.txt
index 43cd8c1b082768ebad44a5cf58fc31980ebfe891..7a018b50952a876b4839eabbd72fac09d2bbd73b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -15,3 +15,4 @@ premailer
diff --git a/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..068c4c6b1d2655b9dcda1120425de7d52d0d543d
--- /dev/null
+++ b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -0,0 +1,17 @@
+2onnx: paddle2onnx
+inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt
\ No newline at end of file
diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..96b43ceda84376f4d27e134d245229922d667e7e
--- /dev/null
+++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
+inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output
diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..835395784022f4fcefbfff084dcef8a7bc2a146d
--- /dev/null
+++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
+norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
+inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output
diff --git a/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..45e4e9e858914dd8596cef10625df8160afe45fb
--- /dev/null
+++ b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -0,0 +1,17 @@
+2onnx: paddle2onnx
+inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict_ch.txt
\ No newline at end of file
diff --git a/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c9d8d654ad286264e8511c788e278cdbcd52ec9
--- /dev/null
+++ b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -0,0 +1,53 @@
+norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/slanet/SLANet.yml -o
+inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index ecb1e36bb1bb83c6ee2dcf1cb243e6ee60de5dd8..688deac0f379b50865fe6739529f9301ebcd919b 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -700,10 +700,18 @@ if [ ${MODE} = "cpp_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [[ ${model_name} =~ "en_table_structure" ]];then
- wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
- cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
+ if [ ${model_name} == "en_table_structure" ]; then
+ wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
+ tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+ elif [ ${model_name} == "en_table_structure_PACT" ]; then
+ wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_slim_infer.tar --no-check-certificate
+ tar xf en_ppocr_mobile_v2.0_table_structure_slim_infer.tar
+ fi
+ cd ../
elif [[ ${model_name} =~ "slanet" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
@@ -791,6 +799,12 @@ if [ ${MODE} = "paddle2onnx_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../
+ elif [[ ${model_name} =~ "slanet" ]];then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate
+ cd ./inference/ && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar && cd ../
+ elif [[ ${model_name} =~ "en_table_structure" ]];then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
+ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && cd ../
# wget data
diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh
index 04bfb590f7c6e64cf136d3feef8594994cb86877..f035e6bb645a1e7927844232c2bff72f0480e38e 100644
--- a/test_tipc/test_paddle2onnx.sh
+++ b/test_tipc/test_paddle2onnx.sh
@@ -105,6 +105,19 @@ function func_paddle2onnx(){
eval $trans_model_cmd
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
+ elif [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
+ # trans det
+ set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
+ set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
+ set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
+ set_save_model=$(func_set_params "--save_file" "${det_save_file_value}")
+ set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}")
+ set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}")
+ trans_det_log="${LOG_PATH}/trans_model_det.log"
+ trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=True > ${trans_det_log} 2>&1 "
+ eval $trans_model_cmd
+ last_status=${PIPESTATUS[0]}
+ status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
# python inference
@@ -117,7 +130,7 @@ function func_paddle2onnx(){
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
- elif [[ ${model_name} =~ "det" ]]; then
+ elif [[ ${model_name} =~ "det" ]] || [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
elif [[ ${model_name} =~ "rec" ]]; then
@@ -136,7 +149,7 @@ function func_paddle2onnx(){
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
- elif [[ ${model_name} =~ "det" ]]; then
+ elif [[ ${model_name} =~ "det" ]]|| [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
elif [[ ${model_name} =~ "rec" ]]; then
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 00fa2e9b7fafd949c59a0eebd43f2f88ae717320..52c225d2b3913cf8c0dc88abcc07f7ccfd3cc914 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -282,44 +282,67 @@ if __name__ == "__main__":
args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir)
text_detector = TextDetector(args)
- count = 0
total_time = 0
- draw_img_save = "./inference_results"
+ draw_img_save_dir = args.draw_img_save_dir
+ os.makedirs(draw_img_save_dir, exist_ok=True)
if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(2):
res = text_detector(img)
- if not os.path.exists(draw_img_save):
- os.makedirs(draw_img_save)
save_results = []
- for image_file in image_file_list:
- img, flag, _ = check_and_read(image_file)
- if not flag:
+ for idx, image_file in enumerate(image_file_list):
+ img, flag_gif, flag_pdf = check_and_read(image_file)
+ if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
- if img is None:
- logger.info("error in loading image:{}".format(image_file))
- continue
- st = time.time()
- dt_boxes, _ = text_detector(img)
- elapse = time.time() - st
- if count > 0:
+ if not flag_pdf:
+ if img is None:
+ logger.debug("error in loading image:{}".format(image_file))
+ continue
+ imgs = [img]
+ else:
+ page_num = args.page_num
+ if page_num > len(img) or page_num == 0:
+ page_num = len(img)
+ imgs = img[:page_num]
+ for index, img in enumerate(imgs):
+ st = time.time()
+ dt_boxes, _ = text_detector(img)
+ elapse = time.time() - st
total_time += elapse
- count += 1
- save_pred = os.path.basename(image_file) + "\t" + str(
- json.dumps([x.tolist() for x in dt_boxes])) + "\n"
- save_results.append(save_pred)
- logger.info(save_pred)
- logger.info("The predict time of {}: {}".format(image_file, elapse))
- src_im = utility.draw_text_det_res(dt_boxes, image_file)
- img_name_pure = os.path.split(image_file)[-1]
- img_path = os.path.join(draw_img_save,
- "det_res_{}".format(img_name_pure))
- cv2.imwrite(img_path, src_im)
- logger.info("The visualized image saved in {}".format(img_path))
+ if len(imgs) > 1:
+ save_pred = os.path.basename(image_file) + '_' + str(
+ index) + "\t" + str(
+ json.dumps([x.tolist() for x in dt_boxes])) + "\n"
+ else:
+ save_pred = os.path.basename(image_file) + "\t" + str(
+ json.dumps([x.tolist() for x in dt_boxes])) + "\n"
+ save_results.append(save_pred)
+ logger.info(save_pred)
+ if len(imgs) > 1:
+ logger.info("{}_{} The predict time of {}: {}".format(
+ idx, index, image_file, elapse))
+ else:
+ logger.info("{} The predict time of {}: {}".format(
+ idx, image_file, elapse))
+ src_im = utility.draw_text_det_res(dt_boxes, img)
+ if flag_gif:
+ save_file = image_file[:-3] + "png"
+ elif flag_pdf:
+ save_file = image_file.replace('.pdf',
+ '_' + str(index) + '.png')
+ else:
+ save_file = image_file
+ img_path = os.path.join(
+ draw_img_save_dir,
+ "det_res_{}".format(os.path.basename(save_file)))
+ cv2.imwrite(img_path, src_im)
+ logger.info("The visualized image saved in {}".format(img_path))
- with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
+ with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
if args.benchmark:
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index e0f2c41fa2aba23491efee920afbd76db1ec84e0..affd0d1bcd1283be02ead3cd61c01c375b49bdf9 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -159,50 +159,75 @@ def main(args):
count = 0
for idx, image_file in enumerate(image_file_list):
- img, flag, _ = check_and_read(image_file)
- if not flag:
+ img, flag_gif, flag_pdf = check_and_read(image_file)
+ if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
- if img is None:
- logger.debug("error in loading image:{}".format(image_file))
- continue
- starttime = time.time()
- dt_boxes, rec_res, time_dict = text_sys(img)
- elapse = time.time() - starttime
- total_time += elapse
- logger.debug(
- str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
- for text, score in rec_res:
- logger.debug("{}, {:.3f}".format(text, score))
- res = [{
- "transcription": rec_res[idx][0],
- "points": np.array(dt_boxes[idx]).astype(np.int32).tolist(),
- } for idx in range(len(dt_boxes))]
- save_pred = os.path.basename(image_file) + "\t" + json.dumps(
- res, ensure_ascii=False) + "\n"
- save_results.append(save_pred)
- if is_visualize:
- image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
- boxes = dt_boxes
- txts = [rec_res[i][0] for i in range(len(rec_res))]
- scores = [rec_res[i][1] for i in range(len(rec_res))]
- draw_img = draw_ocr_box_txt(
- image,
- boxes,
- txts,
- scores,
- drop_score=drop_score,
- font_path=font_path)
- if flag:
- image_file = image_file[:-3] + "png"
- cv2.imwrite(
- os.path.join(draw_img_save_dir, os.path.basename(image_file)),
- draw_img[:, :, ::-1])
- logger.debug("The visualized image saved in {}".format(
- os.path.join(draw_img_save_dir, os.path.basename(image_file))))
+ if not flag_pdf:
+ if img is None:
+ logger.debug("error in loading image:{}".format(image_file))
+ continue
+ imgs = [img]
+ else:
+ page_num = args.page_num
+ if page_num > len(img) or page_num == 0:
+ page_num = len(img)
+ imgs = img[:page_num]
+ for index, img in enumerate(imgs):
+ starttime = time.time()
+ dt_boxes, rec_res, time_dict = text_sys(img)
+ elapse = time.time() - starttime
+ total_time += elapse
+ if len(imgs) > 1:
+ logger.debug(
+ str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
+ % (image_file, elapse))
+ else:
+ logger.debug(
+ str(idx) + " Predict time of %s: %.3fs" % (image_file,
+ elapse))
+ for text, score in rec_res:
+ logger.debug("{}, {:.3f}".format(text, score))
+ res = [{
+ "transcription": rec_res[i][0],
+ "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
+ } for i in range(len(dt_boxes))]
+ if len(imgs) > 1:
+ save_pred = os.path.basename(image_file) + '_' + str(
+ index) + "\t" + json.dumps(
+ res, ensure_ascii=False) + "\n"
+ else:
+ save_pred = os.path.basename(image_file) + "\t" + json.dumps(
+ res, ensure_ascii=False) + "\n"
+ save_results.append(save_pred)
+ if is_visualize:
+ image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ boxes = dt_boxes
+ txts = [rec_res[i][0] for i in range(len(rec_res))]
+ scores = [rec_res[i][1] for i in range(len(rec_res))]
+ draw_img = draw_ocr_box_txt(
+ image,
+ boxes,
+ txts,
+ scores,
+ drop_score=drop_score,
+ font_path=font_path)
+ if flag_gif:
+ save_file = image_file[:-3] + "png"
+ elif flag_pdf:
+ save_file = image_file.replace('.pdf',
+ '_' + str(index) + '.png')
+ else:
+ save_file = image_file
+ cv2.imwrite(
+ os.path.join(draw_img_save_dir,
+ os.path.basename(save_file)),
+ draw_img[:, :, ::-1])
+ logger.debug("The visualized image saved in {}".format(
+ os.path.join(draw_img_save_dir, os.path.basename(
+ save_file))))
logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark:
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index b9c9490bdb99f3bee67cb9460a9975b93b0d6366..e555dbec1b314510aaaf6b31f1b35bf60fefa98e 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -45,6 +45,7 @@ def init_args():
# params for text detector
parser.add_argument("--image_dir", type=str)
+ parser.add_argument("--page_num", type=int, default=0)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960)
@@ -337,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path):
return src_im
-def draw_text_det_res(dt_boxes, img_path):
- src_im = cv2.imread(img_path)
+def draw_text_det_res(dt_boxes, img):
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
- cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
- return src_im
+ cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
+ return img
def resize_img(img, input_size=600):