> points = box;
+
+ int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
+ int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
+ int left = int(*std::min_element(x_collect, x_collect + 4));
+ int right = int(*std::max_element(x_collect, x_collect + 4));
+ int top = int(*std::min_element(y_collect, y_collect + 4));
+ int bottom = int(*std::max_element(y_collect, y_collect + 4));
+
+ cv::Mat img_crop;
+ image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
+
+ for (int i = 0; i < points.size(); i++) {
+ points[i][0] -= left;
+ points[i][1] -= top;
+ }
+
+ int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
+ pow(points[0][1] - points[1][1], 2)));
+ int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
+ pow(points[0][1] - points[3][1], 2)));
+
+ cv::Point2f pts_std[4];
+ pts_std[0] = cv::Point2f(0., 0.);
+ pts_std[1] = cv::Point2f(img_crop_width, 0.);
+ pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
+ pts_std[3] = cv::Point2f(0.f, img_crop_height);
+
+ cv::Point2f pointsf[4];
+ pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
+ pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
+ pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
+ pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
+
+ cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
+
+ cv::Mat dst_img;
+ cv::warpPerspective(img_crop, dst_img, M,
+ cv::Size(img_crop_width, img_crop_height),
+ cv::BORDER_REPLICATE);
+
+ if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
+ cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
+ cv::transpose(dst_img, srcCopy);
+ cv::flip(srcCopy, srcCopy, 0);
+ return srcCopy;
+ } else {
+ return dst_img;
+ }
+}
+
} // namespace PaddleOCR
\ No newline at end of file
diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt
deleted file mode 100644
index d4d66d65225bc9d1d4d62f45550db71fb5d8414e..0000000000000000000000000000000000000000
--- a/deploy/cpp_infer/tools/config.txt
+++ /dev/null
@@ -1,31 +0,0 @@
-# model load config
-use_gpu 0
-gpu_id 0
-gpu_mem 4000
-cpu_math_library_num_threads 10
-use_mkldnn 0
-
-# det config
-max_side_len 960
-det_db_thresh 0.3
-det_db_box_thresh 0.5
-det_db_unclip_ratio 1.6
-use_polygon_score 1
-det_model_dir ./inference/ch_ppocr_mobile_v2.0_det_infer/
-
-# cls config
-use_angle_cls 0
-cls_model_dir ./inference/ch_ppocr_mobile_v2.0_cls_infer/
-cls_thresh 0.9
-
-# rec config
-rec_model_dir ./inference/ch_ppocr_mobile_v2.0_rec_infer/
-char_list_file ../../ppocr/utils/ppocr_keys_v1.txt
-
-# show the detection results
-visualize 0
-
-# use_tensorrt
-use_tensorrt 0
-use_fp16 0
-
diff --git a/deploy/cpp_infer/tools/run.sh b/deploy/cpp_infer/tools/run.sh
deleted file mode 100755
index fa61da75e3a71262f539ee348c69fb82ed2574fb..0000000000000000000000000000000000000000
--- a/deploy/cpp_infer/tools/run.sh
+++ /dev/null
@@ -1,2 +0,0 @@
-
-./build/ocr_system ./tools/config.txt ../../doc/imgs/12.jpg
diff --git a/deploy/hubserving/readme.md b/deploy/hubserving/readme.md
index a39ac5a42b905b1efa73c02d7594511c8a7ea103..11b843fec1052c3ad401ca0b7d1cb602401af8f8 100755
--- a/deploy/hubserving/readme.md
+++ b/deploy/hubserving/readme.md
@@ -29,7 +29,8 @@ deploy/hubserving/ocr_system/
### 1. 准备环境
```shell
# 安装paddlehub
-pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
+# paddlehub 需要 python>3.6.2
+pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
```
### 2. 下载推理模型
diff --git a/deploy/hubserving/readme_en.md b/deploy/hubserving/readme_en.md
index 7d9a8629ef7d27e84e636f029202602a94d1d3f7..539ad722cae78b8315b87d35f9af6ab81140c5b3 100755
--- a/deploy/hubserving/readme_en.md
+++ b/deploy/hubserving/readme_en.md
@@ -30,7 +30,8 @@ The following steps take the 2-stage series service as an example. If only the d
### 1. Prepare the environment
```shell
# Install paddlehub
-pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
+# python>3.6.2 is required bt paddlehub
+pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
```
### 2. Download inference model
diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py
index 100b107a1deb1ce9932c9cefa50659c060f5803e..d94e53034a2bf67b364e6d91f83acfb9e5445b8a 100755
--- a/deploy/slim/quantization/export_model.py
+++ b/deploy/slim/quantization/export_model.py
@@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT
from ppocr.data import build_dataloader
+def export_single_model(quanter, model, infer_shape, save_path, logger):
+ quanter.save_quantized_model(
+ model,
+ save_path,
+ input_spec=[
+ paddle.static.InputSpec(
+ shape=[None] + infer_shape, dtype='float32')
+ ])
+ logger.info('inference QAT model is saved to {}'.format(save_path))
+
+
def main():
############################################################################################################
# 1. quantization configs
@@ -76,14 +87,21 @@ def main():
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
- config['Architecture']["Head"]['out_channels'] = char_num
+ if config['Architecture']["algorithm"] in ["Distillation",
+ ]: # distillation model
+ for key in config['Architecture']["Models"]:
+ config['Architecture']["Models"][key]["Head"][
+ 'out_channels'] = char_num
+ else: # base rec model
+ config['Architecture']["Head"]['out_channels'] = char_num
+
model = build_model(config['Architecture'])
# get QAT model
quanter = QAT(config=quant_config)
quanter.quantize(model)
- init_model(config, model, logger)
+ init_model(config, model)
model.eval()
# build metric
@@ -92,25 +110,30 @@ def main():
# build dataloader
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
+ use_srn = config['Architecture']['algorithm'] == "SRN"
+ model_type = config['Architecture']['model_type']
# start eval
- metirc = program.eval(model, valid_dataloader, post_process_class,
- eval_class)
+ metric = program.eval(model, valid_dataloader, post_process_class,
+ eval_class, model_type, use_srn)
+
logger.info('metric eval ***************')
- for k, v in metirc.items():
+ for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
- save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640]
- quanter.save_quantized_model(
- model,
- save_path,
- input_spec=[
- paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype='float32')
- ])
- logger.info('inference QAT model is saved to {}'.format(save_path))
+ save_path = config["Global"]["save_inference_dir"]
+
+ arch_config = config["Architecture"]
+ if arch_config["algorithm"] in ["Distillation", ]: # distillation model
+ for idx, name in enumerate(model.model_name_list):
+ sub_model_save_path = os.path.join(save_path, name, "inference")
+ export_single_model(quanter, model.model_list[idx], infer_shape,
+ sub_model_save_path, logger)
+ else:
+ save_path = os.path.join(save_path, "inference")
+ export_single_model(quanter, model, infer_shape, save_path, logger)
if __name__ == "__main__":
diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py
index 315e3b4321a544e77795c43d493873fcf46e1930..37aab68a0e88afce54e10fb6248c73684b58d808 100755
--- a/deploy/slim/quantization/quant.py
+++ b/deploy/slim/quantization/quant.py
@@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
- config['Architecture']["Head"]['out_channels'] = char_num
+ if config['Architecture']["algorithm"] in ["Distillation",
+ ]: # distillation model
+ for key in config['Architecture']["Models"]:
+ config['Architecture']["Models"][key]["Head"][
+ 'out_channels'] = char_num
+ else: # base rec model
+ config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
+ quanter = QAT(config=quant_config, act_preprocess=PACT)
+ quanter.quantize(model)
+
if config['Global']['distributed']:
model = paddle.DataParallel(model)
@@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer):
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
- quanter = QAT(config=quant_config, act_preprocess=PACT)
- quanter.quantize(model)
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md
index 08b309cbc0def059df1f12a180be39d94511a78c..6fc85992c04123a10ad937f2694b513b50a37876 100644
--- a/doc/doc_ch/detection.md
+++ b/doc/doc_ch/detection.md
@@ -18,9 +18,9 @@ PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支
```
# 将官网下载的标签文件转换为 train_icdar2015_label.txt
-python gen_label.py --mode="det" --root_path="icdar_c4_train_imgs/" \
- --input_path="ch4_training_localization_transcription_gt" \
- --output_label="train_icdar2015_label.txt"
+python gen_label.py --mode="det" --root_path="/path/to/icdar_c4_train_imgs/" \
+ --input_path="/path/to/ch4_training_localization_transcription_gt" \
+ --output_label="/path/to/train_icdar2015_label.txt"
```
解压数据集和下载标注文件后,PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是:
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index 97e3b92bdf19874927d1cc87e5b14ed7197c3184..b9be1e4cb2d1b256a05b82ef5d6db49dfcb2f31f 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -147,12 +147,12 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
```
如果想使用CPU进行预测,执行命令如下
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
```
@@ -221,7 +221,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Gl
```
-**SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,**可以执行如下命令:
+SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,可以执行如下命令:
```
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
```
diff --git a/doc/doc_ch/knowledge_distillation.md b/doc/doc_ch/knowledge_distillation.md
new file mode 100644
index 0000000000000000000000000000000000000000..b561f718491011e8dddcd44e66bfd6da62101ba6
--- /dev/null
+++ b/doc/doc_ch/knowledge_distillation.md
@@ -0,0 +1,251 @@
+# 知识蒸馏
+
+
+## 1. 简介
+
+### 1.1 知识蒸馏介绍
+
+近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
+
+在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。
+
+深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升。
+
+此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。
+
+### 1.2 PaddleOCR知识蒸馏简介
+
+无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
+
+对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。
+
+在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。
+
+PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要的特点:
+- 支持任意网络的互相学习,不要求子网络结构完全一致或者具有预训练模型;同时子网络数量也没有任何限制,只需要在配置文件中添加即可。
+- 支持loss函数通过配置文件任意配置,不仅可以使用某种loss,也可以使用多种loss的组合
+- 支持知识蒸馏训练、预测、评估与导出等所有模型相关的环境,方便使用与部署。
+
+
+通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。
+
+
+
+## 2. 配置文件解析
+
+在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。
+
+下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。
+
+### 2.1 识别配置文件解析
+
+配置文件在[rec_chinese_lite_train_distillation_v2.1.yml](../../configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml)。
+
+#### 2.1.1 模型结构
+
+知识蒸馏任务中,模型结构配置如下所示。
+
+```yaml
+Architecture:
+ model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与
+ name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
+ algorithm: Distillation # 算法名称
+ Models: # 模型,包含子网络的配置信息
+ Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
+ pretrained: # 该子网络是否需要加载预训练模型
+ freeze_params: false # 是否需要固定参数
+ return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
+ model_type: *model_type # 模型类别
+ algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致
+ Transform:
+ Backbone:
+ name: MobileNetV1Enhance
+ scale: 0.5
+ Neck:
+ name: SequenceEncoder
+ encoder_type: rnn
+ hidden_size: 64
+ Head:
+ name: CTCHead
+ mid_channels: 96
+ fc_decay: 0.00002
+ Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
+ pretrained: # 下面的组网参数同上
+ freeze_params: false
+ return_all_feats: true
+ model_type: *model_type
+ algorithm: CRNN
+ Transform:
+ Backbone:
+ name: MobileNetV1Enhance
+ scale: 0.5
+ Neck:
+ name: SequenceEncoder
+ encoder_type: rnn
+ hidden_size: 64
+ Head:
+ name: CTCHead
+ mid_channels: 96
+ fc_decay: 0.00002
+```
+
+当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student`与`Teacher`的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么`Architecture`可以写为如下格式。
+
+```yaml
+Architecture:
+ model_type: &model_type "rec"
+ name: DistillationModel
+ algorithm: Distillation
+ Models:
+ Teacher:
+ pretrained:
+ freeze_params: false
+ return_all_feats: true
+ model_type: *model_type
+ algorithm: CRNN
+ Transform:
+ Backbone:
+ name: MobileNetV1Enhance
+ scale: 0.5
+ Neck:
+ name: SequenceEncoder
+ encoder_type: rnn
+ hidden_size: 64
+ Head:
+ name: CTCHead
+ mid_channels: 96
+ fc_decay: 0.00002
+ Student:
+ pretrained:
+ freeze_params: false
+ return_all_feats: true
+ model_type: *model_type
+ algorithm: CRNN
+ Transform:
+ Backbone:
+ name: MobileNetV1Enhance
+ scale: 0.5
+ Neck:
+ name: SequenceEncoder
+ encoder_type: rnn
+ hidden_size: 64
+ Head:
+ name: CTCHead
+ mid_channels: 96
+ fc_decay: 0.00002
+ Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
+ pretrained:
+ freeze_params: false
+ return_all_feats: true
+ model_type: *model_type
+ algorithm: CRNN
+ Transform:
+ Backbone:
+ name: MobileNetV1Enhance
+ scale: 0.5
+ Neck:
+ name: SequenceEncoder
+ encoder_type: rnn
+ hidden_size: 64
+ Head:
+ name: CTCHead
+ mid_channels: 96
+ fc_decay: 0.00002
+```
+
+最终该模型训练时,包含3个子网络:`Teacher`, `Student`, `Student2`。
+
+蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py)。
+
+最终模型`forward`输出为一个字典,key为所有的子网络名称,例如这里为`Student`与`Teacher`,value为对应子网络的输出,可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。
+
+在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为`dict`,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为`dict`,key包含`backbone_out`,`neck_out`, `head_out`,`value`为对应模块的tensor,最终对于上述配置文件,`DistillationModel`的输出格式如下。
+
+```json
+{
+ "Teacher": {
+ "backbone_out": tensor,
+ "neck_out": tensor,
+ "head_out": tensor,
+ },
+ "Student": {
+ "backbone_out": tensor,
+ "neck_out": tensor,
+ "head_out": tensor,
+ }
+}
+```
+
+#### 2.1.2 损失函数
+
+知识蒸馏任务中,损失函数配置如下所示。
+
+```yaml
+Loss:
+ name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
+ loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
+ - DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
+ weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
+ model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
+ key: head_out # 取子网络输出dict中,该key对应的tensor
+ - DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
+ weight: 1.0 # 权重
+ act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
+ model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
+ - ["Student", "Teacher"]
+ key: head_out # 取子网络输出dict中,该key对应的tensor
+ - DistillationDistanceLoss: # 蒸馏的距离损失函数
+ weight: 1.0 # 权重
+ mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
+ model_name_pairs: # 用于计算distance loss的子网络名称对
+ - ["Student", "Teacher"]
+ key: backbone_out # 取子网络输出dict中,该key对应的tensor
+```
+
+上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
+
+以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
+
+- `Student`和`Teacher`的最终输出(`head_out`)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
+- `Student`和`Teacher`的最终输出(`head_out`)之间的DML loss,权重为1。
+- `Student`和`Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。
+
+关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)。
+
+
+#### 2.1.3 后处理
+
+知识蒸馏任务中,后处理配置如下所示。
+
+```yaml
+PostProcess:
+ name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
+ model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码
+ key: head_out # 取子网络输出dict中,该key对应的tensor
+```
+
+以上述配置为例,最终会同时计算`Student`和`Teahcer` 2个子网络的CTC解码输出,返回一个`dict`,`key`为用于处理的子网络名称,`value`为用于处理的子网络列表。
+
+关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)
+
+
+#### 2.1.4 指标计算
+
+知识蒸馏任务中,指标计算配置如下所示。
+
+```yaml
+Metric:
+ name: DistillationMetric # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
+ base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标
+ main_indicator: acc # 指标的名称
+ key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准
+```
+
+以上述配置为例,最终会使用`Student`子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。
+
+关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
+
+
+### 2.2 检测配置文件解析
+
+* coming soon!
diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md
index 6ce3003c75a638aa6282c03ba6567059d42f3dbc..f803b0bdefe555093bd92322686f6d00cdbe4e8d 100644
--- a/doc/doc_ch/recognition.md
+++ b/doc/doc_ch/recognition.md
@@ -331,6 +331,8 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
```
+意大利文由拉丁字母组成,因此执行完命令后会得到名为 rec_latin_lite_train.yml 的配置文件。
+
2. 手动修改配置文件
您也可以手动修改模版中的以下几个字段:
@@ -376,7 +378,9 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
-多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。
+多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以通过下面两种方式下载。
+* [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA)。提取码:frgi。
+* [google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
如您希望在现有模型效果的基础上调优,请参考下列说明修改配置文件:
diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md
index c341b49a7b12aa10f0f3187bc861306fcae05c29..167ed7b2b8a13706dfe1533265b6d96560265511 100644
--- a/doc/doc_ch/whl.md
+++ b/doc/doc_ch/whl.md
@@ -5,26 +5,32 @@
### 1.1 安装whl包
pip安装
+
```bash
pip install "paddleocr>=2.0.1" # 推荐使用2.0.1+版本
```
本地构建并安装
+
```bash
python3 setup.py bdist_wheel
pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x是paddleocr的版本号
```
## 2 使用
+
### 2.1 代码使用
+
paddleocr whl包会自动下载ppocr轻量级模型作为默认模型,可以根据第3节**自定义模型**进行自定义更换。
* 检测+方向分类器+识别全流程
+
```python
from paddleocr import PaddleOCR, draw_ocr
+
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
-ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
+ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
result = ocr.ocr(img_path, cls=True)
for line in result:
@@ -32,6 +38,7 @@ for line in result:
# 显示结果
from PIL import Image
+
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
@@ -40,31 +47,36 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
+
结果是一个list,每个item包含了文本框,文字和识别置信度
+
```bash
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
[[[24.0, 109.0], [333.0, 109.0], [333.0, 136.0], [24.0, 136.0]], ['(45元/每公斤,100公斤起订)', 0.9676722]]
......
```
+
结果可视化
-
* 检测+识别
+
```python
from paddleocr import PaddleOCR, draw_ocr
-ocr = PaddleOCR() # need to run only once to download and load model into memory
+
+ocr = PaddleOCR() # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
-result = ocr.ocr(img_path,cls=False)
+result = ocr.ocr(img_path, cls=False)
for line in result:
print(line)
# 显示结果
from PIL import Image
+
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
@@ -73,38 +85,46 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
+
结果是一个list,每个item包含了文本框,文字和识别置信度
+
```bash
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
[[[24.0, 109.0], [333.0, 109.0], [333.0, 136.0], [24.0, 136.0]], ['(45元/每公斤,100公斤起订)', 0.9676722]]
......
```
+
结果可视化
-
* 方向分类器+识别
+
```python
from paddleocr import PaddleOCR
-ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
+
+ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
result = ocr.ocr(img_path, det=False, cls=True)
for line in result:
print(line)
```
+
结果是一个list,每个item只包含识别结果和识别置信度
+
```bash
['韩国小馆', 0.9907421]
```
* 单独执行检测
+
```python
from paddleocr import PaddleOCR, draw_ocr
-ocr = PaddleOCR() # need to run only once to download and load model into memory
+
+ocr = PaddleOCR() # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
result = ocr.ocr(img_path, rec=False)
for line in result:
@@ -118,13 +138,16 @@ im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/Pa
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
+
结果是一个list,每个item只包含文本框
+
```bash
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
......
```
+
结果可视化
@@ -133,29 +156,37 @@ im_show.save('result.jpg')
* 单独执行识别
+
```python
from paddleocr import PaddleOCR
-ocr = PaddleOCR() # need to run only once to download and load model into memory
+
+ocr = PaddleOCR() # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
result = ocr.ocr(img_path, det=False)
for line in result:
print(line)
```
+
结果是一个list,每个item只包含识别结果和识别置信度
+
```bash
['韩国小馆', 0.9907421]
```
* 单独执行方向分类器
+
```python
from paddleocr import PaddleOCR
-ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
+
+ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
result = ocr.ocr(img_path, det=False, rec=False, cls=True)
for line in result:
print(line)
```
+
结果是一个list,每个item只包含分类结果和分类置信度
+
```bash
['0', 0.9999924]
```
@@ -163,15 +194,19 @@ for line in result:
### 2.2 通过命令行使用
查看帮助信息
+
```bash
paddleocr -h
```
* 检测+方向分类器+识别全流程
+
```bash
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --use_angle_cls true
```
+
结果是一个list,每个item包含了文本框,文字和识别置信度
+
```bash
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
@@ -180,10 +215,13 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --use_angle_cls true
```
* 检测+识别
+
```bash
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg
```
+
结果是一个list,每个item包含了文本框,文字和识别置信度
+
```bash
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
@@ -192,20 +230,25 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg
```
* 方向分类器+识别
+
```bash
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls true --det false
```
结果是一个list,每个item只包含识别结果和识别置信度
+
```bash
['韩国小馆', 0.9907421]
```
* 单独执行检测
+
```bash
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
```
+
结果是一个list,每个item只包含文本框
+
```bash
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
@@ -214,34 +257,42 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
```
* 单独执行识别
+
```bash
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --det false
```
结果是一个list,每个item只包含识别结果和识别置信度
+
```bash
['韩国小馆', 0.9907421]
```
* 单独执行方向分类器
+
```bash
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls true --det false --rec false
```
结果是一个list,每个item只包含分类结果和分类置信度
+
```bash
['0', 0.9999924]
```
## 3 自定义模型
-当内置模型无法满足需求时,需要使用到自己训练的模型。
-首先,参照[inference.md](./inference.md) 第一节转换将检测、分类和识别模型转换为inference模型,然后按照如下方式使用
+
+当内置模型无法满足需求时,需要使用到自己训练的模型。 首先,参照[inference.md](./inference.md) 第一节转换将检测、分类和识别模型转换为inference模型,然后按照如下方式使用
### 3.1 代码使用
+
```python
from paddleocr import PaddleOCR, draw_ocr
+
# 模型路径下必须含有model和params文件
-ocr = PaddleOCR(det_model_dir='{your_det_model_dir}', rec_model_dir='{your_rec_model_dir}', rec_char_dict_path='{your_rec_char_dict_path}', cls_model_dir='{your_cls_model_dir}', use_angle_cls=True)
+ocr = PaddleOCR(det_model_dir='{your_det_model_dir}', rec_model_dir='{your_rec_model_dir}',
+ rec_char_dict_path='{your_rec_char_dict_path}', cls_model_dir='{your_cls_model_dir}',
+ use_angle_cls=True)
img_path = 'PaddleOCR/doc/imgs/11.jpg'
result = ocr.ocr(img_path, cls=True)
for line in result:
@@ -249,6 +300,7 @@ for line in result:
# 显示结果
from PIL import Image
+
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
@@ -269,11 +321,13 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
### 4.1 网络图片
- 代码使用
+
```python
-from paddleocr import PaddleOCR, draw_ocr
+from paddleocr import PaddleOCR, draw_ocr, download_with_progressbar
+
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
-ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
+ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
result = ocr.ocr(img_path, cls=True)
for line in result:
@@ -281,7 +335,9 @@ for line in result:
# 显示结果
from PIL import Image
-image = Image.open(img_path).convert('RGB')
+
+download_with_progressbar(img_path, 'tmp.jpg')
+image = Image.open('tmp.jpg').convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
@@ -289,18 +345,24 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
+
- 命令行模式
+
```bash
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
```
### 4.2 numpy数组
+
仅通过代码使用时支持numpy数组作为输入
+
```python
+import cv2
from paddleocr import PaddleOCR, draw_ocr
+
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
-ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
+ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
img = cv2.imread(img_path)
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), 如果你自己训练的模型支持灰度图,可以将这句话的注释取消
@@ -310,6 +372,7 @@ for line in result:
# 显示结果
from PIL import Image
+
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
@@ -355,3 +418,5 @@ im_show.save('result.jpg')
| det | 前向时使用启动检测 | TRUE |
| rec | 前向时是否启动识别 | TRUE |
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
+| show_log | 是否打印det和rec等信息 | FALSE |
+| type | 执行ocr或者表格结构化, 值可选['ocr','structure'] | ocr |
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
index a78bb6b6fef98f8d025b8f21a9ecfe67bb2af007..e30355fb8e29031bd4ce040a86ad0f57d18ce398 100755
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -154,12 +154,12 @@ Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest si
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
```
If you want to use the CPU for prediction, execute the command as follows
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
```
@@ -230,7 +230,7 @@ First, convert the model saved in the SAST text detection training process into
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_tt
```
-**For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`**, run the following command:
+For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`, run the following command:
```
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md
index 7a5e827deea972163eef3adce351dd40048497be..e884651756bdb49bc52e929c2a6e8b6b6c81966f 100644
--- a/doc/doc_en/recognition_en.md
+++ b/doc/doc_en/recognition_en.md
@@ -329,6 +329,7 @@ There are two ways to create the required configuration file::
...
```
+Italian is made up of Latin letters, so after executing the command, you will get the rec_latin_lite_train.yml.
2. Manually modify the configuration file
@@ -375,7 +376,9 @@ Currently, the multi-language algorithms supported by PaddleOCR are:
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
-The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
+The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
+* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
+* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
If you want to finetune on the basis of the existing model effect, please refer to the following instructions to modify the configuration file:
diff --git a/doc/doc_en/update_en.md b/doc/doc_en/update_en.md
index 1e80012e0608f0e28291d0f57b5a5d0beffe2e8c..ca2ecb0535ce27bc7f98a476752a131f869761d5 100644
--- a/doc/doc_en/update_en.md
+++ b/doc/doc_en/update_en.md
@@ -15,8 +15,6 @@
- 2020.6.8 Add [datasets](./datasets_en.md) and keep updating
- 2020.6.5 Support exporting `attention` model to `inference_model`
- 2020.6.5 Support separate prediction and recognition, output result score
-- 2020.6.5 Support exporting `attention` model to `inference_model`
-- 2020.6.5 Support separate prediction and recognition, output result score
- 2020.5.30 Provide Lightweight Chinese OCR online experience
- 2020.5.30 Model prediction and training support on Windows system
- 2020.5.30 Open source general Chinese OCR model
diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md
index eeaf1347dc77a24f158ba8ba2c6f013b1fd89b81..c8c8353accdf7f6ce179d3700547bfe9bd70c200 100644
--- a/doc/doc_en/whl_en.md
+++ b/doc/doc_en/whl_en.md
@@ -305,7 +305,8 @@ paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-f
Support numpy array as input only when used by code
```python
-from paddleocr import PaddleOCR, draw_ocr
+import cv2
+from paddleocr import PaddleOCR, draw_ocr, download_with_progressbar
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg'
img = cv2.imread(img_path)
@@ -316,7 +317,9 @@ for line in result:
# show result
from PIL import Image
-image = Image.open(img_path).convert('RGB')
+
+download_with_progressbar(img_path, 'tmp.jpg')
+image = Image.open('tmp.jpg').convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
@@ -362,3 +365,5 @@ im_show.save('result.jpg')
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
+| show_log | Whether to print log in det and rec | FALSE |
+| type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr |
\ No newline at end of file
diff --git a/doc/joinus.PNG b/doc/joinus.PNG
index 4a274e631d8516789fca47e2db66bc0ac2d0f223..1228ce0a4ddd549b9ddfe00090675d9bd7e3cb6b 100644
Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ
diff --git a/doc/table/1.png b/doc/table/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..47df618ab1bef431a5dd94418c01be16b09d31aa
Binary files /dev/null and b/doc/table/1.png differ
diff --git a/doc/table/layout.jpg b/doc/table/layout.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..db7246b314556d73cd49d049b9b480887b6ef994
Binary files /dev/null and b/doc/table/layout.jpg differ
diff --git a/doc/table/paper-image.jpg b/doc/table/paper-image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..db7246b314556d73cd49d049b9b480887b6ef994
Binary files /dev/null and b/doc/table/paper-image.jpg differ
diff --git a/doc/table/pipeline.jpg b/doc/table/pipeline.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8cea262149199e010450b69d3323b9b06e40c773
Binary files /dev/null and b/doc/table/pipeline.jpg differ
diff --git a/doc/table/pipeline_en.jpg b/doc/table/pipeline_en.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2e4d1a03546308ff79f4dfb6b67e8e83420951c5
Binary files /dev/null and b/doc/table/pipeline_en.jpg differ
diff --git a/doc/table/ppstructure.GIF b/doc/table/ppstructure.GIF
new file mode 100644
index 0000000000000000000000000000000000000000..bff836e3ea53d447c948309de56ac5d2ad553264
Binary files /dev/null and b/doc/table/ppstructure.GIF differ
diff --git a/doc/table/result_all.jpg b/doc/table/result_all.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3bffd40ed8821bb5259337bc7651cde51c5a7861
Binary files /dev/null and b/doc/table/result_all.jpg differ
diff --git a/doc/table/result_text.jpg b/doc/table/result_text.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5f164a1ab10d3f3e63c761e6e059c2905ba9c878
Binary files /dev/null and b/doc/table/result_text.jpg differ
diff --git a/doc/table/table.jpg b/doc/table/table.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3daa619e52dc2471df62ea7767be3bff350b623f
Binary files /dev/null and b/doc/table/table.jpg differ
diff --git a/doc/table/tableocr_pipeline.jpg b/doc/table/tableocr_pipeline.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..da868791b16af3b56cb07c86f18e25b45c6f5b47
Binary files /dev/null and b/doc/table/tableocr_pipeline.jpg differ
diff --git a/doc/table/tableocr_pipeline_en.jpg b/doc/table/tableocr_pipeline_en.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cedc9bd5ca06147d6c4d22e709418fc7081d940e
Binary files /dev/null and b/doc/table/tableocr_pipeline_en.jpg differ
diff --git a/paddleocr.py b/paddleocr.py
index 1e4d94ff4e72da951e1ffb92edb50715482581ae..c52737f55b61cd29c08367adb6d7e05c561e933e 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -19,154 +19,119 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
+import logging
import numpy as np
from pathlib import Path
-import tarfile
-import requests
-from tqdm import tqdm
from tools.infer import predict_system
from ppocr.utils.logging import get_logger
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
-from tools.infer.utility import draw_ocr, init_args, str2bool
+from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
+from tools.infer.utility import draw_ocr, str2bool
+from ppstructure.utility import init_args, draw_structure_result
+from ppstructure.predict_system import OCRSystem, save_structure_res
-__all__ = ['PaddleOCR']
+__all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
model_urls = {
'det': {
'ch':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
+ 'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
},
'rec': {
'ch': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
'french': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
+ },
+ 'structure': {
+ 'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
+ 'dict_path': 'ppocr/utils/dict/table_dict.txt'
}
},
- 'cls':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
+ 'cls': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
+ 'table': {
+ 'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
+ 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
+ }
}
SUPPORT_DET_MODEL = ['DB']
-VERSION = '2.1'
+VERSION = '2.2'
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
-def download_with_progressbar(url, save_path):
- response = requests.get(url, stream=True)
- total_size_in_bytes = int(response.headers.get('content-length', 0))
- block_size = 1024 # 1 Kibibyte
- progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
- with open(save_path, 'wb') as file:
- for data in response.iter_content(block_size):
- progress_bar.update(len(data))
- file.write(data)
- progress_bar.close()
- if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
- logger.error("Something went wrong while downloading models")
- sys.exit(0)
-
-
-def maybe_download(model_storage_directory, url):
- # using custom model
- tar_file_name_list = [
- 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
- ]
- if not os.path.exists(
- os.path.join(model_storage_directory, 'inference.pdiparams')
- ) or not os.path.exists(
- os.path.join(model_storage_directory, 'inference.pdmodel')):
- tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
- print('download {} to {}'.format(url, tmp_path))
- os.makedirs(model_storage_directory, exist_ok=True)
- download_with_progressbar(url, tmp_path)
- with tarfile.open(tmp_path, 'r') as tarObj:
- for member in tarObj.getmembers():
- filename = None
- for tar_file_name in tar_file_name_list:
- if tar_file_name in member.name:
- filename = tar_file_name
- if filename is None:
- continue
- file = tarObj.extractfile(member)
- with open(
- os.path.join(model_storage_directory, filename),
- 'wb') as f:
- f.write(file.read())
- os.remove(tmp_path)
-
-
def parse_args(mMain=True):
import argparse
parser = init_args()
@@ -174,9 +139,10 @@ def parse_args(mMain=True):
parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
+ parser.add_argument("--type", type=str, default='ocr')
for action in parser._actions:
- if action.dest == 'rec_char_dict_path':
+ if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
action.default = None
if mMain:
return parser.parse_args()
@@ -187,6 +153,42 @@ def parse_args(mMain=True):
return argparse.Namespace(**inference_args_dict)
+def parse_lang(lang):
+ latin_lang = [
+ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
+ 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
+ 'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
+ 'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
+ ]
+ arabic_lang = ['ar', 'fa', 'ug', 'ur']
+ cyrillic_lang = [
+ 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
+ 'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
+ ]
+ devanagari_lang = [
+ 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
+ 'gom', 'sa', 'bgc'
+ ]
+ if lang in latin_lang:
+ lang = "latin"
+ elif lang in arabic_lang:
+ lang = "arabic"
+ elif lang in cyrillic_lang:
+ lang = "cyrillic"
+ elif lang in devanagari_lang:
+ lang = "devanagari"
+ assert lang in model_urls[
+ 'rec'], 'param lang must in {}, but got {}'.format(
+ model_urls['rec'].keys(), lang)
+ if lang == "ch":
+ det_lang = "ch"
+ elif lang == 'structure':
+ det_lang = 'structure'
+ else:
+ det_lang = "en"
+ return lang, det_lang
+
+
class PaddleOCR(predict_system.TextSystem):
def __init__(self, **kwargs):
"""
@@ -194,75 +196,41 @@ class PaddleOCR(predict_system.TextSystem):
args:
**kwargs: other params show in paddleocr --help
"""
- postprocess_params = parse_args(mMain=False)
- postprocess_params.__dict__.update(**kwargs)
- self.use_angle_cls = postprocess_params.use_angle_cls
- lang = postprocess_params.lang
- latin_lang = [
- 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
- 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
- 'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
- 'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
- ]
- arabic_lang = ['ar', 'fa', 'ug', 'ur']
- cyrillic_lang = [
- 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
- 'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
- ]
- devanagari_lang = [
- 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
- 'gom', 'sa', 'bgc'
- ]
- if lang in latin_lang:
- lang = "latin"
- elif lang in arabic_lang:
- lang = "arabic"
- elif lang in cyrillic_lang:
- lang = "cyrillic"
- elif lang in devanagari_lang:
- lang = "devanagari"
- assert lang in model_urls[
- 'rec'], 'param lang must in {}, but got {}'.format(
- model_urls['rec'].keys(), lang)
- if lang == "ch":
- det_lang = "ch"
- else:
- det_lang = "en"
- use_inner_dict = False
- if postprocess_params.rec_char_dict_path is None:
- use_inner_dict = True
- postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
- 'dict_path']
+ params = parse_args(mMain=False)
+ params.__dict__.update(**kwargs)
+ if not params.show_log:
+ logger.setLevel(logging.INFO)
+ self.use_angle_cls = params.use_angle_cls
+ lang, det_lang = parse_lang(params.lang)
# init model dir
- if postprocess_params.det_model_dir is None:
- postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
- 'det', det_lang)
- if postprocess_params.rec_model_dir is None:
- postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
- 'rec', lang)
- if postprocess_params.cls_model_dir is None:
- postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
- print(postprocess_params)
+ params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
+ model_urls['det'][det_lang])
+ params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
+ model_urls['rec'][lang]['url'])
+ params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
+ model_urls['cls'])
# download model
- maybe_download(postprocess_params.det_model_dir,
- model_urls['det'][det_lang])
- maybe_download(postprocess_params.rec_model_dir,
- model_urls['rec'][lang]['url'])
- maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
+ maybe_download(params.det_model_dir, det_url)
+ maybe_download(params.rec_model_dir, rec_url)
+ maybe_download(params.cls_model_dir, cls_url)
- if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
+ if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0)
- if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
+ if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0)
- if use_inner_dict:
- postprocess_params.rec_char_dict_path = str(
- Path(__file__).parent / postprocess_params.rec_char_dict_path)
+ if params.rec_char_dict_path is None:
+ params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
+
+ print(params)
# init det_model and rec_model
- super().__init__(postprocess_params)
+ super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True):
"""
@@ -316,11 +284,64 @@ class PaddleOCR(predict_system.TextSystem):
return rec_res
+class PPStructure(OCRSystem):
+ def __init__(self, **kwargs):
+ params = parse_args(mMain=False)
+ params.__dict__.update(**kwargs)
+ if not params.show_log:
+ logger.setLevel(logging.INFO)
+ lang, det_lang = parse_lang(params.lang)
+
+ # init model dir
+ params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
+ model_urls['det'][det_lang])
+ params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
+ model_urls['rec'][lang]['url'])
+ params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
+ model_urls['table']['url'])
+ # download model
+ maybe_download(params.det_model_dir, det_url)
+ maybe_download(params.rec_model_dir, rec_url)
+ maybe_download(params.table_model_dir, table_url)
+
+ if params.rec_char_dict_path is None:
+ params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
+ if params.table_char_dict_path is None:
+ params.table_char_dict_path = str(Path(__file__).parent / model_urls['table']['dict_path'])
+
+ print(params)
+ super().__init__(params)
+
+ def __call__(self, img):
+ if isinstance(img, str):
+ # download net image
+ if img.startswith('http'):
+ download_with_progressbar(img, 'tmp.jpg')
+ img = 'tmp.jpg'
+ image_file = img
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ with open(image_file, 'rb') as f:
+ np_arr = np.frombuffer(f.read(), dtype=np.uint8)
+ img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
+ if img is None:
+ logger.error("error in loading image:{}".format(image_file))
+ return None
+ if isinstance(img, np.ndarray) and len(img.shape) == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ res = super().__call__(img)
+ return res
+
+
def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
- if image_dir.startswith('http'):
+ if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
@@ -328,14 +349,29 @@ def main():
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return
+ if args.type == 'ocr':
+ engine = PaddleOCR(**(args.__dict__))
+ elif args.type == 'structure':
+ engine = PPStructure(**(args.__dict__))
+ else:
+ raise NotImplementedError
- ocr_engine = PaddleOCR(**(args.__dict__))
for img_path in image_file_list:
+ img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
- result = ocr_engine.ocr(img_path,
+ if args.type == 'ocr':
+ result = engine.ocr(img_path,
det=args.det,
rec=args.rec,
cls=args.use_angle_cls)
- if result is not None:
- for line in result:
- logger.info(line)
+ if result is not None:
+ for line in result:
+ logger.info(line)
+ elif args.type == 'structure':
+ result = engine(img_path)
+ save_structure_res(result, args.output, img_name)
+
+ for item in result:
+ item.pop('img')
+ logger.info(item)
+
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 728b8317f54687ee76b519cba18f4d7807493821..e860c5a6986f495e6384d9df93c24795c04a0d5f 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet
+from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
- support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
+ support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 9f175382cc662099e0e8660652ee5aa1521f2f4d..5c384c1d316041d11b41ba9dedbf61e98a0259ac 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -23,12 +23,14 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
from .randaugment import RandAugment
+from .copy_paste import CopyPaste
from .operators import *
from .label_ops import *
from .east_process import *
from .sast_process import *
from .pg_process import *
+from .gen_table_mask import *
def transform(data, ops=None):
diff --git a/ppocr/data/imaug/copy_paste.py b/ppocr/data/imaug/copy_paste.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf62e2a3d813671551efa1a76c03754b1b764f5
--- /dev/null
+++ b/ppocr/data/imaug/copy_paste.py
@@ -0,0 +1,166 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import cv2
+import random
+import numpy as np
+from PIL import Image
+from shapely.geometry import Polygon
+
+from ppocr.data.imaug.iaa_augment import IaaAugment
+from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
+from tools.infer.utility import get_rotate_crop_image
+
+
+class CopyPaste(object):
+ def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
+ self.ext_data_num = 1
+ self.objects_paste_ratio = objects_paste_ratio
+ self.limit_paste = limit_paste
+ augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
+ self.aug = IaaAugment(augmenter_args)
+
+ def __call__(self, data):
+ src_img = data['image']
+ src_polys = data['polys'].tolist()
+ src_ignores = data['ignore_tags'].tolist()
+ ext_data = data['ext_data'][0]
+ ext_image = ext_data['image']
+ ext_polys = ext_data['polys']
+ ext_ignores = ext_data['ignore_tags']
+
+ indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
+ select_num = max(
+ 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
+
+ random.shuffle(indexs)
+ select_idxs = indexs[:select_num]
+ select_polys = ext_polys[select_idxs]
+ select_ignores = ext_ignores[select_idxs]
+
+ src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
+ ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
+ src_img = Image.fromarray(src_img).convert('RGBA')
+ for poly, tag in zip(select_polys, select_ignores):
+ box_img = get_rotate_crop_image(ext_image, poly)
+
+ src_img, box = self.paste_img(src_img, box_img, src_polys)
+ if box is not None:
+ src_polys.append(box)
+ src_ignores.append(tag)
+ src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
+ h, w = src_img.shape[:2]
+ src_polys = np.array(src_polys)
+ src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
+ src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
+ data['image'] = src_img
+ data['polys'] = src_polys
+ data['ignore_tags'] = np.array(src_ignores)
+ return data
+
+ def paste_img(self, src_img, box_img, src_polys):
+ box_img_pil = Image.fromarray(box_img).convert('RGBA')
+ src_w, src_h = src_img.size
+ box_w, box_h = box_img_pil.size
+
+ angle = np.random.randint(0, 360)
+ box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
+ box = rotate_bbox(box_img, box, angle)[0]
+ box_img_pil = box_img_pil.rotate(angle, expand=1)
+ box_w, box_h = box_img_pil.width, box_img_pil.height
+ if src_w - box_w < 0 or src_h - box_h < 0:
+ return src_img, None
+
+ paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
+ src_h - box_h)
+ if paste_x is None:
+ return src_img, None
+ box[:, 0] += paste_x
+ box[:, 1] += paste_y
+ r, g, b, A = box_img_pil.split()
+ src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
+
+ return src_img, box
+
+ def select_coord(self, src_polys, box, endx, endy):
+ if self.limit_paste:
+ xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
+ ), box[:, 0].max(), box[:, 1].max()
+ for _ in range(50):
+ paste_x = random.randint(0, endx)
+ paste_y = random.randint(0, endy)
+ xmin1 = xmin + paste_x
+ xmax1 = xmax + paste_x
+ ymin1 = ymin + paste_y
+ ymax1 = ymax + paste_y
+
+ num_poly_in_rect = 0
+ for poly in src_polys:
+ if not is_poly_outside_rect(poly, xmin1, ymin1,
+ xmax1 - xmin1, ymax1 - ymin1):
+ num_poly_in_rect += 1
+ break
+ if num_poly_in_rect == 0:
+ return paste_x, paste_y
+ return None, None
+ else:
+ paste_x = random.randint(0, endx)
+ paste_y = random.randint(0, endy)
+ return paste_x, paste_y
+
+
+def get_union(pD, pG):
+ return Polygon(pD).union(Polygon(pG)).area
+
+
+def get_intersection_over_union(pD, pG):
+ return get_intersection(pD, pG) / get_union(pD, pG)
+
+
+def get_intersection(pD, pG):
+ return Polygon(pD).intersection(Polygon(pG)).area
+
+
+def rotate_bbox(img, text_polys, angle, scale=1):
+ """
+ from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
+ Args:
+ img: np.ndarray
+ text_polys: np.ndarray N*4*2
+ angle: int
+ scale: int
+
+ Returns:
+
+ """
+ w = img.shape[1]
+ h = img.shape[0]
+
+ rangle = np.deg2rad(angle)
+ nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
+ nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
+ rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
+ rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
+ rot_mat[0, 2] += rot_move[0]
+ rot_mat[1, 2] += rot_move[1]
+
+ # ---------------------- rotate box ----------------------
+ rot_text_polys = list()
+ for bbox in text_polys:
+ point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
+ point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
+ point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
+ point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
+ rot_text_polys.append([point1, point2, point3, point4])
+ return np.array(rot_text_polys, dtype=np.float32)
diff --git a/ppocr/data/imaug/gen_table_mask.py b/ppocr/data/imaug/gen_table_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..08e35d5d1df7f9663b4e008451308d0ee409cf5a
--- /dev/null
+++ b/ppocr/data/imaug/gen_table_mask.py
@@ -0,0 +1,244 @@
+"""
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import sys
+import six
+import cv2
+import numpy as np
+
+
+class GenTableMask(object):
+ """ gen table mask """
+
+ def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
+ self.shrink_h_max = 5
+ self.shrink_w_max = 5
+ self.mask_type = mask_type
+
+ def projection(self, erosion, h, w, spilt_threshold=0):
+ # 水平投影
+ projection_map = np.ones_like(erosion)
+ project_val_array = [0 for _ in range(0, h)]
+
+ for j in range(0, h):
+ for i in range(0, w):
+ if erosion[j, i] == 255:
+ project_val_array[j] += 1
+ # 根据数组,获取切割点
+ start_idx = 0 # 记录进入字符区的索引
+ end_idx = 0 # 记录进入空白区域的索引
+ in_text = False # 是否遍历到了字符区内
+ box_list = []
+ for i in range(len(project_val_array)):
+ if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
+ in_text = True
+ start_idx = i
+ elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
+ end_idx = i
+ in_text = False
+ if end_idx - start_idx <= 2:
+ continue
+ box_list.append((start_idx, end_idx + 1))
+
+ if in_text:
+ box_list.append((start_idx, h - 1))
+ # 绘制投影直方图
+ for j in range(0, h):
+ for i in range(0, project_val_array[j]):
+ projection_map[j, i] = 0
+ return box_list, projection_map
+
+ def projection_cx(self, box_img):
+ box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
+ h, w = box_gray_img.shape
+ # 灰度图片进行二值化处理
+ ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
+ # 纵向腐蚀
+ if h < w:
+ kernel = np.ones((2, 1), np.uint8)
+ erode = cv2.erode(thresh1, kernel, iterations=1)
+ else:
+ erode = thresh1
+ # 水平膨胀
+ kernel = np.ones((1, 5), np.uint8)
+ erosion = cv2.dilate(erode, kernel, iterations=1)
+ # 水平投影
+ projection_map = np.ones_like(erosion)
+ project_val_array = [0 for _ in range(0, h)]
+
+ for j in range(0, h):
+ for i in range(0, w):
+ if erosion[j, i] == 255:
+ project_val_array[j] += 1
+ # 根据数组,获取切割点
+ start_idx = 0 # 记录进入字符区的索引
+ end_idx = 0 # 记录进入空白区域的索引
+ in_text = False # 是否遍历到了字符区内
+ box_list = []
+ spilt_threshold = 0
+ for i in range(len(project_val_array)):
+ if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
+ in_text = True
+ start_idx = i
+ elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
+ end_idx = i
+ in_text = False
+ if end_idx - start_idx <= 2:
+ continue
+ box_list.append((start_idx, end_idx + 1))
+
+ if in_text:
+ box_list.append((start_idx, h - 1))
+ # 绘制投影直方图
+ for j in range(0, h):
+ for i in range(0, project_val_array[j]):
+ projection_map[j, i] = 0
+ split_bbox_list = []
+ if len(box_list) > 1:
+ for i, (h_start, h_end) in enumerate(box_list):
+ if i == 0:
+ h_start = 0
+ if i == len(box_list):
+ h_end = h
+ word_img = erosion[h_start:h_end + 1, :]
+ word_h, word_w = word_img.shape
+ w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
+ w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
+ if h_start > 0:
+ h_start -= 1
+ h_end += 1
+ word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
+ split_bbox_list.append([w_start, h_start, w_end, h_end])
+ else:
+ split_bbox_list.append([0, 0, w, h])
+ return split_bbox_list
+
+ def shrink_bbox(self, bbox):
+ left, top, right, bottom = bbox
+ sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
+ sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
+ left_new = left + sh_w
+ right_new = right - sh_w
+ top_new = top + sh_h
+ bottom_new = bottom - sh_h
+ if left_new >= right_new:
+ left_new = left
+ right_new = right
+ if top_new >= bottom_new:
+ top_new = top
+ bottom_new = bottom
+ return [left_new, top_new, right_new, bottom_new]
+
+ def __call__(self, data):
+ img = data['image']
+ cells = data['cells']
+ height, width = img.shape[0:2]
+ if self.mask_type == 1:
+ mask_img = np.zeros((height, width), dtype=np.float32)
+ else:
+ mask_img = np.zeros((height, width, 3), dtype=np.float32)
+ cell_num = len(cells)
+ for cno in range(cell_num):
+ if "bbox" in cells[cno]:
+ bbox = cells[cno]['bbox']
+ left, top, right, bottom = bbox
+ box_img = img[top:bottom, left:right, :].copy()
+ split_bbox_list = self.projection_cx(box_img)
+ for sno in range(len(split_bbox_list)):
+ split_bbox_list[sno][0] += left
+ split_bbox_list[sno][1] += top
+ split_bbox_list[sno][2] += left
+ split_bbox_list[sno][3] += top
+
+ for sno in range(len(split_bbox_list)):
+ left, top, right, bottom = split_bbox_list[sno]
+ left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
+ if self.mask_type == 1:
+ mask_img[top:bottom, left:right] = 1.0
+ data['mask_img'] = mask_img
+ else:
+ mask_img[top:bottom, left:right, :] = (255, 255, 255)
+ data['image'] = mask_img
+ return data
+
+class ResizeTableImage(object):
+ def __init__(self, max_len, **kwargs):
+ super(ResizeTableImage, self).__init__()
+ self.max_len = max_len
+
+ def get_img_bbox(self, cells):
+ bbox_list = []
+ if len(cells) == 0:
+ return bbox_list
+ cell_num = len(cells)
+ for cno in range(cell_num):
+ if "bbox" in cells[cno]:
+ bbox = cells[cno]['bbox']
+ bbox_list.append(bbox)
+ return bbox_list
+
+ def resize_img_table(self, img, bbox_list, max_len):
+ height, width = img.shape[0:2]
+ ratio = max_len / (max(height, width) * 1.0)
+ resize_h = int(height * ratio)
+ resize_w = int(width * ratio)
+ img_new = cv2.resize(img, (resize_w, resize_h))
+ bbox_list_new = []
+ for bno in range(len(bbox_list)):
+ left, top, right, bottom = bbox_list[bno].copy()
+ left = int(left * ratio)
+ top = int(top * ratio)
+ right = int(right * ratio)
+ bottom = int(bottom * ratio)
+ bbox_list_new.append([left, top, right, bottom])
+ return img_new, bbox_list_new
+
+ def __call__(self, data):
+ img = data['image']
+ if 'cells' not in data:
+ cells = []
+ else:
+ cells = data['cells']
+ bbox_list = self.get_img_bbox(cells)
+ img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
+ data['image'] = img_new
+ cell_num = len(cells)
+ bno = 0
+ for cno in range(cell_num):
+ if "bbox" in data['cells'][cno]:
+ data['cells'][cno]['bbox'] = bbox_list_new[bno]
+ bno += 1
+ data['max_len'] = self.max_len
+ return data
+
+class PaddingTableImage(object):
+ def __init__(self, **kwargs):
+ super(PaddingTableImage, self).__init__()
+
+ def __call__(self, data):
+ img = data['image']
+ max_len = data['max_len']
+ padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
+ height, width = img.shape[0:2]
+ padding_img[0:height, 0:width, :] = img.copy()
+ data['image'] = padding_img
+ return data
+
\ No newline at end of file
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 39ff89303f4895ec9f6e21eff159544d49527a27..f6263950959b0ee6a96647fb248098bb5c567651 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -19,6 +19,7 @@ from __future__ import unicode_literals
import numpy as np
import string
+import json
class ClsLabelEncode(object):
@@ -39,7 +40,6 @@ class DetLabelEncode(object):
pass
def __call__(self, data):
- import json
label = data['label']
label = json.loads(label)
nBox = len(label)
@@ -53,6 +53,8 @@ class DetLabelEncode(object):
txt_tags.append(True)
else:
txt_tags.append(False)
+ if len(boxes) == 0:
+ return None
boxes = self.expand_points_num(boxes)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
@@ -379,3 +381,171 @@ class SRNLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class TableLabelEncode(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ max_elem_length,
+ max_cell_num,
+ character_dict_path,
+ span_weight=1.0,
+ **kwargs):
+ self.max_text_length = max_text_length
+ self.max_elem_length = max_elem_length
+ self.max_cell_num = max_cell_num
+ list_character, list_elem = self.load_char_elem_dict(
+ character_dict_path)
+ list_character = self.add_special_char(list_character)
+ list_elem = self.add_special_char(list_elem)
+ self.dict_character = {}
+ for i, char in enumerate(list_character):
+ self.dict_character[char] = i
+ self.dict_elem = {}
+ for i, elem in enumerate(list_elem):
+ self.dict_elem[elem] = i
+ self.span_weight = span_weight
+
+ def load_char_elem_dict(self, character_dict_path):
+ list_character = []
+ list_elem = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
+ character_num = int(substr[0])
+ elem_num = int(substr[1])
+
+ for cno in range(1, 1 + character_num):
+ character = lines[cno].decode('utf-8').strip("\r\n")
+ list_character.append(character)
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
+ elem = lines[eno].decode('utf-8').strip("\r\n")
+ list_elem.append(elem)
+ return list_character, list_elem
+
+ def add_special_char(self, list_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ list_character = [self.beg_str] + list_character + [self.end_str]
+ return list_character
+
+ def get_span_idx_list(self):
+ span_idx_list = []
+ for elem in self.dict_elem:
+ if 'span' in elem:
+ span_idx_list.append(self.dict_elem[elem])
+ return span_idx_list
+
+ def __call__(self, data):
+ cells = data['cells']
+ structure = data['structure']['tokens']
+ structure = self.encode(structure, 'elem')
+ if structure is None:
+ return None
+ elem_num = len(structure)
+ structure = [0] + structure + [len(self.dict_elem) - 1]
+ structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
+ )
+ structure = np.array(structure)
+ data['structure'] = structure
+ elem_char_idx1 = self.dict_elem['']
+ elem_char_idx2 = self.dict_elem[' | 0:
+ span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
+ span_weight = min(max(span_weight, 1.0), self.span_weight)
+ for cno in range(len(cells)):
+ if 'bbox' in cells[cno]:
+ bbox = cells[cno]['bbox'].copy()
+ bbox[0] = bbox[0] * 1.0 / img_width
+ bbox[1] = bbox[1] * 1.0 / img_height
+ bbox[2] = bbox[2] * 1.0 / img_width
+ bbox[3] = bbox[3] * 1.0 / img_height
+ td_idx = td_idx_list[cno]
+ bbox_list[td_idx] = bbox
+ bbox_list_mask[td_idx] = 1.0
+ cand_span_idx = td_idx + 1
+ if cand_span_idx < (self.max_elem_length + 2):
+ if structure[cand_span_idx] in span_idx_list:
+ structure_mask[cand_span_idx] = span_weight
+
+ data['bbox_list'] = bbox_list
+ data['bbox_list_mask'] = bbox_list_mask
+ data['structure_mask'] = structure_mask
+ char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
+ char_end_idx = self.get_beg_end_flag_idx('end', 'char')
+ elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
+ elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
+ data['sp_tokens'] = np.array([
+ char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
+ elem_char_idx1, elem_char_idx2, self.max_text_length,
+ self.max_elem_length, self.max_cell_num, elem_num
+ ])
+ return data
+
+ def encode(self, text, char_or_elem):
+ """convert text-label into text-index.
+ """
+ if char_or_elem == "char":
+ max_len = self.max_text_length
+ current_dict = self.dict_character
+ else:
+ max_len = self.max_elem_length
+ current_dict = self.dict_elem
+ if len(text) > max_len:
+ return None
+ if len(text) == 0:
+ if char_or_elem == "char":
+ return [self.dict_character['space']]
+ else:
+ return None
+ text_list = []
+ for char in text:
+ if char not in current_dict:
+ return None
+ text_list.append(current_dict[char])
+ if len(text_list) == 0:
+ if char_or_elem == "char":
+ return [self.dict_character['space']]
+ else:
+ return None
+ return text_list
+
+ def get_ignored_tokens(self, char_or_elem):
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
+ if char_or_elem == "char":
+ if beg_or_end == "beg":
+ idx = np.array(self.dict_character[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict_character[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
+ % beg_or_end
+ elif char_or_elem == "elem":
+ if beg_or_end == "beg":
+ idx = np.array(self.dict_elem[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict_elem[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
+ % beg_or_end
+ else:
+ assert False, "Unsupport type %s in char_or_elem" \
+ % char_or_elem
+ return idx
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 950c9377988fe3454f506b9e947445be9094241a..aa3acd1d1264fb2af75c773efd1e5c575b465fe5 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -113,7 +113,7 @@ class NormalizeImage(object):
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
- img.astype('float32') * self.scale - self.mean) / self.std
+ img.astype('float32') * self.scale - self.mean) / self.std
return data
@@ -195,7 +195,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
- h, w, _ = img.shape
+ h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
@@ -206,7 +206,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
- else:
+ elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
@@ -214,6 +214,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
+ elif self.limit_type == 'resize_long':
+ ratio = float(limit_side_len) / max(h,w)
+ else:
+ raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..78b76c5afb8c96bc96730c7b8ad76b4bafa31c67
--- /dev/null
+++ b/ppocr/data/pubtab_dataset.py
@@ -0,0 +1,107 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+import random
+from paddle.io import Dataset
+import json
+
+from .imaug import transform, create_operators
+
+
+class PubTabDataSet(Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ super(PubTabDataSet, self).__init__()
+ self.logger = logger
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+
+ label_file_path = dataset_config.pop('label_file_path')
+
+ self.data_dir = dataset_config['data_dir']
+ self.do_shuffle = loader_config['shuffle']
+ self.do_hard_select = False
+ if 'hard_select' in loader_config:
+ self.do_hard_select = loader_config['hard_select']
+ self.hard_prob = loader_config['hard_prob']
+ if self.do_hard_select:
+ self.img_select_prob = self.load_hard_select_prob()
+ self.table_select_type = None
+ if 'table_select_type' in loader_config:
+ self.table_select_type = loader_config['table_select_type']
+ self.table_select_prob = loader_config['table_select_prob']
+
+ self.seed = seed
+ logger.info("Initialize indexs of datasets:%s" % label_file_path)
+ with open(label_file_path, "rb") as f:
+ self.data_lines = f.readlines()
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+ if mode.lower() == "train":
+ self.shuffle_data_random()
+ self.ops = create_operators(dataset_config['transforms'], global_config)
+
+ def shuffle_data_random(self):
+ if self.do_shuffle:
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
+ return
+
+ def __getitem__(self, idx):
+ try:
+ data_line = self.data_lines[idx]
+ data_line = data_line.decode('utf-8').strip("\n")
+ info = json.loads(data_line)
+ file_name = info['filename']
+ select_flag = True
+ if self.do_hard_select:
+ prob = self.img_select_prob[file_name]
+ if prob < random.uniform(0, 1):
+ select_flag = False
+
+ if self.table_select_type:
+ structure = info['html']['structure']['tokens'].copy()
+ structure_str = ''.join(structure)
+ table_type = "simple"
+ if 'colspan' in structure_str or 'rowspan' in structure_str:
+ table_type = "complex"
+ if table_type == "complex":
+ if self.table_select_prob < random.uniform(0, 1):
+ select_flag = False
+
+ if select_flag:
+ cells = info['html']['cells'].copy()
+ structure = info['html']['structure'].copy()
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {'img_path': img_path, 'cells': cells, 'structure':structure}
+ if not os.path.exists(img_path):
+ raise Exception("{} does not exist!".format(img_path))
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ outs = transform(data, self.ops)
+ else:
+ outs = None
+ except Exception as e:
+ self.logger.error(
+ "When parsing line {}, error happened with msg: {}".format(
+ data_line, e))
+ outs = None
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return len(self.data_idx_order_list)
diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py
index 8f8fcb4dbdf3c68587875b50cb30a834a3943216..e9c3394cbe930d5169ae005e7582a2902e697b7e 100644
--- a/ppocr/data/simple_dataset.py
+++ b/ppocr/data/simple_dataset.py
@@ -69,12 +69,42 @@ class SimpleDataSet(Dataset):
random.shuffle(self.data_lines)
return
+ def get_ext_data(self):
+ ext_data_num = 0
+ for op in self.ops:
+ if hasattr(op, 'ext_data_num'):
+ ext_data_num = getattr(op, 'ext_data_num')
+ break
+ load_data_ops = self.ops[:2]
+ ext_data = []
+
+ while len(ext_data) < ext_data_num:
+ file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
+ ))]
+ data_line = self.data_lines[file_idx]
+ data_line = data_line.decode('utf-8')
+ substr = data_line.strip("\n").split(self.delimiter)
+ file_name = substr[0]
+ label = substr[1]
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {'img_path': img_path, 'label': label}
+ if not os.path.exists(img_path):
+ continue
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ data = transform(data, load_data_ops)
+ if data is None:
+ continue
+ ext_data.append(data)
+ return ext_data
+
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
try:
data_line = data_line.decode('utf-8')
- substr = data_line.strip("\n").strip("\r").split(self.delimiter)
+ substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
@@ -84,6 +114,7 @@ class SimpleDataSet(Dataset):
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
+ data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index e1c3ed95e3212a4837fa9d8b443b1dc721fa8bae..7f4ab152969b46f7a716d9211fc56f7ecb489e75 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -38,11 +38,15 @@ from .basic_loss import DistanceLoss
# combined loss function
from .combined_loss import CombinedLoss
+# table loss
+from .table_att_loss import TableAttentionLoss
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
- 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss']
+
+ 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss'
+ ]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py
index fa3ceda1b747aad3c4b275611b1257bf6950f013..8306523ac1a933f0c664fc0b4cf077659cccdee3 100644
--- a/ppocr/losses/basic_loss.py
+++ b/ppocr/losses/basic_loss.py
@@ -54,6 +54,27 @@ class CELoss(nn.Layer):
return loss
+class KLJSLoss(object):
+ def __init__(self, mode='kl'):
+ assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
+ self.mode = mode
+
+ def __call__(self, p1, p2, reduction="mean"):
+
+ loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
+
+ if self.mode.lower() == "js":
+ loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
+ loss *= 0.5
+ if reduction == "mean":
+ loss = paddle.mean(loss, axis=[1,2])
+ elif reduction=="none" or reduction is None:
+ return loss
+ else:
+ loss = paddle.sum(loss, axis=[1,2])
+
+ return loss
+
class DMLLoss(nn.Layer):
"""
DMLLoss
@@ -69,17 +90,21 @@ class DMLLoss(nn.Layer):
self.act = nn.Sigmoid()
else:
self.act = None
+
+ self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
-
- log_out1 = paddle.log(out1)
- log_out2 = paddle.log(out2)
- loss = (F.kl_div(
- log_out1, out2, reduction='batchmean') + F.kl_div(
- log_out2, out1, reduction='batchmean')) / 2.0
+ if len(out1.shape) < 2:
+ log_out1 = paddle.log(out1)
+ log_out2 = paddle.log(out2)
+ loss = (F.kl_div(
+ log_out1, out2, reduction='batchmean') + F.kl_div(
+ log_out2, out1, reduction='batchmean')) / 2.0
+ else:
+ loss = self.jskl_loss(out1, out2)
return loss
diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py
index 54da70174cba7bf5ca35e8fbf5aa137a437ae29c..0d6fe968d0d7733200a4cfd21d779196cccaba03 100644
--- a/ppocr/losses/combined_loss.py
+++ b/ppocr/losses/combined_loss.py
@@ -17,7 +17,7 @@ import paddle.nn as nn
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss
-from .distillation_loss import DistillationDistanceLoss
+from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
class CombinedLoss(nn.Layer):
@@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer):
def forward(self, input, batch, **kargs):
loss_dict = {}
+ loss_all = 0.
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
- loss = {
- "{}_{}".format(key, idx): loss[key] * weight
- for key in loss
- }
- loss_dict.update(loss)
- loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
+ for key in loss.keys():
+ if key == "loss":
+ loss_all += loss[key] * weight
+ else:
+ loss_dict["{}_{}".format(key, idx)] = loss[key]
+ loss_dict["loss"] = loss_all
return loss_dict
diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py
index 1e8aa0d8602e3ddd49913e6a572914859377ca42..75f0a773152e52c98ada5c1907f1c8cc2f72d8f3 100644
--- a/ppocr/losses/distillation_loss.py
+++ b/ppocr/losses/distillation_loss.py
@@ -14,23 +14,76 @@
import paddle
import paddle.nn as nn
+import numpy as np
+import cv2
from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
+from .det_db_loss import DBLoss
+from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
+
+
+def _sum_loss(loss_dict):
+ if "loss" in loss_dict.keys():
+ return loss_dict
+ else:
+ loss_dict["loss"] = 0.
+ for k, value in loss_dict.items():
+ if k == "loss":
+ continue
+ else:
+ loss_dict["loss"] += value
+ return loss_dict
class DistillationDMLLoss(DMLLoss):
"""
"""
- def __init__(self, model_name_pairs=[], act=None, key=None,
- name="loss_dml"):
+ def __init__(self,
+ model_name_pairs=[],
+ act=None,
+ key=None,
+ maps_name=None,
+ name="dml"):
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
- self.model_name_pairs = model_name_pairs
+ self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
+ self.maps_name = self._check_maps_name(maps_name)
+
+ def _check_model_name_pairs(self, model_name_pairs):
+ if not isinstance(model_name_pairs, list):
+ return []
+ elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
+ return model_name_pairs
+ else:
+ return [model_name_pairs]
+
+ def _check_maps_name(self, maps_name):
+ if maps_name is None:
+ return None
+ elif type(maps_name) == str:
+ return [maps_name]
+ elif type(maps_name) == list:
+ return [maps_name]
+ else:
+ return None
+
+ def _slice_out(self, outs):
+ new_outs = {}
+ for k in self.maps_name:
+ if k == "thrink_maps":
+ new_outs[k] = outs[:, 0, :, :]
+ elif k == "threshold_maps":
+ new_outs[k] = outs[:, 1, :, :]
+ elif k == "binary_maps":
+ new_outs[k] = outs[:, 2, :, :]
+ else:
+ continue
+ return new_outs
def forward(self, predicts, batch):
loss_dict = dict()
@@ -40,13 +93,30 @@ class DistillationDMLLoss(DMLLoss):
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
- loss = super().forward(out1, out2)
- if isinstance(loss, dict):
- for key in loss:
- loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
- idx)] = loss[key]
+
+ if self.maps_name is None:
+ loss = super().forward(out1, out2)
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
+ idx)] = loss[key]
+ else:
+ loss_dict["{}_{}".format(self.name, idx)] = loss
else:
- loss_dict["{}_{}".format(self.name, idx)] = loss
+ outs1 = self._slice_out(out1)
+ outs2 = self._slice_out(out2)
+ for _c, k in enumerate(outs1.keys()):
+ loss = super().forward(outs1[k], outs2[k])
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}_{}_{}_{}".format(key, pair[
+ 0], pair[1], map_name, idx)] = loss[key]
+ else:
+ loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
+ idx)] = loss
+
+ loss_dict = _sum_loss(loss_dict)
+
return loss_dict
@@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss):
return loss_dict
+class DistillationDBLoss(DBLoss):
+ def __init__(self,
+ model_name_list=[],
+ balance_loss=True,
+ main_loss_type='DiceLoss',
+ alpha=5,
+ beta=10,
+ ohem_ratio=3,
+ eps=1e-6,
+ name="db",
+ **kwargs):
+ super().__init__()
+ self.model_name_list = model_name_list
+ self.name = name
+ self.key = None
+
+ def forward(self, predicts, batch):
+ loss_dict = {}
+ for idx, model_name in enumerate(self.model_name_list):
+ out = predicts[model_name]
+ if self.key is not None:
+ out = out[self.key]
+ loss = super().forward(out, batch)
+
+ if isinstance(loss, dict):
+ for key in loss.keys():
+ if key == "loss":
+ continue
+ name = "{}_{}_{}".format(self.name, model_name, key)
+ loss_dict[name] = loss[key]
+ else:
+ loss_dict["{}_{}".format(self.name, model_name)] = loss
+
+ loss_dict = _sum_loss(loss_dict)
+ return loss_dict
+
+
+class DistillationDilaDBLoss(DBLoss):
+ def __init__(self,
+ model_name_pairs=[],
+ key=None,
+ balance_loss=True,
+ main_loss_type='DiceLoss',
+ alpha=5,
+ beta=10,
+ ohem_ratio=3,
+ eps=1e-6,
+ name="dila_dbloss"):
+ super().__init__()
+ self.model_name_pairs = model_name_pairs
+ self.name = name
+ self.key = key
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ stu_outs = predicts[pair[0]]
+ tch_outs = predicts[pair[1]]
+ if self.key is not None:
+ stu_preds = stu_outs[self.key]
+ tch_preds = tch_outs[self.key]
+
+ stu_shrink_maps = stu_preds[:, 0, :, :]
+ stu_binary_maps = stu_preds[:, 2, :, :]
+
+ # dilation to teacher prediction
+ dilation_w = np.array([[1, 1], [1, 1]])
+ th_shrink_maps = tch_preds[:, 0, :, :]
+ th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
+ dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
+ for i in range(th_shrink_maps.shape[0]):
+ dilate_maps[i] = cv2.dilate(
+ th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
+ th_shrink_maps = paddle.to_tensor(dilate_maps)
+
+ label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
+ 1:]
+
+ # calculate the shrink map loss
+ bce_loss = self.alpha * self.bce_loss(
+ stu_shrink_maps, th_shrink_maps, label_shrink_mask)
+ loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
+ label_shrink_mask)
+
+ # k = f"{self.name}_{pair[0]}_{pair[1]}"
+ k = "{}_{}_{}".format(self.name, pair[0], pair[1])
+ loss_dict[k] = bce_loss + loss_binary_maps
+
+ loss_dict = _sum_loss(loss_dict)
+ return loss_dict
+
+
class DistillationDistanceLoss(DistanceLoss):
"""
"""
diff --git a/ppocr/losses/table_att_loss.py b/ppocr/losses/table_att_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7fd99e6952aacc0182a482ca5ae5ddaf959a026
--- /dev/null
+++ b/ppocr/losses/table_att_loss.py
@@ -0,0 +1,109 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+from paddle import fluid
+
+class TableAttentionLoss(nn.Layer):
+ def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
+ super(TableAttentionLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
+ self.structure_weight = structure_weight
+ self.loc_weight = loc_weight
+ self.use_giou = use_giou
+ self.giou_weight = giou_weight
+
+ def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
+ '''
+ :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
+ :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
+ :return: loss
+ '''
+ ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
+ iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
+ ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
+ iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
+
+ iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
+ ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
+
+ # overlap
+ inters = iw * ih
+
+ # union
+ uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
+ ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
+ bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
+
+ # ious
+ ious = inters / uni
+
+ ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
+ ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
+ ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
+ ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
+ ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
+ eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
+
+ # enclose erea
+ enclose = ew * eh + eps
+ giou = ious - (enclose - uni) / enclose
+
+ loss = 1 - giou
+
+ if reduction == 'mean':
+ loss = paddle.mean(loss)
+ elif reduction == 'sum':
+ loss = paddle.sum(loss)
+ else:
+ raise NotImplementedError
+ return loss
+
+ def forward(self, predicts, batch):
+ structure_probs = predicts['structure_probs']
+ structure_targets = batch[1].astype("int64")
+ structure_targets = structure_targets[:, 1:]
+ if len(batch) == 6:
+ structure_mask = batch[5].astype("int64")
+ structure_mask = structure_mask[:, 1:]
+ structure_mask = paddle.reshape(structure_mask, [-1])
+ structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
+ structure_targets = paddle.reshape(structure_targets, [-1])
+ structure_loss = self.loss_func(structure_probs, structure_targets)
+
+ if len(batch) == 6:
+ structure_loss = structure_loss * structure_mask
+
+# structure_loss = paddle.sum(structure_loss) * self.structure_weight
+ structure_loss = paddle.mean(structure_loss) * self.structure_weight
+
+ loc_preds = predicts['loc_preds']
+ loc_targets = batch[2].astype("float32")
+ loc_targets_mask = batch[4].astype("float32")
+ loc_targets = loc_targets[:, 1:, :]
+ loc_targets_mask = loc_targets_mask[:, 1:, :]
+ loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
+ if self.use_giou:
+ loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
+ total_loss = structure_loss + loc_loss + loc_loss_giou
+ return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
+ else:
+ total_loss = structure_loss + loc_loss
+ return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
\ No newline at end of file
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 9e9060fa999bd3175c31dfc0797cd293d4e7afec..64f62e51cdf922773c03bb784a4edffdc17f506f 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -26,11 +26,11 @@ from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
-
+from .table_metric import TableMetric
def build_metric(config):
support_dict = [
- "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
+ "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/det_metric.py b/ppocr/metrics/det_metric.py
index 0f9e94df42bb8f31ebc79693a01968d441b16faa..d3d353042575671826da3fc56bf02ccf40dfa5d4 100644
--- a/ppocr/metrics/det_metric.py
+++ b/ppocr/metrics/det_metric.py
@@ -55,6 +55,7 @@ class DetMetric(object):
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result)
+
def get_metric(self):
"""
return metrics {
diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py
index a7d3d095a7d384bf8cdc69b97f8109c359ac2b5b..c440cebdd0f96493fc33000a0d304cbe5e3f0624 100644
--- a/ppocr/metrics/distillation_metric.py
+++ b/ppocr/metrics/distillation_metric.py
@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
class DistillationMetric(object):
def __init__(self,
key=None,
- base_metric_name="RecMetric",
- main_indicator='acc',
+ base_metric_name=None,
+ main_indicator=None,
**kwargs):
self.main_indicator = main_indicator
self.key = key
@@ -42,16 +42,13 @@ class DistillationMetric(object):
main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset()
- def __call__(self, preds, *args, **kwargs):
+ def __call__(self, preds, batch, **kwargs):
assert isinstance(preds, dict)
if self.metrics is None:
self._init_metrcis(preds)
output = dict()
for key in preds:
- metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
- for sub_key in metric:
- output["{}_{}".format(key, sub_key)] = metric[sub_key]
- return output
+ self.metrics[key].__call__(preds[key], batch, **kwargs)
def get_metric(self):
"""
diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d1c789ecc3979bd4c33620af91ccd28012f7a8
--- /dev/null
+++ b/ppocr/metrics/table_metric.py
@@ -0,0 +1,50 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+class TableMetric(object):
+ def __init__(self, main_indicator='acc', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, pred, batch, *args, **kwargs):
+ structure_probs = pred['structure_probs'].numpy()
+ structure_labels = batch[1]
+ correct_num = 0
+ all_num = 0
+ structure_probs = np.argmax(structure_probs, axis=2)
+ structure_labels = structure_labels[:, 1:]
+ batch_size = structure_probs.shape[0]
+ for bno in range(batch_size):
+ all_num += 1
+ if (structure_probs[bno] == structure_labels[bno]).all():
+ correct_num += 1
+ self.correct_num += correct_num
+ self.all_num += all_num
+ return {
+ 'acc': correct_num * 1.0 / all_num,
+ }
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ }
+ """
+ acc = 1.0 * self.correct_num / self.all_num
+ self.reset()
+ return {'acc': acc}
+
+ def reset(self):
+ self.correct_num = 0
+ self.all_num = 0
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index 52ad15930efc60ee0b87c5268cf2da75df88cddb..c498d9862abcfc85eaf29ed1d949230a1dc1629c 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -77,11 +77,11 @@ class BaseModel(nn.Layer):
if self.use_neck:
x = self.neck(x)
y["neck_out"] = x
- if data is None:
- x = self.head(x)
+ x = self.head(x, targets=data)
+ if isinstance(x, dict):
+ y.update(x)
else:
- x = self.head(x, data)
- y["head_out"] = x
+ y["head_out"] = x
if self.return_all_feats:
return y
else:
diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py
index 2e512331afcfc20e422dbef4ba1a4acd581df9e7..1e95fe574433eaca6f322ff47c8547cc1a29a248 100644
--- a/ppocr/modeling/architectures/distillation_model.py
+++ b/ppocr/modeling/architectures/distillation_model.py
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import init_model, load_pretrained_params
__all__ = ['DistillationModel']
@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
- init_model(model, path=pretrained)
+ load_pretrained_params(model, pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 49c348640ea7f2a7b354f4500eaf4415ad460f9e..618b827de5884af2049ddce218a72e175b354b28 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -12,32 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['build_backbone']
+__all__ = ["build_backbone"]
def build_backbone(config, model_type):
- if model_type == 'det':
+ if model_type == "det":
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST
- support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
- elif model_type == 'rec' or model_type == 'cls':
+ support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
+ elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
+ from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_swin import SwinTransformer
- support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer']
-
- elif model_type == 'e2e':
+ support_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer']
+ elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
- support_dict = ['ResNet']
+ support_dict = ["ResNet"]
+ elif model_type == "table":
+ from .table_resnet_vd import ResNet
+ from .table_mobilenet_v3 import MobileNetV3
+ support_dict = ["ResNet", "MobileNetV3"]
else:
raise NotImplementedError
- module_name = config.pop('name')
+ module_name = config.pop("name")
assert module_name in support_dict, Exception(
- 'when model typs is {}, backbone only support {}'.format(model_type,
+ "when model typs is {}, backbone only support {}".format(model_type,
support_dict))
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe874fac1af439bfb47ba9050a61f02db302e224
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_mv1_enhance.py
@@ -0,0 +1,256 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
+from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
+from paddle.nn.initializer import KaimingNormal
+import math
+import numpy as np
+import paddle
+from paddle import ParamAttr, reshape, transpose, concat, split
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
+from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
+from paddle.nn.initializer import KaimingNormal
+import math
+from paddle.nn.functional import hardswish, hardsigmoid
+from paddle.regularizer import L2Decay
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ num_channels,
+ filter_size,
+ num_filters,
+ stride,
+ padding,
+ channels=None,
+ num_groups=1,
+ act='hard_swish'):
+ super(ConvBNLayer, self).__init__()
+
+ self._conv = Conv2D(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False)
+
+ self._batch_norm = BatchNorm(
+ num_filters,
+ act=act,
+ param_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class DepthwiseSeparable(nn.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters1,
+ num_filters2,
+ num_groups,
+ stride,
+ scale,
+ dw_size=3,
+ padding=1,
+ use_se=False):
+ super(DepthwiseSeparable, self).__init__()
+ self.use_se = use_se
+ self._depthwise_conv = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=int(num_filters1 * scale),
+ filter_size=dw_size,
+ stride=stride,
+ padding=padding,
+ num_groups=int(num_groups * scale))
+ if use_se:
+ self._se = SEModule(int(num_filters1 * scale))
+ self._pointwise_conv = ConvBNLayer(
+ num_channels=int(num_filters1 * scale),
+ filter_size=1,
+ num_filters=int(num_filters2 * scale),
+ stride=1,
+ padding=0)
+
+ def forward(self, inputs):
+ y = self._depthwise_conv(inputs)
+ if self.use_se:
+ y = self._se(y)
+ y = self._pointwise_conv(y)
+ return y
+
+
+class MobileNetV1Enhance(nn.Layer):
+ def __init__(self, in_channels=3, scale=0.5, **kwargs):
+ super().__init__()
+ self.scale = scale
+ self.block_list = []
+
+ self.conv1 = ConvBNLayer(
+ num_channels=3,
+ filter_size=3,
+ channels=3,
+ num_filters=int(32 * scale),
+ stride=2,
+ padding=1)
+
+ conv2_1 = DepthwiseSeparable(
+ num_channels=int(32 * scale),
+ num_filters1=32,
+ num_filters2=64,
+ num_groups=32,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv2_1)
+
+ conv2_2 = DepthwiseSeparable(
+ num_channels=int(64 * scale),
+ num_filters1=64,
+ num_filters2=128,
+ num_groups=64,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv2_2)
+
+ conv3_1 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=128,
+ num_groups=128,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv3_1)
+
+ conv3_2 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=256,
+ num_groups=128,
+ stride=(2, 1),
+ scale=scale)
+ self.block_list.append(conv3_2)
+
+ conv4_1 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=256,
+ num_groups=256,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv4_1)
+
+ conv4_2 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=512,
+ num_groups=256,
+ stride=(2, 1),
+ scale=scale)
+ self.block_list.append(conv4_2)
+
+ for _ in range(5):
+ conv5 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=512,
+ num_groups=512,
+ stride=1,
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=False)
+ self.block_list.append(conv5)
+
+ conv5_6 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=1024,
+ num_groups=512,
+ stride=(2, 1),
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=True)
+ self.block_list.append(conv5_6)
+
+ conv6 = DepthwiseSeparable(
+ num_channels=int(1024 * scale),
+ num_filters1=1024,
+ num_filters2=1024,
+ num_groups=1024,
+ stride=1,
+ dw_size=5,
+ padding=2,
+ use_se=True,
+ scale=scale)
+ self.block_list.append(conv6)
+
+ self.block_list = nn.Sequential(*self.block_list)
+
+ self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+ self.out_channels = int(1024 * scale)
+
+ def forward(self, inputs):
+ y = self.conv1(inputs)
+ y = self.block_list(y)
+ y = self.pool(y)
+ return y
+
+
+class SEModule(nn.Layer):
+ def __init__(self, channel, reduction=4):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2D(1)
+ self.conv1 = Conv2D(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(),
+ bias_attr=ParamAttr())
+ self.conv2 = Conv2D(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(),
+ bias_attr=ParamAttr())
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = hardsigmoid(outputs)
+ return paddle.multiply(x=inputs, y=outputs)
diff --git a/ppocr/modeling/backbones/table_mobilenet_v3.py b/ppocr/modeling/backbones/table_mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa87f976038d8d5eeafadceb869b9232ba22cd9
--- /dev/null
+++ b/ppocr/modeling/backbones/table_mobilenet_v3.py
@@ -0,0 +1,287 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+__all__ = ['MobileNetV3']
+
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class MobileNetV3(nn.Layer):
+ def __init__(self,
+ in_channels=3,
+ model_name='large',
+ scale=0.5,
+ disable_se=False,
+ **kwargs):
+ """
+ the MobilenetV3 backbone network for detection module.
+ Args:
+ params(dict): the super parameters for build network
+ """
+ super(MobileNetV3, self).__init__()
+
+ self.disable_se = disable_se
+
+ if model_name == "large":
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, False, 'relu', 1],
+ [3, 64, 24, False, 'relu', 2],
+ [3, 72, 24, False, 'relu', 1],
+ [5, 72, 40, True, 'relu', 2],
+ [5, 120, 40, True, 'relu', 1],
+ [5, 120, 40, True, 'relu', 1],
+ [3, 240, 80, False, 'hardswish', 2],
+ [3, 200, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 480, 112, True, 'hardswish', 1],
+ [3, 672, 112, True, 'hardswish', 1],
+ [5, 672, 160, True, 'hardswish', 2],
+ [5, 960, 160, True, 'hardswish', 1],
+ [5, 960, 160, True, 'hardswish', 1],
+ ]
+ cls_ch_squeeze = 960
+ elif model_name == "small":
+ cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, True, 'relu', 2],
+ [3, 72, 24, False, 'relu', 2],
+ [3, 88, 24, False, 'relu', 1],
+ [5, 96, 40, True, 'hardswish', 2],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 120, 48, True, 'hardswish', 1],
+ [5, 144, 48, True, 'hardswish', 1],
+ [5, 288, 96, True, 'hardswish', 2],
+ [5, 576, 96, True, 'hardswish', 1],
+ [5, 576, 96, True, 'hardswish', 1],
+ ]
+ cls_ch_squeeze = 576
+ else:
+ raise NotImplementedError("mode[" + model_name +
+ "_model] is not implemented!")
+
+ supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+ assert scale in supported_scale, \
+ "supported scale are {} but input scale is {}".format(supported_scale, scale)
+ inplanes = 16
+ # conv1
+ self.conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=make_divisible(inplanes * scale),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=1,
+ if_act=True,
+ act='hardswish',
+ name='conv1')
+
+ self.stages = []
+ self.out_channels = []
+ block_list = []
+ i = 0
+ inplanes = make_divisible(inplanes * scale)
+ for (k, exp, c, se, nl, s) in cfg:
+ se = se and not self.disable_se
+ start_idx = 2 if model_name == 'large' else 0
+ if s == 2 and i > start_idx:
+ self.out_channels.append(inplanes)
+ self.stages.append(nn.Sequential(*block_list))
+ block_list = []
+ block_list.append(
+ ResidualUnit(
+ in_channels=inplanes,
+ mid_channels=make_divisible(scale * exp),
+ out_channels=make_divisible(scale * c),
+ kernel_size=k,
+ stride=s,
+ use_se=se,
+ act=nl,
+ name="conv" + str(i + 2)))
+ inplanes = make_divisible(scale * c)
+ i += 1
+ block_list.append(
+ ConvBNLayer(
+ in_channels=inplanes,
+ out_channels=make_divisible(scale * cls_ch_squeeze),
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ if_act=True,
+ act='hardswish',
+ name='conv_last'))
+ self.stages.append(nn.Sequential(*block_list))
+ self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
+ for i, stage in enumerate(self.stages):
+ self.add_sublayer(sublayer=stage, name="stage{}".format(i))
+
+ def forward(self, x):
+ x = self.conv(x)
+ out_list = []
+ for stage in self.stages:
+ x = stage(x)
+ out_list.append(x)
+ return out_list
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+ self.if_act = if_act
+ self.act = act
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '_weights'),
+ bias_attr=False)
+
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=None,
+ param_attr=ParamAttr(name=name + "_bn_scale"),
+ bias_attr=ParamAttr(name=name + "_bn_offset"),
+ moving_mean_name=name + "_bn_mean",
+ moving_variance_name=name + "_bn_variance")
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.if_act:
+ if self.act == "relu":
+ x = F.relu(x)
+ elif self.act == "hardswish":
+ x = F.hardswish(x)
+ else:
+ print("The activation function({}) is selected incorrectly.".
+ format(self.act))
+ exit()
+ return x
+
+
+class ResidualUnit(nn.Layer):
+ def __init__(self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ use_se,
+ act=None,
+ name=''):
+ super(ResidualUnit, self).__init__()
+ self.if_shortcut = stride == 1 and in_channels == out_channels
+ self.if_se = use_se
+
+ self.expand_conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ if_act=True,
+ act=act,
+ name=name + "_expand")
+ self.bottleneck_conv = ConvBNLayer(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=int((kernel_size - 1) // 2),
+ groups=mid_channels,
+ if_act=True,
+ act=act,
+ name=name + "_depthwise")
+ if self.if_se:
+ self.mid_se = SEModule(mid_channels, name=name + "_se")
+ self.linear_conv = ConvBNLayer(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ if_act=False,
+ act=None,
+ name=name + "_linear")
+
+ def forward(self, inputs):
+ x = self.expand_conv(inputs)
+ x = self.bottleneck_conv(x)
+ if self.if_se:
+ x = self.mid_se(x)
+ x = self.linear_conv(x)
+ if self.if_shortcut:
+ x = paddle.add(inputs, x)
+ return x
+
+
+class SEModule(nn.Layer):
+ def __init__(self, in_channels, reduction=4, name=""):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2D(1)
+ self.conv1 = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=in_channels // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(name=name + "_1_weights"),
+ bias_attr=ParamAttr(name=name + "_1_offset"))
+ self.conv2 = nn.Conv2D(
+ in_channels=in_channels // reduction,
+ out_channels=in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(name + "_2_weights"),
+ bias_attr=ParamAttr(name=name + "_2_offset"))
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
+ return inputs * outputs
\ No newline at end of file
diff --git a/ppocr/modeling/backbones/table_resnet_vd.py b/ppocr/modeling/backbones/table_resnet_vd.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c07c2684eec8d0c4a445cc88c543bfe1da9c864
--- /dev/null
+++ b/ppocr/modeling/backbones/table_resnet_vd.py
@@ -0,0 +1,280 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+__all__ = ["ResNet"]
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None, ):
+ super(ConvBNLayer, self).__init__()
+
+ self.is_vd_mode = is_vd_mode
+ self._pool2d_avg = nn.AvgPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self._conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False)
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self._batch_norm = nn.BatchNorm(
+ out_channels,
+ act=act,
+ param_attr=ParamAttr(name=bn_name + '_scale'),
+ bias_attr=ParamAttr(bn_name + '_offset'),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance')
+
+ def forward(self, inputs):
+ if self.is_vd_mode:
+ inputs = self._pool2d_avg(inputs)
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2b")
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ act=None,
+ name=name + "_branch2c")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ stride=1,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv2)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ name=name + "_branch2b")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv1)
+ y = F.relu(y)
+ return y
+
+
+class ResNet(nn.Layer):
+ def __init__(self, in_channels=3, layers=50, **kwargs):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ supported_layers = [18, 34, 50, 101, 152, 200]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ elif layers == 200:
+ depth = [3, 12, 48, 3]
+ num_channels = [64, 256, 512,
+ 1024] if layers >= 50 else [64, 64, 128, 256]
+ num_filters = [64, 128, 256, 512]
+
+ self.conv1_1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=32,
+ kernel_size=3,
+ stride=2,
+ act='relu',
+ name="conv1_1")
+ self.conv1_2 = ConvBNLayer(
+ in_channels=32,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv1_2")
+ self.conv1_3 = ConvBNLayer(
+ in_channels=32,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv1_3")
+ self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+ self.stages = []
+ self.out_channels = []
+ if layers >= 50:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ bottleneck_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BottleneckBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block] * 4,
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(bottleneck_block)
+ self.out_channels.append(num_filters[block] * 4)
+ self.stages.append(nn.Sequential(*block_list))
+ else:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ basic_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BasicBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block],
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(basic_block)
+ self.out_channels.append(num_filters[block])
+ self.stages.append(nn.Sequential(*block_list))
+
+ def forward(self, inputs):
+ y = self.conv1_1(inputs)
+ y = self.conv1_2(y)
+ y = self.conv1_3(y)
+ y = self.pool2d_max(y)
+ out = []
+ for block in self.stages:
+ y = block(y)
+ out.append(y)
+ return out
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 10acd0fa1b99f4b5d86b58dcb956675e889f3028..63951cd57d7ab4acff0e872edfa14084e4d85474 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -32,8 +32,12 @@ def build_head(config):
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
- 'SRNHead', 'PGHead', 'TransformerOptim']
+ 'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead']
+
+
+ #table head
+ from .table_att_head import TableAttentionHead
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
diff --git a/ppocr/modeling/heads/cls_head.py b/ppocr/modeling/heads/cls_head.py
index d9b78b841b3c31ea349cfbf4e767328b12f39aa7..91bfa615a8206b5ec0f993429ccae990a05d0b9b 100644
--- a/ppocr/modeling/heads/cls_head.py
+++ b/ppocr/modeling/heads/cls_head.py
@@ -43,7 +43,7 @@ class ClsHead(nn.Layer):
initializer=nn.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_0.b_0"), )
- def forward(self, x):
+ def forward(self, x, targets=None):
x = self.pool(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.fc(x)
diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py
index 83e7a5ebfe131ed209b7dd2d4b5a324605be8370..f76cb34d37af7d81b5e628d06c1a4cfe126f8bb4 100644
--- a/ppocr/modeling/heads/det_db_head.py
+++ b/ppocr/modeling/heads/det_db_head.py
@@ -106,7 +106,7 @@ class DBHead(nn.Layer):
def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
- def forward(self, x):
+ def forward(self, x, targets=None):
shrink_maps = self.binarize(x)
if not self.training:
return {'maps': shrink_maps}
diff --git a/ppocr/modeling/heads/det_east_head.py b/ppocr/modeling/heads/det_east_head.py
index 9d0c3c4cf83adb018fcc368374cbe305658e07a9..004eb5d7bb9a134d1a84f980e37e5336dc43a29a 100644
--- a/ppocr/modeling/heads/det_east_head.py
+++ b/ppocr/modeling/heads/det_east_head.py
@@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
act=None,
name="f_geo")
- def forward(self, x):
+ def forward(self, x, targets=None):
f_det = self.det_conv1(x)
f_det = self.det_conv2(f_det)
f_score = self.score_conv(f_det)
diff --git a/ppocr/modeling/heads/det_sast_head.py b/ppocr/modeling/heads/det_sast_head.py
index 263b28672299e733369938fa03952dca7685fabe..7a88a2db6c29c8c4fa1ee94d27bd0701cdbc90f8 100644
--- a/ppocr/modeling/heads/det_sast_head.py
+++ b/ppocr/modeling/heads/det_sast_head.py
@@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
self.head1 = SAST_Header1(in_channels)
self.head2 = SAST_Header2(in_channels)
- def forward(self, x):
+ def forward(self, x, targets=None):
f_score, f_border = self.head1(x)
f_tvo, f_tco = self.head2(x)
diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py
index 0da9de7580a0ceb473f971b2246c966497026a5d..274e1cdac5172f45590c9f7d7b50522c74db6750 100644
--- a/ppocr/modeling/heads/e2e_pg_head.py
+++ b/ppocr/modeling/heads/e2e_pg_head.py
@@ -220,7 +220,7 @@ class PGHead(nn.Layer):
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False)
- def forward(self, x):
+ def forward(self, x, targets=None):
f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score)
diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py
index 481f93e47e58f8267b23e632df1a1e80733d5944..9c38d31fa0abcf39a583e5edcebfc8f336f41c46 100755
--- a/ppocr/modeling/heads/rec_ctc_head.py
+++ b/ppocr/modeling/heads/rec_ctc_head.py
@@ -33,19 +33,47 @@ def get_para_bias_attr(l2_decay, k):
class CTCHead(nn.Layer):
- def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ fc_decay=0.0004,
+ mid_channels=None,
+ **kwargs):
super(CTCHead, self).__init__()
- weight_attr, bias_attr = get_para_bias_attr(
- l2_decay=fc_decay, k=in_channels)
- self.fc = nn.Linear(
- in_channels,
- out_channels,
- weight_attr=weight_attr,
- bias_attr=bias_attr)
+ if mid_channels is None:
+ weight_attr, bias_attr = get_para_bias_attr(
+ l2_decay=fc_decay, k=in_channels)
+ self.fc = nn.Linear(
+ in_channels,
+ out_channels,
+ weight_attr=weight_attr,
+ bias_attr=bias_attr)
+ else:
+ weight_attr1, bias_attr1 = get_para_bias_attr(
+ l2_decay=fc_decay, k=in_channels)
+ self.fc1 = nn.Linear(
+ in_channels,
+ mid_channels,
+ weight_attr=weight_attr1,
+ bias_attr=bias_attr1)
+
+ weight_attr2, bias_attr2 = get_para_bias_attr(
+ l2_decay=fc_decay, k=mid_channels)
+ self.fc2 = nn.Linear(
+ mid_channels,
+ out_channels,
+ weight_attr=weight_attr2,
+ bias_attr=bias_attr2)
self.out_channels = out_channels
+ self.mid_channels = mid_channels
- def forward(self, x, labels=None):
- predicts = self.fc(x)
+ def forward(self, x, targets=None):
+ if self.mid_channels is None:
+ predicts = self.fc(x)
+ else:
+ predicts = self.fc1(x)
+ predicts = self.fc2(predicts)
+
if not self.training:
predicts = F.softmax(predicts, axis=2)
return predicts
diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py
index d2c7fc028d28c79057708d4e6f306c417ba6306a..8d59e4711a043afd9234f430a62c9876c0a8f6f4 100644
--- a/ppocr/modeling/heads/rec_srn_head.py
+++ b/ppocr/modeling/heads/rec_srn_head.py
@@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
- def forward(self, inputs, others):
+ def forward(self, inputs, targets=None):
+ others = targets[-4:]
encoder_word_pos = others[0]
gsrm_word_pos = others[1]
gsrm_slf_attn_bias1 = others[2]
diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..155f036d15673135eae9e5ee493648603609535d
--- /dev/null
+++ b/ppocr/modeling/heads/table_att_head.py
@@ -0,0 +1,238 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+
+class TableAttentionHead(nn.Layer):
+ def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
+ super(TableAttentionHead, self).__init__()
+ self.input_size = in_channels[-1]
+ self.hidden_size = hidden_size
+ self.elem_num = 30
+ self.max_text_length = 100
+ self.max_elem_length = 500
+ self.max_cell_num = 500
+
+ self.structure_attention_cell = AttentionGRUCell(
+ self.input_size, hidden_size, self.elem_num, use_gru=False)
+ self.structure_generator = nn.Linear(hidden_size, self.elem_num)
+ self.loc_type = loc_type
+ self.in_max_len = in_max_len
+
+ if self.loc_type == 1:
+ self.loc_generator = nn.Linear(hidden_size, 4)
+ else:
+ if self.in_max_len == 640:
+ self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
+ elif self.in_max_len == 800:
+ self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
+ else:
+ self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
+ self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None):
+ # if and else branch are both needed when you want to assign a variable
+ # if you modify the var in just one branch, then the modification will not work.
+ fea = inputs[-1]
+ if len(fea.shape) == 3:
+ pass
+ else:
+ last_shape = int(np.prod(fea.shape[2:])) # gry added
+ fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
+ fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
+ batch_size = fea.shape[0]
+
+ hidden = paddle.zeros((batch_size, self.hidden_size))
+ output_hiddens = []
+ if self.training and targets is not None:
+ structure = targets[0]
+ for i in range(self.max_elem_length+1):
+ elem_onehots = self._char_to_onehot(
+ structure[:, i], onehot_dim=self.elem_num)
+ (outputs, hidden), alpha = self.structure_attention_cell(
+ hidden, fea, elem_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ structure_probs = self.structure_generator(output)
+ if self.loc_type == 1:
+ loc_preds = self.loc_generator(output)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
+ structure_probs = None
+ loc_preds = None
+ elem_onehots = None
+ outputs = None
+ alpha = None
+ max_elem_length = paddle.to_tensor(self.max_elem_length)
+ i = 0
+ while i < max_elem_length+1:
+ elem_onehots = self._char_to_onehot(
+ temp_elem, onehot_dim=self.elem_num)
+ (outputs, hidden), alpha = self.structure_attention_cell(
+ hidden, fea, elem_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ structure_probs_step = self.structure_generator(outputs)
+ temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
+ i += 1
+
+ output = paddle.concat(output_hiddens, axis=1)
+ structure_probs = self.structure_generator(output)
+ structure_probs = F.softmax(structure_probs)
+ if self.loc_type == 1:
+ loc_preds = self.loc_generator(output)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
+ return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
+
+
+class AttentionGRUCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionGRUCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+ return cur_hidden, alpha
+
+
+class AttentionLSTM(nn.Layer):
+ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
+ super(AttentionLSTM, self).__init__()
+ self.input_size = in_channels
+ self.hidden_size = hidden_size
+ self.num_classes = out_channels
+
+ self.attention_cell = AttentionLSTMCell(
+ in_channels, hidden_size, out_channels, use_gru=False)
+ self.generator = nn.Linear(hidden_size, out_channels)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None, batch_max_length=25):
+ batch_size = inputs.shape[0]
+ num_steps = batch_max_length
+
+ hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
+ (batch_size, self.hidden_size)))
+ output_hiddens = []
+
+ if targets is not None:
+ for i in range(num_steps):
+ # one-hot vectors for a i-th char
+ char_onehots = self._char_to_onehot(
+ targets[:, i], onehot_dim=self.num_classes)
+ hidden, alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+
+ hidden = (hidden[1][0], hidden[1][1])
+ output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ probs = self.generator(output)
+
+ else:
+ targets = paddle.zeros(shape=[batch_size], dtype="int32")
+ probs = None
+
+ for i in range(num_steps):
+ char_onehots = self._char_to_onehot(
+ targets, onehot_dim=self.num_classes)
+ hidden, alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+ probs_step = self.generator(hidden[0])
+ hidden = (hidden[1][0], hidden[1][1])
+ if probs is None:
+ probs = paddle.unsqueeze(probs_step, axis=1)
+ else:
+ probs = paddle.concat(
+ [probs, paddle.unsqueeze(
+ probs_step, axis=1)], axis=1)
+
+ next_input = probs_step.argmax(axis=1)
+
+ targets = next_input
+
+ return probs
+
+
+class AttentionLSTMCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionLSTMCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ if not use_gru:
+ self.rnn = nn.LSTMCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ else:
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+
+ return cur_hidden, alpha
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index 37a5cf7863cb386884d82ed88c756c9fc06a541d..e97c4f64bdc9acd6729d67a9c6ff7a7563f6c95e 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -21,7 +21,8 @@ def build_neck(config):
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
from .pg_fpn import PGFPN
- support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
+ from .table_fpn import TableFPN
+ support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
diff --git a/ppocr/modeling/necks/table_fpn.py b/ppocr/modeling/necks/table_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..734f15af65e4e15a7ddb4004954a61bfa1934246
--- /dev/null
+++ b/ppocr/modeling/necks/table_fpn.py
@@ -0,0 +1,110 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class TableFPN(nn.Layer):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(TableFPN, self).__init__()
+ self.out_channels = 512
+ weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.in2_conv = nn.Conv2D(
+ in_channels=in_channels[0],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in3_conv = nn.Conv2D(
+ in_channels=in_channels[1],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride = 1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in4_conv = nn.Conv2D(
+ in_channels=in_channels[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in5_conv = nn.Conv2D(
+ in_channels=in_channels[3],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p5_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p4_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p3_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p2_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.fuse_conv = nn.Conv2D(
+ in_channels=self.out_channels * 4,
+ out_channels=512,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.in5_conv(c5)
+ in4 = self.in4_conv(c4)
+ in3 = self.in3_conv(c3)
+ in2 = self.in2_conv(c2)
+
+ out4 = in4 + F.upsample(
+ in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
+ out3 = in3 + F.upsample(
+ out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
+ out2 = in2 + F.upsample(
+ out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
+
+ p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ fuse = paddle.concat([in5, p4, p3, p2], axis=1)
+ fuse_conv = self.fuse_conv(fuse) * 0.005
+ return [c5 + fuse_conv]
diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py
index 78338edf67d69e32322912d75dec01ce1e63cb49..dcce6246ac64b4b84229cbd69a4dc53c658b4c7b 100644
--- a/ppocr/modeling/transforms/tps.py
+++ b/ppocr/modeling/transforms/tps.py
@@ -230,15 +230,8 @@ class GridGenerator(nn.Layer):
def build_inv_delta_C_paddle(self, C):
""" Return inv_delta_C which is needed to calculate T """
F = self.F
- hat_C = paddle.zeros((F, F), dtype='float64') # F x F
- for i in range(0, F):
- for j in range(i, F):
- if i == j:
- hat_C[i, j] = 1
- else:
- r = paddle.norm(C[i] - C[j])
- hat_C[i, j] = r
- hat_C[j, i] = r
+ hat_eye = paddle.eye(F, dtype='float64') # F x F
+ hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3
[
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index f7f1bcd61949471b9358666eea4f9407c30f2f66..f1829e3edd726a63d7381e29ac29587694850e49 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -21,10 +21,13 @@ import copy
__all__ = ['build_post_process']
-from .db_postprocess import DBPostProcess
+from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
-from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode
+
+from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \
+ TableLabelDecode
+
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
@@ -32,7 +35,7 @@ def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
- 'DistillationCTCLabelDecode', 'NRTRLabelDecode'
+ 'DistillationCTCLabelDecode', 'NRTRLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
]
config = copy.deepcopy(config)
diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py
index 769ddbe23253ce58e2bccd46ef5074cc2a7d27da..d9c9869dfcd35cb9b491db826f3bff5f766723f4 100755
--- a/ppocr/postprocess/db_postprocess.py
+++ b/ppocr/postprocess/db_postprocess.py
@@ -187,3 +187,29 @@ class DBPostProcess(object):
boxes_batch.append({'points': boxes})
return boxes_batch
+
+
+class DistillationDBPostProcess(object):
+ def __init__(self, model_name=["student"],
+ key=None,
+ thresh=0.3,
+ box_thresh=0.6,
+ max_candidates=1000,
+ unclip_ratio=1.5,
+ use_dilation=False,
+ score_mode="fast",
+ **kwargs):
+ self.model_name = model_name
+ self.key = key
+ self.post_process = DBPostProcess(thresh=thresh,
+ box_thresh=box_thresh,
+ max_candidates=max_candidates,
+ unclip_ratio=unclip_ratio,
+ use_dilation=use_dilation,
+ score_mode=score_mode)
+
+ def __call__(self, predicts, shape_list):
+ results = {}
+ for k in self.model_name:
+ results[k] = self.post_process(predicts[k], shape_list=shape_list)
+ return results
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 371e238664c2d8de8ffbc0958785a84b64d9cf5a..9f23b5495f63a41283656ceaf9df76f96b8d1592 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
- self.character_str = ""
+ self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
- self.character_str += line
+ self.character_str.append(line)
if use_space_char:
- self.character_str += " "
+ self.character_str.append(" ")
dict_character = list(self.character_str)
else:
@@ -381,3 +381,138 @@ class SRNLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class TableLabelDecode(object):
+ """ """
+
+ def __init__(self,
+ character_dict_path,
+ **kwargs):
+ list_character, list_elem = self.load_char_elem_dict(character_dict_path)
+ list_character = self.add_special_char(list_character)
+ list_elem = self.add_special_char(list_elem)
+ self.dict_character = {}
+ self.dict_idx_character = {}
+ for i, char in enumerate(list_character):
+ self.dict_idx_character[i] = char
+ self.dict_character[char] = i
+ self.dict_elem = {}
+ self.dict_idx_elem = {}
+ for i, elem in enumerate(list_elem):
+ self.dict_idx_elem[i] = elem
+ self.dict_elem[elem] = i
+
+ def load_char_elem_dict(self, character_dict_path):
+ list_character = []
+ list_elem = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
+ character_num = int(substr[0])
+ elem_num = int(substr[1])
+ for cno in range(1, 1 + character_num):
+ character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
+ list_character.append(character)
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
+ elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
+ list_elem.append(elem)
+ return list_character, list_elem
+
+ def add_special_char(self, list_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ list_character = [self.beg_str] + list_character + [self.end_str]
+ return list_character
+
+ def __call__(self, preds):
+ structure_probs = preds['structure_probs']
+ loc_preds = preds['loc_preds']
+ if isinstance(structure_probs,paddle.Tensor):
+ structure_probs = structure_probs.numpy()
+ if isinstance(loc_preds,paddle.Tensor):
+ loc_preds = loc_preds.numpy()
+ structure_idx = structure_probs.argmax(axis=2)
+ structure_probs = structure_probs.max(axis=2)
+ structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
+ structure_probs, 'elem')
+ res_html_code_list = []
+ res_loc_list = []
+ batch_num = len(structure_str)
+ for bno in range(batch_num):
+ res_loc = []
+ for sno in range(len(structure_str[bno])):
+ text = structure_str[bno][sno]
+ if text in [' | ', ' | 0 and tmp_elem_idx == end_idx:
+ break
+ if tmp_elem_idx in ignored_tokens:
+ continue
+
+ char_list.append(current_dict[tmp_elem_idx])
+ elem_pos_list.append(idx)
+ score_list.append(structure_probs[batch_idx, idx])
+ elem_idx_list.append(tmp_elem_idx)
+ result_list.append(char_list)
+ result_pos_list.append(elem_pos_list)
+ result_score_list.append(score_list)
+ result_elem_idx_list.append(elem_idx_list)
+ return result_list, result_pos_list, result_score_list, result_elem_idx_list
+
+ def get_ignored_tokens(self, char_or_elem):
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
+ if char_or_elem == "char":
+ if beg_or_end == "beg":
+ idx = self.dict_character[self.beg_str]
+ elif beg_or_end == "end":
+ idx = self.dict_character[self.end_str]
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
+ % beg_or_end
+ elif char_or_elem == "elem":
+ if beg_or_end == "beg":
+ idx = self.dict_elem[self.beg_str]
+ elif beg_or_end == "end":
+ idx = self.dict_elem[self.end_str]
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
+ % beg_or_end
+ else:
+ assert False, "Unsupport type %s in char_or_elem" \
+ % char_or_elem
+ return idx
diff --git a/ppocr/utils/dict/table_dict.txt b/ppocr/utils/dict/table_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2ef028c786cbce6d1e25856c62986d757b31f93b
--- /dev/null
+++ b/ppocr/utils/dict/table_dict.txt
@@ -0,0 +1,277 @@
+←
+
+☆
+─
+α
+
+
+⋅
+$
+ω
+ψ
+χ
+(
+υ
+≥
+σ
+,
+ρ
+ε
+0
+■
+4
+8
+✗
+b
+<
+✓
+Ψ
+Ω
+€
+D
+3
+Π
+H
+║
+
+L
+Φ
+Χ
+θ
+P
+κ
+λ
+μ
+T
+ξ
+X
+β
+γ
+δ
+\
+ζ
+η
+`
+d
+
+h
+f
+l
+Θ
+p
+√
+t
+
+x
+Β
+Γ
+Δ
+|
+ǂ
+ɛ
+j
+̧
+➢
+
+̌
+′
+«
+△
+▲
+#
+
+'
+Ι
++
+¶
+/
+▼
+⇑
+□
+·
+7
+▪
+;
+?
+➔
+∩
+C
+÷
+G
+⇒
+K
+
+O
+S
+С
+W
+Α
+[
+○
+_
+●
+‡
+c
+z
+g
+
+o
+
+〈
+〉
+s
+⩽
+w
+φ
+ʹ
+{
+»
+∣
+̆
+e
+ˆ
+∈
+τ
+◆
+ι
+∅
+∆
+∙
+∘
+Ø
+ß
+✔
+∞
+∑
+−
+×
+◊
+∗
+∖
+˃
+˂
+∫
+"
+i
+&
+π
+↔
+*
+∥
+æ
+∧
+.
+⁄
+ø
+Q
+∼
+6
+⁎
+:
+★
+>
+a
+B
+≈
+F
+J
+̄
+N
+♯
+R
+V
+
+―
+Z
+♣
+^
+¤
+¥
+§
+
+¢
+£
+≦
+
+≤
+‖
+Λ
+©
+n
+↓
+→
+↑
+r
+°
+±
+v
+
+♂
+k
+♀
+~
+ᅟ
+̇
+@
+”
+♦
+ł
+®
+⊕
+„
+!
+
+%
+⇓
+)
+-
+1
+5
+9
+=
+А
+A
+‰
+⋆
+Σ
+E
+◦
+I
+※
+M
+m
+̨
+⩾
+†
+
+•
+U
+Y
+
+]
+̸
+2
+‐
+–
+‒
+̂
+—
+̀
+́
+’
+‘
+⋮
+⋯
+̊
+“
+̈
+≧
+q
+u
+ı
+y
+
+
+̃
+}
+ν
diff --git a/ppocr/utils/dict/table_structure_dict.txt b/ppocr/utils/dict/table_structure_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9c4531e5f3b8c498e70d3c2ea0471e5e746a2c30
--- /dev/null
+++ b/ppocr/utils/dict/table_structure_dict.txt
@@ -0,0 +1,2759 @@
+277 28 1267 1186
+
+V
+a
+r
+i
+b
+l
+e
+
+H
+z
+d
+
+t
+o
+9
+5
+%
+C
+I
+
+p
+
+v
+u
+*
+A
+g
+(
+m
+n
+)
+0
+.
+7
+1
+6
+≤
+>
+8
+3
+–
+2
+G
+4
+M
+F
+T
+y
+f
+s
+L
+w
+c
+U
+h
+D
+S
+Q
+R
+x
+P
+-
+E
+O
+/
+k
+,
++
+N
+K
+q
+′
+[
+]
+<
+≥
+
+−
+
+μ
+±
+J
+j
+W
+_
+Δ
+B
+“
+:
+Y
+α
+λ
+;
+
+
+?
+∼
+=
+°
+#
+̊
+̈
+̂
+’
+Z
+X
+∗
+—
+β
+'
+†
+~
+@
+"
+γ
+↓
+↑
+&
+‡
+χ
+”
+σ
+§
+|
+¶
+‐
+×
+$
+→
+√
+✓
+‘
+\
+∞
+π
+•
+®
+^
+∆
+≧
+
+
+́
+♀
+♂
+‒
+⁎
+▲
+·
+£
+φ
+Ψ
+ß
+△
+☆
+▪
+η
+€
+∧
+̃
+Φ
+ρ
+̄
+δ
+‰
+̧
+Ω
+♦
+{
+}
+̀
+∑
+∫
+ø
+κ
+ε
+¥
+※
+`
+ω
+Σ
+➔
+‖
+Β
+̸
+
+─
+●
+⩾
+Χ
+Α
+⋅
+◆
+★
+■
+ψ
+ǂ
+□
+ζ
+!
+Γ
+↔
+θ
+⁄
+〈
+〉
+―
+υ
+τ
+⋆
+Ø
+©
+∥
+С
+˂
+➢
+ɛ
+
+✗
+←
+○
+¢
+⩽
+∖
+˃
+
+≈
+Π
+̌
+≦
+∅
+ᅟ
+
+
+∣
+¤
+♯
+̆
+ξ
+÷
+▼
+
+ι
+ν
+║
+
+
+◦
+
+◊
+∙
+«
+»
+ł
+ı
+Θ
+∈
+„
+∘
+✔
+̇
+æ
+ʹ
+ˆ
+♣
+⇓
+∩
+⊕
+⇒
+⇑
+̨
+Ι
+Λ
+⋯
+А
+⋮
+
+
+
+ |
+
+
+
+
+
+ colspan="2"
+ colspan="3"
+ rowspan="2"
+ colspan="4"
+ colspan="6"
+ rowspan="3"
+ colspan="9"
+ colspan="10"
+ colspan="7"
+ rowspan="4"
+ rowspan="5"
+ rowspan="9"
+ colspan="8"
+ rowspan="8"
+ rowspan="6"
+ rowspan="7"
+ rowspan="10"
+0 2924682
+1 3405345
+2 2363468
+3 2709165
+4 4078680
+5 3250792
+6 1923159
+7 1617890
+8 1450532
+9 1717624
+10 1477550
+11 1489223
+12 915528
+13 819193
+14 593660
+15 518924
+16 682065
+17 494584
+18 400591
+19 396421
+20 340994
+21 280688
+22 250328
+23 226786
+24 199927
+25 182707
+26 164629
+27 141613
+28 127554
+29 116286
+30 107682
+31 96367
+32 88002
+33 79234
+34 72186
+35 65921
+36 60374
+37 55976
+38 52166
+39 47414
+40 44932
+41 41279
+42 38232
+43 35463
+44 33703
+45 30557
+46 29639
+47 27000
+48 25447
+49 23186
+50 22093
+51 20412
+52 19844
+53 18261
+54 17561
+55 16499
+56 15597
+57 14558
+58 14372
+59 13445
+60 13514
+61 12058
+62 11145
+63 10767
+64 10370
+65 9630
+66 9337
+67 8881
+68 8727
+69 8060
+70 7994
+71 7740
+72 7189
+73 6729
+74 6749
+75 6548
+76 6321
+77 5957
+78 5740
+79 5407
+80 5370
+81 5035
+82 4921
+83 4656
+84 4600
+85 4519
+86 4277
+87 4023
+88 3939
+89 3910
+90 3861
+91 3560
+92 3483
+93 3406
+94 3346
+95 3229
+96 3122
+97 3086
+98 3001
+99 2884
+100 2822
+101 2677
+102 2670
+103 2610
+104 2452
+105 2446
+106 2400
+107 2300
+108 2316
+109 2196
+110 2089
+111 2083
+112 2041
+113 1881
+114 1838
+115 1896
+116 1795
+117 1786
+118 1743
+119 1765
+120 1750
+121 1683
+122 1563
+123 1499
+124 1513
+125 1462
+126 1388
+127 1441
+128 1417
+129 1392
+130 1306
+131 1321
+132 1274
+133 1294
+134 1240
+135 1126
+136 1157
+137 1130
+138 1084
+139 1130
+140 1083
+141 1040
+142 980
+143 1031
+144 974
+145 980
+146 932
+147 898
+148 960
+149 907
+150 852
+151 912
+152 859
+153 847
+154 876
+155 792
+156 791
+157 765
+158 788
+159 787
+160 744
+161 673
+162 683
+163 697
+164 666
+165 680
+166 632
+167 677
+168 657
+169 618
+170 587
+171 585
+172 567
+173 549
+174 562
+175 548
+176 542
+177 539
+178 542
+179 549
+180 547
+181 526
+182 525
+183 514
+184 512
+185 505
+186 515
+187 467
+188 475
+189 458
+190 435
+191 443
+192 427
+193 424
+194 404
+195 389
+196 429
+197 404
+198 386
+199 351
+200 388
+201 408
+202 361
+203 346
+204 324
+205 361
+206 363
+207 364
+208 323
+209 336
+210 342
+211 315
+212 325
+213 328
+214 314
+215 327
+216 320
+217 300
+218 295
+219 315
+220 310
+221 295
+222 275
+223 248
+224 274
+225 232
+226 293
+227 259
+228 286
+229 263
+230 242
+231 214
+232 261
+233 231
+234 211
+235 250
+236 233
+237 206
+238 224
+239 210
+240 233
+241 223
+242 216
+243 222
+244 207
+245 212
+246 196
+247 205
+248 201
+249 202
+250 211
+251 201
+252 215
+253 179
+254 163
+255 179
+256 191
+257 188
+258 196
+259 150
+260 154
+261 176
+262 211
+263 166
+264 171
+265 165
+266 149
+267 182
+268 159
+269 161
+270 164
+271 161
+272 141
+273 151
+274 127
+275 129
+276 142
+277 158
+278 148
+279 135
+280 127
+281 134
+282 138
+283 131
+284 126
+285 125
+286 130
+287 126
+288 135
+289 125
+290 135
+291 131
+292 95
+293 135
+294 106
+295 117
+296 136
+297 128
+298 128
+299 118
+300 109
+301 112
+302 117
+303 108
+304 120
+305 100
+306 95
+307 108
+308 112
+309 77
+310 120
+311 104
+312 109
+313 89
+314 98
+315 82
+316 98
+317 93
+318 77
+319 93
+320 77
+321 98
+322 93
+323 86
+324 89
+325 73
+326 70
+327 71
+328 77
+329 87
+330 77
+331 93
+332 100
+333 83
+334 72
+335 74
+336 69
+337 77
+338 68
+339 78
+340 90
+341 98
+342 75
+343 80
+344 63
+345 71
+346 83
+347 66
+348 71
+349 70
+350 62
+351 62
+352 59
+353 63
+354 62
+355 52
+356 64
+357 64
+358 56
+359 49
+360 57
+361 63
+362 60
+363 68
+364 62
+365 55
+366 54
+367 40
+368 75
+369 70
+370 53
+371 58
+372 57
+373 55
+374 69
+375 57
+376 53
+377 43
+378 45
+379 47
+380 56
+381 51
+382 59
+383 51
+384 43
+385 34
+386 57
+387 49
+388 39
+389 46
+390 48
+391 43
+392 40
+393 54
+394 50
+395 41
+396 43
+397 33
+398 27
+399 49
+400 44
+401 44
+402 38
+403 30
+404 32
+405 37
+406 39
+407 42
+408 53
+409 39
+410 34
+411 31
+412 32
+413 52
+414 27
+415 41
+416 34
+417 36
+418 50
+419 35
+420 32
+421 33
+422 45
+423 35
+424 40
+425 29
+426 41
+427 40
+428 39
+429 32
+430 31
+431 34
+432 29
+433 27
+434 26
+435 22
+436 34
+437 28
+438 30
+439 38
+440 35
+441 36
+442 36
+443 27
+444 24
+445 33
+446 31
+447 25
+448 33
+449 27
+450 32
+451 46
+452 31
+453 35
+454 35
+455 34
+456 26
+457 21
+458 25
+459 26
+460 24
+461 27
+462 33
+463 30
+464 35
+465 21
+466 32
+467 19
+468 27
+469 16
+470 28
+471 26
+472 27
+473 26
+474 25
+475 25
+476 27
+477 20
+478 28
+479 22
+480 23
+481 16
+482 25
+483 27
+484 19
+485 23
+486 19
+487 15
+488 15
+489 23
+490 24
+491 19
+492 20
+493 18
+494 17
+495 30
+496 28
+497 20
+498 29
+499 17
+500 19
+501 21
+502 15
+503 24
+504 15
+505 19
+506 25
+507 16
+508 23
+509 26
+510 21
+511 15
+512 12
+513 16
+514 18
+515 24
+516 26
+517 18
+518 8
+519 25
+520 14
+521 8
+522 24
+523 20
+524 18
+525 15
+526 13
+527 17
+528 18
+529 22
+530 21
+531 9
+532 16
+533 17
+534 13
+535 17
+536 15
+537 13
+538 20
+539 13
+540 19
+541 29
+542 10
+543 8
+544 18
+545 13
+546 9
+547 18
+548 10
+549 18
+550 18
+551 9
+552 9
+553 15
+554 13
+555 15
+556 14
+557 14
+558 18
+559 8
+560 13
+561 9
+562 7
+563 12
+564 6
+565 9
+566 9
+567 18
+568 9
+569 10
+570 13
+571 14
+572 13
+573 21
+574 8
+575 16
+576 12
+577 9
+578 16
+579 17
+580 22
+581 6
+582 14
+583 13
+584 15
+585 11
+586 13
+587 5
+588 12
+589 13
+590 15
+591 13
+592 15
+593 12
+594 7
+595 18
+596 12
+597 13
+598 13
+599 13
+600 12
+601 12
+602 10
+603 11
+604 6
+605 6
+606 2
+607 9
+608 8
+609 12
+610 9
+611 12
+612 13
+613 12
+614 14
+615 9
+616 8
+617 9
+618 14
+619 13
+620 12
+621 6
+622 8
+623 8
+624 8
+625 12
+626 8
+627 7
+628 5
+629 8
+630 12
+631 6
+632 10
+633 10
+634 7
+635 8
+636 9
+637 6
+638 9
+639 4
+640 12
+641 4
+642 3
+643 11
+644 10
+645 6
+646 12
+647 12
+648 4
+649 4
+650 9
+651 8
+652 6
+653 5
+654 14
+655 10
+656 11
+657 8
+658 5
+659 5
+660 9
+661 13
+662 4
+663 5
+664 9
+665 11
+666 12
+667 7
+668 13
+669 2
+670 1
+671 7
+672 7
+673 7
+674 10
+675 9
+676 6
+677 5
+678 7
+679 6
+680 3
+681 3
+682 4
+683 9
+684 8
+685 5
+686 3
+687 11
+688 9
+689 2
+690 6
+691 5
+692 9
+693 5
+694 6
+695 5
+696 9
+697 8
+698 3
+699 7
+700 5
+701 9
+702 8
+703 7
+704 2
+705 3
+706 7
+707 6
+708 6
+709 10
+710 2
+711 10
+712 6
+713 7
+714 5
+715 6
+716 4
+717 6
+718 8
+719 4
+720 6
+721 7
+722 5
+723 7
+724 3
+725 10
+726 10
+727 3
+728 7
+729 7
+730 5
+731 2
+732 1
+733 5
+734 1
+735 5
+736 6
+737 2
+738 2
+739 3
+740 7
+741 2
+742 7
+743 4
+744 5
+745 4
+746 5
+747 3
+748 1
+749 4
+750 4
+751 2
+752 4
+753 6
+754 6
+755 6
+756 3
+757 2
+758 5
+759 5
+760 3
+761 4
+762 2
+763 1
+764 8
+765 3
+766 4
+767 3
+768 1
+769 5
+770 3
+771 3
+772 4
+773 4
+774 1
+775 3
+776 2
+777 2
+778 3
+779 3
+780 1
+781 4
+782 3
+783 4
+784 6
+785 3
+786 5
+787 4
+788 2
+789 4
+790 5
+791 4
+792 6
+794 4
+795 1
+796 1
+797 4
+798 2
+799 3
+800 3
+801 1
+802 5
+803 5
+804 3
+805 3
+806 3
+807 4
+808 4
+809 2
+811 5
+812 4
+813 6
+814 3
+815 2
+816 2
+817 3
+818 5
+819 3
+820 1
+821 1
+822 4
+823 3
+824 4
+825 8
+826 3
+827 5
+828 5
+829 3
+830 6
+831 3
+832 4
+833 8
+834 5
+835 3
+836 3
+837 2
+838 4
+839 2
+840 1
+841 3
+842 2
+843 1
+844 3
+846 4
+847 4
+848 3
+849 3
+850 2
+851 3
+853 1
+854 4
+855 4
+856 2
+857 4
+858 1
+859 2
+860 5
+861 1
+862 1
+863 4
+864 2
+865 2
+867 5
+868 1
+869 4
+870 1
+871 1
+872 1
+873 2
+875 5
+876 3
+877 1
+878 3
+879 3
+880 3
+881 2
+882 1
+883 6
+884 2
+885 2
+886 1
+887 1
+888 3
+889 2
+890 2
+891 3
+892 1
+893 3
+894 1
+895 5
+896 1
+897 3
+899 2
+900 2
+902 1
+903 2
+904 4
+905 4
+906 3
+907 1
+908 1
+909 2
+910 5
+911 2
+912 3
+914 1
+915 1
+916 2
+918 2
+919 2
+920 4
+921 4
+922 1
+923 1
+924 4
+925 5
+926 1
+928 2
+929 1
+930 1
+931 1
+932 1
+933 1
+934 2
+935 1
+936 1
+937 1
+938 2
+939 1
+941 1
+942 4
+944 2
+945 2
+946 2
+947 1
+948 1
+950 1
+951 2
+953 1
+954 2
+955 1
+956 1
+957 2
+958 1
+960 3
+962 4
+963 1
+964 1
+965 3
+966 2
+967 2
+968 1
+969 3
+970 3
+972 1
+974 4
+975 3
+976 3
+977 2
+979 2
+980 1
+981 1
+983 5
+984 1
+985 3
+986 1
+987 2
+988 4
+989 2
+991 2
+992 2
+993 1
+994 1
+996 2
+997 2
+998 1
+999 3
+1000 2
+1001 1
+1002 3
+1003 3
+1004 2
+1005 3
+1006 1
+1007 2
+1009 1
+1011 1
+1013 3
+1014 1
+1016 2
+1017 1
+1018 1
+1019 1
+1020 4
+1021 1
+1022 2
+1025 1
+1026 1
+1027 2
+1028 1
+1030 1
+1031 2
+1032 4
+1034 3
+1035 2
+1036 1
+1038 1
+1039 1
+1040 1
+1041 1
+1042 2
+1043 1
+1044 2
+1045 4
+1048 1
+1050 1
+1051 1
+1052 2
+1054 1
+1055 3
+1056 2
+1057 1
+1059 1
+1061 2
+1063 1
+1064 1
+1065 1
+1066 1
+1067 1
+1068 1
+1069 2
+1074 1
+1075 1
+1077 1
+1078 1
+1079 1
+1082 1
+1085 1
+1088 1
+1090 1
+1091 1
+1092 2
+1094 2
+1097 2
+1098 1
+1099 2
+1101 2
+1102 1
+1104 1
+1105 1
+1107 1
+1109 1
+1111 2
+1112 1
+1114 2
+1115 2
+1116 2
+1117 1
+1118 1
+1119 1
+1120 1
+1122 1
+1123 1
+1127 1
+1128 3
+1132 2
+1138 3
+1142 1
+1145 4
+1150 1
+1153 2
+1154 1
+1158 1
+1159 1
+1163 1
+1165 1
+1169 2
+1174 1
+1176 1
+1177 1
+1178 2
+1179 1
+1180 2
+1181 1
+1182 1
+1183 2
+1185 1
+1187 1
+1191 2
+1193 1
+1195 3
+1196 1
+1201 3
+1203 1
+1206 1
+1210 1
+1213 1
+1214 1
+1215 2
+1218 1
+1220 1
+1221 1
+1225 1
+1226 1
+1233 2
+1241 1
+1243 1
+1249 1
+1250 2
+1251 1
+1254 1
+1255 2
+1260 1
+1268 1
+1270 1
+1273 1
+1274 1
+1277 1
+1284 1
+1287 1
+1291 1
+1292 2
+1294 1
+1295 2
+1297 1
+1298 1
+1301 1
+1307 1
+1308 3
+1311 2
+1313 1
+1316 1
+1321 1
+1324 1
+1325 1
+1330 1
+1333 1
+1334 1
+1338 2
+1340 1
+1341 1
+1342 1
+1343 1
+1345 1
+1355 1
+1357 1
+1360 2
+1375 1
+1376 1
+1380 1
+1383 1
+1387 1
+1389 1
+1393 1
+1394 1
+1396 1
+1398 1
+1410 1
+1414 1
+1419 1
+1425 1
+1434 1
+1435 1
+1438 1
+1439 1
+1447 1
+1455 2
+1460 1
+1461 1
+1463 1
+1466 1
+1470 1
+1473 1
+1478 1
+1480 1
+1483 1
+1484 1
+1485 2
+1492 2
+1499 1
+1509 1
+1512 1
+1513 1
+1523 1
+1524 1
+1525 2
+1529 1
+1539 1
+1544 1
+1568 1
+1584 1
+1591 1
+1598 1
+1600 1
+1604 1
+1614 1
+1617 1
+1621 1
+1622 1
+1626 1
+1638 1
+1648 1
+1658 1
+1661 1
+1679 1
+1682 1
+1693 1
+1700 1
+1705 1
+1707 1
+1722 1
+1728 1
+1758 1
+1762 1
+1763 1
+1775 1
+1776 1
+1801 1
+1810 1
+1812 1
+1827 1
+1834 1
+1846 1
+1847 1
+1848 1
+1851 1
+1862 1
+1866 1
+1877 2
+1884 1
+1888 1
+1903 1
+1912 1
+1925 1
+1938 1
+1955 1
+1998 1
+2054 1
+2058 1
+2065 1
+2069 1
+2076 1
+2089 1
+2104 1
+2111 1
+2133 1
+2138 1
+2156 1
+2204 1
+2212 1
+2237 1
+2246 2
+2298 1
+2304 1
+2360 1
+2400 1
+2481 1
+2544 1
+2586 1
+2622 1
+2666 1
+2682 1
+2725 1
+2920 1
+3997 1
+4019 1
+5211 1
+12 19
+14 1
+16 401
+18 2
+20 421
+22 557
+24 625
+26 50
+28 4481
+30 52
+32 550
+34 5840
+36 4644
+38 87
+40 5794
+41 33
+42 571
+44 11805
+46 4711
+47 7
+48 597
+49 12
+50 678
+51 2
+52 14715
+53 3
+54 7322
+55 3
+56 508
+57 39
+58 3486
+59 11
+60 8974
+61 45
+62 1276
+63 4
+64 15693
+65 15
+66 657
+67 13
+68 6409
+69 10
+70 3188
+71 25
+72 1889
+73 27
+74 10370
+75 9
+76 12432
+77 23
+78 520
+79 15
+80 1534
+81 29
+82 2944
+83 23
+84 12071
+85 36
+86 1502
+87 10
+88 10978
+89 11
+90 889
+91 16
+92 4571
+93 17
+94 7855
+95 21
+96 2271
+97 33
+98 1423
+99 15
+100 11096
+101 21
+102 4082
+103 13
+104 5442
+105 25
+106 2113
+107 26
+108 3779
+109 43
+110 1294
+111 29
+112 7860
+113 29
+114 4965
+115 22
+116 7898
+117 25
+118 1772
+119 28
+120 1149
+121 38
+122 1483
+123 32
+124 10572
+125 25
+126 1147
+127 31
+128 1699
+129 22
+130 5533
+131 22
+132 4669
+133 34
+134 3777
+135 10
+136 5412
+137 21
+138 855
+139 26
+140 2485
+141 46
+142 1970
+143 27
+144 6565
+145 40
+146 933
+147 15
+148 7923
+149 16
+150 735
+151 23
+152 1111
+153 33
+154 3714
+155 27
+156 2445
+157 30
+158 3367
+159 10
+160 4646
+161 27
+162 990
+163 23
+164 5679
+165 25
+166 2186
+167 17
+168 899
+169 32
+170 1034
+171 22
+172 6185
+173 32
+174 2685
+175 17
+176 1354
+177 38
+178 1460
+179 15
+180 3478
+181 20
+182 958
+183 20
+184 6055
+185 23
+186 2180
+187 15
+188 1416
+189 30
+190 1284
+191 22
+192 1341
+193 21
+194 2413
+195 18
+196 4984
+197 13
+198 830
+199 22
+200 1834
+201 19
+202 2238
+203 9
+204 3050
+205 22
+206 616
+207 17
+208 2892
+209 22
+210 711
+211 30
+212 2631
+213 19
+214 3341
+215 21
+216 987
+217 26
+218 823
+219 9
+220 3588
+221 20
+222 692
+223 7
+224 2925
+225 31
+226 1075
+227 16
+228 2909
+229 18
+230 673
+231 20
+232 2215
+233 14
+234 1584
+235 21
+236 1292
+237 29
+238 1647
+239 25
+240 1014
+241 30
+242 1648
+243 19
+244 4465
+245 10
+246 787
+247 11
+248 480
+249 25
+250 842
+251 15
+252 1219
+253 23
+254 1508
+255 8
+256 3525
+257 16
+258 490
+259 12
+260 1678
+261 14
+262 822
+263 16
+264 1729
+265 28
+266 604
+267 11
+268 2572
+269 7
+270 1242
+271 15
+272 725
+273 18
+274 1983
+275 13
+276 1662
+277 19
+278 491
+279 12
+280 1586
+281 14
+282 563
+283 10
+284 2363
+285 10
+286 656
+287 14
+288 725
+289 28
+290 871
+291 9
+292 2606
+293 12
+294 961
+295 9
+296 478
+297 13
+298 1252
+299 10
+300 736
+301 19
+302 466
+303 13
+304 2254
+305 12
+306 486
+307 14
+308 1145
+309 13
+310 955
+311 13
+312 1235
+313 13
+314 931
+315 14
+316 1768
+317 11
+318 330
+319 10
+320 539
+321 23
+322 570
+323 12
+324 1789
+325 13
+326 884
+327 5
+328 1422
+329 14
+330 317
+331 11
+332 509
+333 13
+334 1062
+335 12
+336 577
+337 27
+338 378
+339 10
+340 2313
+341 9
+342 391
+343 13
+344 894
+345 17
+346 664
+347 9
+348 453
+349 6
+350 363
+351 15
+352 1115
+353 13
+354 1054
+355 8
+356 1108
+357 12
+358 354
+359 7
+360 363
+361 16
+362 344
+363 11
+364 1734
+365 12
+366 265
+367 10
+368 969
+369 16
+370 316
+371 12
+372 757
+373 7
+374 563
+375 15
+376 857
+377 9
+378 469
+379 9
+380 385
+381 12
+382 921
+383 15
+384 764
+385 14
+386 246
+387 6
+388 1108
+389 14
+390 230
+391 8
+392 266
+393 11
+394 641
+395 8
+396 719
+397 9
+398 243
+399 4
+400 1108
+401 7
+402 229
+403 7
+404 903
+405 7
+406 257
+407 12
+408 244
+409 3
+410 541
+411 6
+412 744
+413 8
+414 419
+415 8
+416 388
+417 19
+418 470
+419 14
+420 612
+421 6
+422 342
+423 3
+424 1179
+425 3
+426 116
+427 14
+428 207
+429 6
+430 255
+431 4
+432 288
+433 12
+434 343
+435 6
+436 1015
+437 3
+438 538
+439 10
+440 194
+441 6
+442 188
+443 15
+444 524
+445 7
+446 214
+447 7
+448 574
+449 6
+450 214
+451 5
+452 635
+453 9
+454 464
+455 5
+456 205
+457 9
+458 163
+459 2
+460 558
+461 4
+462 171
+463 14
+464 444
+465 11
+466 543
+467 5
+468 388
+469 6
+470 141
+471 4
+472 647
+473 3
+474 210
+475 4
+476 193
+477 7
+478 195
+479 7
+480 443
+481 10
+482 198
+483 3
+484 816
+485 6
+486 128
+487 9
+488 215
+489 9
+490 328
+491 7
+492 158
+493 11
+494 335
+495 8
+496 435
+497 6
+498 174
+499 1
+500 373
+501 5
+502 140
+503 7
+504 330
+505 9
+506 149
+507 5
+508 642
+509 3
+510 179
+511 3
+512 159
+513 8
+514 204
+515 7
+516 306
+517 4
+518 110
+519 5
+520 326
+521 6
+522 305
+523 6
+524 294
+525 7
+526 268
+527 5
+528 149
+529 4
+530 133
+531 2
+532 513
+533 10
+534 116
+535 5
+536 258
+537 4
+538 113
+539 4
+540 138
+541 6
+542 116
+544 485
+545 4
+546 93
+547 9
+548 299
+549 3
+550 256
+551 6
+552 92
+553 3
+554 175
+555 6
+556 253
+557 7
+558 95
+559 2
+560 128
+561 4
+562 206
+563 2
+564 465
+565 3
+566 69
+567 3
+568 157
+569 7
+570 97
+571 8
+572 118
+573 5
+574 130
+575 4
+576 301
+577 6
+578 177
+579 2
+580 397
+581 3
+582 80
+583 1
+584 128
+585 5
+586 52
+587 2
+588 72
+589 1
+590 84
+591 6
+592 323
+593 11
+594 77
+595 5
+596 205
+597 1
+598 244
+599 4
+600 69
+601 3
+602 89
+603 5
+604 254
+605 6
+606 147
+607 3
+608 83
+609 3
+610 77
+611 3
+612 194
+613 1
+614 98
+615 3
+616 243
+617 3
+618 50
+619 8
+620 188
+621 4
+622 67
+623 4
+624 123
+625 2
+626 50
+627 1
+628 239
+629 2
+630 51
+631 4
+632 65
+633 5
+634 188
+636 81
+637 3
+638 46
+639 3
+640 103
+641 1
+642 136
+643 3
+644 188
+645 3
+646 58
+648 122
+649 4
+650 47
+651 2
+652 155
+653 4
+654 71
+655 1
+656 71
+657 3
+658 50
+659 2
+660 177
+661 5
+662 66
+663 2
+664 183
+665 3
+666 50
+667 2
+668 53
+669 2
+670 115
+672 66
+673 2
+674 47
+675 1
+676 197
+677 2
+678 46
+679 3
+680 95
+681 3
+682 46
+683 3
+684 107
+685 1
+686 86
+687 2
+688 158
+689 4
+690 51
+691 1
+692 80
+694 56
+695 4
+696 40
+698 43
+699 3
+700 95
+701 2
+702 51
+703 2
+704 133
+705 1
+706 100
+707 2
+708 121
+709 2
+710 15
+711 3
+712 35
+713 2
+714 20
+715 3
+716 37
+717 2
+718 78
+720 55
+721 1
+722 42
+723 2
+724 218
+725 3
+726 23
+727 2
+728 26
+729 1
+730 64
+731 2
+732 65
+734 24
+735 2
+736 53
+737 1
+738 32
+739 1
+740 60
+742 81
+743 1
+744 77
+745 1
+746 47
+747 1
+748 62
+749 1
+750 19
+751 1
+752 86
+753 3
+754 40
+756 55
+757 2
+758 38
+759 1
+760 101
+761 1
+762 22
+764 67
+765 2
+766 35
+767 1
+768 38
+769 1
+770 22
+771 1
+772 82
+773 1
+774 73
+776 29
+777 1
+778 55
+780 23
+781 1
+782 16
+784 84
+785 3
+786 28
+788 59
+789 1
+790 33
+791 3
+792 24
+794 13
+795 1
+796 110
+797 2
+798 15
+800 22
+801 3
+802 29
+803 1
+804 87
+806 21
+808 29
+810 48
+812 28
+813 1
+814 58
+815 1
+816 48
+817 1
+818 31
+819 1
+820 66
+822 17
+823 2
+824 58
+826 10
+827 2
+828 25
+829 1
+830 29
+831 1
+832 63
+833 1
+834 26
+835 3
+836 52
+837 1
+838 18
+840 27
+841 2
+842 12
+843 1
+844 83
+845 1
+846 7
+847 1
+848 10
+850 26
+852 25
+853 1
+854 15
+856 27
+858 32
+859 1
+860 15
+862 43
+864 32
+865 1
+866 6
+868 39
+870 11
+872 25
+873 1
+874 10
+875 1
+876 20
+877 2
+878 19
+879 1
+880 30
+882 11
+884 53
+886 25
+887 1
+888 28
+890 6
+892 36
+894 10
+896 13
+898 14
+900 31
+902 14
+903 2
+904 43
+906 25
+908 9
+910 11
+911 1
+912 16
+913 1
+914 24
+916 27
+918 6
+920 15
+922 27
+923 1
+924 23
+926 13
+928 42
+929 1
+930 3
+932 27
+934 17
+936 8
+937 1
+938 11
+940 33
+942 4
+943 1
+944 18
+946 15
+948 13
+950 18
+952 12
+954 11
+956 21
+958 10
+960 13
+962 5
+964 32
+966 13
+968 8
+970 8
+971 1
+972 23
+973 2
+974 12
+975 1
+976 22
+978 7
+979 1
+980 14
+982 8
+984 22
+985 1
+986 6
+988 17
+989 1
+990 6
+992 13
+994 19
+996 11
+998 4
+1000 9
+1002 2
+1004 14
+1006 5
+1008 3
+1010 9
+1012 29
+1014 6
+1016 22
+1017 1
+1018 8
+1019 1
+1020 7
+1022 6
+1023 1
+1024 10
+1026 2
+1028 8
+1030 11
+1031 2
+1032 8
+1034 9
+1036 13
+1038 12
+1040 12
+1042 3
+1044 12
+1046 3
+1048 11
+1050 2
+1051 1
+1052 2
+1054 11
+1056 6
+1058 8
+1059 1
+1060 23
+1062 6
+1063 1
+1064 8
+1066 3
+1068 6
+1070 8
+1071 1
+1072 5
+1074 3
+1076 5
+1078 3
+1080 11
+1081 1
+1082 7
+1084 18
+1086 4
+1087 1
+1088 3
+1090 3
+1092 7
+1094 3
+1096 12
+1098 6
+1099 1
+1100 2
+1102 6
+1104 14
+1106 3
+1108 6
+1110 5
+1112 2
+1114 8
+1116 3
+1118 3
+1120 7
+1122 10
+1124 6
+1126 8
+1128 1
+1130 4
+1132 3
+1134 2
+1136 5
+1138 5
+1140 8
+1142 3
+1144 7
+1146 3
+1148 11
+1150 1
+1152 5
+1154 1
+1156 5
+1158 1
+1160 5
+1162 3
+1164 6
+1165 1
+1166 1
+1168 4
+1169 1
+1170 3
+1171 1
+1172 2
+1174 5
+1176 3
+1177 1
+1180 8
+1182 2
+1184 4
+1186 2
+1188 3
+1190 2
+1192 5
+1194 6
+1196 1
+1198 2
+1200 2
+1204 10
+1206 2
+1208 9
+1210 1
+1214 6
+1216 3
+1218 4
+1220 9
+1221 2
+1222 1
+1224 5
+1226 4
+1228 8
+1230 1
+1232 1
+1234 3
+1236 5
+1240 3
+1242 1
+1244 3
+1245 1
+1246 4
+1248 6
+1250 2
+1252 7
+1256 3
+1258 2
+1260 2
+1262 3
+1264 4
+1265 1
+1266 1
+1270 1
+1271 1
+1272 2
+1274 3
+1276 3
+1278 1
+1280 3
+1284 1
+1286 1
+1290 1
+1292 3
+1294 1
+1296 7
+1300 2
+1302 4
+1304 3
+1306 2
+1308 2
+1312 1
+1314 1
+1316 3
+1318 2
+1320 1
+1324 8
+1326 1
+1330 1
+1331 1
+1336 2
+1338 1
+1340 3
+1341 1
+1344 1
+1346 2
+1347 1
+1348 3
+1352 1
+1354 2
+1356 1
+1358 1
+1360 3
+1362 1
+1364 4
+1366 1
+1370 1
+1372 3
+1380 2
+1384 2
+1388 2
+1390 2
+1392 2
+1394 1
+1396 1
+1398 1
+1400 2
+1402 1
+1404 1
+1406 1
+1410 1
+1412 5
+1418 1
+1420 1
+1424 1
+1432 2
+1434 2
+1442 3
+1444 5
+1448 1
+1454 1
+1456 1
+1460 3
+1462 4
+1468 1
+1474 1
+1476 1
+1478 2
+1480 1
+1486 2
+1488 1
+1492 1
+1496 1
+1500 3
+1503 1
+1506 1
+1512 2
+1516 1
+1522 1
+1524 2
+1534 4
+1536 1
+1538 1
+1540 2
+1544 2
+1548 1
+1556 1
+1560 1
+1562 1
+1564 2
+1566 1
+1568 1
+1570 1
+1572 1
+1576 1
+1590 1
+1594 1
+1604 1
+1608 1
+1614 1
+1622 1
+1624 2
+1628 1
+1629 1
+1636 1
+1642 1
+1654 2
+1660 1
+1664 1
+1670 1
+1684 4
+1698 1
+1732 3
+1742 1
+1752 1
+1760 1
+1764 1
+1772 2
+1798 1
+1808 1
+1820 1
+1852 1
+1856 1
+1874 1
+1902 1
+1908 1
+1952 1
+2004 1
+2018 1
+2020 1
+2028 1
+2174 1
+2233 1
+2244 1
+2280 1
+2290 1
+2352 1
+2604 1
+4190 1
diff --git a/ppocr/utils/gen_label.py b/ppocr/utils/gen_label.py
index 43afe9ddf182ad0da8df023ff29cd3759011d890..fb78bd38bcfc1a59cac48a28bbb655ecb83bcb3f 100644
--- a/ppocr/utils/gen_label.py
+++ b/ppocr/utils/gen_label.py
@@ -1,16 +1,16 @@
-#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
import argparse
import json
@@ -31,7 +31,9 @@ def gen_det_label(root_path, input_dir, out_label):
for label_file in os.listdir(input_dir):
img_path = root_path + label_file[3:-4] + ".jpg"
label = []
- with open(os.path.join(input_dir, label_file), 'r') as f:
+ with open(
+ os.path.join(input_dir, label_file), 'r',
+ encoding='utf-8-sig') as f:
for line in f.readlines():
tmp = line.strip("\n\r").replace("\xef\xbb\xbf",
"").split(',')
diff --git a/ppocr/utils/logging.py b/ppocr/utils/logging.py
index 951141db8f39acac612029c8b69f4a29a0ab27ce..11896c37d9285e19a9526caa9c637d7eda7b1979 100644
--- a/ppocr/utils/logging.py
+++ b/ppocr/utils/logging.py
@@ -22,7 +22,7 @@ logger_initialized = {}
@functools.lru_cache()
-def get_logger(name='root', log_file=None, log_level=logging.INFO):
+def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..453abb693d4c0ed370c1031b677d5bf51661add9
--- /dev/null
+++ b/ppocr/utils/network.py
@@ -0,0 +1,82 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import tarfile
+import requests
+from tqdm import tqdm
+
+from ppocr.utils.logging import get_logger
+
+
+def download_with_progressbar(url, save_path):
+ logger = get_logger()
+ response = requests.get(url, stream=True)
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
+ block_size = 1024 # 1 Kibibyte
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
+ with open(save_path, 'wb') as file:
+ for data in response.iter_content(block_size):
+ progress_bar.update(len(data))
+ file.write(data)
+ progress_bar.close()
+ if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
+ logger.error("Something went wrong while downloading models")
+ sys.exit(0)
+
+
+def maybe_download(model_storage_directory, url):
+ # using custom model
+ tar_file_name_list = [
+ 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
+ ]
+ if not os.path.exists(
+ os.path.join(model_storage_directory, 'inference.pdiparams')
+ ) or not os.path.exists(
+ os.path.join(model_storage_directory, 'inference.pdmodel')):
+ assert url.endswith('.tar'), 'Only supports tar compressed package'
+ tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
+ print('download {} to {}'.format(url, tmp_path))
+ os.makedirs(model_storage_directory, exist_ok=True)
+ download_with_progressbar(url, tmp_path)
+ with tarfile.open(tmp_path, 'r') as tarObj:
+ for member in tarObj.getmembers():
+ filename = None
+ for tar_file_name in tar_file_name_list:
+ if tar_file_name in member.name:
+ filename = tar_file_name
+ if filename is None:
+ continue
+ file = tarObj.extractfile(member)
+ with open(
+ os.path.join(model_storage_directory, filename),
+ 'wb') as f:
+ f.write(file.read())
+ os.remove(tmp_path)
+
+
+def is_link(s):
+ return s is not None and s.startswith('http')
+
+
+def confirm_model_dir_url(model_dir, default_model_dir, default_url):
+ url = default_url
+ if model_dir is None or is_link(model_dir):
+ if is_link(model_dir):
+ url = model_dir
+ file_name = url.split('/')[-1][:-4]
+ model_dir = default_model_dir
+ model_dir = os.path.join(model_dir, file_name)
+ return model_dir, url
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index 23f5401bb71a2ef50ff2ff2c3c27275d7e10b3c0..3bb022ed98b140995b79ceea93d7f494d3f5930d 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -25,7 +25,7 @@ import paddle
from ppocr.utils.logging import get_logger
-__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
+__all__ = ['init_model', 'save_model', 'load_dygraph_params']
def _mkdir_if_not_exist(path, logger):
@@ -89,6 +89,55 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
return best_model_dict
+def load_dygraph_params(config, model, logger, optimizer):
+ ckp = config['Global']['checkpoints']
+ if ckp and os.path.exists(ckp + ".pdparams"):
+ pre_best_model_dict = init_model(config, model, optimizer)
+ return pre_best_model_dict
+ else:
+ pm = config['Global']['pretrained_model']
+ if pm is None:
+ return {}
+ if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
+ logger.info(f"The pretrained_model {pm} does not exists!")
+ return {}
+ pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
+ params = paddle.load(pm)
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for k1, k2 in zip(state_dict.keys(), params.keys()):
+ if list(state_dict[k1].shape) == list(params[k2].shape):
+ new_state_dict[k1] = params[k2]
+ else:
+ logger.info(
+ f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
+ )
+ model.set_state_dict(new_state_dict)
+ logger.info(f"loaded pretrained_model successful from {pm}")
+ return {}
+
+def load_pretrained_params(model, path):
+ if path is None:
+ return False
+ if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
+ print(f"The pretrained_model {path} does not exists!")
+ return False
+
+ path = path if path.endswith('.pdparams') else path + '.pdparams'
+ params = paddle.load(path)
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for k1, k2 in zip(state_dict.keys(), params.keys()):
+ if list(state_dict[k1].shape) == list(params[k2].shape):
+ new_state_dict[k1] = params[k2]
+ else:
+ print(
+ f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
+ )
+ model.set_state_dict(new_state_dict)
+ print(f"load pretrain successful from {path}")
+ return model
+
def save_model(model,
optimizer,
model_path,
diff --git a/ppstructure/README.md b/ppstructure/README.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..8e1642cc75cc52b179d0f8441a8da2fe86e78d7b 100644
--- a/ppstructure/README.md
+++ b/ppstructure/README.md
@@ -0,0 +1,189 @@
+English | [简体中文](README_ch.md)
+
+# PP-Structure
+
+PP-Structure is an OCR toolkit that can be used for complex documents analysis. The main features are as follows:
+- Support the layout analysis of documents, divide the documents into 5 types of areas **text, title, table, image and list** (conjunction with Layout-Parser)
+- Support to extract the texts from the text, title, picture and list areas (used in conjunction with PP-OCR)
+- Support to extract excel files from the table areas
+- Support python whl package and command line usage, easy to use
+- Support custom training for layout analysis and table structure tasks
+
+## 1. Visualization
+
+
+
+
+
+## 2. Installation
+
+### 2.1 Install requirements
+
+- **(1) Install PaddlePaddle**
+
+```bash
+pip3 install --upgrade pip
+
+# GPU
+python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/simple
+
+# CPU
+ python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
+
+# For more,refer[Installation](https://www.paddlepaddle.org.cn/install/quick)。
+```
+
+- **(2) Install Layout-Parser**
+
+```bash
+pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+```
+
+### 2.2 Install PaddleOCR(including PP-OCR and PP-Structure)
+
+- **(1) PIP install PaddleOCR whl package(inference only)**
+
+```bash
+pip install "paddleocr>=2.2"
+```
+
+- **(2) Clone PaddleOCR(Inference+training)**
+
+```bash
+git clone https://github.com/PaddlePaddle/PaddleOCR
+```
+
+
+## 3. Quick Start
+
+### 3.1 Use by command line
+
+```bash
+paddleocr --image_dir=../doc/table/1.png --type=structure
+```
+
+### 3.2 Use by python API
+
+```python
+import os
+import cv2
+from paddleocr import PPStructure,draw_structure_result,save_structure_res
+
+table_engine = PPStructure(show_log=True)
+
+save_folder = './output/table'
+img_path = '../doc/table/1.png'
+img = cv2.imread(img_path)
+result = table_engine(img)
+save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+
+for line in result:
+ line.pop('img')
+ print(line)
+
+from PIL import Image
+
+font_path = '../doc/fonts/simfang.ttf'
+image = Image.open(img_path).convert('RGB')
+im_show = draw_structure_result(image, result,font_path=font_path)
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+### 3.3 Returned results format
+The returned results of PP-Structure is a list composed of a dict, an example is as follows
+
+```shell
+[
+ { 'type': 'Text',
+ 'bbox': [34, 432, 345, 462],
+ 'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
+ [('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
+ }
+]
+```
+The description of each field in dict is as follows
+
+| Parameter | Description |
+| --------------- | -------------|
+|type|Type of image area|
+|bbox|The coordinates of the image area in the original image, respectively [left upper x, left upper y, right bottom x, right bottom y]|
+|res|OCR or table recognition result of image area。 Table: HTML string of the table; OCR: A tuple containing the detection coordinates and recognition results of each single line of text|
+
+
+### 3.4 Parameter description:
+
+| Parameter | Description | Default value |
+| --------------- | ---------------------------------------- | ------------------------------------------- |
+| output | The path where excel and recognition results are saved | ./output/table |
+| table_max_len | The long side of the image is resized in table structure model | 488 |
+| table_model_dir | inference model path of table structure model | None |
+| table_char_type | dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.tx |
+
+Most of the parameters are consistent with the paddleocr whl package, see [doc of whl](../doc/doc_en/whl_en.md)
+
+After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
+
+## 4. PP-Structure Pipeline
+
+the process is as follows
+
+
+In PP-Structure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, including **text, title, image, list and table** 5 categories. For the first 4 types of areas, directly use the PP-OCR to complete the text detection and recognition. The table area will be converted to an excel file of the same table style via Table OCR.
+
+### 4.1 LayoutParser
+
+Layout analysis divides the document data into regions, including the use of Python scripts for layout analysis tools, extraction of special category detection boxes, performance indicators, and custom training layout analysis models. For details, please refer to [document](layout/README_en.md).
+
+### 4.2 Table Recognition
+
+Table Recognition converts table image into excel documents, which include the detection and recognition of table text and the prediction of table structure and cell coordinates. For detailed, please refer to [document](table/README.md)
+
+## 5. Prediction by inference engine
+
+Use the following commands to complete the inference.
+
+```python
+cd PaddleOCR/ppstructure
+
+# download model
+mkdir inference && cd inference
+# Download the detection model of the ultra-lightweight Chinese OCR model and uncompress it
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+# Download the recognition model of the ultra-lightweight Chinese OCR model and uncompress it
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
+# Download the table structure model of the ultra-lightweight Chinese OCR model and uncompress it
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+cd ..
+
+python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
+```
+After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
+
+**Model List**
+
+
+|model name|description|config|model size|download|
+| --- | --- | --- | --- | --- |
+|en_ppocr_mobile_v2.0_table_structure|Table structure prediction for English table scenarios|[table_mv3.yml](../configs/table/table_mv3.yml)|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
+
+**Model List**
+
+LayoutParser model
+
+|model name|description|download|
+| --- | --- | --- |
+| ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis model trained on the PubLayNet data set can be divided into 5 types of areas **text, title, table, picture and list** | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
+| ppyolov2_r50vd_dcn_365e_tableBank_word | The layout analysis model trained on the TableBank Word dataset can only detect tables | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
+| ppyolov2_r50vd_dcn_365e_tableBank_latex | The layout analysis model trained on the TableBank Latex dataset can only detect tables | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
+
+OCR and table recognition model
+
+|model name|description|model size|download|
+| --- | --- | --- | --- |
+|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
+|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
+|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
+|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
+|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
+
+If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` .
diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c8acac590039647cf52f47b16a99092ff68f2b6e 100644
--- a/ppstructure/README_ch.md
+++ b/ppstructure/README_ch.md
@@ -0,0 +1,188 @@
+[English](README.md) | 简体中文
+
+# PP-Structure
+
+PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,主要特性如下:
+- 支持对图片形式的文档进行版面分析,可以划分**文字、标题、表格、图片以及列表**5类区域(与Layout-Parser联合使用)
+- 支持文字、标题、图片以及列表区域提取为文字字段(与PP-OCR联合使用)
+- 支持表格区域进行结构化分析,最终结果输出Excel文件
+- 支持python whl包和命令行两种方式,简单易用
+- 支持版面分析和表格结构化两类任务自定义训练
+
+## 1. 效果展示
+
+
+
+
+
+## 2. 安装
+
+### 2.1 安装依赖
+
+- **(1) 安装PaddlePaddle**
+
+```bash
+pip3 install --upgrade pip
+
+# GPU安装
+python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/simple
+
+# CPU安装
+ python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
+
+# 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
+```
+
+- **(2) 安装 Layout-Parser**
+
+```bash
+pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+```
+
+### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure)
+
+- **(1) PIP快速安装PaddleOCR whl包(仅预测)**
+
+```bash
+pip install "paddleocr>=2.2" # 推荐使用2.2+版本
+```
+
+- **(2) 完整克隆PaddleOCR源码(预测+训练)**
+
+```bash
+【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
+
+#如果因为网络问题无法pull成功,也可选择使用码云上的托管:
+git clone https://gitee.com/paddlepaddle/PaddleOCR
+
+#注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
+```
+
+
+## 3. PP-Structure 快速开始
+
+### 3.1 命令行使用(默认参数,极简)
+
+```bash
+paddleocr --image_dir=../doc/table/1.png --type=structure
+```
+
+### 3.2 Python脚本使用(自定义参数,灵活)
+
+```python
+import os
+import cv2
+from paddleocr import PPStructure,draw_structure_result,save_structure_res
+
+table_engine = PPStructure(show_log=True)
+
+save_folder = './output/table'
+img_path = '../doc/table/1.png'
+img = cv2.imread(img_path)
+result = table_engine(img)
+save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+
+for line in result:
+ line.pop('img')
+ print(line)
+
+from PIL import Image
+
+font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
+image = Image.open(img_path).convert('RGB')
+im_show = draw_structure_result(image, result,font_path=font_path)
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+
+### 3.3 返回结果说明
+PP-Structure的返回结果为一个dict组成的list,示例如下
+
+```shell
+[
+ { 'type': 'Text',
+ 'bbox': [34, 432, 345, 462],
+ 'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
+ [('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
+ }
+]
+```
+dict 里各个字段说明如下
+
+| 字段 | 说明 |
+| --------------- | -------------|
+|type|图片区域的类型|
+|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
+|res|图片区域的OCR或表格识别结果。 表格: 表格的HTML字符串; OCR: 一个包含各个单行文字的检测坐标和识别结果的元组|
+
+
+### 3.4 参数说明
+
+| 字段 | 说明 | 默认值 |
+| --------------- | ---------------------------------------- | ------------------------------------------- |
+| output | excel和识别结果保存的地址 | ./output/table |
+| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
+| table_model_dir | 表格结构模型 inference 模型地址 | None |
+| table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx |
+
+大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
+
+运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+
+
+## 4. PP-Structure Pipeline介绍
+
+
+
+在PP-Structure中,图片会先经由Layout-Parser进行版面分析,在版面分析中,会对图片里的区域进行分类,包括**文字、标题、图片、列表和表格**5类。对于前4类区域,直接使用PP-OCR完成对应区域文字检测与识别。对于表格类区域,经过表格结构化处理后,表格图片转换为相同表格样式的Excel文件。
+
+### 4.1 版面分析
+
+版面分析对文档数据进行区域分类,其中包括版面分析工具的Python脚本使用、提取指定类别检测框、性能指标以及自定义训练版面分析模型,详细内容可以参考[文档](layout/README_ch.md)。
+
+### 4.2 表格识别
+
+表格识别将表格图片转换为excel文档,其中包含对于表格文本的检测和识别以及对于表格结构和单元格坐标的预测,详细说明参考[文档](table/README_ch.md)
+
+## 5. 预测引擎推理(与whl包效果相同)
+
+使用如下命令即可完成预测引擎的推理
+
+```python
+cd ppstructure
+
+# 下载模型
+mkdir inference && cd inference
+# 下载超轻量级中文OCR模型的检测模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+# 下载超轻量级中文OCR模型的识别模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
+# 下载超轻量级英文表格英寸模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+cd ..
+
+python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
+```
+运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+
+**Model List**
+
+LayoutParser 模型
+
+|模型名称|模型简介|下载地址|
+| --- | --- | --- |
+| ppyolov2_r50vd_dcn_365e_publaynet | PubLayNet 数据集训练的版面分析模型,可以划分**文字、标题、表格、图片以及列表**5类区域 | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
+| ppyolov2_r50vd_dcn_365e_tableBank_word | TableBank Word 数据集训练的版面分析模型,只能检测表格 | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
+| ppyolov2_r50vd_dcn_365e_tableBank_latex | TableBank Latex 数据集训练的版面分析模型,只能检测表格 | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
+
+OCR和表格识别模型
+
+|模型名称|模型简介|推理模型大小|下载地址|
+| --- | --- | --- | --- |
+|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
+|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
+|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
+|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
+|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
+
+如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
diff --git a/ppstructure/__init__.py b/ppstructure/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d11e265597c7c8e39098a228108da3bb954b892
--- /dev/null
+++ b/ppstructure/__init__.py
@@ -0,0 +1,13 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/ppstructure/layout/README.md b/ppstructure/layout/README.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..74cb928e30c012d5b469d685fd63b443a7d22613 100644
--- a/ppstructure/layout/README.md
+++ b/ppstructure/layout/README.md
@@ -0,0 +1,141 @@
+English | [简体中文](README_ch.md)
+
+
+# Getting Started
+
+[1. Install whl package](#Install)
+
+[2. Quick Start](#QuickStart)
+
+[3. PostProcess](#PostProcess)
+
+[4. Results](#Results)
+
+[5. Training](#Training)
+
+
+
+## 1. Install whl package
+```bash
+wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+pip install -U layoutparser-0.0.0-py3-none-any.whl
+```
+
+
+
+## 2. Quick Start
+
+Use LayoutParser to identify the layout of a document:
+
+```python
+import cv2
+import layoutparser as lp
+image = cv2.imread("doc/table/layout.jpg")
+image = image[..., ::-1]
+
+# load model
+model = lp.PaddleDetectionLayoutModel(config_path="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
+ threshold=0.5,
+ label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},
+ enforce_cpu=False,
+ enable_mkldnn=True)
+# detect
+layout = model.detect(image)
+
+# show result
+show_img = lp.draw_box(image, layout, box_width=3, show_element_type=True)
+show_img.show()
+```
+
+The following figure shows the result, with different colored detection boxes representing different categories and displaying specific categories in the upper left corner of the box with `show_element_type`
+
+
+ 
+
+`PaddleDetectionLayoutModel`parameters are described as follows:
+
+| parameter | description | default | remark |
+| :------------: | :------------------------------------------------------: | :---------: | :----------------------------------------------------------: |
+| config_path | model config path | None | Specify config_ path will automatically download the model (only for the first time,the model will exist and will not be downloaded again) |
+| model_path | model path | None | local model path, config_ path and model_ path must be set to one, cannot be none at the same time |
+| threshold | threshold of prediction score | 0.5 | \ |
+| input_shape | picture size of reshape | [3,640,640] | \ |
+| batch_size | testing batch size | 1 | \ |
+| label_map | category mapping table | None | Setting config_ path, it can be none, and the label is automatically obtained according to the dataset name_ map |
+| enforce_cpu | whether to use CPU | False | False to use GPU, and True to force the use of CPU |
+| enforce_mkldnn | whether mkldnn acceleration is enabled in CPU prediction | True | \ |
+| thread_num | the number of CPU threads | 10 | \ |
+
+The following model configurations and label maps are currently supported, which you can use by modifying '--config_path' and '--label_map' to detect different types of content:
+
+| dataset | config_path | label_map |
+| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------------------- |
+| [TableBank](https://doc-analysis.github.io/tablebank-page/index.html) word | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_word/config | {0:"Table"} |
+| TableBank latex | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config | {0:"Table"} |
+| [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | {0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"} |
+
+* TableBank word and TableBank latex are trained on datasets of word documents and latex documents respectively;
+* Download TableBank dataset contains both word and latex。
+
+
+
+## 3. PostProcess
+
+Layout parser contains multiple categories, if you only want to get the detection box for a specific category (such as the "Text" category), you can use the following code:
+
+```python
+# follow the above code
+# filter areas for a specific text type
+text_blocks = lp.Layout([b for b in layout if b.type=='Text'])
+figure_blocks = lp.Layout([b for b in layout if b.type=='Figure'])
+
+# text areas may be detected within the image area, delete these areas
+text_blocks = lp.Layout([b for b in text_blocks \
+ if not any(b.is_in(b_fig) for b_fig in figure_blocks)])
+
+# sort text areas and assign ID
+h, w = image.shape[:2]
+
+left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image)
+
+left_blocks = text_blocks.filter_by(left_interval, center=True)
+left_blocks.sort(key = lambda b:b.coordinates[1])
+
+right_blocks = [b for b in text_blocks if b not in left_blocks]
+right_blocks.sort(key = lambda b:b.coordinates[1])
+
+# the two lists are merged and the indexes are added in order
+text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])
+
+# display result
+show_img = lp.draw_box(image, text_blocks,
+ box_width=3,
+ show_element_id=True)
+show_img.show()
+```
+
+Displays results with only the "Text" category:
+
+
+ 
+
+
+
+## 4. Results
+
+| Dataset | mAP | CPU time cost | GPU time cost |
+| --------- | ---- | ------------- | ------------- |
+| PubLayNet | 93.6 | 1713.7ms | 66.6ms |
+| TableBank | 96.2 | 1968.4ms | 65.1ms |
+
+**Envrionment:**
+
+ **CPU:** Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz,24core
+
+ **GPU:** a single NVIDIA Tesla P40
+
+
+
+## 5. Training
+
+The above model is based on [PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection). If you want to train your own layout parser model,please refer to:[train_layoutparser_model](train_layoutparser_model.md)
diff --git a/ppstructure/layout/README_ch.md b/ppstructure/layout/README_ch.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c722e0bd88f40ff6b711edecff0433029e101f87 100644
--- a/ppstructure/layout/README_ch.md
+++ b/ppstructure/layout/README_ch.md
@@ -0,0 +1,141 @@
+[English](README.md) | 简体中文
+
+# 版面分析使用说明
+
+[1. 安装whl包](#安装whl包)
+
+[2. 使用](#使用)
+
+[3. 后处理](#后处理)
+
+[4. 指标](#指标)
+
+[5. 训练版面分析模型](#训练版面分析模型)
+
+
+
+## 1. 安装whl包
+```bash
+pip install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+```
+
+
+
+## 2. 使用
+
+使用layoutparser识别给定文档的布局:
+
+```python
+import cv2
+import layoutparser as lp
+image = cv2.imread("doc/table/layout.jpg")
+image = image[..., ::-1]
+
+# 加载模型
+model = lp.PaddleDetectionLayoutModel(config_path="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
+ threshold=0.5,
+ label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},
+ enforce_cpu=False,
+ enable_mkldnn=True)
+# 检测
+layout = model.detect(image)
+
+# 显示结果
+show_img = lp.draw_box(image, layout, box_width=3, show_element_type=True)
+show_img.show()
+```
+
+下图展示了结果,不同颜色的检测框表示不同的类别,并通过`show_element_type`在框的左上角显示具体类别:
+
+
+ 
+
+
+`PaddleDetectionLayoutModel`函数参数说明如下:
+
+| 参数 | 含义 | 默认值 | 备注 |
+| :------------: | :-------------------------: | :---------: | :----------------------------------------------------------: |
+| config_path | 模型配置路径 | None | 指定config_path会自动下载模型(仅第一次,之后模型存在,不会再下载) |
+| model_path | 模型路径 | None | 本地模型路径,config_path和model_path必须设置一个,不能同时为None |
+| threshold | 预测得分的阈值 | 0.5 | \ |
+| input_shape | reshape之后图片尺寸 | [3,640,640] | \ |
+| batch_size | 测试batch size | 1 | \ |
+| label_map | 类别映射表 | None | 设置config_path时,可以为None,根据数据集名称自动获取label_map |
+| enforce_cpu | 代码是否使用CPU运行 | False | 设置为False表示使用GPU,True表示强制使用CPU |
+| enforce_mkldnn | CPU预测中是否开启MKLDNN加速 | True | \ |
+| thread_num | 设置CPU线程数 | 10 | \ |
+
+目前支持以下几种模型配置和label map,您可以通过修改 `--config_path`和 `--label_map`使用这些模型,从而检测不同类型的内容:
+
+| dataset | config_path | label_map |
+| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------------------- |
+| [TableBank](https://doc-analysis.github.io/tablebank-page/index.html) word | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_word/config | {0:"Table"} |
+| TableBank latex | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config | {0:"Table"} |
+| [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | {0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"} |
+
+* TableBank word和TableBank latex分别在word文档、latex文档数据集训练;
+* 下载的TableBank数据集里同时包含word和latex。
+
+
+
+## 3. 后处理
+
+版面分析检测包含多个类别,如果只想获取指定类别(如"Text"类别)的检测框、可以使用下述代码:
+
+```python
+# 接上面代码
+# 首先过滤特定文本类型的区域
+text_blocks = lp.Layout([b for b in layout if b.type=='Text'])
+figure_blocks = lp.Layout([b for b in layout if b.type=='Figure'])
+
+# 因为在图像区域内可能检测到文本区域,所以只需要删除它们
+text_blocks = lp.Layout([b for b in text_blocks \
+ if not any(b.is_in(b_fig) for b_fig in figure_blocks)])
+
+# 对文本区域排序并分配id
+h, w = image.shape[:2]
+
+left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image)
+
+left_blocks = text_blocks.filter_by(left_interval, center=True)
+left_blocks.sort(key = lambda b:b.coordinates[1])
+
+right_blocks = [b for b in text_blocks if b not in left_blocks]
+right_blocks.sort(key = lambda b:b.coordinates[1])
+
+# 最终合并两个列表,并按顺序添加索引
+text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])
+
+# 显示结果
+show_img = lp.draw_box(image, text_blocks,
+ box_width=3,
+ show_element_id=True)
+show_img.show()
+```
+
+显示只有"Text"类别的结果:
+
+
+ 
+
+
+
+
+## 4. 指标
+
+| Dataset | mAP | CPU time cost | GPU time cost |
+| --------- | ---- | ------------- | ------------- |
+| PubLayNet | 93.6 | 1713.7ms | 66.6ms |
+| TableBank | 96.2 | 1968.4ms | 65.1ms |
+
+**Envrionment:**
+
+ **CPU:** Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz,24core
+
+ **GPU:** a single NVIDIA Tesla P40
+
+
+
+## 5. 训练版面分析模型
+
+上述模型基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection) 训练,如果您想训练自己的版面分析模型,请参考:[train_layoutparser_model](train_layoutparser_model_ch.md)
diff --git a/ppstructure/layout/train_layoutparser_model.md b/ppstructure/layout/train_layoutparser_model.md
new file mode 100644
index 0000000000000000000000000000000000000000..08f5ebbf1aa276e4a3ecf27af46442161afcda1f
--- /dev/null
+++ b/ppstructure/layout/train_layoutparser_model.md
@@ -0,0 +1,203 @@
+# Training layout-parse
+
+[1. Installation](#Installation)
+
+ [1.1 Requirements](#Requirements)
+
+ [1.2 Install PaddleDetection](#Install PaddleDetection)
+
+[2. Data preparation](#Data preparation)
+
+[3. Configuration](#Configuration)
+
+[4. Training](#Training)
+
+[5. Prediction](#Prediction)
+
+[6. Deployment](#Deployment)
+
+ [6.1 Export model](#Export model)
+
+ [6.2 Inference](#Inference)
+
+
+
+## 1. Installation
+
+
+
+### 1.1 Requirements
+
+- PaddlePaddle 2.1
+- OS 64 bit
+- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit
+- pip/pip3(9.0.1+), 64 bit
+- CUDA >= 10.1
+- cuDNN >= 7.6
+
+
+
+### 1.2 Install PaddleDetection
+
+```bash
+# Clone PaddleDetection repository
+cd
+git clone https://github.com/PaddlePaddle/PaddleDetection.git
+
+cd PaddleDetection
+# Install other dependencies
+pip install -r requirements.txt
+```
+
+For more installation tutorials, please refer to: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
+
+
+
+## 2. Data preparation
+
+Download the [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) dataset
+
+```bash
+cd PaddleDetection/dataset/
+mkdir publaynet
+# execute the command,download PubLayNet
+wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733
+# unpack
+tar -xvf publaynet.tar.gz
+```
+
+PubLayNet directory structure after decompressing :
+
+| File or Folder | Description | num |
+| :------------- | :----------------------------------------------- | ------- |
+| `train/` | Images in the training subset | 335,703 |
+| `val/` | Images in the validation subset | 11,245 |
+| `test/` | Images in the testing subset | 11,405 |
+| `train.json` | Annotations for training images | 1 |
+| `val.json` | Annotations for validation images | 1 |
+| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | 1 |
+| `README.txt` | Text file with the file names and description | 1 |
+
+For other datasets,please refer to [the PrepareDataSet]((https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md) )
+
+
+
+## 3. Configuration
+
+We use the `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml` configuration for training,the configuration file is as follows
+
+```bash
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ './_base_/ppyolov2_r50vd_dcn.yml',
+ './_base_/optimizer_365e.yml',
+ './_base_/ppyolov2_reader.yml',
+]
+
+snapshot_epoch: 8
+weights: output/ppyolov2_r50vd_dcn_365e_coco/model_final
+```
+The `ppyolov2_r50vd_dcn_365e_coco.yml` configuration depends on other configuration files, in this case:
+
+- coco_detection.yml:mainly explains the path of training data and verification data
+
+- runtime.yml:mainly describes the common parameters, such as whether to use the GPU and how many epoch to save model etc.
+
+- optimizer_365e.yml:mainly explains the learning rate and optimizer configuration
+
+- ppyolov2_r50vd_dcn.yml:mainly describes the model and the network
+
+- ppyolov2_reader.yml:mainly describes the configuration of data readers, such as batch size and number of concurrent loading child processes, and also includes post preprocessing, such as resize and data augmention etc.
+
+
+Modify the preceding files, such as the dataset path and batch size etc.
+
+
+
+## 4. Training
+
+PaddleDetection provides single-card/multi-card training mode to meet various training needs of users:
+
+* GPU single card training
+
+```bash
+export CUDA_VISIBLE_DEVICES=0 #Don't need to run this command on Windows and Mac
+python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml
+```
+
+* GPU multi-card training
+
+```bash
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval
+```
+
+--eval: training while verifying
+
+* Model recovery training
+
+During the daily training, if training is interrupted due to some reasons, you can use the -r command to resume the training:
+
+```bash
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000
+```
+
+Note: If you encounter "`Out of memory error`" , try reducing `batch_size` in the `ppyolov2_reader.yml` file
+
+prediction
+
+## 5. Prediction
+
+Set parameters and use PaddleDetection to predict:
+
+```bash
+export CUDA_VISIBLE_DEVICES=0
+python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture
+```
+
+`--draw_threshold` is an optional parameter. According to the calculation of [NMS](https://ieeexplore.ieee.org/document/1699659), different threshold will produce different results, ` keep_top_k ` represent the maximum amount of output target, the default value is 10. You can set different value according to your own actual situation。
+
+
+
+## 6. Deployment
+
+Use your trained model in Layout Parser
+
+
+
+### 6.1 Export model
+
+n the process of model training, the model file saved contains the process of forward prediction and back propagation. In the actual industrial deployment, there is no need for back propagation. Therefore, the model should be translated into the model format required by the deployment. The `tools/export_model.py` script is provided in PaddleDetection to export the model.
+
+The exported model name defaults to `model.*`, Layout Parser's code model is `inference.*`, So change [PaddleDetection/ppdet/engine/trainer. Py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py# L512) (click on the link to see the detailed line of code), change 'model' to 'inference'.
+
+Execute the script to export model:
+
+```bash
+python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams
+```
+
+The prediction model is exported to `inference/ppyolov2_r50vd_dcn_365e_coco` ,including:`infer_cfg.yml`(prediction not required), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel`
+
+More model export tutorials, please refer to:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md)
+
+
+
+### 6.2 Inference
+
+`model_path` represent the trained model path, and layoutparser is used to predict:
+
+```bash
+import layoutparser as lp
+model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True)
+```
+
+
+
+***
+
+More PaddleDetection training tutorials,please reference:[PaddleDetection Training](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md)
+
+***
diff --git a/ppstructure/layout/train_layoutparser_model_ch.md b/ppstructure/layout/train_layoutparser_model_ch.md
new file mode 100644
index 0000000000000000000000000000000000000000..2f73c63adcea3f82ae579222e658291224f46237
--- /dev/null
+++ b/ppstructure/layout/train_layoutparser_model_ch.md
@@ -0,0 +1,203 @@
+# 训练版面分析
+
+[1. 安装](#安装)
+
+ [1.1 环境要求](#环境要求)
+
+ [1.2 安装PaddleDetection](#安装PaddleDetection)
+
+[2. 准备数据](#准备数据)
+
+[3. 配置文件改动和说明](#配置文件改动和说明)
+
+[4. PaddleDetection训练](#训练)
+
+[5. PaddleDetection预测](#预测)
+
+[6. 预测部署](#预测部署)
+
+ [6.1 模型导出](#模型导出)
+
+ [6.2 layout parser预测](#layout_parser预测)
+
+
+
+## 1. 安装
+
+
+
+### 1.1 环境要求
+
+- PaddlePaddle 2.1
+- OS 64 bit
+- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit
+- pip/pip3(9.0.1+), 64 bit
+- CUDA >= 10.1
+- cuDNN >= 7.6
+
+
+
+### 1.2 安装PaddleDetection
+
+```bash
+# 克隆PaddleDetection仓库
+cd
+git clone https://github.com/PaddlePaddle/PaddleDetection.git
+
+cd PaddleDetection
+# 安装其他依赖
+pip install -r requirements.txt
+```
+
+更多安装教程,请参考: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
+
+
+
+## 2. 准备数据
+
+下载 [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) 数据集:
+
+```bash
+cd PaddleDetection/dataset/
+mkdir publaynet
+# 执行命令,下载
+wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733
+# 解压
+tar -xvf publaynet.tar.gz
+```
+
+解压之后PubLayNet目录结构:
+
+| File or Folder | Description | num |
+| :------------- | :----------------------------------------------- | ------- |
+| `train/` | Images in the training subset | 335,703 |
+| `val/` | Images in the validation subset | 11,245 |
+| `test/` | Images in the testing subset | 11,405 |
+| `train.json` | Annotations for training images | 1 |
+| `val.json` | Annotations for validation images | 1 |
+| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | 1 |
+| `README.txt` | Text file with the file names and description | 1 |
+
+如果使用其它数据集,请参考[准备训练数据](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md)
+
+
+
+## 3. 配置文件改动和说明
+
+我们使用 `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml`配置进行训练,配置文件摘要如下:
+
+```bash
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ './_base_/ppyolov2_r50vd_dcn.yml',
+ './_base_/optimizer_365e.yml',
+ './_base_/ppyolov2_reader.yml',
+]
+
+snapshot_epoch: 8
+weights: output/ppyolov2_r50vd_dcn_365e_coco/model_final
+```
+从中可以看到 `ppyolov2_r50vd_dcn_365e_coco.yml` 配置需要依赖其他的配置文件,在该例子中需要依赖:
+
+- coco_detection.yml:主要说明了训练数据和验证数据的路径
+
+- runtime.yml:主要说明了公共的运行参数,比如是否使用GPU、每多少个epoch存储checkpoint等
+
+- optimizer_365e.yml:主要说明了学习率和优化器的配置
+
+- ppyolov2_r50vd_dcn.yml:主要说明模型和主干网络的情况
+
+- ppyolov2_reader.yml:主要说明数据读取器配置,如batch size,并发加载子进程数等,同时包含读取后预处理操作,如resize、数据增强等等
+
+
+根据实际情况,修改上述文件,比如数据集路径、batch size等。
+
+
+
+## 4. PaddleDetection训练
+
+PaddleDetection提供了单卡/多卡训练模式,满足用户多种训练需求
+
+* GPU 单卡训练
+
+```bash
+export CUDA_VISIBLE_DEVICES=0 #windows和Mac下不需要执行该命令
+python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml
+```
+
+* GPU多卡训练
+
+```bash
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval
+```
+
+--eval:表示边训练边验证
+
+* 模型恢复训练
+
+在日常训练过程中,有的用户由于一些原因导致训练中断,用户可以使用-r的命令恢复训练:
+
+```bash
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000
+```
+
+注意:如果遇到 "`Out of memory error`" 问题, 尝试在 `ppyolov2_reader.yml` 文件中调小`batch_size`
+
+
+
+## 5. PaddleDetection预测
+
+设置参数,使用PaddleDetection预测:
+
+```bash
+export CUDA_VISIBLE_DEVICES=0
+python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture
+```
+
+`--draw_threshold` 是个可选参数. 根据 [NMS](https://ieeexplore.ieee.org/document/1699659) 的计算,不同阈值会产生不同的结果 `keep_top_k`表示设置输出目标的最大数量,默认值为100,用户可以根据自己的实际情况进行设定。
+
+
+
+## 6. 预测部署
+
+在layout parser中使用自己训练好的模型。
+
+
+
+### 6.1 模型导出
+
+在模型训练过程中保存的模型文件是包含前向预测和反向传播的过程,在实际的工业部署则不需要反向传播,因此需要将模型进行导成部署需要的模型格式。 在PaddleDetection中提供了 `tools/export_model.py`脚本来导出模型。
+
+导出模型名称默认是`model.*`,layout parser代码模型名称是`inference.*`, 所以修改[PaddleDetection/ppdet/engine/trainer.py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py#L512) (点开链接查看详细代码行),将`model`改为`inference`即可。
+
+执行导出模型脚本:
+
+```bash
+python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams
+```
+
+预测模型会导出到`inference/ppyolov2_r50vd_dcn_365e_coco`目录下,分别为`infer_cfg.yml`(预测不需要), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel` 。
+
+更多模型导出教程,请参考:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md)
+
+
+
+### 6.2 layout_parser预测
+
+`model_path`指定训练好的模型路径,使用layout parser进行预测:
+
+```bash
+import layoutparser as lp
+model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True)
+```
+
+
+
+***
+
+更多PaddleDetection训练教程,请参考:[PaddleDetection训练](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md)
+
+***
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..b2de3d4de80b39f046cf6cbc8a9ebbc52bf69334 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import subprocess
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+import cv2
+import numpy as np
+import time
+import logging
+
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.utils.logging import get_logger
+from tools.infer.predict_system import TextSystem
+from ppstructure.table.predict_table import TableSystem, to_excel
+from ppstructure.utility import parse_args, draw_structure_result
+
+logger = get_logger()
+
+
+class OCRSystem(object):
+ def __init__(self, args):
+ import layoutparser as lp
+ # args.det_limit_type = 'resize_long'
+ args.drop_score = 0
+ if not args.show_log:
+ logger.setLevel(logging.INFO)
+ self.text_system = TextSystem(args)
+ self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
+
+ config_path = None
+ model_path = None
+ if os.path.isdir(args.layout_path_model):
+ model_path = args.layout_path_model
+ else:
+ config_path = args.layout_path_model
+ self.table_layout = lp.PaddleDetectionLayoutModel(config_path=config_path,
+ model_path=model_path,
+ threshold=0.5, enable_mkldnn=args.enable_mkldnn,
+ enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
+ self.use_angle_cls = args.use_angle_cls
+ self.drop_score = args.drop_score
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ layout_res = self.table_layout.detect(img[..., ::-1])
+ res_list = []
+ for region in layout_res:
+ x1, y1, x2, y2 = region.coordinates
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+ roi_img = ori_im[y1:y2, x1:x2, :]
+ if region.type == 'Table':
+ res = self.table_system(roi_img)
+ else:
+ filter_boxes, filter_rec_res = self.text_system(roi_img)
+ filter_boxes = [x + [x1, y1] for x in filter_boxes]
+ filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
+ # remove style char
+ style_token = ['', '', '', '', '', '', '', '',
+ '', '', '', '', '', '']
+ filter_rec_res_tmp = []
+ for rec_res in filter_rec_res:
+ rec_str, rec_conf = rec_res
+ for token in style_token:
+ if token in rec_str:
+ rec_str = rec_str.replace(token, '')
+ filter_rec_res_tmp.append((rec_str, rec_conf))
+ res = (filter_boxes, filter_rec_res_tmp)
+ res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'img': roi_img, 'res': res})
+ return res_list
+
+
+def save_structure_res(res, save_folder, img_name):
+ excel_save_folder = os.path.join(save_folder, img_name)
+ os.makedirs(excel_save_folder, exist_ok=True)
+ # save res
+ with open(os.path.join(excel_save_folder, 'res.txt'), 'w', encoding='utf8') as f:
+ for region in res:
+ if region['type'] == 'Table':
+ excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
+ to_excel(region['res'], excel_path)
+ if region['type'] == 'Figure':
+ roi_img = region['img']
+ img_path = os.path.join(excel_save_folder, '{}.jpg'.format(region['bbox']))
+ cv2.imwrite(img_path, roi_img)
+ else:
+ for box, rec_res in zip(region['res'][0], region['res'][1]):
+ f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ image_file_list = image_file_list
+ image_file_list = image_file_list[args.process_id::args.total_process_num]
+ save_folder = args.output
+ os.makedirs(save_folder, exist_ok=True)
+
+ structure_sys = OCRSystem(args)
+ img_num = len(image_file_list)
+ for i, image_file in enumerate(image_file_list):
+ logger.info("[{}/{}] {}".format(i, img_num, image_file))
+ img, flag = check_and_read_gif(image_file)
+ img_name = os.path.basename(image_file).split('.')[0]
+
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.error("error in loading image:{}".format(image_file))
+ continue
+ starttime = time.time()
+ res = structure_sys(img)
+ save_structure_res(res, save_folder, img_name)
+ draw_img = draw_structure_result(img, res, args.vis_font_path)
+ cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
+ logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
+ elapse = time.time() - starttime
+ logger.info("Predict time : {:.3f}s".format(elapse))
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.use_mp:
+ p_list = []
+ total_process_num = args.total_process_num
+ for process_id in range(total_process_num):
+ cmd = [sys.executable, "-u"] + sys.argv + [
+ "--process_id={}".format(process_id),
+ "--use_mp={}".format(False)
+ ]
+ p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
+ p_list.append(p)
+ for p in p_list:
+ p.wait()
+ else:
+ main(args)
diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..a8d10b79e507ab59ef2481982a33902e4a95e73e 100644
--- a/ppstructure/table/README.md
+++ b/ppstructure/table/README.md
@@ -0,0 +1,116 @@
+# Table Recognition
+
+## 1. pipeline
+The table recognition mainly contains three models
+1. Single line text detection-DB
+2. Single line text recognition-CRNN
+3. Table structure and cell coordinate prediction-RARE
+
+The table recognition flow chart is as follows
+
+
+
+1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result.
+2. The table structure and cell coordinates is predicted by RARE model.
+3. The recognition result of the cell is combined by the coordinates, recognition result of the single line and the coordinates of the cell.
+4. The cell recognition result and the table structure together construct the html string of the table.
+
+## 2. Performance
+We evaluated the algorithm on the PubTabNet[1] eval dataset, and the performance is as follows:
+
+
+|Method|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
+| --- | --- |
+| EDD[2] | 88.3 |
+| Ours | 93.32 |
+
+## 3. How to use
+
+### 3.1 quick start
+
+```python
+cd PaddleOCR/ppstructure
+
+# download model
+mkdir inference && cd inference
+# Download the detection model of the ultra-lightweight table English OCR model and unzip it
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar
+# Download the recognition model of the ultra-lightweight table English OCR model and unzip it
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
+# Download the ultra-lightweight English table inch model and unzip it
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+cd ..
+# run
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+```
+Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
+
+After running, the excel sheet of each picture will be saved in the directory specified by the output field
+
+### 3.2 Train
+
+In this chapter, we only introduce the training of the table structure model, For model training of [text detection](../../doc/doc_en/detection_en.md) and [text recognition](../../doc/doc_en/recognition_en.md), please refer to the corresponding documents
+
+#### data preparation
+The training data uses public data set [PubTabNet](https://arxiv.org/abs/1911.10683 ), Can be downloaded from the official [website](https://github.com/ibm-aur-nlp/PubTabNet) 。The PubTabNet data set contains about 500,000 images, as well as annotations in html format。
+
+#### Start training
+*If you are installing the cpu version of paddle, please modify the `use_gpu` field in the configuration file to false*
+```shell
+# single GPU training
+python3 tools/train.py -c configs/table/table_mv3.yml
+# multi-GPU training
+# Set the GPU ID used by the '--gpus' parameter.
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml
+```
+
+In the above instruction, use `-c` to select the training to use the `configs/table/table_mv3.yml` configuration file.
+For a detailed explanation of the configuration file, please refer to [config](../../doc/doc_en/config_en.md).
+
+#### load trained model and continue training
+
+If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
+
+```shell
+python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model
+```
+
+**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
+
+### 3.3 Eval
+
+The table uses [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
+```json
+{"PMC4289340_004_00.png": [
+ ["", "", "", "", "", "", " | ", "", " | ", "", " | ", " ", "", "", "", "", " | ", "", " | ", "", " | ", " ", "", " ", "", ""],
+ [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
+ [["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]]
+]}
+```
+In gt json, the key is the image name, the value is the corresponding gt, and gt is a list composed of four items, and each item is
+1. HTML string list of table structure
+2. The coordinates of each cell (not including the empty text in the cell)
+3. The text information in each cell (not including the empty text in the cell)
+
+Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
+```python
+cd PaddleOCR/ppstructure
+python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
+```
+
+If the PubLatNet eval dataset is used, it will be output
+```bash
+teds: 93.32
+```
+
+### 3.4 Inference
+
+```python
+cd PaddleOCR/ppstructure
+python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+```
+After running, the excel sheet of each picture will be saved in the directory specified by the output field
+
+Reference
+1. https://github.com/ibm-aur-nlp/PubTabNet
+2. https://arxiv.org/pdf/1911.10683
\ No newline at end of file
diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..2ded403c371984a447f94268d23ca1c6240cf432 100644
--- a/ppstructure/table/README_ch.md
+++ b/ppstructure/table/README_ch.md
@@ -0,0 +1,113 @@
+# 表格识别
+
+## 1. 表格识别 pipeline
+表格识别主要包含三个模型
+1. 单行文本检测-DB
+2. 单行文本识别-CRNN
+3. 表格结构和cell坐标预测-RARE
+
+具体流程图如下
+
+
+
+流程说明:
+
+1. 图片由单行文字检测模型检测到单行文字的坐标,然后送入识别模型拿到识别结果。
+2. 图片由表格结构和cell坐标预测模型拿到表格的结构信息和单元格的坐标信息。
+3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
+4. 单元格的识别结果和表格结构一起构造表格的html字符串。
+
+## 2. 性能
+我们在 PubTabNet[1] 评估数据集上对算法进行了评估,性能如下
+
+
+|算法|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
+| --- | --- |
+| EDD[2] | 88.3 |
+| Ours | 93.32 |
+
+## 3. 使用
+
+### 3.1 快速开始
+
+```python
+cd PaddleOCR/ppstructure
+
+# 下载模型
+mkdir inference && cd inference
+# 下载超轻量级表格英文OCR模型的检测模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar
+# 下载超轻量级表格英文OCR模型的识别模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
+# 下载超轻量级英文表格英寸模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+cd ..
+# 执行预测
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+```
+运行完成后,每张图片的excel表格会保存到output字段指定的目录下
+
+note: 上述模型是在 PubLayNet 数据集上训练的表格识别模型,仅支持英文扫描场景,如需识别其他场景需要自己训练模型后替换 `det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
+
+### 3.2 训练
+在这一章节中,我们仅介绍表格结构模型的训练,[文字检测](../../doc/doc_ch/detection.md)和[文字识别](../../doc/doc_ch/recognition.md)的模型训练请参考对应的文档。
+
+#### 数据准备
+训练数据使用公开数据集PubTabNet ([论文](https://arxiv.org/abs/1911.10683),[下载地址](https://github.com/ibm-aur-nlp/PubTabNet))。PubTabNet数据集包含约50万张表格数据的图像,以及图像对应的html格式的注释。
+
+#### 启动训练
+*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
+```shell
+# 单机单卡训练
+python3 tools/train.py -c configs/table/table_mv3.yml
+# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml
+```
+
+上述指令中,通过-c 选择训练使用configs/table/table_mv3.yml配置文件。有关配置文件的详细解释,请参考[链接](../../doc/doc_ch/config.md)。
+
+#### 断点训练
+
+如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
+```shell
+python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model
+```
+
+**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
+
+
+### 3.3 评估
+
+表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
+```json
+{"PMC4289340_004_00.png": [
+ ["", "", "", "", "", "", " | ", "", " | ", "", " | ", " ", "", "", "", "", " | ", "", " | ", "", " | ", " ", "", " ", "", ""],
+ [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
+ [["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]]
+]}
+```
+json 中,key为图片名,value为对应的gt,gt是一个由三个item组成的list,每个item分别为
+1. 表格结构的html字符串list
+2. 每个cell的坐标 (不包括cell里文字为空的)
+3. 每个cell里的文字信息 (不包括cell里文字为空的)
+
+准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
+```python
+cd PaddleOCR/ppstructure
+python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
+```
+如使用PubLatNet评估数据集,将会输出
+```bash
+teds: 93.32
+```
+
+### 3.4 预测
+
+```python
+cd PaddleOCR/ppstructure
+python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+```
+
+Reference
+1. https://github.com/ibm-aur-nlp/PubTabNet
+2. https://arxiv.org/pdf/1911.10683
\ No newline at end of file
diff --git a/ppstructure/table/__init__.py b/ppstructure/table/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d11e265597c7c8e39098a228108da3bb954b892
--- /dev/null
+++ b/ppstructure/table/__init__.py
@@ -0,0 +1,13 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py
new file mode 100755
index 0000000000000000000000000000000000000000..87b44d3d9792356ec1cdc65693392c288bf67448
--- /dev/null
+++ b/ppstructure/table/eval_table.py
@@ -0,0 +1,72 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import cv2
+import json
+from tqdm import tqdm
+from ppstructure.table.table_metric import TEDS
+from ppstructure.table.predict_table import TableSystem
+from ppstructure.utility import init_args
+from ppocr.utils.logging import get_logger
+
+logger = get_logger()
+
+
+def parse_args():
+ parser = init_args()
+ parser.add_argument("--gt_path", type=str)
+ return parser.parse_args()
+
+def main(gt_path, img_root, args):
+ teds = TEDS(n_jobs=16)
+
+ text_sys = TableSystem(args)
+ jsons_gt = json.load(open(gt_path)) # gt
+ pred_htmls = []
+ gt_htmls = []
+ for img_name in tqdm(jsons_gt):
+ # read image
+ img = cv2.imread(os.path.join(img_root,img_name))
+ pred_html = text_sys(img)
+ pred_htmls.append(pred_html)
+
+ gt_structures, gt_bboxes, gt_contents = jsons_gt[img_name]
+ gt_html, gt = get_gt_html(gt_structures, gt_contents)
+ gt_htmls.append(gt_html)
+ scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
+ logger.info('teds:', sum(scores) / len(scores))
+
+
+def get_gt_html(gt_structures, gt_contents):
+ end_html = []
+ td_index = 0
+ for tag in gt_structures:
+ if ' | ' in tag:
+ if gt_contents[td_index] != []:
+ end_html.extend(gt_contents[td_index])
+ end_html.append(tag)
+ td_index += 1
+ else:
+ end_html.append(tag)
+ return ''.join(end_html), end_html
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args.gt_path,args.image_dir, args)
diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py
new file mode 100755
index 0000000000000000000000000000000000000000..c3b56384403f5fd92a8db4b4bb378a6d55e5a76c
--- /dev/null
+++ b/ppstructure/table/matcher.py
@@ -0,0 +1,192 @@
+import json
+def distance(box_1, box_2):
+ x1, y1, x2, y2 = box_1
+ x3, y3, x4, y4 = box_2
+ dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
+ dis_2 = abs(x3 - x1) + abs(y3 - y1)
+ dis_3 = abs(x4- x2) + abs(y4 - y2)
+ return dis + min(dis_2, dis_3)
+
+def compute_iou(rec1, rec2):
+ """
+ computing IoU
+ :param rec1: (y0, x0, y1, x1), which reflects
+ (top, left, bottom, right)
+ :param rec2: (y0, x0, y1, x1)
+ :return: scala value of IoU
+ """
+ # computing area of each rectangles
+ S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+ S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+ # computing the sum_area
+ sum_area = S_rec1 + S_rec2
+
+ # find the each edge of intersect rectangle
+ left_line = max(rec1[1], rec2[1])
+ right_line = min(rec1[3], rec2[3])
+ top_line = max(rec1[0], rec2[0])
+ bottom_line = min(rec1[2], rec2[2])
+
+ # judge if there is an intersect
+ if left_line >= right_line or top_line >= bottom_line:
+ return 0.0
+ else:
+ intersect = (right_line - left_line) * (bottom_line - top_line)
+ return (intersect / (sum_area - intersect))*1.0
+
+
+
+def matcher_merge(ocr_bboxes, pred_bboxes):
+ all_dis = []
+ ious = []
+ matched = {}
+ for i, gt_box in enumerate(ocr_bboxes):
+ distances = []
+ for j, pred_box in enumerate(pred_bboxes):
+ # compute l1 distence and IOU between two boxes
+ distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
+ sorted_distances = distances.copy()
+ # select nearest cell
+ sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
+ if distances.index(sorted_distances[0]) not in matched.keys():
+ matched[distances.index(sorted_distances[0])] = [i]
+ else:
+ matched[distances.index(sorted_distances[0])].append(i)
+ return matched#, sum(ious) / len(ious)
+
+def complex_num(pred_bboxes):
+ complex_nums = []
+ for bbox in pred_bboxes:
+ distances = []
+ temp_ious = []
+ for pred_bbox in pred_bboxes:
+ if bbox != pred_bbox:
+ distances.append(distance(bbox, pred_bbox))
+ temp_ious.append(compute_iou(bbox, pred_bbox))
+ complex_nums.append(temp_ious[distances.index(min(distances))])
+ return sum(complex_nums) / len(complex_nums)
+
+def get_rows(pred_bboxes):
+ pre_bbox = pred_bboxes[0]
+ res = []
+ step = 0
+ for i in range(len(pred_bboxes)):
+ bbox = pred_bboxes[i]
+ if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
+ break
+ else:
+ res.append(bbox)
+ step += 1
+ for i in range(step):
+ pred_bboxes.pop(0)
+ return res, pred_bboxes
+def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
+ ys_1 = []
+ ys_2 = []
+ for box in pred_bboxes:
+ ys_1.append(box[1])
+ ys_2.append(box[3])
+ min_y_1 = sum(ys_1) / len(ys_1)
+ min_y_2 = sum(ys_2) / len(ys_2)
+ re_boxes = []
+ for box in pred_bboxes:
+ box[1] = min_y_1
+ box[3] = min_y_2
+ re_boxes.append(box)
+ return re_boxes
+
+def matcher_refine_row(gt_bboxes, pred_bboxes):
+ before_refine_pred_bboxes = pred_bboxes.copy()
+ pred_bboxes = []
+ while(len(before_refine_pred_bboxes) != 0):
+ row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
+ print(row_bboxes)
+ pred_bboxes.extend(refine_rows(row_bboxes))
+ all_dis = []
+ ious = []
+ matched = {}
+ for i, gt_box in enumerate(gt_bboxes):
+ distances = []
+ #temp_ious = []
+ for j, pred_box in enumerate(pred_bboxes):
+ distances.append(distance(gt_box, pred_box))
+ #temp_ious.append(compute_iou(gt_box, pred_box))
+ #all_dis.append(min(distances))
+ #ious.append(temp_ious[distances.index(min(distances))])
+ if distances.index(min(distances)) not in matched.keys():
+ matched[distances.index(min(distances))] = [i]
+ else:
+ matched[distances.index(min(distances))].append(i)
+ return matched#, sum(ious) / len(ious)
+
+
+
+#先挑选出一行,再进行匹配
+def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
+ gt_box_index = 0
+ delete_gt_bboxes = gt_bboxes.copy()
+ match_bboxes_ready = []
+ matched = {}
+ while(len(delete_gt_bboxes) != 0):
+ row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
+ row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
+ if len(pred_bboxes_rows) > 0:
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ print(row_bboxes)
+ for i, gt_box in enumerate(row_bboxes):
+ #print(gt_box)
+ pred_distances = []
+ distances = []
+ for pred_bbox in pred_bboxes:
+ pred_distances.append(distance(gt_box, pred_bbox))
+ for j, pred_box in enumerate(match_bboxes_ready):
+ distances.append(distance(gt_box, pred_box))
+ index = pred_distances.index(min(distances))
+ #print('index', index)
+ if index not in matched.keys():
+ matched[index] = [gt_box_index]
+ else:
+ matched[index].append(gt_box_index)
+ gt_box_index += 1
+ return matched
+
+def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
+ '''
+ gt_bboxes: 排序后
+ pred_bboxes:
+ '''
+ pre_bbox = gt_bboxes[0]
+ matched = {}
+ match_bboxes_ready = []
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ for i, gt_box in enumerate(gt_bboxes):
+
+ pred_distances = []
+ for pred_bbox in pred_bboxes:
+ pred_distances.append(distance(gt_box, pred_bbox))
+ distances = []
+ gap_pre = gt_box[1] - pre_bbox[1]
+ gap_pre_1 = gt_box[0] - pre_bbox[2]
+ #print(gap_pre, len(pred_bboxes_rows))
+ if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ if len(pred_bboxes_rows) == 1:
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
+ match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
+ if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
+ break
+ #print(match_bboxes_ready)
+ for j, pred_box in enumerate(match_bboxes_ready):
+ distances.append(distance(gt_box, pred_box))
+ index = pred_distances.index(min(distances))
+ #print(gt_box, index)
+ #match_bboxes_ready.pop(distances.index(min(distances)))
+ print(gt_box, match_bboxes_ready[distances.index(min(distances))])
+ if index not in matched.keys():
+ matched[index] = [i]
+ else:
+ matched[index].append(i)
+ pre_bbox = gt_box
+ return matched
diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py
new file mode 100755
index 0000000000000000000000000000000000000000..fc85327b3a446573259546d84c439f5f8e5b3ac7
--- /dev/null
+++ b/ppstructure/table/predict_structure.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import time
+
+import tools.infer.utility as utility
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppstructure.utility import parse_args
+
+logger = get_logger()
+
+
+class TableStructurer(object):
+ def __init__(self, args):
+ pre_process_list = [{
+ 'ResizeTableImage': {
+ 'max_len': args.table_max_len
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'PaddingTableImage': None
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image']
+ }
+ }]
+ postprocess_params = {
+ 'name': 'TableLabelDecode',
+ "character_type": args.table_char_type,
+ "character_dict_path": args.table_char_dict_path,
+ }
+
+ self.preprocess_op = create_operators(pre_process_list)
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 'table', logger)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img = data[0]
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+ starttime = time.time()
+
+ self.input_tensor.copy_from_cpu(img)
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+
+ preds = {}
+ preds['structure_probs'] = outputs[1]
+ preds['loc_preds'] = outputs[0]
+
+ post_result = self.postprocess_op(preds)
+
+ structure_str_list = post_result['structure_str_list']
+ res_loc = post_result['res_loc']
+ imgh, imgw = ori_im.shape[0:2]
+ res_loc_final = []
+ for rno in range(len(res_loc[0])):
+ x0, y0, x1, y1 = res_loc[0][rno]
+ left = max(int(imgw * x0), 0)
+ top = max(int(imgh * y0), 0)
+ right = min(int(imgw * x1), imgw - 1)
+ bottom = min(int(imgh * y1), imgh - 1)
+ res_loc_final.append([left, top, right, bottom])
+
+ structure_str_list = structure_str_list[0][:-1]
+ structure_str_list = ['', '', ''] + structure_str_list + [' ', '', '']
+
+ elapse = time.time() - starttime
+ return (structure_str_list, res_loc_final), elapse
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ table_structurer = TableStructurer(args)
+ count = 0
+ total_time = 0
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ structure_res, elapse = table_structurer(img)
+
+ logger.info("result: {}".format(structure_res))
+
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..352ae84de1f435f91258cf0ced4dce9345de1220
--- /dev/null
+++ b/ppstructure/table/predict_table.py
@@ -0,0 +1,221 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import subprocess
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+import cv2
+import copy
+import numpy as np
+import time
+import tools.infer.predict_rec as predict_rec
+import tools.infer.predict_det as predict_det
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.utils.logging import get_logger
+from ppstructure.table.matcher import distance, compute_iou
+from ppstructure.utility import parse_args
+import ppstructure.table.predict_structure as predict_strture
+
+logger = get_logger()
+
+
+def expand(pix, det_box, shape):
+ x0, y0, x1, y1 = det_box
+ # print(shape)
+ h, w, c = shape
+ tmp_x0 = x0 - pix
+ tmp_x1 = x1 + pix
+ tmp_y0 = y0 - pix
+ tmp_y1 = y1 + pix
+ x0_ = tmp_x0 if tmp_x0 >= 0 else 0
+ x1_ = tmp_x1 if tmp_x1 <= w else w
+ y0_ = tmp_y0 if tmp_y0 >= 0 else 0
+ y1_ = tmp_y1 if tmp_y1 <= h else h
+ return x0_, y0_, x1_, y1_
+
+
+class TableSystem(object):
+ def __init__(self, args, text_detector=None, text_recognizer=None):
+ self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
+ self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
+ self.table_structurer = predict_strture.TableStructurer(args)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ structure_res, elapse = self.table_structurer(copy.deepcopy(img))
+ dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
+ dt_boxes = sorted_boxes(dt_boxes)
+
+ r_boxes = []
+ for box in dt_boxes:
+ x_min = box[:, 0].min() - 1
+ x_max = box[:, 0].max() + 1
+ y_min = box[:, 1].min() - 1
+ y_max = box[:, 1].max() + 1
+ box = [x_min, y_min, x_max, y_max]
+ r_boxes.append(box)
+ dt_boxes = np.array(r_boxes)
+
+ logger.debug("dt_boxes num : {}, elapse : {}".format(
+ len(dt_boxes), elapse))
+ if dt_boxes is None:
+ return None, None
+ img_crop_list = []
+
+ for i in range(len(dt_boxes)):
+ det_box = dt_boxes[i]
+ x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
+ text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
+ img_crop_list.append(text_rect)
+ rec_res, elapse = self.text_recognizer(img_crop_list)
+ logger.debug("rec_res num : {}, elapse : {}".format(
+ len(rec_res), elapse))
+
+ pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
+ return pred_html
+
+ def rebuild_table(self, structure_res, dt_boxes, rec_res):
+ pred_structures, pred_bboxes = structure_res
+ matched_index = self.match_result(dt_boxes, pred_bboxes)
+ pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
+ return pred_html, pred
+
+ def match_result(self, dt_boxes, pred_bboxes):
+ matched = {}
+ for i, gt_box in enumerate(dt_boxes):
+ # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
+ distances = []
+ for j, pred_box in enumerate(pred_bboxes):
+ distances.append(
+ (distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU
+ sorted_distances = distances.copy()
+ # 根据距离和IOU挑选最"近"的cell
+ sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
+ if distances.index(sorted_distances[0]) not in matched.keys():
+ matched[distances.index(sorted_distances[0])] = [i]
+ else:
+ matched[distances.index(sorted_distances[0])].append(i)
+ return matched
+
+ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
+ end_html = []
+ td_index = 0
+ for tag in pred_structures:
+ if ' | ' in tag:
+ if td_index in matched_index.keys():
+ b_with = False
+ if '' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1:
+ b_with = True
+ end_html.extend('')
+ for i, td_index_index in enumerate(matched_index[td_index]):
+ content = ocr_contents[td_index_index][0]
+ if len(matched_index[td_index]) > 1:
+ if len(content) == 0:
+ continue
+ if content[0] == ' ':
+ content = content[1:]
+ if '' in content:
+ content = content[3:]
+ if '' in content:
+ content = content[:-4]
+ if len(content) == 0:
+ continue
+ if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]:
+ content += ' '
+ end_html.extend(content)
+ if b_with:
+ end_html.extend('')
+
+ end_html.append(tag)
+ td_index += 1
+ else:
+ end_html.append(tag)
+ return ''.join(end_html), end_html
+
+
+def sorted_boxes(dt_boxes):
+ """
+ Sort text boxes in order from top to bottom, left to right
+ args:
+ dt_boxes(array):detected text boxes with shape [4, 2]
+ return:
+ sorted boxes(array) with shape [4, 2]
+ """
+ num_boxes = dt_boxes.shape[0]
+ sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+ _boxes = list(sorted_boxes)
+
+ for i in range(num_boxes - 1):
+ if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
+ (_boxes[i + 1][0][0] < _boxes[i][0][0]):
+ tmp = _boxes[i]
+ _boxes[i] = _boxes[i + 1]
+ _boxes[i + 1] = tmp
+ return _boxes
+
+
+def to_excel(html_table, excel_path):
+ from tablepyxl import tablepyxl
+ tablepyxl.document_to_xl(html_table, excel_path)
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ image_file_list = image_file_list[args.process_id::args.total_process_num]
+ os.makedirs(args.output, exist_ok=True)
+
+ text_sys = TableSystem(args)
+ img_num = len(image_file_list)
+ for i, image_file in enumerate(image_file_list):
+ logger.info("[{}/{}] {}".format(i, img_num, image_file))
+ img, flag = check_and_read_gif(image_file)
+ excel_path = os.path.join(args.output, os.path.basename(image_file).split('.')[0] + '.xlsx')
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.error("error in loading image:{}".format(image_file))
+ continue
+ starttime = time.time()
+ pred_html = text_sys(img)
+
+ to_excel(pred_html, excel_path)
+ logger.info('excel saved to {}'.format(excel_path))
+ logger.info(pred_html)
+ elapse = time.time() - starttime
+ logger.info("Predict time : {:.3f}s".format(elapse))
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.use_mp:
+ p_list = []
+ total_process_num = args.total_process_num
+ for process_id in range(total_process_num):
+ cmd = [sys.executable, "-u"] + sys.argv + [
+ "--process_id={}".format(process_id),
+ "--use_mp={}".format(False)
+ ]
+ p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
+ p_list.append(p)
+ for p in p_list:
+ p.wait()
+ else:
+ main(args)
diff --git a/ppstructure/table/table_metric/__init__.py b/ppstructure/table/table_metric/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..de2d307430f68881ece1e41357d3b2f423e07ddd
--- /dev/null
+++ b/ppstructure/table/table_metric/__init__.py
@@ -0,0 +1,16 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ['TEDS']
+from .table_metric import TEDS
\ No newline at end of file
diff --git a/ppstructure/table/table_metric/parallel.py b/ppstructure/table/table_metric/parallel.py
new file mode 100755
index 0000000000000000000000000000000000000000..f7326a1f506ca5fb7b3e97b0d077dc016e7eb7c7
--- /dev/null
+++ b/ppstructure/table/table_metric/parallel.py
@@ -0,0 +1,51 @@
+from tqdm import tqdm
+from concurrent.futures import ProcessPoolExecutor, as_completed
+
+
+def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
+ """
+ A parallel version of the map function with a progress bar.
+ Args:
+ array (array-like): An array to iterate over.
+ function (function): A python function to apply to the elements of array
+ n_jobs (int, default=16): The number of cores to use
+ use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
+ keyword arguments to function
+ front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
+ Useful for catching bugs
+ Returns:
+ [function(array[0]), function(array[1]), ...]
+ """
+ # We run the first few iterations serially to catch bugs
+ if front_num > 0:
+ front = [function(**a) if use_kwargs else function(a)
+ for a in array[:front_num]]
+ else:
+ front = []
+ # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
+ if n_jobs == 1:
+ return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
+ # Assemble the workers
+ with ProcessPoolExecutor(max_workers=n_jobs) as pool:
+ # Pass the elements of array into function
+ if use_kwargs:
+ futures = [pool.submit(function, **a) for a in array[front_num:]]
+ else:
+ futures = [pool.submit(function, a) for a in array[front_num:]]
+ kwargs = {
+ 'total': len(futures),
+ 'unit': 'it',
+ 'unit_scale': True,
+ 'leave': True
+ }
+ # Print out the progress as tasks complete
+ for f in tqdm(as_completed(futures), **kwargs):
+ pass
+ out = []
+ # Get the results from the futures.
+ for i, future in tqdm(enumerate(futures)):
+ try:
+ out.append(future.result())
+ except Exception as e:
+ out.append(e)
+ return front + out
diff --git a/ppstructure/table/table_metric/table_metric.py b/ppstructure/table/table_metric/table_metric.py
new file mode 100755
index 0000000000000000000000000000000000000000..9aca98ad785d4614a803fa5a277a6e4a27b3b078
--- /dev/null
+++ b/ppstructure/table/table_metric/table_metric.py
@@ -0,0 +1,247 @@
+# Copyright 2020 IBM
+# Author: peter.zhong@au1.ibm.com
+#
+# This is free software; you can redistribute it and/or modify
+# it under the terms of the Apache 2.0 License.
+#
+# This software is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# Apache 2.0 License for more details.
+
+import distance
+from apted import APTED, Config
+from apted.helpers import Tree
+from lxml import etree, html
+from collections import deque
+from .parallel import parallel_process
+from tqdm import tqdm
+
+
+class TableTree(Tree):
+ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
+ self.tag = tag
+ self.colspan = colspan
+ self.rowspan = rowspan
+ self.content = content
+ self.children = list(children)
+
+ def bracket(self):
+ """Show tree using brackets notation"""
+ if self.tag == 'td':
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
+ (self.tag, self.colspan, self.rowspan, self.content)
+ else:
+ result = '"tag": %s' % self.tag
+ for child in self.children:
+ result += child.bracket()
+ return "{{{}}}".format(result)
+
+
+class CustomConfig(Config):
+ @staticmethod
+ def maximum(*sequences):
+ """Get maximum possible value
+ """
+ return max(map(len, sequences))
+
+ def normalized_distance(self, *sequences):
+ """Get distance from 0 to 1
+ """
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
+
+ def rename(self, node1, node2):
+ """Compares attributes of trees"""
+ #print(node1.tag)
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
+ return 1.
+ if node1.tag == 'td':
+ if node1.content or node2.content:
+ #print(node1.content, )
+ return self.normalized_distance(node1.content, node2.content)
+ return 0.
+
+
+
+class CustomConfig_del_short(Config):
+ @staticmethod
+ def maximum(*sequences):
+ """Get maximum possible value
+ """
+ return max(map(len, sequences))
+
+ def normalized_distance(self, *sequences):
+ """Get distance from 0 to 1
+ """
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
+
+ def rename(self, node1, node2):
+ """Compares attributes of trees"""
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
+ return 1.
+ if node1.tag == 'td':
+ if node1.content or node2.content:
+ #print('before')
+ #print(node1.content, node2.content)
+ #print('after')
+ node1_content = node1.content
+ node2_content = node2.content
+ if len(node1_content) < 3:
+ node1_content = ['####']
+ if len(node2_content) < 3:
+ node2_content = ['####']
+ return self.normalized_distance(node1_content, node2_content)
+ return 0.
+
+class CustomConfig_del_block(Config):
+ @staticmethod
+ def maximum(*sequences):
+ """Get maximum possible value
+ """
+ return max(map(len, sequences))
+
+ def normalized_distance(self, *sequences):
+ """Get distance from 0 to 1
+ """
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
+
+ def rename(self, node1, node2):
+ """Compares attributes of trees"""
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
+ return 1.
+ if node1.tag == 'td':
+ if node1.content or node2.content:
+
+ node1_content = node1.content
+ node2_content = node2.content
+ while ' ' in node1_content:
+ print(node1_content.index(' '))
+ node1_content.pop(node1_content.index(' '))
+ while ' ' in node2_content:
+ print(node2_content.index(' '))
+ node2_content.pop(node2_content.index(' '))
+ return self.normalized_distance(node1_content, node2_content)
+ return 0.
+
+class TEDS(object):
+ ''' Tree Edit Distance basead Similarity
+ '''
+
+ def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
+ assert isinstance(n_jobs, int) and (
+ n_jobs >= 1), 'n_jobs must be an integer greather than 1'
+ self.structure_only = structure_only
+ self.n_jobs = n_jobs
+ self.ignore_nodes = ignore_nodes
+ self.__tokens__ = []
+
+ def tokenize(self, node):
+ ''' Tokenizes table cells
+ '''
+ self.__tokens__.append('<%s>' % node.tag)
+ if node.text is not None:
+ self.__tokens__ += list(node.text)
+ for n in node.getchildren():
+ self.tokenize(n)
+ if node.tag != 'unk':
+ self.__tokens__.append('%s>' % node.tag)
+ if node.tag != 'td' and node.tail is not None:
+ self.__tokens__ += list(node.tail)
+
+ def load_html_tree(self, node, parent=None):
+ ''' Converts HTML tree to the format required by apted
+ '''
+ global __tokens__
+ if node.tag == 'td':
+ if self.structure_only:
+ cell = []
+ else:
+ self.__tokens__ = []
+ self.tokenize(node)
+ cell = self.__tokens__[1:-1].copy()
+ new_node = TableTree(node.tag,
+ int(node.attrib.get('colspan', '1')),
+ int(node.attrib.get('rowspan', '1')),
+ cell, *deque())
+ else:
+ new_node = TableTree(node.tag, None, None, None, *deque())
+ if parent is not None:
+ parent.children.append(new_node)
+ if node.tag != 'td':
+ for n in node.getchildren():
+ self.load_html_tree(n, new_node)
+ if parent is None:
+ return new_node
+
+ def evaluate(self, pred, true):
+ ''' Computes TEDS score between the prediction and the ground truth of a
+ given sample
+ '''
+ if (not pred) or (not true):
+ return 0.0
+ parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
+ pred = html.fromstring(pred, parser=parser)
+ true = html.fromstring(true, parser=parser)
+ if pred.xpath('body/table') and true.xpath('body/table'):
+ pred = pred.xpath('body/table')[0]
+ true = true.xpath('body/table')[0]
+ if self.ignore_nodes:
+ etree.strip_tags(pred, *self.ignore_nodes)
+ etree.strip_tags(true, *self.ignore_nodes)
+ n_nodes_pred = len(pred.xpath(".//*"))
+ n_nodes_true = len(true.xpath(".//*"))
+ n_nodes = max(n_nodes_pred, n_nodes_true)
+ tree_pred = self.load_html_tree(pred)
+ tree_true = self.load_html_tree(true)
+ distance = APTED(tree_pred, tree_true,
+ CustomConfig()).compute_edit_distance()
+ return 1.0 - (float(distance) / n_nodes)
+ else:
+ return 0.0
+
+ def batch_evaluate(self, pred_json, true_json):
+ ''' Computes TEDS score between the prediction and the ground truth of
+ a batch of samples
+ @params pred_json: {'FILENAME': 'HTML CODE', ...}
+ @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
+ @output: {'FILENAME': 'TEDS SCORE', ...}
+ '''
+ samples = true_json.keys()
+ if self.n_jobs == 1:
+ scores = [self.evaluate(pred_json.get(
+ filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
+ else:
+ inputs = [{'pred': pred_json.get(
+ filename, ''), 'true': true_json[filename]['html']} for filename in samples]
+ scores = parallel_process(
+ inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
+ scores = dict(zip(samples, scores))
+ return scores
+
+ def batch_evaluate_html(self, pred_htmls, true_htmls):
+ ''' Computes TEDS score between the prediction and the ground truth of
+ a batch of samples
+ '''
+ if self.n_jobs == 1:
+ scores = [self.evaluate(pred_html, true_html) for (
+ pred_html, true_html) in zip(pred_htmls, true_htmls)]
+ else:
+ inputs = [{"pred": pred_html, "true": true_html} for(
+ pred_html, true_html) in zip(pred_htmls, true_htmls)]
+
+ scores = parallel_process(
+ inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
+ return scores
+
+
+if __name__ == '__main__':
+ import json
+ import pprint
+ with open('sample_pred.json') as fp:
+ pred_json = json.load(fp)
+ with open('sample_gt.json') as fp:
+ true_json = json.load(fp)
+ teds = TEDS(n_jobs=4)
+ scores = teds.batch_evaluate(pred_json, true_json)
+ pp = pprint.PrettyPrinter()
+ pp.pprint(scores)
diff --git a/ppstructure/table/tablepyxl/__init__.py b/ppstructure/table/tablepyxl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc0085071cf4497b01fc648e7c38f2e8d9d173d0
--- /dev/null
+++ b/ppstructure/table/tablepyxl/__init__.py
@@ -0,0 +1,13 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
\ No newline at end of file
diff --git a/ppstructure/table/tablepyxl/style.py b/ppstructure/table/tablepyxl/style.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebd794b1b47d7f9e4f9294dde7330f592d613656
--- /dev/null
+++ b/ppstructure/table/tablepyxl/style.py
@@ -0,0 +1,283 @@
+# This is where we handle translating css styles into openpyxl styles
+# and cascading those from parent to child in the dom.
+
+from openpyxl.cell import cell
+from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color
+from openpyxl.styles.fills import FILL_SOLID
+from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
+from openpyxl.styles.colors import BLACK
+
+FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy'
+
+
+def colormap(color):
+ """
+ Convenience for looking up known colors
+ """
+ cmap = {'black': BLACK}
+ return cmap.get(color, color)
+
+
+def style_string_to_dict(style):
+ """
+ Convert css style string to a python dictionary
+ """
+ def clean_split(string, delim):
+ return (s.strip() for s in string.split(delim))
+ styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
+ return dict(styles)
+
+
+def get_side(style, name):
+ return {'border_style': style.get('border-{}-style'.format(name)),
+ 'color': colormap(style.get('border-{}-color'.format(name)))}
+
+known_styles = {}
+
+
+def style_dict_to_named_style(style_dict, number_format=None):
+ """
+ Change css style (stored in a python dictionary) to openpyxl NamedStyle
+ """
+
+ style_and_format_string = str({
+ 'style_dict': style_dict,
+ 'parent': style_dict.parent,
+ 'number_format': number_format,
+ })
+
+ if style_and_format_string not in known_styles:
+ # Font
+ font = Font(bold=style_dict.get('font-weight') == 'bold',
+ color=style_dict.get_color('color', None),
+ size=style_dict.get('font-size'))
+
+ # Alignment
+ alignment = Alignment(horizontal=style_dict.get('text-align', 'general'),
+ vertical=style_dict.get('vertical-align'),
+ wrap_text=style_dict.get('white-space', 'nowrap') == 'normal')
+
+ # Fill
+ bg_color = style_dict.get_color('background-color')
+ fg_color = style_dict.get_color('foreground-color', Color())
+ fill_type = style_dict.get('fill-type')
+ if bg_color and bg_color != 'transparent':
+ fill = PatternFill(fill_type=fill_type or FILL_SOLID,
+ start_color=bg_color,
+ end_color=fg_color)
+ else:
+ fill = PatternFill()
+
+ # Border
+ border = Border(left=Side(**get_side(style_dict, 'left')),
+ right=Side(**get_side(style_dict, 'right')),
+ top=Side(**get_side(style_dict, 'top')),
+ bottom=Side(**get_side(style_dict, 'bottom')),
+ diagonal=Side(**get_side(style_dict, 'diagonal')),
+ diagonal_direction=None,
+ outline=Side(**get_side(style_dict, 'outline')),
+ vertical=None,
+ horizontal=None)
+
+ name = 'Style {}'.format(len(known_styles) + 1)
+
+ pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border,
+ number_format=number_format)
+
+ known_styles[style_and_format_string] = pyxl_style
+
+ return known_styles[style_and_format_string]
+
+
+class StyleDict(dict):
+ """
+ It's like a dictionary, but it looks for items in the parent dictionary
+ """
+ def __init__(self, *args, **kwargs):
+ self.parent = kwargs.pop('parent', None)
+ super(StyleDict, self).__init__(*args, **kwargs)
+
+ def __getitem__(self, item):
+ if item in self:
+ return super(StyleDict, self).__getitem__(item)
+ elif self.parent:
+ return self.parent[item]
+ else:
+ raise KeyError('{} not found'.format(item))
+
+ def __hash__(self):
+ return hash(tuple([(k, self.get(k)) for k in self._keys()]))
+
+ # Yielding the keys avoids creating unnecessary data structures
+ # and happily works with both python2 and python3 where the
+ # .keys() method is a dictionary_view in python3 and a list in python2.
+ def _keys(self):
+ yielded = set()
+ for k in self.keys():
+ yielded.add(k)
+ yield k
+ if self.parent:
+ for k in self.parent._keys():
+ if k not in yielded:
+ yielded.add(k)
+ yield k
+
+ def get(self, k, d=None):
+ try:
+ return self[k]
+ except KeyError:
+ return d
+
+ def get_color(self, k, d=None):
+ """
+ Strip leading # off colors if necessary
+ """
+ color = self.get(k, d)
+ if hasattr(color, 'startswith') and color.startswith('#'):
+ color = color[1:]
+ if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
+ color = ''.join(2 * c for c in color)
+ return color
+
+
+class Element(object):
+ """
+ Our base class for representing an html element along with a cascading style.
+ The element is created along with a parent so that the StyleDict that we store
+ can point to the parent's StyleDict.
+ """
+ def __init__(self, element, parent=None):
+ self.element = element
+ self.number_format = None
+ parent_style = parent.style_dict if parent else None
+ self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style)
+ self._style_cache = None
+
+ def style(self):
+ """
+ Turn the css styles for this element into an openpyxl NamedStyle.
+ """
+ if not self._style_cache:
+ self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format)
+ return self._style_cache
+
+ def get_dimension(self, dimension_key):
+ """
+ Extracts the dimension from the style dict of the Element and returns it as a float.
+ """
+ dimension = self.style_dict.get(dimension_key)
+ if dimension:
+ if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']:
+ dimension = dimension[:-2]
+ dimension = float(dimension)
+ return dimension
+
+
+class Table(Element):
+ """
+ The concrete implementations of Elements are semantically named for the types of elements we are interested in.
+ This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
+ allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
+ """
+ def __init__(self, table):
+ """
+ takes an html table object (from lxml)
+ """
+ super(Table, self).__init__(table)
+ table_head = table.find('thead')
+ self.head = TableHead(table_head, parent=self) if table_head is not None else None
+ table_body = table.find('tbody')
+ self.body = TableBody(table_body if table_body is not None else table, parent=self)
+
+
+class TableHead(Element):
+ """
+ This class maps to the `` element of the html table.
+ """
+ def __init__(self, head, parent=None):
+ super(TableHead, self).__init__(head, parent=parent)
+ self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')]
+
+
+class TableBody(Element):
+ """
+ This class maps to the ` | ` element of the html table.
+ """
+ def __init__(self, body, parent=None):
+ super(TableBody, self).__init__(body, parent=parent)
+ self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')]
+
+
+class TableRow(Element):
+ """
+ This class maps to the `` element of the html table.
+ """
+ def __init__(self, tr, parent=None):
+ super(TableRow, self).__init__(tr, parent=parent)
+ self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')]
+
+
+def element_to_string(el):
+ return _element_to_string(el).strip()
+
+
+def _element_to_string(el):
+ string = ''
+
+ for x in el.iterchildren():
+ string += '\n' + _element_to_string(x)
+
+ text = el.text.strip() if el.text else ''
+ tail = el.tail.strip() if el.tail else ''
+
+ return text + string + '\n' + tail
+
+
+class TableCell(Element):
+ """
+ This class maps to the `` element of the html table.
+ """
+ CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE',
+ 'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'}
+
+ def __init__(self, cell, parent=None):
+ super(TableCell, self).__init__(cell, parent=parent)
+ self.value = element_to_string(cell)
+ self.number_format = self.get_number_format()
+
+ def data_type(self):
+ cell_types = self.CELL_TYPES & set(self.element.get('class', '').split())
+ if cell_types:
+ if 'TYPE_FORMULA' in cell_types:
+ # Make sure TYPE_FORMULA takes precedence over the other classes in the set.
+ cell_type = 'TYPE_FORMULA'
+ elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}:
+ cell_type = 'TYPE_NUMERIC'
+ else:
+ cell_type = cell_types.pop()
+ else:
+ cell_type = 'TYPE_STRING'
+ return getattr(cell, cell_type)
+
+ def get_number_format(self):
+ if 'TYPE_CURRENCY' in self.element.get('class', '').split():
+ return FORMAT_CURRENCY_USD_SIMPLE
+ if 'TYPE_INTEGER' in self.element.get('class', '').split():
+ return '#,##0'
+ if 'TYPE_PERCENTAGE' in self.element.get('class', '').split():
+ return FORMAT_PERCENTAGE
+ if 'TYPE_DATE' in self.element.get('class', '').split():
+ return FORMAT_DATE_MMDDYYYY
+ if self.data_type() == cell.TYPE_NUMERIC:
+ try:
+ int(self.value)
+ except ValueError:
+ return '#,##0.##'
+ else:
+ return '#,##0'
+
+ def format(self, cell):
+ cell.style = self.style()
+ data_type = self.data_type()
+ if data_type:
+ cell.data_type = data_type
\ No newline at end of file
diff --git a/ppstructure/table/tablepyxl/tablepyxl.py b/ppstructure/table/tablepyxl/tablepyxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba3cc0fc499fccd93ffe3993a99296bc6603ed8a
--- /dev/null
+++ b/ppstructure/table/tablepyxl/tablepyxl.py
@@ -0,0 +1,118 @@
+# Do imports like python3 so our package works for 2 and 3
+from __future__ import absolute_import
+
+from lxml import html
+from openpyxl import Workbook
+from openpyxl.utils import get_column_letter
+from premailer import Premailer
+from tablepyxl.style import Table
+
+
+def string_to_int(s):
+ if s.isdigit():
+ return int(s)
+ return 0
+
+
+def get_Tables(doc):
+ tree = html.fromstring(doc)
+ comments = tree.xpath('//comment()')
+ for comment in comments:
+ comment.drop_tag()
+ return [Table(table) for table in tree.xpath('//table')]
+
+
+def write_rows(worksheet, elem, row, column=1):
+ """
+ Writes every tr child element of elem to a row in the worksheet
+ returns the next row after all rows are written
+ """
+ from openpyxl.cell.cell import MergedCell
+
+ initial_column = column
+ for table_row in elem.rows:
+ for table_cell in table_row.cells:
+ cell = worksheet.cell(row=row, column=column)
+ while isinstance(cell, MergedCell):
+ column += 1
+ cell = worksheet.cell(row=row, column=column)
+
+ colspan = string_to_int(table_cell.element.get("colspan", "1"))
+ rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
+ if rowspan > 1 or colspan > 1:
+ worksheet.merge_cells(start_row=row, start_column=column,
+ end_row=row + rowspan - 1, end_column=column + colspan - 1)
+
+ cell.value = table_cell.value
+ table_cell.format(cell)
+ min_width = table_cell.get_dimension('min-width')
+ max_width = table_cell.get_dimension('max-width')
+
+ if colspan == 1:
+ # Initially, when iterating for the first time through the loop, the width of all the cells is None.
+ # As we start filling in contents, the initial width of the cell (which can be retrieved by:
+ # worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
+ # cell in the same column (i.e. width of A2 = width of A1)
+ width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2)
+ if max_width and width > max_width:
+ width = max_width
+ elif min_width and width < min_width:
+ width = min_width
+ worksheet.column_dimensions[get_column_letter(column)].width = width
+ column += colspan
+ row += 1
+ column = initial_column
+ return row
+
+
+def table_to_sheet(table, wb):
+ """
+ Takes a table and workbook and writes the table to a new sheet.
+ The sheet title will be the same as the table attribute name.
+ """
+ ws = wb.create_sheet(title=table.element.get('name'))
+ insert_table(table, ws, 1, 1)
+
+
+def document_to_workbook(doc, wb=None, base_url=None):
+ """
+ Takes a string representation of an html document and writes one sheet for
+ every table in the document.
+ The workbook is returned
+ """
+ if not wb:
+ wb = Workbook()
+ wb.remove(wb.active)
+
+ inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform()
+ tables = get_Tables(inline_styles_doc)
+
+ for table in tables:
+ table_to_sheet(table, wb)
+
+ return wb
+
+
+def document_to_xl(doc, filename, base_url=None):
+ """
+ Takes a string representation of an html document and writes one sheet for
+ every table in the document. The workbook is written out to a file called filename
+ """
+ wb = document_to_workbook(doc, base_url=base_url)
+ wb.save(filename)
+
+
+def insert_table(table, worksheet, column, row):
+ if table.head:
+ row = write_rows(worksheet, table.head, row, column)
+ if table.body:
+ row = write_rows(worksheet, table.body, row, column)
+
+
+def insert_table_at_cell(table, cell):
+ """
+ Inserts a table at the location of an openpyxl Cell object.
+ """
+ ws = cell.parent
+ column, row = cell.column, cell.row
+ insert_table(table, ws, column, row)
\ No newline at end of file
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d9fa76d0ada58e363243c114519d001de3fbf2a
--- /dev/null
+++ b/ppstructure/utility.py
@@ -0,0 +1,52 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from PIL import Image
+import numpy as np
+from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args
+
+
+def init_args():
+ parser = infer_args()
+
+ # params for output
+ parser.add_argument("--output", type=str, default='./output/table')
+ # params for table structure
+ parser.add_argument("--table_max_len", type=int, default=488)
+ parser.add_argument("--table_model_dir", type=str)
+ parser.add_argument("--table_char_type", type=str, default='en')
+ parser.add_argument("--table_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
+ parser.add_argument("--layout_path_model", type=str, default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
+ return parser
+
+
+def parse_args():
+ parser = init_args()
+ return parser.parse_args()
+
+
+def draw_structure_result(image, result, font_path):
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ boxes, txts, scores = [], [], []
+ for region in result:
+ if region['type'] == 'Table':
+ pass
+ else:
+ for box, rec_res in zip(region['res'][0], region['res'][1]):
+ boxes.append(np.array(box).reshape(-1, 2))
+ txts.append(rec_res[0])
+ scores.append(rec_res[1])
+ im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
+ return im_show
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index d96b3a9d84d30b221e3e0d4b302d634c1681272f..351d409092a1f387b720c3ff2d43889170f320a7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,4 @@ tqdm
numpy
visualdl
python-Levenshtein
-opencv-contrib-python==4.2.0.32
\ No newline at end of file
+opencv-contrib-python==4.4.0.46
\ No newline at end of file
diff --git a/setup.py b/setup.py
index a1ddbbb6d6d0c2657bb699a72bde75ef07ab3a94..7d4d871d89defcf832910c60f18b094f10ba11db 100644
--- a/setup.py
+++ b/setup.py
@@ -14,6 +14,7 @@
from setuptools import setup
from io import open
+from paddleocr import VERSION
with open('requirements.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
@@ -32,7 +33,7 @@ setup(
package_dir={'paddleocr': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
- version='2.0.6',
+ version=VERSION,
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
diff --git a/tests/ocr_det_params.txt b/tests/ocr_det_params.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6aff66c6aa8591c9f48c81cf857809f956a3cda2
--- /dev/null
+++ b/tests/ocr_det_params.txt
@@ -0,0 +1,52 @@
+===========================train_params===========================
+model_name:ocr_det
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_infer=2|whole_train_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
+null:null
+##
+trainer:norm_train|pact_train
+norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
+pact_train:deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c configs/det/det_mv3_db.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.pretrained_model:
+norm_export:tools/export_model.py -c configs/det/det_mv3_db.yml -o
+quant_export:deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o
+fpgm_export:deploy/slim/prune/export_prune_model.py
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
+infer_export:null
+infer_quant:False
+inference:tools/infer/predict_det.py
+--use_gpu:True|False
+--enable_mkldnn:True|False
+--cpu_threads:1|6
+--rec_batch_num:1
+--use_tensorrt:False|True
+--precision:fp32|fp16|int8
+--det_model_dir:
+--image_dir:./inference/ch_det_data_50/all-sum-510/
+--save_log_path:null
+--benchmark:True
+null:null
+
diff --git a/tests/ocr_rec_params.txt b/tests/ocr_rec_params.txt
new file mode 100644
index 0000000000000000000000000000000000000000..71d12f90b3bda128c3f6047c6740911dac417954
--- /dev/null
+++ b/tests/ocr_rec_params.txt
@@ -0,0 +1,51 @@
+===========================train_params===========================
+model_name:ocr_rec
+python:python3.7
+gpu_list:0|2,3
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_infer=2|whole_train_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_infer=128|whole_train_infer=128
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./train_data/ic15_data/train
+null:null
+##
+trainer:norm_train|pact_train
+norm_train:tools/train.py -c configs/rec/rec_icdar15_train.yml -o
+pact_train:deploy/slim/quantization/quant.py -c configs/rec/rec_icdar15_train.yml -o
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c configs/rec/rec_icdar15_train.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.pretrained_model:
+norm_export:tools/export_model.py -c configs/rec/rec_icdar15_train.yml -o
+quant_export:deploy/slim/quantization/export_model.py -c configs/rec/rec_icdar15_train.yml -o
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:./inference/ch_ppocr_mobile_v2.0_rec_infer/
+infer_export:null
+infer_quant:False
+inference:tools/infer/predict_rec.py
+--use_gpu:True|False
+--enable_mkldnn:True|False
+--cpu_threads:1|6
+--rec_batch_num:1
+--use_tensorrt:True|False
+--precision:fp32|fp16|int8
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
diff --git a/tests/prepare.sh b/tests/prepare.sh
new file mode 100644
index 0000000000000000000000000000000000000000..418e5661ad0f315bc60b8fda37742c115b395b7c
--- /dev/null
+++ b/tests/prepare.sh
@@ -0,0 +1,76 @@
+#!/bin/bash
+FILENAME=$1
+# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer']
+MODE=$2
+
+dataline=$(cat ${FILENAME})
+
+# parser params
+IFS=$'\n'
+lines=(${dataline})
+function func_parser_key(){
+ strs=$1
+ IFS=":"
+ array=(${strs})
+ tmp=${array[0]}
+ echo ${tmp}
+}
+function func_parser_value(){
+ strs=$1
+ IFS=":"
+ array=(${strs})
+ tmp=${array[1]}
+ echo ${tmp}
+}
+IFS=$'\n'
+# The training params
+model_name=$(func_parser_value "${lines[1]}")
+
+trainer_list=$(func_parser_value "${lines[14]}")
+
+# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
+MODE=$2
+
+if [ ${MODE} = "lite_train_infer" ];then
+ # pretrain lite train data
+ wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
+ rm -rf ./train_data/icdar2015
+ rm -rf ./train_data/ic15_data
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar # todo change to bcebos
+
+ cd ./train_data/ && tar xf icdar2015_lite.tar && tar xf ic15_data.tar
+ ln -s ./icdar2015_lite ./icdar2015
+ cd ../
+elif [ ${MODE} = "whole_train_infer" ];then
+ wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
+ rm -rf ./train_data/icdar2015
+ rm -rf ./train_data/ic15_data
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
+ cd ./train_data/ && tar xf icdar2015.tar && tar xf ic15_data.tar && cd ../
+elif [ ${MODE} = "whole_infer" ];then
+ wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
+ rm -rf ./train_data/icdar2015
+ rm -rf ./train_data/ic15_data
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
+ cd ./train_data/ && tar xf icdar2015_infer.tar && tar xf ic15_data.tar
+ ln -s ./icdar2015_infer ./icdar2015
+ cd ../
+else
+ if [ ${model_name} = "ocr_det" ]; then
+ eval_model_name="ch_ppocr_mobile_v2.0_det_infer"
+ rm -rf ./train_data/icdar2015
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
+ cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
+ else
+ rm -rf ./train_data/ic15_data
+ eval_model_name="ch_ppocr_mobile_v2.0_rec_infer"
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar
+ cd ./inference && tar xf ${eval_model_name}.tar && tar xf ic15_data.tar && cd ../
+ fi
+fi
+
diff --git a/tests/readme.md b/tests/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..1c5e0faee90cad9709b6e4d517cbf7830aa2bb8e
--- /dev/null
+++ b/tests/readme.md
@@ -0,0 +1,58 @@
+
+# 介绍
+
+test.sh和params.txt文件配合使用,完成OCR轻量检测和识别模型从训练到预测的流程测试。
+
+# 安装依赖
+- 安装PaddlePaddle >= 2.0
+- 安装PaddleOCR依赖
+ ```
+ pip3 install -r ../requirements.txt
+ ```
+- 安装autolog
+ ```
+ git clone https://github.com/LDOUBLEV/AutoLog
+ cd AutoLog
+ pip3 install -r requirements.txt
+ python3 setup.py bdist_wheel
+ pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
+ cd ../
+ ```
+
+# 目录介绍
+
+```bash
+tests/
+├── ocr_det_params.txt # 测试OCR检测模型的参数配置文件
+├── ocr_rec_params.txt # 测试OCR识别模型的参数配置文件
+└── prepare.sh # 完成test.sh运行所需要的数据和模型下载
+└── test.sh # 根据
+```
+
+# 使用方法
+test.sh包含四种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是:
+- 模式1 lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
+```
+bash test/prepare.sh ./tests/ocr_det_params.txt 'lite_train_infer'
+bash tests/test.sh ./tests/ocr_det_params.txt 'lite_train_infer'
+```
+- 模式2 whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理;
+```
+bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_infer'
+bash tests/test.sh ./tests/ocr_det_params.txt 'whole_infer'
+```
+
+- 模式3 infer 不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度;
+```
+bash tests/prepare.sh ./tests/ocr_det_params.txt 'infer'
+用法1:
+bash tests/test.sh ./tests/ocr_det_params.txt 'infer'
+用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
+bash tests/test.sh ./tests/ocr_det_params.txt 'infer' '1'
+```
+
+模式4: whole_train_infer , CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度
+```
+bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_train_infer'
+bash tests/test.sh ./tests/ocr_det_params.txt 'whole_train_infer'
+```
diff --git a/tests/test.sh b/tests/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9888e0faabb13b00acdf41ad154ba0a0e7ec2b63
--- /dev/null
+++ b/tests/test.sh
@@ -0,0 +1,365 @@
+#!/bin/bash
+FILENAME=$1
+# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer']
+MODE=$2
+
+dataline=$(cat ${FILENAME})
+
+# parser params
+IFS=$'\n'
+lines=(${dataline})
+
+function func_parser_key(){
+ strs=$1
+ IFS=":"
+ array=(${strs})
+ tmp=${array[0]}
+ echo ${tmp}
+}
+function func_parser_value(){
+ strs=$1
+ IFS=":"
+ array=(${strs})
+ tmp=${array[1]}
+ echo ${tmp}
+}
+function func_set_params(){
+ key=$1
+ value=$2
+ if [ ${key} = "null" ];then
+ echo " "
+ elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
+ echo " "
+ else
+ echo "${key}=${value}"
+ fi
+}
+function func_parser_params(){
+ strs=$1
+ IFS=":"
+ array=(${strs})
+ key=${array[0]}
+ tmp=${array[1]}
+ IFS="|"
+ res=""
+ for _params in ${tmp[*]}; do
+ IFS="="
+ array=(${_params})
+ mode=${array[0]}
+ value=${array[1]}
+ if [[ ${mode} = ${MODE} ]]; then
+ IFS="|"
+ #echo $(func_set_params "${mode}" "${value}")
+ echo $value
+ break
+ fi
+ IFS="|"
+ done
+ echo ${res}
+}
+function status_check(){
+ last_status=$1 # the exit code
+ run_command=$2
+ run_log=$3
+ if [ $last_status -eq 0 ]; then
+ echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log}
+ else
+ echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log}
+ fi
+}
+
+IFS=$'\n'
+# The training params
+model_name=$(func_parser_value "${lines[1]}")
+python=$(func_parser_value "${lines[2]}")
+gpu_list=$(func_parser_value "${lines[3]}")
+train_use_gpu_key=$(func_parser_key "${lines[4]}")
+train_use_gpu_value=$(func_parser_value "${lines[4]}")
+autocast_list=$(func_parser_value "${lines[5]}")
+autocast_key=$(func_parser_key "${lines[5]}")
+epoch_key=$(func_parser_key "${lines[6]}")
+epoch_num=$(func_parser_params "${lines[6]}")
+save_model_key=$(func_parser_key "${lines[7]}")
+train_batch_key=$(func_parser_key "${lines[8]}")
+train_batch_value=$(func_parser_params "${lines[8]}")
+pretrain_model_key=$(func_parser_key "${lines[9]}")
+pretrain_model_value=$(func_parser_value "${lines[9]}")
+train_model_name=$(func_parser_value "${lines[10]}")
+train_infer_img_dir=$(func_parser_value "${lines[11]}")
+train_param_key1=$(func_parser_key "${lines[12]}")
+train_param_value1=$(func_parser_value "${lines[12]}")
+
+trainer_list=$(func_parser_value "${lines[14]}")
+trainer_norm=$(func_parser_key "${lines[15]}")
+norm_trainer=$(func_parser_value "${lines[15]}")
+pact_key=$(func_parser_key "${lines[16]}")
+pact_trainer=$(func_parser_value "${lines[16]}")
+fpgm_key=$(func_parser_key "${lines[17]}")
+fpgm_trainer=$(func_parser_value "${lines[17]}")
+distill_key=$(func_parser_key "${lines[18]}")
+distill_trainer=$(func_parser_value "${lines[18]}")
+trainer_key1=$(func_parser_key "${lines[19]}")
+trainer_value1=$(func_parser_value "${lines[19]}")
+trainer_key2=$(func_parser_key "${lines[20]}")
+trainer_value2=$(func_parser_value "${lines[20]}")
+
+eval_py=$(func_parser_value "${lines[23]}")
+eval_key1=$(func_parser_key "${lines[24]}")
+eval_value1=$(func_parser_value "${lines[24]}")
+
+save_infer_key=$(func_parser_key "${lines[27]}")
+export_weight=$(func_parser_key "${lines[28]}")
+norm_export=$(func_parser_value "${lines[29]}")
+pact_export=$(func_parser_value "${lines[30]}")
+fpgm_export=$(func_parser_value "${lines[31]}")
+distill_export=$(func_parser_value "${lines[32]}")
+export_key1=$(func_parser_key "${lines[33]}")
+export_value1=$(func_parser_value "${lines[33]}")
+export_key2=$(func_parser_key "${lines[34]}")
+export_value2=$(func_parser_value "${lines[34]}")
+
+# parser inference model
+infer_model_dir_list=$(func_parser_value "${lines[36]}")
+infer_export_list=$(func_parser_value "${lines[37]}")
+infer_is_quant=$(func_parser_value "${lines[38]}")
+# parser inference
+inference_py=$(func_parser_value "${lines[39]}")
+use_gpu_key=$(func_parser_key "${lines[40]}")
+use_gpu_list=$(func_parser_value "${lines[40]}")
+use_mkldnn_key=$(func_parser_key "${lines[41]}")
+use_mkldnn_list=$(func_parser_value "${lines[41]}")
+cpu_threads_key=$(func_parser_key "${lines[42]}")
+cpu_threads_list=$(func_parser_value "${lines[42]}")
+batch_size_key=$(func_parser_key "${lines[43]}")
+batch_size_list=$(func_parser_value "${lines[43]}")
+use_trt_key=$(func_parser_key "${lines[44]}")
+use_trt_list=$(func_parser_value "${lines[44]}")
+precision_key=$(func_parser_key "${lines[45]}")
+precision_list=$(func_parser_value "${lines[45]}")
+infer_model_key=$(func_parser_key "${lines[46]}")
+image_dir_key=$(func_parser_key "${lines[47]}")
+infer_img_dir=$(func_parser_value "${lines[47]}")
+save_log_key=$(func_parser_key "${lines[48]}")
+benchmark_key=$(func_parser_key "${lines[49]}")
+benchmark_value=$(func_parser_value "${lines[49]}")
+infer_key1=$(func_parser_key "${lines[50]}")
+infer_value1=$(func_parser_value "${lines[50]}")
+
+LOG_PATH="./tests/output"
+mkdir -p ${LOG_PATH}
+status_log="${LOG_PATH}/results.log"
+
+
+function func_inference(){
+ IFS='|'
+ _python=$1
+ _script=$2
+ _model_dir=$3
+ _log_path=$4
+ _img_dir=$5
+ _flag_quant=$6
+ # inference
+ for use_gpu in ${use_gpu_list[*]}; do
+ if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
+ for use_mkldnn in ${use_mkldnn_list[*]}; do
+ if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then
+ continue
+ fi
+ for threads in ${cpu_threads_list[*]}; do
+ for batch_size in ${batch_size_list[*]}; do
+ _save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
+ set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
+ set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+ set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
+ set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+ set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ eval $command
+ last_status=${PIPESTATUS[0]}
+ eval "cat ${_save_log_path}"
+ status_check $last_status "${command}" "${status_log}"
+ done
+ done
+ done
+ elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
+ for use_trt in ${use_trt_list[*]}; do
+ for precision in ${precision_list[*]}; do
+ if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
+ continue
+ fi
+ if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
+ continue
+ fi
+ if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
+ continue
+ fi
+ for batch_size in ${batch_size_list[*]}; do
+ _save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
+ set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
+ set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+ set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
+ set_precision=$(func_set_params "${precision_key}" "${precision}")
+ set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+ set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ eval $command
+ last_status=${PIPESTATUS[0]}
+ eval "cat ${_save_log_path}"
+ status_check $last_status "${command}" "${status_log}"
+
+ done
+ done
+ done
+ else
+ echo "Does not support hardware other than CPU and GPU Currently!"
+ fi
+ done
+}
+
+if [ ${MODE} = "infer" ]; then
+ GPUID=$3
+ if [ ${#GPUID} -le 0 ];then
+ env=" "
+ else
+ env="export CUDA_VISIBLE_DEVICES=${GPUID}"
+ fi
+ # set CUDA_VISIBLE_DEVICES
+ eval $env
+ export Count=0
+ IFS="|"
+ infer_run_exports=(${infer_export_list})
+ infer_quant_flag=(${infer_is_quant})
+ for infer_model in ${infer_model_dir_list[*]}; do
+ # run export
+ if [ ${infer_run_exports[Count]} != "null" ];then
+ save_infer_dir=$(dirname $infer_model)
+ set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
+ set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
+ export_cmd="${python} ${norm_export} ${set_export_weight} ${set_save_infer_key}"
+ eval $export_cmd
+ status_export=$?
+ if [ ${status_export} = 0 ];then
+ status_check $status_export "${export_cmd}" "${status_log}"
+ fi
+ else
+ save_infer_dir=${infer_model}
+ fi
+ #run inference
+ is_quant=${infer_quant_flag[Count]}
+ func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
+ Count=$(($Count + 1))
+ done
+
+else
+ IFS="|"
+ export Count=0
+ USE_GPU_KEY=(${train_use_gpu_value})
+ for gpu in ${gpu_list[*]}; do
+ use_gpu=${USE_GPU_KEY[Count]}
+ Count=$(($Count + 1))
+ if [ ${gpu} = "-1" ];then
+ env=""
+ elif [ ${#gpu} -le 1 ];then
+ env="export CUDA_VISIBLE_DEVICES=${gpu}"
+ eval ${env}
+ elif [ ${#gpu} -le 15 ];then
+ IFS=","
+ array=(${gpu})
+ env="export CUDA_VISIBLE_DEVICES=${array[0]}"
+ IFS="|"
+ else
+ IFS=";"
+ array=(${gpu})
+ ips=${array[0]}
+ gpu=${array[1]}
+ IFS="|"
+ env=" "
+ fi
+ for autocast in ${autocast_list[*]}; do
+ for trainer in ${trainer_list[*]}; do
+ flag_quant=False
+ if [ ${trainer} = ${pact_key} ]; then
+ run_train=${pact_trainer}
+ run_export=${pact_export}
+ flag_quant=True
+ elif [ ${trainer} = "${fpgm_key}" ]; then
+ run_train=${fpgm_trainer}
+ run_export=${fpgm_export}
+ elif [ ${trainer} = "${distill_key}" ]; then
+ run_train=${distill_trainer}
+ run_export=${distill_export}
+ elif [ ${trainer} = ${trainer_key1} ]; then
+ run_train=${trainer_value1}
+ run_export=${export_value1}
+ elif [[ ${trainer} = ${trainer_key2} ]]; then
+ run_train=${trainer_value2}
+ run_export=${export_value2}
+ else
+ run_train=${norm_trainer}
+ run_export=${norm_export}
+ fi
+
+ if [ ${run_train} = "null" ]; then
+ continue
+ fi
+
+ set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
+ set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
+ set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
+ set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
+ set_train_params1=$(func_set_params "${train_param_key1}" "${train_param_value1}")
+ set_use_gpu=$(func_set_params "${train_use_gpu_key}" "${use_gpu}")
+ save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}"
+
+ # load pretrain from norm training if current trainer is pact or fpgm trainer
+ if [ ${trainer} = ${pact_key} ] || [ ${trainer} = ${fpgm_key} ]; then
+ set_pretrain="${load_norm_train_model}"
+ fi
+
+ set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
+ if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
+ cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} "
+ elif [ ${#gpu} -le 15 ];then # train with multi-gpu
+ cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1}"
+ else # train with multi-machine
+ cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1}"
+ fi
+ # run train
+ eval "unset CUDA_VISIBLE_DEVICES"
+ eval $cmd
+ status_check $? "${cmd}" "${status_log}"
+
+ set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
+ # save norm trained models to set pretrain for pact training and fpgm training
+ if [ ${trainer} = ${trainer_norm} ]; then
+ load_norm_train_model=${set_eval_pretrain}
+ fi
+ # run eval
+ if [ ${eval_py} != "null" ]; then
+ set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}")
+ eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}"
+ eval $eval_cmd
+ status_check $? "${eval_cmd}" "${status_log}"
+ fi
+ # run export model
+ if [ ${run_export} != "null" ]; then
+ # run export model
+ save_infer_path="${save_log}"
+ set_export_weight=$(func_set_params "${export_weight}" "${save_log}/${train_model_name}")
+ set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
+ export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key}"
+ eval $export_cmd
+ status_check $? "${export_cmd}" "${status_log}"
+
+ #run inference
+ eval $env
+ save_infer_path="${save_log}"
+ func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}"
+ eval "unset CUDA_VISIBLE_DEVICES"
+ fi
+ done # done with: for trainer in ${trainer_list[*]}; do
+ done # done with: for autocast in ${autocast_list[*]}; do
+ done # done with: for gpu in ${gpu_list[*]}; do
+fi # end if [ ${MODE} = "infer" ]; then
diff --git a/tools/eval.py b/tools/eval.py
index 66eb315f9b37ed681f6a899613fa43c1313bc654..0120baab0f34d5fadbbf4df20d92d6b62dd176a2 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import init_model, load_pretrained_params
from ppocr.utils.utility import print_dict
import tools.program as program
@@ -44,10 +44,21 @@ def main():
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
- config['Architecture']["Head"]['out_channels'] = len(
- getattr(post_process_class, 'character'))
+ char_num = len(getattr(post_process_class, 'character'))
+ if config['Architecture']["algorithm"] in ["Distillation",
+ ]: # distillation model
+ for key in config['Architecture']["Models"]:
+ config['Architecture']["Models"][key]["Head"][
+ 'out_channels'] = char_num
+ else: # base rec model
+ config['Architecture']["Head"]['out_channels'] = char_num
+
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
+ if "model_type" in config['Architecture'].keys():
+ model_type = config['Architecture']['model_type']
+ else:
+ model_type = None
best_model_dict = init_model(config, model)
if len(best_model_dict):
@@ -60,7 +71,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class, use_srn)
+ eval_class, model_type, use_srn)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
diff --git a/tools/export_model.py b/tools/export_model.py
index 625c82468edff7c3eeb787422bdef07b4b274460..785aca10e46200bda49bdff2b89ba00cafbe7a20 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
-
+ elif arch_config["model_type"] == "table":
+ infer_shape = [3, 488, 488]
model = to_static(
model,
input_spec=[
diff --git a/tools/infer/benchmark_utils.py b/tools/infer/benchmark_utils.py
deleted file mode 100644
index 1a241d063368d19567e253bf1dada09801d468bc..0000000000000000000000000000000000000000
--- a/tools/infer/benchmark_utils.py
+++ /dev/null
@@ -1,232 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import argparse
-import os
-import time
-import logging
-
-import paddle
-import paddle.inference as paddle_infer
-
-from pathlib import Path
-
-CUR_DIR = os.path.dirname(os.path.abspath(__file__))
-
-
-class PaddleInferBenchmark(object):
- def __init__(self,
- config,
- model_info: dict={},
- data_info: dict={},
- perf_info: dict={},
- resource_info: dict={},
- save_log_path: str="",
- **kwargs):
- """
- Construct PaddleInferBenchmark Class to format logs.
- args:
- config(paddle.inference.Config): paddle inference config
- model_info(dict): basic model info
- {'model_name': 'resnet50'
- 'precision': 'fp32'}
- data_info(dict): input data info
- {'batch_size': 1
- 'shape': '3,224,224'
- 'data_num': 1000}
- perf_info(dict): performance result
- {'preprocess_time_s': 1.0
- 'inference_time_s': 2.0
- 'postprocess_time_s': 1.0
- 'total_time_s': 4.0}
- resource_info(dict):
- cpu and gpu resources
- {'cpu_rss': 100
- 'gpu_rss': 100
- 'gpu_util': 60}
- """
- # PaddleInferBenchmark Log Version
- self.log_version = 1.0
-
- # Paddle Version
- self.paddle_version = paddle.__version__
- self.paddle_commit = paddle.__git_commit__
- paddle_infer_info = paddle_infer.get_version()
- self.paddle_branch = paddle_infer_info.strip().split(': ')[-1]
-
- # model info
- self.model_info = model_info
-
- # data info
- self.data_info = data_info
-
- # perf info
- self.perf_info = perf_info
-
- try:
- self.model_name = model_info['model_name']
- self.precision = model_info['precision']
-
- self.batch_size = data_info['batch_size']
- self.shape = data_info['shape']
- self.data_num = data_info['data_num']
-
- self.preprocess_time_s = round(perf_info['preprocess_time_s'], 4)
- self.inference_time_s = round(perf_info['inference_time_s'], 4)
- self.postprocess_time_s = round(perf_info['postprocess_time_s'], 4)
- self.total_time_s = round(perf_info['total_time_s'], 4)
- except:
- self.print_help()
- raise ValueError(
- "Set argument wrong, please check input argument and its type")
-
- # conf info
- self.config_status = self.parse_config(config)
- self.save_log_path = save_log_path
- # mem info
- if isinstance(resource_info, dict):
- self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0))
- self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0))
- self.gpu_util = round(resource_info.get('gpu_util', 0), 2)
- else:
- self.cpu_rss_mb = 0
- self.gpu_rss_mb = 0
- self.gpu_util = 0
-
- # init benchmark logger
- self.benchmark_logger()
-
- def benchmark_logger(self):
- """
- benchmark logger
- """
- # Init logger
- FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- log_output = f"{self.save_log_path}/{self.model_name}.log"
- Path(f"{self.save_log_path}").mkdir(parents=True, exist_ok=True)
- logging.basicConfig(
- level=logging.INFO,
- format=FORMAT,
- handlers=[
- logging.FileHandler(
- filename=log_output, mode='w'),
- logging.StreamHandler(),
- ])
- self.logger = logging.getLogger(__name__)
- self.logger.info(
- f"Paddle Inference benchmark log will be saved to {log_output}")
-
- def parse_config(self, config) -> dict:
- """
- parse paddle predictor config
- args:
- config(paddle.inference.Config): paddle inference config
- return:
- config_status(dict): dict style config info
- """
- config_status = {}
- config_status['runtime_device'] = "gpu" if config.use_gpu() else "cpu"
- config_status['ir_optim'] = config.ir_optim()
- config_status['enable_tensorrt'] = config.tensorrt_engine_enabled()
- config_status['precision'] = self.precision
- config_status['enable_mkldnn'] = config.mkldnn_enabled()
- config_status[
- 'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads(
- )
- return config_status
-
- def report(self, identifier=None):
- """
- print log report
- args:
- identifier(string): identify log
- """
- if identifier:
- identifier = f"[{identifier}]"
- else:
- identifier = ""
-
- self.logger.info("\n")
- self.logger.info(
- "---------------------- Paddle info ----------------------")
- self.logger.info(f"{identifier} paddle_version: {self.paddle_version}")
- self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}")
- self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}")
- self.logger.info(f"{identifier} log_api_version: {self.log_version}")
- self.logger.info(
- "----------------------- Conf info -----------------------")
- self.logger.info(
- f"{identifier} runtime_device: {self.config_status['runtime_device']}"
- )
- self.logger.info(
- f"{identifier} ir_optim: {self.config_status['ir_optim']}")
- self.logger.info(f"{identifier} enable_memory_optim: {True}")
- self.logger.info(
- f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}"
- )
- self.logger.info(
- f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}")
- self.logger.info(
- f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}"
- )
- self.logger.info(
- "----------------------- Model info ----------------------")
- self.logger.info(f"{identifier} model_name: {self.model_name}")
- self.logger.info(f"{identifier} precision: {self.precision}")
- self.logger.info(
- "----------------------- Data info -----------------------")
- self.logger.info(f"{identifier} batch_size: {self.batch_size}")
- self.logger.info(f"{identifier} input_shape: {self.shape}")
- self.logger.info(f"{identifier} data_num: {self.data_num}")
- self.logger.info(
- "----------------------- Perf info -----------------------")
- self.logger.info(
- f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%"
- )
- self.logger.info(
- f"{identifier} total time spent(s): {self.total_time_s}")
- self.logger.info(
- f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
- )
-
- def print_help(self):
- """
- print function help
- """
- print("""Usage:
- ==== Print inference benchmark logs. ====
- config = paddle.inference.Config()
- model_info = {'model_name': 'resnet50'
- 'precision': 'fp32'}
- data_info = {'batch_size': 1
- 'shape': '3,224,224'
- 'data_num': 1000}
- perf_info = {'preprocess_time_s': 1.0
- 'inference_time_s': 2.0
- 'postprocess_time_s': 1.0
- 'total_time_s': 4.0}
- resource_info = {'cpu_rss_mb': 100
- 'gpu_rss_mb': 100
- 'gpu_util': 60}
- log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
- log('Test')
- """)
-
- def __call__(self, identifier=None):
- """
- __call__
- args:
- identifier(string): identify log
- """
- self.report(identifier)
diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py
index 0037b226df8e1de8edbdb7668e349925a942e8b9..53e50bd6d1d1a2bd07b9f1204b9f56594c669d13 100755
--- a/tools/infer/predict_cls.py
+++ b/tools/infer/predict_cls.py
@@ -48,8 +48,6 @@ class TextClassifier(object):
self.predictor, self.input_tensor, self.output_tensors, _ = \
utility.create_predictor(args, 'cls', logger)
- self.cls_times = utility.Timer()
-
def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape
h = img.shape[0]
@@ -85,35 +83,28 @@ class TextClassifier(object):
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
elapse = 0
- self.cls_times.total_time.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
+ starttime = time.time()
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
- self.cls_times.preprocess_time.start()
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
- starttime = time.time()
- self.cls_times.preprocess_time.end()
- self.cls_times.inference_time.start()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu()
- self.cls_times.inference_time.end()
- self.cls_times.postprocess_time.start()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out)
- self.cls_times.postprocess_time.end()
elapse += time.time() - starttime
for rno in range(len(cls_result)):
label, score = cls_result[rno]
@@ -121,9 +112,6 @@ class TextClassifier(object):
if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
- self.cls_times.total_time.end()
- self.cls_times.img_num += img_num
- elapse = self.cls_times.total_time.value()
return img_list, cls_res, elapse
@@ -157,7 +145,6 @@ def main(args):
cls_res[ino]))
logger.info(
"The predict time about text angle classify module is as follows: ")
- text_classifier.cls_times.info(average=False)
if __name__ == "__main__":
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index f5bade36315fbe321927df82cdd7cd8bf40b2ae5..3de00d83a8f9f55af9b89d5d2cd5c877399c5930 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -31,8 +31,6 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
-import tools.infer.benchmark_utils as benchmark_utils
-
logger = get_logger()
@@ -43,7 +41,7 @@ class TextDetector(object):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
- 'limit_type': args.det_limit_type
+ 'limit_type': args.det_limit_type,
}
}, {
'NormalizeImage': {
@@ -100,7 +98,24 @@ class TextDetector(object):
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'det', logger)
- self.det_times = utility.Timer()
+ if args.benchmark:
+ import auto_log
+ pid = os.getpid()
+ self.autolog = auto_log.AutoLogger(
+ model_name="det",
+ model_precision=args.precision,
+ batch_size=1,
+ data_shape="dynamic",
+ save_path=None,
+ inference_config=self.config,
+ pids=pid,
+ process_name=None,
+ gpu_ids=0,
+ time_keys=[
+ 'preprocess_time', 'inference_time', 'postprocess_time'
+ ],
+ warmup=2,
+ logger=logger)
def order_points_clockwise(self, pts):
"""
@@ -158,8 +173,12 @@ class TextDetector(object):
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
- self.det_times.total_time.start()
- self.det_times.preprocess_time.start()
+
+ st = time.time()
+
+ if self.args.benchmark:
+ self.autolog.times.start()
+
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
@@ -168,8 +187,8 @@ class TextDetector(object):
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
- self.det_times.preprocess_time.end()
- self.det_times.inference_time.start()
+ if self.args.benchmark:
+ self.autolog.times.stamp()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
@@ -177,7 +196,8 @@ class TextDetector(object):
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
- self.det_times.inference_time.end()
+ if self.args.benchmark:
+ self.autolog.times.stamp()
preds = {}
if self.det_algorithm == "EAST":
@@ -193,9 +213,7 @@ class TextDetector(object):
else:
raise NotImplementedError
- self.det_times.postprocess_time.start()
-
- self.predictor.try_shrink_memory()
+ #self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon:
@@ -203,10 +221,10 @@ class TextDetector(object):
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
- self.det_times.postprocess_time.end()
- self.det_times.total_time.end()
- self.det_times.img_num += 1
- return dt_boxes, self.det_times.total_time.value()
+ if self.args.benchmark:
+ self.autolog.times.end(stamp=True)
+ et = time.time()
+ return dt_boxes, et - st
if __name__ == "__main__":
@@ -216,12 +234,11 @@ if __name__ == "__main__":
count = 0
total_time = 0
draw_img_save = "./inference_results"
- cpu_mem, gpu_mem, gpu_util = 0, 0, 0
- # warmup 10 times
- fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32)
- for i in range(10):
- dt_boxes, _ = text_detector(fake_img)
+ if args.warmup:
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
+ for i in range(2):
+ res = text_detector(img)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
@@ -239,49 +256,13 @@ if __name__ == "__main__":
total_time += elapse
count += 1
- if args.benchmark:
- cm, gm, gu = utility.get_current_memory_mb(0)
- cpu_mem += cm
- gpu_mem += gm
- gpu_util += gu
-
logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure))
-
+ cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
- # print the information about memory and time-spent
- if args.benchmark:
- mems = {
- 'cpu_rss_mb': cpu_mem / count,
- 'gpu_rss_mb': gpu_mem / count,
- 'gpu_util': gpu_util * 100 / count
- }
- else:
- mems = None
- logger.info("The predict time about detection module is as follows: ")
- det_time_dict = text_detector.det_times.report(average=True)
- det_model_name = args.det_model_dir
if args.benchmark:
- # construct log information
- model_info = {
- 'model_name': args.det_model_dir.split('/')[-1],
- 'precision': args.precision
- }
- data_info = {
- 'batch_size': 1,
- 'shape': 'dynamic_shape',
- 'data_num': det_time_dict['img_num']
- }
- perf_info = {
- 'preprocess_time_s': det_time_dict['preprocess_time'],
- 'inference_time_s': det_time_dict['inference_time'],
- 'postprocess_time_s': det_time_dict['postprocess_time'],
- 'total_time_s': det_time_dict['total_time']
- }
- benchmark_log = benchmark_utils.PaddleInferBenchmark(
- text_detector.config, model_info, data_info, perf_info, mems)
- benchmark_log("Det")
+ text_detector.autolog.report()
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 2eeb39b2a0bff15241ea7762b4981e4daaada096..bb4a31706471b9b1745519ac9f390d01b60d5d44 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -28,7 +28,6 @@ import traceback
import paddle
import tools.infer.utility as utility
-import tools.infer.benchmark_utils as benchmark_utils
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
@@ -65,8 +64,25 @@ class TextRecognizer(object):
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
-
- self.rec_times = utility.Timer()
+ self.benchmark = args.benchmark
+ if args.benchmark:
+ import auto_log
+ pid = os.getpid()
+ self.autolog = auto_log.AutoLogger(
+ model_name="rec",
+ model_precision=args.precision,
+ batch_size=args.rec_batch_num,
+ data_shape="dynamic",
+ save_path=None, #args.save_log_path,
+ inference_config=self.config,
+ pids=pid,
+ process_name=None,
+ gpu_ids=0 if args.use_gpu else None,
+ time_keys=[
+ 'preprocess_time', 'inference_time', 'postprocess_time'
+ ],
+ warmup=2,
+ logger=logger)
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
@@ -168,14 +184,15 @@ class TextRecognizer(object):
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
- self.rec_times.total_time.start()
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
+ st = time.time()
+ if self.benchmark:
+ self.autolog.times.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
- self.rec_times.preprocess_time.start()
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
@@ -200,6 +217,8 @@ class TextRecognizer(object):
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
+ if self.benchmark:
+ self.autolog.times.stamp()
if self.rec_algorithm == "SRN":
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
@@ -216,23 +235,20 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
- self.rec_times.preprocess_time.end()
- self.rec_times.inference_time.start()
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
- self.rec_times.inference_time.end()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
+ if self.benchmark:
+ self.autolog.times.stamp()
preds = {"predict": outputs[2]}
else:
- self.rec_times.preprocess_time.end()
- self.rec_times.inference_time.start()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
@@ -240,16 +256,15 @@ class TextRecognizer(object):
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
+ if self.benchmark:
+ self.autolog.times.stamp()
preds = outputs[0]
- self.rec_times.inference_time.end()
- self.rec_times.postprocess_time.start()
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
- self.rec_times.postprocess_time.end()
- self.rec_times.img_num += int(norm_img_batch.shape[0])
- self.rec_times.total_time.end()
- return rec_res, self.rec_times.total_time.value()
+ if self.benchmark:
+ self.autolog.times.end(stamp=True)
+ return rec_res, time.time() - st
def main(args):
@@ -257,13 +272,12 @@ def main(args):
text_recognizer = TextRecognizer(args)
valid_image_file_list = []
img_list = []
- cpu_mem, gpu_mem, gpu_util = 0, 0, 0
- count = 0
- # warmup 10 times
- fake_img = np.random.uniform(-1, 1, [1, 32, 320, 3]).astype(np.float32)
- for i in range(10):
- dt_boxes, _ = text_recognizer(fake_img)
+ # warmup 2 times
+ if args.warmup:
+ img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
+ for i in range(2):
+ res = text_recognizer([img])
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
@@ -276,12 +290,6 @@ def main(args):
img_list.append(img)
try:
rec_res, _ = text_recognizer(img_list)
- if args.benchmark:
- cm, gm, gu = utility.get_current_memory_mb(0)
- cpu_mem += cm
- gpu_mem += gm
- gpu_util += gu
- count += 1
except Exception as E:
logger.info(traceback.format_exc())
@@ -291,37 +299,7 @@ def main(args):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
rec_res[ino]))
if args.benchmark:
- mems = {
- 'cpu_rss_mb': cpu_mem / count,
- 'gpu_rss_mb': gpu_mem / count,
- 'gpu_util': gpu_util * 100 / count
- }
- else:
- mems = None
- logger.info("The predict time about recognizer module is as follows: ")
- rec_time_dict = text_recognizer.rec_times.report(average=True)
- rec_model_name = args.rec_model_dir
-
- if args.benchmark:
- # construct log information
- model_info = {
- 'model_name': args.rec_model_dir.split('/')[-1],
- 'precision': args.precision
- }
- data_info = {
- 'batch_size': args.rec_batch_num,
- 'shape': 'dynamic_shape',
- 'data_num': rec_time_dict['img_num']
- }
- perf_info = {
- 'preprocess_time_s': rec_time_dict['preprocess_time'],
- 'inference_time_s': rec_time_dict['inference_time'],
- 'postprocess_time_s': rec_time_dict['postprocess_time'],
- 'total_time_s': rec_time_dict['total_time']
- }
- benchmark_log = benchmark_utils.PaddleInferBenchmark(
- text_recognizer.config, model_info, data_info, perf_info, mems)
- benchmark_log("Rec")
+ text_recognizer.autolog.report()
if __name__ == "__main__":
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index ad1b7d4ef432924f1781a16eae07c171e571826b..eae0e27cd284ccce9f41f0c20b05dee09f46fc84 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -13,6 +13,7 @@
# limitations under the License.
import os
import sys
+import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
@@ -24,6 +25,7 @@ import cv2
import copy
import numpy as np
import time
+import logging
from PIL import Image
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec
@@ -31,13 +33,15 @@ import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
-from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb
-import tools.infer.benchmark_utils as benchmark_utils
+from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image
logger = get_logger()
class TextSystem(object):
def __init__(self, args):
+ if not args.show_log:
+ logger.setLevel(logging.INFO)
+
self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
@@ -45,39 +49,6 @@ class TextSystem(object):
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
- def get_rotate_crop_image(self, img, points):
- '''
- img_height, img_width = img.shape[0:2]
- left = int(np.min(points[:, 0]))
- right = int(np.max(points[:, 0]))
- top = int(np.min(points[:, 1]))
- bottom = int(np.max(points[:, 1]))
- img_crop = img[top:bottom, left:right, :].copy()
- points[:, 0] = points[:, 0] - left
- points[:, 1] = points[:, 1] - top
- '''
- img_crop_width = int(
- max(
- np.linalg.norm(points[0] - points[1]),
- np.linalg.norm(points[2] - points[3])))
- img_crop_height = int(
- max(
- np.linalg.norm(points[0] - points[3]),
- np.linalg.norm(points[1] - points[2])))
- pts_std = np.float32([[0, 0], [img_crop_width, 0],
- [img_crop_width, img_crop_height],
- [0, img_crop_height]])
- M = cv2.getPerspectiveTransform(points, pts_std)
- dst_img = cv2.warpPerspective(
- img,
- M, (img_crop_width, img_crop_height),
- borderMode=cv2.BORDER_REPLICATE,
- flags=cv2.INTER_CUBIC)
- dst_img_height, dst_img_width = dst_img.shape[0:2]
- if dst_img_height * 1.0 / dst_img_width >= 1.5:
- dst_img = np.rot90(dst_img)
- return dst_img
-
def print_draw_crop_rec_res(self, img_crop_list, rec_res):
bbox_num = len(img_crop_list)
for bno in range(bbox_num):
@@ -88,8 +59,7 @@ class TextSystem(object):
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
- logger.info("dt_boxes num : {}, elapse : {}".format(
-
+ logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
return None, None
@@ -99,16 +69,16 @@ class TextSystem(object):
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
- img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
+ img_crop = get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
- logger.info("cls num : {}, elapse : {}".format(
+ logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
- logger.info("rec_res num : {}, elapse : {}".format(
+ logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], []
@@ -143,15 +113,24 @@ def sorted_boxes(dt_boxes):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
+ image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args)
is_visualize = True
font_path = args.vis_font_path
drop_score = args.drop_score
+
+ # warm up 10 times
+ if args.warmup:
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
+ for i in range(10):
+ res = text_sys(img)
+
total_time = 0
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
_st = time.time()
count = 0
for idx, image_file in enumerate(image_file_list):
+
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
@@ -162,12 +141,6 @@ def main(args):
dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime
total_time += elapse
- if args.benchmark and idx % 20 == 0:
- cm, gm, gu = get_current_memory_mb(0)
- cpu_mem += cm
- gpu_mem += gm
- gpu_util += gu
- count += 1
logger.info(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
@@ -201,63 +174,20 @@ def main(args):
logger.info("The predict total time is {}".format(time.time() - _st))
logger.info("\nThe predict total time is {}".format(total_time))
- img_num = text_sys.text_detector.det_times.img_num
- if args.benchmark:
- mems = {
- 'cpu_rss_mb': cpu_mem / count,
- 'gpu_rss_mb': gpu_mem / count,
- 'gpu_util': gpu_util * 100 / count
- }
- else:
- mems = None
- det_time_dict = text_sys.text_detector.det_times.report(average=True)
- rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True)
- det_model_name = args.det_model_dir
- rec_model_name = args.rec_model_dir
-
- # construct det log information
- model_info = {
- 'model_name': args.det_model_dir.split('/')[-1],
- 'precision': args.precision
- }
- data_info = {
- 'batch_size': 1,
- 'shape': 'dynamic_shape',
- 'data_num': det_time_dict['img_num']
- }
- perf_info = {
- 'preprocess_time_s': det_time_dict['preprocess_time'],
- 'inference_time_s': det_time_dict['inference_time'],
- 'postprocess_time_s': det_time_dict['postprocess_time'],
- 'total_time_s': det_time_dict['total_time']
- }
-
- benchmark_log = benchmark_utils.PaddleInferBenchmark(
- text_sys.text_detector.config, model_info, data_info, perf_info, mems,
- args.save_log_path)
- benchmark_log("Det")
-
- # construct rec log information
- model_info = {
- 'model_name': args.rec_model_dir.split('/')[-1],
- 'precision': args.precision
- }
- data_info = {
- 'batch_size': args.rec_batch_num,
- 'shape': 'dynamic_shape',
- 'data_num': rec_time_dict['img_num']
- }
- perf_info = {
- 'preprocess_time_s': rec_time_dict['preprocess_time'],
- 'inference_time_s': rec_time_dict['inference_time'],
- 'postprocess_time_s': rec_time_dict['postprocess_time'],
- 'total_time_s': rec_time_dict['total_time']
- }
- benchmark_log = benchmark_utils.PaddleInferBenchmark(
- text_sys.text_recognizer.config, model_info, data_info, perf_info, mems,
- args.save_log_path)
- benchmark_log("Rec")
-
if __name__ == "__main__":
- main(utility.parse_args())
+ args = utility.parse_args()
+ if args.use_mp:
+ p_list = []
+ total_process_num = args.total_process_num
+ for process_id in range(total_process_num):
+ cmd = [sys.executable, "-u"] + sys.argv + [
+ "--process_id={}".format(process_id),
+ "--use_mp={}".format(False)
+ ]
+ p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
+ p_list.append(p)
+ for p in p_list:
+ p.wait()
+ else:
+ main(args)
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 69f28e00b39b657954902e1b4c851fe357ea3619..1c82280099f17f6d3bf848669e47439505f10576 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -24,8 +24,6 @@ from paddle import inference
import time
from ppocr.utils.logging import get_logger
-logger = get_logger()
-
def str2bool(v):
return v.lower() in ("true", "t", "1")
@@ -37,6 +35,7 @@ def init_args():
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
+ parser.add_argument("--min_subgraph_size", type=int, default=10)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
@@ -49,10 +48,10 @@ def init_args():
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
- parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
- parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
+ parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
+ parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
parser.add_argument("--max_batch_size", type=int, default=10)
- parser.add_argument("--use_dilation", type=bool, default=False)
+ parser.add_argument("--use_dilation", type=str2bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
@@ -62,7 +61,7 @@ def init_args():
# SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
- parser.add_argument("--det_sast_polygon", type=bool, default=False)
+ parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
@@ -91,7 +90,7 @@ def init_args():
parser.add_argument(
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
- parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
+ parser.add_argument("--e2e_pgnet_polygon", type=str2bool, default=True)
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
# params for text classifier
@@ -105,15 +104,17 @@ def init_args():
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
+ parser.add_argument("--warmup", type=str2bool, default=True)
+ # multi-process
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
-
- parser.add_argument("--benchmark", type=bool, default=False)
- parser.add_argument("--save_log_path", type=str, default="./log_output/")
+ parser.add_argument("--benchmark", type=str2bool, default=False)
+ parser.add_argument("--save_log_path", type=str, default="./log_output/")
+ parser.add_argument("--show_log", type=str2bool, default=True)
return parser
@@ -122,76 +123,6 @@ def parse_args():
return parser.parse_args()
-class Times(object):
- def __init__(self):
- self.time = 0.
- self.st = 0.
- self.et = 0.
-
- def start(self):
- self.st = time.time()
-
- def end(self, accumulative=True):
- self.et = time.time()
- if accumulative:
- self.time += self.et - self.st
- else:
- self.time = self.et - self.st
-
- def reset(self):
- self.time = 0.
- self.st = 0.
- self.et = 0.
-
- def value(self):
- return round(self.time, 4)
-
-
-class Timer(Times):
- def __init__(self):
- super(Timer, self).__init__()
- self.total_time = Times()
- self.preprocess_time = Times()
- self.inference_time = Times()
- self.postprocess_time = Times()
- self.img_num = 0
-
- def info(self, average=False):
- logger.info("----------------------- Perf info -----------------------")
- logger.info("total_time: {}, img_num: {}".format(self.total_time.value(
- ), self.img_num))
- preprocess_time = round(self.preprocess_time.value() / self.img_num,
- 4) if average else self.preprocess_time.value()
- postprocess_time = round(
- self.postprocess_time.value() / self.img_num,
- 4) if average else self.postprocess_time.value()
- inference_time = round(self.inference_time.value() / self.img_num,
- 4) if average else self.inference_time.value()
-
- average_latency = self.total_time.value() / self.img_num
- logger.info("average_latency(ms): {:.2f}, QPS: {:2f}".format(
- average_latency * 1000, 1 / average_latency))
- logger.info(
- "preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}".
- format(preprocess_time * 1000, inference_time * 1000,
- postprocess_time * 1000))
-
- def report(self, average=False):
- dic = {}
- dic['preprocess_time'] = round(
- self.preprocess_time.value() / self.img_num,
- 4) if average else self.preprocess_time.value()
- dic['postprocess_time'] = round(
- self.postprocess_time.value() / self.img_num,
- 4) if average else self.postprocess_time.value()
- dic['inference_time'] = round(
- self.inference_time.value() / self.img_num,
- 4) if average else self.inference_time.value()
- dic['img_num'] = self.img_num
- dic['total_time'] = round(self.total_time.value(), 4)
- return dic
-
-
def create_predictor(args, mode, logger):
if mode == "det":
model_dir = args.det_model_dir
@@ -199,6 +130,8 @@ def create_predictor(args, mode, logger):
model_dir = args.cls_model_dir
elif mode == 'rec':
model_dir = args.rec_model_dir
+ elif mode == 'table':
+ model_dir = args.table_model_dir
else:
model_dir = args.e2e_model_dir
@@ -208,11 +141,10 @@ def create_predictor(args, mode, logger):
model_file_path = model_dir + "/inference.pdmodel"
params_file_path = model_dir + "/inference.pdiparams"
if not os.path.exists(model_file_path):
- logger.info("not find model file path {}".format(model_file_path))
- sys.exit(0)
+ raise ValueError("not find model file path {}".format(model_file_path))
if not os.path.exists(params_file_path):
- logger.info("not find params file path {}".format(params_file_path))
- sys.exit(0)
+ raise ValueError("not find params file path {}".format(
+ params_file_path))
config = inference.Config(model_file_path, params_file_path)
@@ -230,71 +162,74 @@ def create_predictor(args, mode, logger):
config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt:
config.enable_tensorrt_engine(
- precision_mode=inference.PrecisionType.Float32,
+ precision_mode=precision,
max_batch_size=args.max_batch_size,
- min_subgraph_size=3) # skip the minmum trt subgraph
- if mode == "det" and "mobile" in model_file_path:
+ min_subgraph_size=args.min_subgraph_size)
+ # skip the minmum trt subgraph
+ if mode == "det":
min_input_shape = {
"x": [1, 3, 50, 50],
- "conv2d_92.tmp_0": [1, 96, 20, 20],
- "conv2d_91.tmp_0": [1, 96, 10, 10],
- "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
- "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
- "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
- "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
- "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20],
+ "conv2d_92.tmp_0": [1, 120, 20, 20],
+ "conv2d_91.tmp_0": [1, 24, 10, 10],
+ "conv2d_59.tmp_0": [1, 96, 20, 20],
+ "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
+ "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
+ "conv2d_124.tmp_0": [1, 256, 20, 20],
+ "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
+ "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
+ "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
"elementwise_add_7": [1, 56, 2, 2],
- "nearest_interp_v2_0.tmp_0": [1, 96, 2, 2]
+ "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
}
max_input_shape = {
"x": [1, 3, 2000, 2000],
- "conv2d_92.tmp_0": [1, 96, 400, 400],
- "conv2d_91.tmp_0": [1, 96, 200, 200],
- "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
- "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
- "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
- "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
- "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400],
+ "conv2d_92.tmp_0": [1, 120, 400, 400],
+ "conv2d_91.tmp_0": [1, 24, 200, 200],
+ "conv2d_59.tmp_0": [1, 96, 400, 400],
+ "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
+ "conv2d_124.tmp_0": [1, 256, 400, 400],
+ "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
+ "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
+ "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
+ "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
"elementwise_add_7": [1, 56, 400, 400],
- "nearest_interp_v2_0.tmp_0": [1, 96, 400, 400]
+ "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
- "conv2d_92.tmp_0": [1, 96, 160, 160],
- "conv2d_91.tmp_0": [1, 96, 80, 80],
- "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
- "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
- "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
- "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
- "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160],
+ "conv2d_92.tmp_0": [1, 120, 160, 160],
+ "conv2d_91.tmp_0": [1, 24, 80, 80],
+ "conv2d_59.tmp_0": [1, 96, 160, 160],
+ "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
+ "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
+ "conv2d_124.tmp_0": [1, 256, 160, 160],
+ "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
+ "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
+ "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
"elementwise_add_7": [1, 56, 40, 40],
- "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
+ "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
}
- if mode == "det" and "server" in model_file_path:
- min_input_shape = {
- "x": [1, 3, 50, 50],
- "conv2d_59.tmp_0": [1, 96, 20, 20],
- "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
- "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
- "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
- "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20]
+ min_pact_shape = {
+ "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
+ "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
+ "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
+ "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
}
- max_input_shape = {
- "x": [1, 3, 2000, 2000],
- "conv2d_59.tmp_0": [1, 96, 400, 400],
- "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
- "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
- "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
- "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400]
+ max_pact_shape = {
+ "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
+ "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
+ "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
+ "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
}
- opt_input_shape = {
- "x": [1, 3, 640, 640],
- "conv2d_59.tmp_0": [1, 96, 160, 160],
- "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
- "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
- "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
- "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160]
+ opt_pact_shape = {
+ "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
+ "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
+ "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
+ "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
}
+ min_input_shape.update(min_pact_shape)
+ max_input_shape.update(max_pact_shape)
+ opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
@@ -324,10 +259,13 @@ def create_predictor(args, mode, logger):
# enable memory optim
config.enable_memory_optim()
- config.disable_glog_info()
+ #config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
+ if mode == 'table':
+ config.delete_pass("fc_fuse_pass") # not supported for table
config.switch_use_feed_fetch_ops(False)
+ config.switch_ir_optim(True)
# create predictor
predictor = inference.create_predictor(config)
@@ -590,29 +528,39 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
return image
-def get_current_memory_mb(gpu_id=None):
- """
- It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
- And this function Current program is time-consuming.
- """
- import pynvml
- import psutil
- import GPUtil
- pid = os.getpid()
- p = psutil.Process(pid)
- info = p.memory_full_info()
- cpu_mem = info.uss / 1024. / 1024.
- gpu_mem = 0
- gpu_percent = 0
- if gpu_id is not None:
- GPUs = GPUtil.getGPUs()
- gpu_load = GPUs[gpu_id].load
- gpu_percent = gpu_load
- pynvml.nvmlInit()
- handle = pynvml.nvmlDeviceGetHandleByIndex(0)
- meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
- gpu_mem = meminfo.used / 1024. / 1024.
- return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
+def get_rotate_crop_image(img, points):
+ '''
+ img_height, img_width = img.shape[0:2]
+ left = int(np.min(points[:, 0]))
+ right = int(np.max(points[:, 0]))
+ top = int(np.min(points[:, 1]))
+ bottom = int(np.max(points[:, 1]))
+ img_crop = img[top:bottom, left:right, :].copy()
+ points[:, 0] = points[:, 0] - left
+ points[:, 1] = points[:, 1] - top
+ '''
+ assert len(points) == 4, "shape of points must be 4*2"
+ img_crop_width = int(
+ max(
+ np.linalg.norm(points[0] - points[1]),
+ np.linalg.norm(points[2] - points[3])))
+ img_crop_height = int(
+ max(
+ np.linalg.norm(points[0] - points[3]),
+ np.linalg.norm(points[1] - points[2])))
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height]])
+ M = cv2.getPerspectiveTransform(points, pts_std)
+ dst_img = cv2.warpPerspective(
+ img,
+ M, (img_crop_width, img_crop_height),
+ borderMode=cv2.BORDER_REPLICATE,
+ flags=cv2.INTER_CUBIC)
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
+ dst_img = np.rot90(dst_img)
+ return dst_img
if __name__ == '__main__':
diff --git a/tools/infer_det.py b/tools/infer_det.py
index 674f52ee35aab25356ccdbf371f8bac5b52b871a..a964cd28c934504ce79ea4873d3345295c1266e5 100755
--- a/tools/infer_det.py
+++ b/tools/infer_det.py
@@ -112,4 +112,4 @@ def main():
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
- main()
\ No newline at end of file
+ main()
diff --git a/tools/infer_table.py b/tools/infer_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..f743d87540f7fd64157a808db156c9f62a042d9c
--- /dev/null
+++ b/tools/infer_table.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+import json
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import paddle
+from paddle.jit import to_static
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import init_model
+from ppocr.utils.utility import get_image_file_list
+import tools.program as program
+import cv2
+
+def main(config, device, logger, vdl_writer):
+ global_config = config['Global']
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ if hasattr(post_process_class, 'character'):
+ config['Architecture']["Head"]['out_channels'] = len(
+ getattr(post_process_class, 'character'))
+
+ model = build_model(config['Architecture'])
+
+ init_model(config, model, logger)
+
+ # create data ops
+ transforms = []
+ use_padding = False
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ continue
+ if op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = ['image']
+ if op_name == "ResizeTableImage":
+ use_padding = True
+ padding_max_len = op['ResizeTableImage']['max_len']
+ transforms.append(op)
+
+ global_config['infer_mode'] = True
+ ops = create_operators(transforms, global_config)
+
+ model.eval()
+ for file in get_image_file_list(config['Global']['infer_img']):
+ logger.info("infer_img: {}".format(file))
+ with open(file, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, ops)
+ images = np.expand_dims(batch[0], axis=0)
+ images = paddle.to_tensor(images)
+ preds = model(images)
+ post_result = post_process_class(preds)
+ res_html_code = post_result['res_html_code']
+ res_loc = post_result['res_loc']
+ img = cv2.imread(file)
+ imgh, imgw = img.shape[0:2]
+ res_loc_final = []
+ for rno in range(len(res_loc[0])):
+ x0, y0, x1, y1 = res_loc[0][rno]
+ left = max(int(imgw * x0), 0)
+ top = max(int(imgh * y0), 0)
+ right = min(int(imgw * x1), imgw - 1)
+ bottom = min(int(imgh * y1), imgh - 1)
+ cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
+ res_loc_final.append([left, top, right, bottom])
+ res_loc_str = json.dumps(res_loc_final)
+ logger.info("result: {}, {}".format(res_html_code, res_loc_final))
+ logger.info("success!")
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main(config, device, logger, vdl_writer)
+
diff --git a/tools/program.py b/tools/program.py
index 71076a19f7a391bd60ef585e3893617941ac5bb5..aa5f9388064f79f3c360e29c0c93623e708ad5cf 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -186,7 +186,14 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
+
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
+
+ try:
+ model_type = config['Architecture']['model_type']
+ except:
+ model_type = None
+
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
else:
@@ -208,9 +215,9 @@ def train(config,
lr = optimizer.get_lr()
images = batch[0]
if use_srn:
- others = batch[-4:]
- preds = model(images, others)
model_average = True
+ if use_srn or model_type == 'table':
+ preds = model(images, data=batch[1:])
elif use_nrtr:
max_len = batch[2].max()
preds = model(images, batch[1][:,:2+max_len])
@@ -235,8 +242,11 @@ def train(config,
if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
- post_result = post_process_class(preds, batch[1])
- eval_class(post_result, batch)
+ if model_type == 'table':
+ eval_class(preds, batch)
+ else:
+ post_result = post_process_class(preds, batch[1])
+ eval_class(post_result, batch)
metric = eval_class.get_metric()
train_stats.update(metric)
@@ -272,6 +282,7 @@ def train(config,
valid_dataloader,
post_process_class,
eval_class,
+ model_type,
use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
@@ -339,7 +350,11 @@ def train(config,
return
-def eval(model, valid_dataloader, post_process_class, eval_class,
+def eval(model,
+ valid_dataloader,
+ post_process_class,
+ eval_class,
+ model_type,
use_srn=False):
model.eval()
with paddle.no_grad():
@@ -353,17 +368,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break
images = batch[0]
start = time.time()
- if use_srn:
- others = batch[-4:]
- preds = model(images, others)
+
+ if use_srn or model_type == 'table':
+ preds = model(images, data=batch[1:])
else:
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
- post_result = post_process_class(preds, batch[1])
total_time += time.time() - start
# Evaluate the results of the current batch
- eval_class(post_result, batch)
+ if model_type == 'table':
+ eval_class(preds, batch)
+ else:
+ post_result = post_process_class(preds, batch[1])
+ eval_class(post_result, batch)
pbar.update(1)
total_frame += len(images)
# Get final metric,eg. acc or hmean
@@ -387,7 +405,9 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
- 'CLS', 'PGNet', 'Distillation', 'NRTR'
+
+ 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
+
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
diff --git a/tools/train.py b/tools/train.py
index b024240b4d5d4973645336c62d3762087ec7bbeb..05d295aa99718c25b94a123c23d08c2904fe8c6a 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import init_model, load_dygraph_params
import tools.program as program
dist.get_world_size()
@@ -97,8 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
- pre_best_model_dict = init_model(config, model, optimizer)
-
+ pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
|