diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1584bc76a9dd8ddff9d05a8cb693bcbd2e09fcde..b6a299ba4009ba73afc9673fe008b97ca139c57b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,10 +1,11 @@
+repos:
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
- sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
+ rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
- sha: a11d9314b22d8f8c7556443875b731ef05965464
+ rev: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
@@ -15,7 +16,7 @@
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
- sha: v1.0.1
+ rev: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
diff --git a/configs/rec/rec_d28_can.yml b/configs/rec/rec_d28_can.yml
new file mode 100644
index 0000000000000000000000000000000000000000..aeaccb6b0d245cc4be479d17597d0029647db574
--- /dev/null
+++ b/configs/rec/rec_d28_can.yml
@@ -0,0 +1,114 @@
+Global:
+ use_gpu: True
+ epoch_num: 240
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/can/
+ save_epoch_step: 1
+ # evaluation is run every 1105 iterations
+ eval_batch_step: [0, 1105]
+ cal_metric_during_train: True
+ pretrained_model: ./output/rec/can/CAN
+ checkpoints: ./output/rec/can/CAN
+ save_inference_dir: ./inference/rec_d28_can/
+ use_visualdl: False
+ infer_img: doc/imgs_hme/hme_01.jpeg
+ # for data or label process
+ character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
+ max_text_length: 36
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_can.txt
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ clip_norm_global: 100.0
+ lr:
+ name: TwoStepCosine
+ learning_rate: 0.01
+ warmup_epoch: 1
+ weight_decay: 0.0001
+
+Architecture:
+ model_type: rec
+ algorithm: CAN
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: DenseNet
+ growthRate: 24
+ reduction: 0.5
+ bottleneck: True
+ use_dropout: True
+ input_channel: 1
+
+ Head:
+ name: CANHead
+ in_channel: 684
+ out_channel: 111
+ max_text_length: 36
+ ratio: 16
+ attdecoder:
+ is_train: True
+ input_size: 256
+ hidden_size: 256
+ encoder_out_channel: 684
+ dropout: True
+ dropout_ratio: 0.5
+ word_num: 111
+ counting_decoder_out_channel: 111
+ attention:
+ attention_dim: 512
+ word_conv_kernel: 1
+
+Loss:
+ name: CANLoss
+
+PostProcess:
+ name: SeqLabelDecode
+ character: 111
+
+Metric:
+ name: CANMetric
+ main_indicator: exp_rate
+
+Train:
+ dataset:
+ name: HMERDataSet
+ data_dir: ./train_data/CROHME/training/images/
+ transforms:
+ - DecodeImage:
+ channel_first: False
+ - GrayImageChannelFormat:
+ normalize: True
+ inverse: True
+ - KeepKeys:
+ keep_keys: ['image', 'label']
+ label_file_list: ["./train_data/CROHME/training/labels.json"]
+ loader:
+ shuffle: True
+ batch_size_per_card: 2
+ drop_last: True
+ num_workers: 1
+ collate_fn: DyMaskCollator
+
+Eval:
+ dataset:
+ name: HMERDataSet
+ data_dir: ./train_data/CROHME/evaluation/images/
+ transforms:
+ - DecodeImage:
+ channel_first: False
+ - GrayImageChannelFormat:
+ normalize: True
+ inverse: True
+ - KeepKeys:
+ keep_keys: ['image', 'label']
+ label_file_list: ["./train_data/CROHME/evaluation/labels.json"]
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 4
+ collate_fn: DyMaskCollator
diff --git a/doc/doc_ch/algorithm_rec_can.md b/doc/doc_ch/algorithm_rec_can.md
new file mode 100644
index 0000000000000000000000000000000000000000..9585dae0cf9ecc215b4ff2f6d656345418e197d4
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_can.md
@@ -0,0 +1,170 @@
+# 手写数学公式识别算法-ABINet
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463)
+> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai
+> ECCV, 2022
+
+
+
+`CAN`使用CROHME手写公式数据集进行训练,在对应测试集上的精度如下:
+
+|模型 |骨干网络|配置文件|ExpRate|下载链接|
+| ----- | ----- | ----- | ----- | ----- |
+|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar)|
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+
+### 3.1 模型训练
+
+请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`CAN`识别模型时需要**更换配置文件**为`CAN`的[配置文件](../../configs/rec/rec_d28_can.yml)。
+
+#### 启动训练
+
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+```shell
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_d28_can.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml
+```
+
+**注意:**
+- 我们提供的数据集,即`CROHME数据集`将手写公式存储为黑底白字的格式,若您自行准备的数据集与之相反,即以白底黑字模式存储,请在训练时做出如下修改
+```
+python3 tools/train.py -c configs/rec/rec_d28_can.yml
+-o Train.dataset.transforms.GrayImageChannelFormat.inverse=False
+```
+
+#
+
+### 3.2 评估
+
+可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/best_accuracy
+```
+
+
+### 3.3 预测
+
+使用如下命令进行单张图片预测:
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/imgs_hme/hme_01.jpg' Global.pretrained_model=./rec_d28_can_train/best_accuracy
+
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_hme/'。
+```
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_d28_can_train.tar) ),可以使用如下命令进行转换:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
+
+# 目前的静态图模型默认的输出长度最大为36,如果您需要预测更长的序列,请在导出模型时指定其输出序列为合适的值,例如 Architecture.Head.max_text_length=72
+```
+**注意:**
+- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
+- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应ABINet的`infer_shape`。
+
+转换成功后,在目录下有三个文件:
+```
+/inference/rec_d28_can/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+执行如下命令进行模型推理:
+
+```shell
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_hme/hme_01.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
+
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_hme/'。
+
+# 如果您需要在白底黑字的图片上进行预测,请设置 --rec_image_inverse=False
+```
+
+![测试图片样例](../imgs_hme/hme_00.jpg)
+
+执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下:
+```shell
+Predicts of ./doc/imgs_hme/hme_03.jpg:['x _ { k } x x _ { k } + y _ { k } y x _ { k }', []]
+```
+
+
+**注意**:
+
+- 需要注意预测图像为**黑底白字**,即手写公式部分为白色,背景为黑色的图片。
+- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
+- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中CAN的预处理为您的预处理方法。
+
+
+
+### 4.2 C++推理部署
+
+由于C++预处理后处理还未支持ABINet,所以暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+1. CROHME数据集来自于[CAN源repo](https://github.com/LBH1024/CAN) 。
+
+## 引用
+
+```bibtex
+@misc{https://doi.org/10.48550/arxiv.2207.11463,
+ doi = {10.48550/ARXIV.2207.11463},
+ url = {https://arxiv.org/abs/2207.11463},
+ author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang},
+ keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
+ title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition},
+ publisher = {arXiv},
+ year = {2022},
+ copyright = {arXiv.org perpetual, non-exclusive license}
+}
+```
diff --git a/doc/doc_en/algorithm_rec_can_en.md b/doc/doc_en/algorithm_rec_can_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..f2bc645af0d2bea629e6dfe78f6c49f51a3af73d
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_can_en.md
@@ -0,0 +1,115 @@
+# RobustScanner
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463)
+> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai
+> ECCV, 2022
+
+Using CROHME handwrittem mathematical expression recognition datasets for training, and evaluating on its test sets, the algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|exprate|Download link|
+| --- | --- | --- | --- | --- |
+|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|coming soon|
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_d28_can.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/imgs_hme/hme_01.jpg' Global.pretrained_model=./rec_d28_can_train/best_accuracy
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
+```
+
+For RobustScanner text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_hme/hme_01.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_image_shape="1, 132, 519" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@misc{https://doi.org/10.48550/arxiv.2207.11463,
+ doi = {10.48550/ARXIV.2207.11463},
+ url = {https://arxiv.org/abs/2207.11463},
+ author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang},
+ keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
+ title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition},
+ publisher = {arXiv},
+ year = {2022},
+ copyright = {arXiv.org perpetual, non-exclusive license}
+}
+```
diff --git a/doc/imgs_hme/hme_00.jpg b/doc/imgs_hme/hme_00.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..66ff27db266b5d4fa05d8acd95ba881bb8a1aec0
Binary files /dev/null and b/doc/imgs_hme/hme_00.jpg differ
diff --git a/doc/imgs_hme/hme_01.jpg b/doc/imgs_hme/hme_01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..68b7f09fc2f330ee523ded27a14486b3c92763cb
Binary files /dev/null and b/doc/imgs_hme/hme_01.jpg differ
diff --git a/doc/imgs_hme/hme_02.jpg b/doc/imgs_hme/hme_02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ecc760f5382bfe3d94de6141379f6a5a196e8430
Binary files /dev/null and b/doc/imgs_hme/hme_02.jpg differ
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index b602a346dbe4b0d45af287f25f05ead0f62daf44..1f3de63de72f44a7daf8d641b2d5c6bc5568df8f 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -37,6 +37,7 @@ from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
+from ppocr.data.hmer_dataset import HMERDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -55,7 +56,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
- 'LMDBDataSetSR'
+ 'LMDBDataSetSR', 'HMERDataSet'
]
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
diff --git a/ppocr/data/collate_fn.py b/ppocr/data/collate_fn.py
index 0da6060f042a0e60cdf211d8bc13aede32d5930a..fec1e895ff033cd991ed3a927ad7e158989bab45 100644
--- a/ppocr/data/collate_fn.py
+++ b/ppocr/data/collate_fn.py
@@ -70,3 +70,49 @@ class SSLRotateCollate(object):
def __call__(self, batch):
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
return output
+
+
+class DyMaskCollator(object):
+ """
+ batch: [
+ image [batch_size, channel, maxHinbatch, maxWinbatch]
+ image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
+ label [batch_size, maxLabelLen]
+ label_mask [batch_size, maxLabelLen]
+ ...
+ ]
+ """
+
+ def __call__(self, batch):
+ max_width, max_height, max_length = 0, 0, 0
+ bs, channel = len(batch), batch[0][0].shape[0]
+ proper_items = []
+ for item in batch:
+ if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
+ 2] * max_height > 1600 * 320:
+ continue
+ max_height = item[0].shape[1] if item[0].shape[
+ 1] > max_height else max_height
+ max_width = item[0].shape[2] if item[0].shape[
+ 2] > max_width else max_width
+ max_length = item[1].shape[0] if item[1].shape[
+ 0] > max_length else max_length
+ proper_items.append(item)
+
+ images, image_masks = np.zeros(
+ (len(proper_items), channel, max_height, max_width),
+ dtype='float32'), np.zeros(
+ (len(proper_items), 1, max_height, max_width), dtype='float32')
+ labels, label_masks = np.zeros(
+ (len(proper_items), max_length), dtype='int64'), np.zeros(
+ (len(proper_items), max_length), dtype='int64')
+
+ for i in range(len(proper_items)):
+ _, h, w = proper_items[i][0].shape
+ images[i][:, :h, :w] = proper_items[i][0]
+ image_masks[i][:, :h, :w] = 1
+ l = proper_items[i][1].shape[0]
+ labels[i][:l] = proper_items[i][1]
+ label_masks[i][:l] = 1
+
+ return images, image_masks, labels, label_masks
diff --git a/ppocr/data/hmer_dataset.py b/ppocr/data/hmer_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5d92f264b2604d68929bf08e8514c1a87b9198b
--- /dev/null
+++ b/ppocr/data/hmer_dataset.py
@@ -0,0 +1,99 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os, json, random, traceback
+import numpy as np
+
+from PIL import Image
+from paddle.io import Dataset
+
+from .imaug import transform, create_operators
+
+
+class HMERDataSet(Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ super(HMERDataSet, self).__init__()
+
+ self.logger = logger
+ self.seed = seed
+ self.mode = mode
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ self.data_dir = config[mode]['dataset']['data_dir']
+
+ label_file_list = dataset_config['label_file_list']
+ data_source_num = len(label_file_list)
+ ratio_list = dataset_config.get("ratio_list", [1.0])
+
+ self.data_lines, self.labels = self.get_image_info_list(label_file_list,
+ ratio_list)
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+ if self.mode == "train" and self.do_shuffle:
+ self.shuffle_data_random()
+
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+
+ assert len(
+ ratio_list
+ ) == data_source_num, "The length of ratio_list should be the same as the file_list."
+
+ self.ops = create_operators(dataset_config['transforms'], global_config)
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
+ def get_image_info_list(self, file_list, ratio_list):
+ if isinstance(file_list, str):
+ file_list = [file_list]
+ labels = {}
+ for idx, file in enumerate(file_list):
+ with open(file, "r") as f:
+ lines = json.load(f)
+ labels.update(lines)
+ data_lines = [name for name in labels.keys()]
+ return data_lines, labels
+
+ def shuffle_data_random(self):
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
+ return
+
+ def __len__(self):
+ return len(self.data_idx_order_list)
+
+ def __getitem__(self, idx):
+ file_idx = self.data_idx_order_list[idx]
+ data_name = self.data_lines[file_idx]
+ try:
+ file_name = data_name + '.jpg'
+ img_path = os.path.join(self.data_dir, file_name)
+ if not os.path.exists(img_path):
+ raise Exception("{} does not exist!".format(img_path))
+ with open(img_path, 'rb') as f:
+ img = f.read()
+
+ label = self.labels.get(data_name).split()
+ label = np.array([int(item) for item in label])
+
+ data = {'image': img, 'label': label}
+ outs = transform(data, self.ops)
+ except:
+ self.logger.error(
+ "When parsing line {}, error happened with msg: {}".format(
+ file_name, traceback.format_exc()))
+ outs = None
+ if outs is None:
+ # during evaluation, we should fix the idx to get same results for many times of evaluation.
+ rnd_idx = np.random.randint(self.__len__())
+ return self.__getitem__(rnd_idx)
+ return outs
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 93d97446d44070b9c10064fbe10b0b5e05628a6a..a64092286f4cd36ddac5503fb00ebaa43b083a83 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
- RFLRecResizeImg
+ RFLRecResizeImg, GrayImageChannelFormat
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index e22153bdeab06565feed79715633172a275aecc7..bc7fbc6046c0c773f2f666eed8a661f1f6d8f89a 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -465,6 +465,36 @@ class RobustScannerRecResizeImg(object):
return data
+class GrayImageChannelFormat(object):
+ """
+ format gray scale image's channel: (3,h,w) -> (1,h,w)
+ Args:
+ normalize: True/False
+ when True convert image dynamic range [0,255]->[0,1]
+ inverse: inverse gray image
+ """
+
+ def __init__(self, normalize=True, inverse=False, **kwargs):
+ self.normalize = normalize
+ self.inverse = inverse
+
+ def __call__(self, data):
+ img = data['image']
+ img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ img_single_channel = np.expand_dims(img_single_channel, 0)
+
+ if self.normalize:
+ img_single_channel = img_single_channel / 255.0
+
+ if self.inverse:
+ data['image'] = np.abs(img_single_channel - 1).astype('float32')
+ else:
+ data['image'] = img_single_channel.astype('float32')
+
+ data['src_image'] = img
+ return data
+
+
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 6abaa408b3f6995a0b4c377206e8a1551b48c56b..6a34dd1c87a92db3618bfaeed6aeee3cb29b261a 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -40,6 +40,7 @@ from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss
+from .rec_can_loss import CANLoss
# cls loss
from .cls_loss import ClsLoss
@@ -71,7 +72,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
- 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss'
+ 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_can_loss.py b/ppocr/losses/rec_can_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6c655e0e64ac17aabc853b7e70d31ad0cf059a6
--- /dev/null
+++ b/ppocr/losses/rec_can_loss.py
@@ -0,0 +1,61 @@
+import paddle
+import paddle.nn as nn
+import numpy as np
+
+
+class CANLoss(nn.Layer):
+ '''
+ CANLoss is consist of two part:
+ word_average_loss: average accuracy of the symbol
+ counting_loss: counting loss of every symbol
+ '''
+
+ def __init__(self):
+ super(CANLoss, self).__init__()
+
+ self.use_label_mask = False
+ self.out_channel = 111
+ self.cross = nn.CrossEntropyLoss(
+ reduction='none') if self.use_label_mask else nn.CrossEntropyLoss()
+ self.counting_loss = nn.SmoothL1Loss(reduction='mean')
+ self.ratio = 16
+
+ def forward(self, preds, batch):
+ word_probs = preds[0]
+ counting_preds = preds[1]
+ counting_preds1 = preds[2]
+ counting_preds2 = preds[3]
+ labels = batch[2]
+ labels_mask = batch[3]
+ counting_labels = gen_counting_label(labels, self.out_channel, True)
+ counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \
+ + self.counting_loss(counting_preds, counting_labels)
+
+ word_loss = self.cross(
+ paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
+ paddle.reshape(labels, [-1]))
+ word_average_loss = paddle.sum(
+ paddle.reshape(word_loss * labels_mask, [-1])) / (
+ paddle.sum(labels_mask) + 1e-10
+ ) if self.use_label_mask else word_loss
+ loss = word_average_loss + counting_loss
+ return {'loss': loss}
+
+
+def gen_counting_label(labels, channel, tag):
+ b, t = labels.shape
+ counting_labels = np.zeros([b, channel])
+
+ if tag:
+ ignore = [0, 1, 107, 108, 109, 110]
+ else:
+ ignore = []
+ for i in range(b):
+ for j in range(t):
+ k = labels[i][j]
+ if k in ignore:
+ continue
+ else:
+ counting_labels[i][k] += 1
+ counting_labels = paddle.to_tensor(counting_labels, dtype='float32')
+ return counting_labels
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 20aea8b5995a49306d427bc427048c9df8d0923d..5e840a194adc2683e92c308f232dc869df34de8e 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -22,7 +22,7 @@ import copy
__all__ = ["build_metric"]
from .det_metric import DetMetric, DetFCEMetric
-from .rec_metric import RecMetric, CNTMetric
+from .rec_metric import RecMetric, CNTMetric, CANMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
@@ -38,7 +38,7 @@ def build_metric(config):
support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
- 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric'
+ 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric', 'CANMetric'
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index 4758e71d0930261044841a6a820308a04391fc0b..305b913c72da5842b6654f1fc9b27e6e2b46b436 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -13,6 +13,9 @@
# limitations under the License.
from rapidfuzz.distance import Levenshtein
+from difflib import SequenceMatcher
+
+import numpy as np
import string
@@ -106,3 +109,71 @@ class CNTMetric(object):
def reset(self):
self.correct_num = 0
self.all_num = 0
+
+
+class CANMetric(object):
+ def __init__(self, main_indicator='exp_rate', **kwargs):
+ self.main_indicator = main_indicator
+ self.word_right = []
+ self.exp_right = []
+ self.word_total_length = 0
+ self.exp_total_num = 0
+ self.word_rate = 0
+ self.exp_rate = 0
+ self.reset()
+ self.epoch_reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ for k, v in kwargs.items():
+ epoch_reset = v
+ if epoch_reset:
+ self.epoch_reset()
+ word_probs = preds
+ word_label, word_label_mask = batch
+ line_right = 0
+ if word_probs is not None:
+ word_pred = word_probs.argmax(2)
+ word_pred = word_pred.cpu().detach().numpy()
+ word_scores = [
+ SequenceMatcher(
+ None,
+ s1[:int(np.sum(s3))],
+ s2[:int(np.sum(s3))],
+ autojunk=False).ratio() * (
+ len(s1[:int(np.sum(s3))]) + len(s2[:int(np.sum(s3))])) /
+ len(s1[:int(np.sum(s3))]) / 2
+ for s1, s2, s3 in zip(word_label, word_pred, word_label_mask)
+ ]
+ batch_size = len(word_scores)
+ for i in range(batch_size):
+ if word_scores[i] == 1:
+ line_right += 1
+ self.word_rate = np.mean(word_scores) #float
+ self.exp_rate = line_right / batch_size #float
+ exp_length, word_length = word_label.shape[:2]
+ self.word_right.append(self.word_rate * word_length)
+ self.exp_right.append(self.exp_rate * exp_length)
+ self.word_total_length = self.word_total_length + word_length
+ self.exp_total_num = self.exp_total_num + exp_length
+
+ def get_metric(self):
+ """
+ return {
+ 'word_rate': 0,
+ "exp_rate": 0,
+ }
+ """
+ cur_word_rate = sum(self.word_right) / self.word_total_length
+ cur_exp_rate = sum(self.exp_right) / self.exp_total_num
+ self.reset()
+ return {'word_rate': cur_word_rate, "exp_rate": cur_exp_rate}
+
+ def reset(self):
+ self.word_rate = 0
+ self.exp_rate = 0
+
+ def epoch_reset(self):
+ self.word_right = []
+ self.exp_right = []
+ self.word_total_length = 0
+ self.exp_total_num = 0
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 84892fa9c7fd61838e984690b17931f367ab0585..e2c2e9c4a4ed526b36d512d824ae8a8a701c17bc 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -43,10 +43,12 @@ def build_backbone(config, model_type):
from .rec_svtrnet import SVTRNet
from .rec_vitstr import ViTSTR
from .rec_resnet_rfl import ResNetRFL
+ from .rec_densenet import DenseNet
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
- 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL'
+ 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
+ 'DenseNet'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_densenet.py b/ppocr/modeling/backbones/rec_densenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3391d4086857bc02bca20c1d0b9279172e6bc44
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_densenet.py
@@ -0,0 +1,135 @@
+import math
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class Bottleneck(nn.Layer):
+ '''
+ ratio: 16
+ growthRate: 24
+ reduction: 0.5
+ bottleneck: True
+ use_dropout: True
+ '''
+
+ def __init__(self, nChannels, growthRate, use_dropout):
+ super(Bottleneck, self).__init__()
+ interChannels = 4 * growthRate
+ self.bn1 = nn.BatchNorm2D(interChannels)
+ self.conv1 = nn.Conv2D(
+ nChannels, interChannels, kernel_size=1,
+ bias_attr=None) # Xavier initialization
+ self.bn2 = nn.BatchNorm2D(growthRate)
+ self.conv2 = nn.Conv2D(
+ interChannels, growthRate, kernel_size=3, padding=1,
+ bias_attr=None) # Xavier initialization
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = F.relu(self.bn2(self.conv2(out)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = paddle.concat([x, out], 1)
+ return out
+
+
+class SingleLayer(nn.Layer):
+ def __init__(self, nChannels, growthRate, use_dropout):
+ super(SingleLayer, self).__init__()
+ self.bn1 = nn.BatchNorm2D(nChannels)
+ self.conv1 = nn.Conv2D(
+ nChannels, growthRate, kernel_size=3, padding=1, bias_attr=False)
+
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def forward(self, x):
+ out = self.conv1(F.relu(x))
+ if self.use_dropout:
+ out = self.dropout(out)
+
+ out = paddle.concat([x, out], 1)
+ return out
+
+
+class Transition(nn.Layer):
+ def __init__(self, nChannels, out_channels, use_dropout):
+ super(Transition, self).__init__()
+ self.bn1 = nn.BatchNorm2D(out_channels)
+ self.conv1 = nn.Conv2D(
+ nChannels, out_channels, kernel_size=1, bias_attr=False)
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = F.avg_pool2d(out, 2, ceil_mode=True, exclusive=False)
+ return out
+
+
+class DenseNet(nn.Layer):
+ def __init__(self, growthRate, reduction, bottleneck, use_dropout,
+ input_channel, **kwargs):
+ super(DenseNet, self).__init__()
+ '''
+ ratio: 16
+ growthRate: 24
+ reduction: 0.5
+ '''
+ nDenseBlocks = 16
+ nChannels = 2 * growthRate
+
+ self.conv1 = nn.Conv2D(
+ input_channel,
+ nChannels,
+ kernel_size=7,
+ padding=3,
+ stride=2,
+ bias_attr=False)
+ self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks,
+ bottleneck, use_dropout)
+ nChannels += nDenseBlocks * growthRate
+ out_channels = int(math.floor(nChannels * reduction))
+ self.trans1 = Transition(nChannels, out_channels, use_dropout)
+
+ nChannels = out_channels
+ self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks,
+ bottleneck, use_dropout)
+ nChannels += nDenseBlocks * growthRate
+ out_channels = int(math.floor(nChannels * reduction))
+ self.trans2 = Transition(nChannels, out_channels, use_dropout)
+
+ nChannels = out_channels
+ self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks,
+ bottleneck, use_dropout)
+ self.out_channels = out_channels
+
+ def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,
+ use_dropout):
+ layers = []
+ for i in range(int(nDenseBlocks)):
+ if bottleneck:
+ layers.append(Bottleneck(nChannels, growthRate, use_dropout))
+ else:
+ layers.append(SingleLayer(nChannels, growthRate, use_dropout))
+ nChannels += growthRate
+ return nn.Sequential(*layers)
+
+ def forward(self, inputs):
+ x, x_m, y = inputs
+ out = self.conv1(x)
+ out = F.relu(out)
+ out = F.max_pool2d(out, 2, ceil_mode=True)
+ out = self.dense1(out)
+ out = self.trans1(out)
+ out = self.dense2(out)
+ out = self.trans2(out)
+ out = self.dense3(out)
+ return out, x_m, y
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 63002140c5be4bd7e32b56995c6410ecc8a0fa36..fdf5a8a96d587c66ac9b26ff4e1264ab9c8173f3 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -40,6 +40,7 @@ def build_head(config):
from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VLHead
from .rec_rfl_head import RFLHead
+ from .rec_can_head import CANHead
# cls head
from .cls_head import ClsHead
@@ -56,7 +57,7 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
- 'DRRGHead'
+ 'DRRGHead', 'CANHead'
]
#table head
diff --git a/ppocr/modeling/heads/rec_can_head.py b/ppocr/modeling/heads/rec_can_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..afd78ee9d277dde7838b96938f4b2221b690c5d2
--- /dev/null
+++ b/ppocr/modeling/heads/rec_can_head.py
@@ -0,0 +1,294 @@
+from turtle import forward
+import paddle.nn as nn
+import paddle
+import math
+'''
+Counting Module
+'''
+
+
+class ChannelAtt(nn.Layer):
+ def __init__(self, channel, reduction):
+ super(ChannelAtt, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2D(1)
+
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction),
+ nn.ReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid())
+
+ def forward(self, x):
+ b, c, _, _ = x.shape
+ y = paddle.reshape(self.avg_pool(x), [b, c])
+ y = paddle.reshape(self.fc(y), [b, c, 1, 1])
+ return x * y
+
+
+class CountingDecoder(nn.Layer):
+ def __init__(self, in_channel, out_channel, kernel_size):
+ super(CountingDecoder, self).__init__()
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ self.trans_layer = nn.Sequential(
+ nn.Conv2D(
+ self.in_channel,
+ 512,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ bias_attr=False),
+ nn.BatchNorm2D(512))
+
+ self.channel_att = ChannelAtt(512, 16)
+
+ self.pred_layer = nn.Sequential(
+ nn.Conv2D(
+ 512, self.out_channel, kernel_size=1, bias_attr=False),
+ nn.Sigmoid())
+
+ def forward(self, x, mask):
+ b, _, h, w = x.shape
+ x = self.trans_layer(x)
+ x = self.channel_att(x)
+ x = self.pred_layer(x)
+
+ if mask is not None:
+ x = x * mask
+ x = paddle.reshape(x, [b, self.out_channel, -1])
+ x1 = paddle.sum(x, axis=-1)
+
+ return x1, paddle.reshape(x, [b, self.out_channel, h, w])
+
+
+'''
+Attention Decoder
+'''
+
+
+class PositionEmbeddingSine(nn.Layer):
+ def __init__(self,
+ num_pos_feats=64,
+ temperature=10000,
+ normalize=False,
+ scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask):
+ y_embed = paddle.cumsum(mask, 1, dtype='float32')
+ x_embed = paddle.cumsum(mask, 2, dtype='float32')
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+ dim_t = paddle.arange(self.num_pos_feats, dtype='float32')
+ dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape)
+ dim_t = self.temperature**(2 * (dim_t / dim_d).astype('int64') /
+ self.num_pos_feats)
+
+ pos_x = paddle.unsqueeze(x_embed, [3]) / dim_t
+ pos_y = paddle.unsqueeze(y_embed, [3]) / dim_t
+
+ pos_x = paddle.flatten(
+ paddle.stack(
+ [
+ paddle.sin(pos_x[:, :, :, 0::2]),
+ paddle.cos(pos_x[:, :, :, 1::2])
+ ],
+ axis=4),
+ 3)
+ pos_y = paddle.flatten(
+ paddle.stack(
+ [
+ paddle.sin(pos_y[:, :, :, 0::2]),
+ paddle.cos(pos_y[:, :, :, 1::2])
+ ],
+ axis=4),
+ 3)
+
+ pos = paddle.transpose(
+ paddle.concat(
+ [pos_y, pos_x], axis=3), [0, 3, 1, 2])
+
+ return pos
+
+
+class AttDecoder(nn.Layer):
+ def __init__(self, ratio, is_train, input_size, hidden_size,
+ encoder_out_channel, dropout, dropout_ratio, word_num,
+ counting_decoder_out_channel, attention):
+ super(AttDecoder, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.out_channel = encoder_out_channel
+ self.attention_dim = attention['attention_dim']
+ self.dropout_prob = dropout
+ self.ratio = ratio
+ self.word_num = word_num
+
+ self.counting_num = counting_decoder_out_channel
+ self.is_train = is_train
+
+ self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
+ self.embedding = nn.Embedding(self.word_num, self.input_size)
+ self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
+ self.word_attention = Attention(hidden_size, attention['attention_dim'])
+
+ self.encoder_feature_conv = nn.Conv2D(
+ self.out_channel,
+ self.attention_dim,
+ kernel_size=attention['word_conv_kernel'],
+ padding=attention['word_conv_kernel'] // 2)
+
+ self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
+ self.word_embedding_weight = nn.Linear(self.input_size,
+ self.hidden_size)
+ self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
+ self.counting_context_weight = nn.Linear(self.counting_num,
+ self.hidden_size)
+ self.word_convert = nn.Linear(self.hidden_size, self.word_num)
+
+ if dropout:
+ self.dropout = nn.Dropout(dropout_ratio)
+
+ def forward(self, cnn_features, labels, counting_preds, images_mask):
+ if self.is_train:
+ _, num_steps = labels.shape
+ else:
+ num_steps = 36
+
+ batch_size, _, height, width = cnn_features.shape
+ images_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+
+ word_probs = paddle.zeros((batch_size, num_steps, self.word_num))
+ word_alpha_sum = paddle.zeros((batch_size, 1, height, width))
+
+ hidden = self.init_hidden(cnn_features, images_mask)
+ counting_context_weighted = self.counting_context_weight(counting_preds)
+ cnn_features_trans = self.encoder_feature_conv(cnn_features)
+
+ position_embedding = PositionEmbeddingSine(256, normalize=True)
+ pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])
+
+ cnn_features_trans = cnn_features_trans + pos
+
+ word = paddle.ones([batch_size, 1], dtype='int64') # init word as sos
+ word = word.squeeze(axis=1)
+ for i in range(num_steps):
+ word_embedding = self.embedding(word)
+ _, hidden = self.word_input_gru(word_embedding, hidden)
+ word_context_vec, _, word_alpha_sum = self.word_attention(
+ cnn_features, cnn_features_trans, hidden, word_alpha_sum,
+ images_mask)
+
+ current_state = self.word_state_weight(hidden)
+ word_weighted_embedding = self.word_embedding_weight(word_embedding)
+ word_context_weighted = self.word_context_weight(word_context_vec)
+
+ if self.dropout_prob:
+ word_out_state = self.dropout(
+ current_state + word_weighted_embedding +
+ word_context_weighted + counting_context_weighted)
+ else:
+ word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
+
+ word_prob = self.word_convert(word_out_state)
+ word_probs[:, i] = word_prob
+
+ if self.is_train:
+ word = labels[:, i]
+ else:
+ word = word_prob.argmax(1)
+ word = paddle.multiply(
+ word, labels[:, i]
+ ) # labels are oneslike tensor in infer/predict mode
+
+ return word_probs
+
+ def init_hidden(self, features, feature_mask):
+ average = paddle.sum(paddle.sum(features * feature_mask, axis=-1),
+ axis=-1) / paddle.sum(
+ (paddle.sum(feature_mask, axis=-1)), axis=-1)
+ average = self.init_weight(average)
+ return paddle.tanh(average)
+
+
+'''
+Attention Module
+'''
+
+
+class Attention(nn.Layer):
+ def __init__(self, hidden_size, attention_dim):
+ super(Attention, self).__init__()
+ self.hidden = hidden_size
+ self.attention_dim = attention_dim
+ self.hidden_weight = nn.Linear(self.hidden, self.attention_dim)
+ self.attention_conv = nn.Conv2D(
+ 1, 512, kernel_size=11, padding=5, bias_attr=False)
+ self.attention_weight = nn.Linear(
+ 512, self.attention_dim, bias_attr=False)
+ self.alpha_convert = nn.Linear(self.attention_dim, 1)
+
+ def forward(self,
+ cnn_features,
+ cnn_features_trans,
+ hidden,
+ alpha_sum,
+ image_mask=None):
+ query = self.hidden_weight(hidden)
+ alpha_sum_trans = self.attention_conv(alpha_sum)
+ coverage_alpha = self.attention_weight(
+ paddle.transpose(alpha_sum_trans, [0, 2, 3, 1]))
+ alpha_score = paddle.tanh(
+ paddle.unsqueeze(query, [1, 2]) + coverage_alpha + paddle.transpose(
+ cnn_features_trans, [0, 2, 3, 1]))
+ energy = self.alpha_convert(alpha_score)
+ energy = energy - energy.max()
+ energy_exp = paddle.exp(paddle.squeeze(energy, -1))
+
+ if image_mask is not None:
+ energy_exp = energy_exp * paddle.squeeze(image_mask, 1)
+ alpha = energy_exp / (paddle.unsqueeze(
+ paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10)
+ alpha_sum = paddle.unsqueeze(alpha, 1) + alpha_sum
+ context_vector = paddle.sum(
+ paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1)
+
+ return context_vector, alpha, alpha_sum
+
+
+class CANHead(nn.Layer):
+ def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
+ super(CANHead, self).__init__()
+
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ self.counting_decoder1 = CountingDecoder(self.in_channel,
+ self.out_channel, 3) # mscm
+ self.counting_decoder2 = CountingDecoder(self.in_channel,
+ self.out_channel, 5)
+
+ self.decoder = AttDecoder(ratio, **attdecoder)
+
+ self.ratio = ratio
+
+ def forward(self, inputs, targets=None):
+ cnn_features, images_mask, labels = inputs
+
+ counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
+ counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
+ counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
+ counting_preds = (counting_preds1 + counting_preds2) / 2
+
+ word_probs = self.decoder(cnn_features, labels, counting_preds,
+ images_mask)
+ return word_probs, counting_preds, counting_preds1, counting_preds2
diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py
index 7d45109b4857871f52764c64d6d32e5322fc7c57..be52a918458d64f0ae15b52ebf511e5068184f59 100644
--- a/ppocr/optimizer/learning_rate.py
+++ b/ppocr/optimizer/learning_rate.py
@@ -18,7 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals
from paddle.optimizer import lr
-from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay
+from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay, TwoStepCosineDecay
class Linear(object):
@@ -386,3 +386,44 @@ class MultiStepDecay(object):
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
+
+
+class TwoStepCosine(object):
+ """
+ Cosine learning rate decay
+ lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
+ Args:
+ lr(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ epochs(int): total training epochs
+ last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
+ """
+
+ def __init__(self,
+ learning_rate,
+ step_each_epoch,
+ epochs,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(TwoStepCosine, self).__init__()
+ self.learning_rate = learning_rate
+ self.T_max1 = step_each_epoch * 200
+ self.T_max2 = step_each_epoch * epochs
+ self.last_epoch = last_epoch
+ self.warmup_epoch = round(warmup_epoch * step_each_epoch)
+
+ def __call__(self):
+ learning_rate = TwoStepCosineDecay(
+ learning_rate=self.learning_rate,
+ T_max1=self.T_max1,
+ T_max2=self.T_max2,
+ last_epoch=self.last_epoch)
+ if self.warmup_epoch > 0:
+ learning_rate = lr.LinearWarmup(
+ learning_rate=learning_rate,
+ warmup_steps=self.warmup_epoch,
+ start_lr=0.0,
+ end_lr=self.learning_rate,
+ last_epoch=self.last_epoch)
+ return learning_rate
diff --git a/ppocr/optimizer/lr_scheduler.py b/ppocr/optimizer/lr_scheduler.py
index f62f1f3b0adbd8df0e03a66faa4565f2f7df28bc..cd09367e2ab8a649e3c375698f5b182eb5c3ff7a 100644
--- a/ppocr/optimizer/lr_scheduler.py
+++ b/ppocr/optimizer/lr_scheduler.py
@@ -160,3 +160,63 @@ class OneCycleDecay(LRScheduler):
start_step = phase['end_step']
return computed_lr
+
+
+class TwoStepCosineDecay(LRScheduler):
+ def __init__(self,
+ learning_rate,
+ T_max1,
+ T_max2,
+ eta_min=0,
+ last_epoch=-1,
+ verbose=False):
+ if not isinstance(T_max1, int):
+ raise TypeError(
+ "The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s."
+ % type(T_max1))
+ if not isinstance(T_max2, int):
+ raise TypeError(
+ "The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s."
+ % type(T_max2))
+ if not isinstance(eta_min, (float, int)):
+ raise TypeError(
+ "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
+ % type(eta_min))
+ assert T_max1 > 0 and isinstance(
+ T_max1, int), " 'T_max1' must be a positive integer."
+ assert T_max2 > 0 and isinstance(
+ T_max2, int), " 'T_max1' must be a positive integer."
+ self.T_max1 = T_max1
+ self.T_max2 = T_max2
+ self.eta_min = float(eta_min)
+ super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch,
+ verbose)
+
+ def get_lr(self):
+
+ if self.last_epoch <= self.T_max1:
+ if self.last_epoch == 0:
+ return self.base_lr
+ elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0:
+ return self.last_lr + (self.base_lr - self.eta_min) * (
+ 1 - math.cos(math.pi / self.T_max1)) / 2
+
+ return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / (
+ 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)) * (
+ self.last_lr - self.eta_min) + self.eta_min
+ else:
+ if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0:
+ return self.last_lr + (self.base_lr - self.eta_min) * (
+ 1 - math.cos(math.pi / self.T_max2)) / 2
+
+ return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / (
+ 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)) * (
+ self.last_lr - self.eta_min) + self.eta_min
+
+ def _get_closed_form_lr(self):
+ if self.last_epoch <= self.T_max1:
+ return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
+ math.pi * self.last_epoch / self.T_max1)) / 2
+ else:
+ return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
+ math.pi * self.last_epoch / self.T_max2)) / 2
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 3a09030b25461029d9160699dc591eaedab9e0db..e86a7ea70c5271482774d5c7f14a8d649184804f 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -37,6 +37,7 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess
from .ct_postprocess import CTPostProcess
from .drrg_postprocess import DRRGPostprocess
+from .rec_postprocess import SeqLabelDecode
def build_post_process(config, global_config=None):
@@ -51,7 +52,7 @@ def build_post_process(config, global_config=None):
'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
- 'RFLLabelDecode', 'DRRGPostprocess'
+ 'RFLLabelDecode', 'DRRGPostprocess', 'SeqLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 59b5254e480e7c52aca4ce648c379a280683db4f..4d88c278ec8ce02931478cc62492ab1a6f360594 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -897,3 +897,36 @@ class VLLabelDecode(BaseRecLabelDecode):
return text
label = self.decode(label)
return text, label
+
+
+class SeqLabelDecode(BaseRecLabelDecode):
+ """ Convert between latex-symbol and symbol-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(SeqLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def decode(self, text_index, preds_prob=None):
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ seq_end = text_index[batch_idx].argmin(0)
+ idx_list = text_index[batch_idx][:seq_end].tolist()
+ symbol_list = [self.character[idx] for idx in idx_list]
+ probs = []
+ if preds_prob is not None:
+ probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
+
+ result_list.append([' '.join(symbol_list), probs])
+ return result_list
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ pred_prob, _, _, _ = preds
+ preds_idx = pred_prob.argmax(axis=2)
+
+ text = self.decode(preds_idx)
+ if label is None:
+ return text
+ label = self.decode(label)
+ return text, label
diff --git a/ppocr/utils/dict/latex_symbol_dict.txt b/ppocr/utils/dict/latex_symbol_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b43f1fa8b904e3107eb450f6d7332aec6b5b81e2
--- /dev/null
+++ b/ppocr/utils/dict/latex_symbol_dict.txt
@@ -0,0 +1,111 @@
+eos
+sos
+!
+'
+(
+)
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+<
+=
+>
+A
+B
+C
+E
+F
+G
+H
+I
+L
+M
+N
+P
+R
+S
+T
+V
+X
+Y
+[
+\Delta
+\alpha
+\beta
+\cdot
+\cdots
+\cos
+\div
+\exists
+\forall
+\frac
+\gamma
+\geq
+\in
+\infty
+\int
+\lambda
+\ldots
+\leq
+\lim
+\log
+\mu
+\neq
+\phi
+\pi
+\pm
+\prime
+\rightarrow
+\sigma
+\sin
+\sqrt
+\sum
+\tan
+\theta
+\times
+]
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+\{
+|
+\}
+{
+}
+^
+_
\ No newline at end of file
diff --git a/test_tipc/configs/rec_d28_can/rec_d28_can.yml b/test_tipc/configs/rec_d28_can/rec_d28_can.yml
new file mode 100644
index 0000000000000000000000000000000000000000..aeaccb6b0d245cc4be479d17597d0029647db574
--- /dev/null
+++ b/test_tipc/configs/rec_d28_can/rec_d28_can.yml
@@ -0,0 +1,114 @@
+Global:
+ use_gpu: True
+ epoch_num: 240
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/can/
+ save_epoch_step: 1
+ # evaluation is run every 1105 iterations
+ eval_batch_step: [0, 1105]
+ cal_metric_during_train: True
+ pretrained_model: ./output/rec/can/CAN
+ checkpoints: ./output/rec/can/CAN
+ save_inference_dir: ./inference/rec_d28_can/
+ use_visualdl: False
+ infer_img: doc/imgs_hme/hme_01.jpeg
+ # for data or label process
+ character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
+ max_text_length: 36
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_can.txt
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ clip_norm_global: 100.0
+ lr:
+ name: TwoStepCosine
+ learning_rate: 0.01
+ warmup_epoch: 1
+ weight_decay: 0.0001
+
+Architecture:
+ model_type: rec
+ algorithm: CAN
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: DenseNet
+ growthRate: 24
+ reduction: 0.5
+ bottleneck: True
+ use_dropout: True
+ input_channel: 1
+
+ Head:
+ name: CANHead
+ in_channel: 684
+ out_channel: 111
+ max_text_length: 36
+ ratio: 16
+ attdecoder:
+ is_train: True
+ input_size: 256
+ hidden_size: 256
+ encoder_out_channel: 684
+ dropout: True
+ dropout_ratio: 0.5
+ word_num: 111
+ counting_decoder_out_channel: 111
+ attention:
+ attention_dim: 512
+ word_conv_kernel: 1
+
+Loss:
+ name: CANLoss
+
+PostProcess:
+ name: SeqLabelDecode
+ character: 111
+
+Metric:
+ name: CANMetric
+ main_indicator: exp_rate
+
+Train:
+ dataset:
+ name: HMERDataSet
+ data_dir: ./train_data/CROHME/training/images/
+ transforms:
+ - DecodeImage:
+ channel_first: False
+ - GrayImageChannelFormat:
+ normalize: True
+ inverse: True
+ - KeepKeys:
+ keep_keys: ['image', 'label']
+ label_file_list: ["./train_data/CROHME/training/labels.json"]
+ loader:
+ shuffle: True
+ batch_size_per_card: 2
+ drop_last: True
+ num_workers: 1
+ collate_fn: DyMaskCollator
+
+Eval:
+ dataset:
+ name: HMERDataSet
+ data_dir: ./train_data/CROHME/evaluation/images/
+ transforms:
+ - DecodeImage:
+ channel_first: False
+ - GrayImageChannelFormat:
+ normalize: True
+ inverse: True
+ - KeepKeys:
+ keep_keys: ['image', 'label']
+ label_file_list: ["./train_data/CROHME/evaluation/labels.json"]
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 4
+ collate_fn: DyMaskCollator
diff --git a/test_tipc/configs/rec_d28_can/train_infer_python.txt b/test_tipc/configs/rec_d28_can/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..be50c59805b2f1b8a0baa5a61b9d5cd4a21b68df
--- /dev/null
+++ b/test_tipc/configs/rec_d28_can/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_d28_can
+python:python
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=240
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=8
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./doc/imgs_hme
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_d28_can_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/latex_symbol_dict.txt --rec_image_shape="1,100,100" --rec_algorithm="CAN"
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--rec_model_dir:./output/
+--image_dir:./doc/imgs_hme
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[1,100,100]}]
diff --git a/tools/eval.py b/tools/eval.py
index 3d1d3813d33e251ec83a9729383fe772bc4cc225..21f4d94d5e4ed560b8775c8827ffdbbd00355218 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -74,7 +74,9 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
- extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"]
+ extra_input_models = [
+ "SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
+ ]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
@@ -83,7 +85,10 @@ def main():
else:
extra_input = config['Architecture']['algorithm'] in extra_input_models
if "model_type" in config['Architecture'].keys():
- model_type = config['Architecture']['model_type']
+ if config['Architecture']['algorithm'] == 'CAN':
+ model_type = 'can'
+ else:
+ model_type = config['Architecture']['model_type']
else:
model_type = None
@@ -92,7 +97,7 @@ def main():
# amp
use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2')
- amp_custom_black_list = config['Global'].get('amp_custom_black_list',[])
+ amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
@@ -120,7 +125,8 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class, model_type, extra_input, scaler, amp_level, amp_custom_black_list)
+ eval_class, model_type, extra_input, scaler,
+ amp_level, amp_custom_black_list)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
diff --git a/tools/export_model.py b/tools/export_model.py
index 52f05bfcba0487d1c5abd0f7d7221c2feca40ae9..4b90fcae435619a53a3def8cc4dc46b4e2963bff 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -123,6 +123,17 @@ def export_single_model(model,
]
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "CAN":
+ other_shape = [[
+ paddle.static.InputSpec(
+ shape=[None, 1, None, None],
+ dtype="float32"), paddle.static.InputSpec(
+ shape=[None, 1, None, None], dtype="float32"),
+ paddle.static.InputSpec(
+ shape=[None, arch_config['Head']['max_text_length']],
+ dtype="int64")
+ ]]
+ model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index bffeb25534068691fee21bbf946cc7cda7326d27..c1604798e7506561b4f03bd401f4588c28b6cb1a 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -108,6 +108,13 @@ class TextRecognizer(object):
}
elif self.rec_algorithm == "PREN":
postprocess_params = {'name': 'PRENLabelDecode'}
+ elif self.rec_algorithm == "CAN":
+ self.inverse = args.rec_image_inverse
+ postprocess_params = {
+ 'name': 'SeqLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
@@ -351,6 +358,30 @@ class TextRecognizer(object):
return resized_image
+ def norm_img_can(self, img, image_shape):
+
+ img = cv2.cvtColor(
+ img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
+
+ if self.inverse:
+ img = 255 - img
+
+ if self.rec_image_shape[0] == 1:
+ h, w = img.shape
+ _, imgH, imgW = self.rec_image_shape
+ if h < imgH or w < imgW:
+ padding_h = max(imgH - h, 0)
+ padding_w = max(imgW - w, 0)
+ img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
+ 'constant',
+ constant_values=(255))
+ img = img_padded
+
+ img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
+ img = img.astype('float32')
+
+ return img
+
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@@ -430,6 +461,17 @@ class TextRecognizer(object):
word_positions = np.array(range(0, 40)).astype('int64')
word_positions = np.expand_dims(word_positions, axis=0)
word_positions_list.append(word_positions)
+ elif self.rec_algorithm == "CAN":
+ norm_img = self.norm_img_can(img_list[indices[ino]],
+ max_wh_ratio)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ norm_image_mask = np.ones(norm_img.shape, dtype='float32')
+ word_label = np.ones([1, 36], dtype='int64')
+ norm_img_mask_batch = []
+ word_label_list = []
+ norm_img_mask_batch.append(norm_image_mask)
+ word_label_list.append(word_label)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
@@ -527,6 +569,33 @@ class TextRecognizer(object):
if self.benchmark:
self.autolog.times.stamp()
preds = outputs[0]
+ elif self.rec_algorithm == "CAN":
+ norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
+ word_label_list = np.concatenate(word_label_list)
+ inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
+ if self.use_onnx:
+ input_dict = {}
+ input_dict[self.input_tensor.name] = norm_img_batch
+ outputs = self.predictor.run(self.output_tensors,
+ input_dict)
+ preds = outputs
+ else:
+ input_names = self.predictor.get_input_names()
+ input_tensor = []
+ for i in range(len(input_names)):
+ input_tensor_i = self.predictor.get_input_handle(
+ input_names[i])
+ input_tensor_i.copy_from_cpu(inputs[i])
+ input_tensor.append(input_tensor_i)
+ self.input_tensor = input_tensor
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ if self.benchmark:
+ self.autolog.times.stamp()
+ preds = outputs
else:
if self.use_onnx:
input_dict = {}
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index f6a44e35a5b303d6ed30bf8057a62409aa690fef..34cad2590f2904f79709530acf841033c89088e0 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -84,6 +84,7 @@ def init_args():
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
parser.add_argument("--rec_model_dir", type=str)
+ parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index cb8a6ec3050c878669f539b8b11d97214f5eec20..29aab9b57853b16bf615c893c30351a403270b57 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -141,6 +141,11 @@ def main():
paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
+ if config['Architecture']['algorithm'] == "CAN":
+ image_mask = paddle.ones(
+ (np.expand_dims(
+ batch[0], axis=0).shape), dtype='float32')
+ label = paddle.ones((1, 36), dtype='int64')
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
@@ -149,6 +154,8 @@ def main():
preds = model(images, img_metas)
elif config['Architecture']['algorithm'] == "RobustScanner":
preds = model(images, img_metas)
+ elif config['Architecture']['algorithm'] == "CAN":
+ preds = model([images, image_mask, label])
else:
preds = model(images)
post_result = post_process_class(preds)
diff --git a/tools/program.py b/tools/program.py
index 5d2bd5bfb034940e3bec802b5e7041c8e82a9271..c491247a64697774cca73e209b714f7560c5fcdd 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -273,6 +273,8 @@ def train(config,
preds = model(images, data=batch[1:])
elif model_type in ["kie"]:
preds = model(batch)
+ elif algorithm in ['CAN']:
+ preds = model(batch[:3])
else:
preds = model(images)
preds = to_float32(preds)
@@ -286,6 +288,8 @@ def train(config,
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'sr']:
preds = model(batch)
+ elif algorithm in ['CAN']:
+ preds = model(batch[:3])
else:
preds = model(images)
loss = loss_class(preds, batch)
@@ -302,6 +306,9 @@ def train(config,
elif model_type in ['table']:
post_result = post_process_class(preds, batch)
eval_class(post_result, batch)
+ elif algorithm in ['CAN']:
+ model_type = 'can'
+ eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
else:
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss
@@ -496,6 +503,8 @@ def eval(model,
preds = model(images, data=batch[1:])
elif model_type in ["kie"]:
preds = model(batch)
+ elif model_type in ['can']:
+ preds = model(batch[:3])
elif model_type in ['sr']:
preds = model(batch)
sr_img = preds["sr_img"]
@@ -508,6 +517,8 @@ def eval(model,
preds = model(images, data=batch[1:])
elif model_type in ["kie"]:
preds = model(batch)
+ elif model_type in ['can']:
+ preds = model(batch[:3])
elif model_type in ['sr']:
preds = model(batch)
sr_img = preds["sr_img"]
@@ -532,6 +543,8 @@ def eval(model,
eval_class(post_result, batch_numpy)
elif model_type in ['sr']:
eval_class(preds, batch_numpy)
+ elif model_type in ['can']:
+ eval_class(preds[0], batch_numpy[2:], epoch_reset=False)
else:
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
@@ -629,7 +642,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
- 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG'
+ 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN'
]
if use_xpu: