diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1584bc76a9dd8ddff9d05a8cb693bcbd2e09fcde..b6a299ba4009ba73afc9673fe008b97ca139c57b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,11 @@ +repos: - repo: https://github.com/PaddlePaddle/mirrors-yapf.git - sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 + rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 hooks: - id: yapf files: \.py$ - repo: https://github.com/pre-commit/pre-commit-hooks - sha: a11d9314b22d8f8c7556443875b731ef05965464 + rev: a11d9314b22d8f8c7556443875b731ef05965464 hooks: - id: check-merge-conflict - id: check-symlinks @@ -15,7 +16,7 @@ - id: trailing-whitespace files: \.md$ - repo: https://github.com/Lucas-C/pre-commit-hooks - sha: v1.0.1 + rev: v1.0.1 hooks: - id: forbid-crlf files: \.md$ diff --git a/configs/rec/rec_d28_can.yml b/configs/rec/rec_d28_can.yml new file mode 100644 index 0000000000000000000000000000000000000000..aeaccb6b0d245cc4be479d17597d0029647db574 --- /dev/null +++ b/configs/rec/rec_d28_can.yml @@ -0,0 +1,114 @@ +Global: + use_gpu: True + epoch_num: 240 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/can/ + save_epoch_step: 1 + # evaluation is run every 1105 iterations + eval_batch_step: [0, 1105] + cal_metric_during_train: True + pretrained_model: ./output/rec/can/CAN + checkpoints: ./output/rec/can/CAN + save_inference_dir: ./inference/rec_d28_can/ + use_visualdl: False + infer_img: doc/imgs_hme/hme_01.jpeg + # for data or label process + character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt + max_text_length: 36 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_can.txt + +Optimizer: + name: Momentum + momentum: 0.9 + clip_norm_global: 100.0 + lr: + name: TwoStepCosine + learning_rate: 0.01 + warmup_epoch: 1 + weight_decay: 0.0001 + +Architecture: + model_type: rec + algorithm: CAN + in_channels: 1 + Transform: + Backbone: + name: DenseNet + growthRate: 24 + reduction: 0.5 + bottleneck: True + use_dropout: True + input_channel: 1 + + Head: + name: CANHead + in_channel: 684 + out_channel: 111 + max_text_length: 36 + ratio: 16 + attdecoder: + is_train: True + input_size: 256 + hidden_size: 256 + encoder_out_channel: 684 + dropout: True + dropout_ratio: 0.5 + word_num: 111 + counting_decoder_out_channel: 111 + attention: + attention_dim: 512 + word_conv_kernel: 1 + +Loss: + name: CANLoss + +PostProcess: + name: SeqLabelDecode + character: 111 + +Metric: + name: CANMetric + main_indicator: exp_rate + +Train: + dataset: + name: HMERDataSet + data_dir: ./train_data/CROHME/training/images/ + transforms: + - DecodeImage: + channel_first: False + - GrayImageChannelFormat: + normalize: True + inverse: True + - KeepKeys: + keep_keys: ['image', 'label'] + label_file_list: ["./train_data/CROHME/training/labels.json"] + loader: + shuffle: True + batch_size_per_card: 2 + drop_last: True + num_workers: 1 + collate_fn: DyMaskCollator + +Eval: + dataset: + name: HMERDataSet + data_dir: ./train_data/CROHME/evaluation/images/ + transforms: + - DecodeImage: + channel_first: False + - GrayImageChannelFormat: + normalize: True + inverse: True + - KeepKeys: + keep_keys: ['image', 'label'] + label_file_list: ["./train_data/CROHME/evaluation/labels.json"] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 4 + collate_fn: DyMaskCollator diff --git a/doc/doc_ch/algorithm_rec_can.md b/doc/doc_ch/algorithm_rec_can.md new file mode 100644 index 0000000000000000000000000000000000000000..9585dae0cf9ecc215b4ff2f6d656345418e197d4 --- /dev/null +++ b/doc/doc_ch/algorithm_rec_can.md @@ -0,0 +1,170 @@ +# 手写数学公式识别算法-ABINet + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463) +> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai +> ECCV, 2022 + + + +`CAN`使用CROHME手写公式数据集进行训练,在对应测试集上的精度如下: + +|模型 |骨干网络|配置文件|ExpRate|下载链接| +| ----- | ----- | ----- | ----- | ----- | +|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar)| + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + + +### 3.1 模型训练 + +请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`CAN`识别模型时需要**更换配置文件**为`CAN`的[配置文件](../../configs/rec/rec_d28_can.yml)。 + +#### 启动训练 + + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: +```shell +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_d28_can.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml +``` + +**注意:** +- 我们提供的数据集,即`CROHME数据集`将手写公式存储为黑底白字的格式,若您自行准备的数据集与之相反,即以白底黑字模式存储,请在训练时做出如下修改 +``` +python3 tools/train.py -c configs/rec/rec_d28_can.yml +-o Train.dataset.transforms.GrayImageChannelFormat.inverse=False +``` + +# + +### 3.2 评估 + +可下载已训练完成的[模型文件](#model),使用如下命令进行评估: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/best_accuracy +``` + + +### 3.3 预测 + +使用如下命令进行单张图片预测: +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/imgs_hme/hme_01.jpg' Global.pretrained_model=./rec_d28_can_train/best_accuracy + +# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_hme/'。 +``` + + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_d28_can_train.tar) ),可以使用如下命令进行转换: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False + +# 目前的静态图模型默认的输出长度最大为36,如果您需要预测更长的序列,请在导出模型时指定其输出序列为合适的值,例如 Architecture.Head.max_text_length=72 +``` +**注意:** +- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。 +- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应ABINet的`infer_shape`。 + +转换成功后,在目录下有三个文件: +``` +/inference/rec_d28_can/ + ├── inference.pdiparams # 识别inference模型的参数文件 + ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略 + └── inference.pdmodel # 识别inference模型的program文件 +``` + +执行如下命令进行模型推理: + +```shell +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_hme/hme_01.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt" + +# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_hme/'。 + +# 如果您需要在白底黑字的图片上进行预测,请设置 --rec_image_inverse=False +``` + +![测试图片样例](../imgs_hme/hme_00.jpg) + +执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下: +```shell +Predicts of ./doc/imgs_hme/hme_03.jpg:['x _ { k } x x _ { k } + y _ { k } y x _ { k }', []] +``` + + +**注意**: + +- 需要注意预测图像为**黑底白字**,即手写公式部分为白色,背景为黑色的图片。 +- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。 +- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中CAN的预处理为您的预处理方法。 + + + +### 4.2 C++推理部署 + +由于C++预处理后处理还未支持ABINet,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + +1. CROHME数据集来自于[CAN源repo](https://github.com/LBH1024/CAN) 。 + +## 引用 + +```bibtex +@misc{https://doi.org/10.48550/arxiv.2207.11463, + doi = {10.48550/ARXIV.2207.11463}, + url = {https://arxiv.org/abs/2207.11463}, + author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang}, + keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition}, + publisher = {arXiv}, + year = {2022}, + copyright = {arXiv.org perpetual, non-exclusive license} +} +``` diff --git a/doc/doc_en/algorithm_rec_can_en.md b/doc/doc_en/algorithm_rec_can_en.md new file mode 100644 index 0000000000000000000000000000000000000000..f2bc645af0d2bea629e6dfe78f6c49f51a3af73d --- /dev/null +++ b/doc/doc_en/algorithm_rec_can_en.md @@ -0,0 +1,115 @@ +# RobustScanner + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463) +> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai +> ECCV, 2022 + +Using CROHME handwrittem mathematical expression recognition datasets for training, and evaluating on its test sets, the algorithm reproduction effect is as follows: + +|Model|Backbone|config|exprate|Download link| +| --- | --- | --- | --- | --- | +|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|coming soon| + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_d28_can.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml +``` + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/imgs_hme/hme_01.jpg' Global.pretrained_model=./rec_d28_can_train/best_accuracy +``` + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert: + +``` +python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False +``` + +For RobustScanner text recognition model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_hme/hme_01.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_image_shape="1, 132, 519" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt" +``` + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@misc{https://doi.org/10.48550/arxiv.2207.11463, + doi = {10.48550/ARXIV.2207.11463}, + url = {https://arxiv.org/abs/2207.11463}, + author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang}, + keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition}, + publisher = {arXiv}, + year = {2022}, + copyright = {arXiv.org perpetual, non-exclusive license} +} +``` diff --git a/doc/imgs_hme/hme_00.jpg b/doc/imgs_hme/hme_00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..66ff27db266b5d4fa05d8acd95ba881bb8a1aec0 Binary files /dev/null and b/doc/imgs_hme/hme_00.jpg differ diff --git a/doc/imgs_hme/hme_01.jpg b/doc/imgs_hme/hme_01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..68b7f09fc2f330ee523ded27a14486b3c92763cb Binary files /dev/null and b/doc/imgs_hme/hme_01.jpg differ diff --git a/doc/imgs_hme/hme_02.jpg b/doc/imgs_hme/hme_02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ecc760f5382bfe3d94de6141379f6a5a196e8430 Binary files /dev/null and b/doc/imgs_hme/hme_02.jpg differ diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index b602a346dbe4b0d45af287f25f05ead0f62daf44..1f3de63de72f44a7daf8d641b2d5c6bc5568df8f 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -37,6 +37,7 @@ from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pubtab_dataset import PubTabDataSet +from ppocr.data.hmer_dataset import HMERDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -55,7 +56,7 @@ def build_dataloader(config, mode, device, logger, seed=None): support_dict = [ 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet', - 'LMDBDataSetSR' + 'LMDBDataSetSR', 'HMERDataSet' ] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( diff --git a/ppocr/data/collate_fn.py b/ppocr/data/collate_fn.py index 0da6060f042a0e60cdf211d8bc13aede32d5930a..fec1e895ff033cd991ed3a927ad7e158989bab45 100644 --- a/ppocr/data/collate_fn.py +++ b/ppocr/data/collate_fn.py @@ -70,3 +70,49 @@ class SSLRotateCollate(object): def __call__(self, batch): output = [np.concatenate(d, axis=0) for d in zip(*batch)] return output + + +class DyMaskCollator(object): + """ + batch: [ + image [batch_size, channel, maxHinbatch, maxWinbatch] + image_mask [batch_size, channel, maxHinbatch, maxWinbatch] + label [batch_size, maxLabelLen] + label_mask [batch_size, maxLabelLen] + ... + ] + """ + + def __call__(self, batch): + max_width, max_height, max_length = 0, 0, 0 + bs, channel = len(batch), batch[0][0].shape[0] + proper_items = [] + for item in batch: + if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[ + 2] * max_height > 1600 * 320: + continue + max_height = item[0].shape[1] if item[0].shape[ + 1] > max_height else max_height + max_width = item[0].shape[2] if item[0].shape[ + 2] > max_width else max_width + max_length = item[1].shape[0] if item[1].shape[ + 0] > max_length else max_length + proper_items.append(item) + + images, image_masks = np.zeros( + (len(proper_items), channel, max_height, max_width), + dtype='float32'), np.zeros( + (len(proper_items), 1, max_height, max_width), dtype='float32') + labels, label_masks = np.zeros( + (len(proper_items), max_length), dtype='int64'), np.zeros( + (len(proper_items), max_length), dtype='int64') + + for i in range(len(proper_items)): + _, h, w = proper_items[i][0].shape + images[i][:, :h, :w] = proper_items[i][0] + image_masks[i][:, :h, :w] = 1 + l = proper_items[i][1].shape[0] + labels[i][:l] = proper_items[i][1] + label_masks[i][:l] = 1 + + return images, image_masks, labels, label_masks diff --git a/ppocr/data/hmer_dataset.py b/ppocr/data/hmer_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d92f264b2604d68929bf08e8514c1a87b9198b --- /dev/null +++ b/ppocr/data/hmer_dataset.py @@ -0,0 +1,99 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os, json, random, traceback +import numpy as np + +from PIL import Image +from paddle.io import Dataset + +from .imaug import transform, create_operators + + +class HMERDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(HMERDataSet, self).__init__() + + self.logger = logger + self.seed = seed + self.mode = mode + + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + self.data_dir = config[mode]['dataset']['data_dir'] + + label_file_list = dataset_config['label_file_list'] + data_source_num = len(label_file_list) + ratio_list = dataset_config.get("ratio_list", [1.0]) + + self.data_lines, self.labels = self.get_image_info_list(label_file_list, + ratio_list) + self.data_idx_order_list = list(range(len(self.data_lines))) + if self.mode == "train" and self.do_shuffle: + self.shuffle_data_random() + + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * int(data_source_num) + + assert len( + ratio_list + ) == data_source_num, "The length of ratio_list should be the same as the file_list." + + self.ops = create_operators(dataset_config['transforms'], global_config) + self.need_reset = True in [x < 1 for x in ratio_list] + + def get_image_info_list(self, file_list, ratio_list): + if isinstance(file_list, str): + file_list = [file_list] + labels = {} + for idx, file in enumerate(file_list): + with open(file, "r") as f: + lines = json.load(f) + labels.update(lines) + data_lines = [name for name in labels.keys()] + return data_lines, labels + + def shuffle_data_random(self): + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def __len__(self): + return len(self.data_idx_order_list) + + def __getitem__(self, idx): + file_idx = self.data_idx_order_list[idx] + data_name = self.data_lines[file_idx] + try: + file_name = data_name + '.jpg' + img_path = os.path.join(self.data_dir, file_name) + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + with open(img_path, 'rb') as f: + img = f.read() + + label = self.labels.get(data_name).split() + label = np.array([int(item) for item in label]) + + data = {'image': img, 'label': label} + outs = transform(data, self.ops) + except: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + file_name, traceback.format_exc())) + outs = None + if outs is None: + # during evaluation, we should fix the idx to get same results for many times of evaluation. + rnd_idx = np.random.randint(self.__len__()) + return self.__getitem__(rnd_idx) + return outs diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 93d97446d44070b9c10064fbe10b0b5e05628a6a..a64092286f4cd36ddac5503fb00ebaa43b083a83 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \ - RFLRecResizeImg + RFLRecResizeImg, GrayImageChannelFormat from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index e22153bdeab06565feed79715633172a275aecc7..bc7fbc6046c0c773f2f666eed8a661f1f6d8f89a 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -465,6 +465,36 @@ class RobustScannerRecResizeImg(object): return data +class GrayImageChannelFormat(object): + """ + format gray scale image's channel: (3,h,w) -> (1,h,w) + Args: + normalize: True/False + when True convert image dynamic range [0,255]->[0,1] + inverse: inverse gray image + """ + + def __init__(self, normalize=True, inverse=False, **kwargs): + self.normalize = normalize + self.inverse = inverse + + def __call__(self, data): + img = data['image'] + img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img_single_channel = np.expand_dims(img_single_channel, 0) + + if self.normalize: + img_single_channel = img_single_channel / 255.0 + + if self.inverse: + data['image'] = np.abs(img_single_channel - 1).astype('float32') + else: + data['image'] = img_single_channel.astype('float32') + + data['src_image'] = img + return data + + def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape h = img.shape[0] diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 6abaa408b3f6995a0b4c377206e8a1551b48c56b..6a34dd1c87a92db3618bfaeed6aeee3cb29b261a 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -40,6 +40,7 @@ from .rec_multi_loss import MultiLoss from .rec_vl_loss import VLLoss from .rec_spin_att_loss import SPINAttentionLoss from .rec_rfl_loss import RFLLoss +from .rec_can_loss import CANLoss # cls loss from .cls_loss import ClsLoss @@ -71,7 +72,7 @@ def build_loss(config): 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', - 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss' + 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_can_loss.py b/ppocr/losses/rec_can_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c655e0e64ac17aabc853b7e70d31ad0cf059a6 --- /dev/null +++ b/ppocr/losses/rec_can_loss.py @@ -0,0 +1,61 @@ +import paddle +import paddle.nn as nn +import numpy as np + + +class CANLoss(nn.Layer): + ''' + CANLoss is consist of two part: + word_average_loss: average accuracy of the symbol + counting_loss: counting loss of every symbol + ''' + + def __init__(self): + super(CANLoss, self).__init__() + + self.use_label_mask = False + self.out_channel = 111 + self.cross = nn.CrossEntropyLoss( + reduction='none') if self.use_label_mask else nn.CrossEntropyLoss() + self.counting_loss = nn.SmoothL1Loss(reduction='mean') + self.ratio = 16 + + def forward(self, preds, batch): + word_probs = preds[0] + counting_preds = preds[1] + counting_preds1 = preds[2] + counting_preds2 = preds[3] + labels = batch[2] + labels_mask = batch[3] + counting_labels = gen_counting_label(labels, self.out_channel, True) + counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \ + + self.counting_loss(counting_preds, counting_labels) + + word_loss = self.cross( + paddle.reshape(word_probs, [-1, word_probs.shape[-1]]), + paddle.reshape(labels, [-1])) + word_average_loss = paddle.sum( + paddle.reshape(word_loss * labels_mask, [-1])) / ( + paddle.sum(labels_mask) + 1e-10 + ) if self.use_label_mask else word_loss + loss = word_average_loss + counting_loss + return {'loss': loss} + + +def gen_counting_label(labels, channel, tag): + b, t = labels.shape + counting_labels = np.zeros([b, channel]) + + if tag: + ignore = [0, 1, 107, 108, 109, 110] + else: + ignore = [] + for i in range(b): + for j in range(t): + k = labels[i][j] + if k in ignore: + continue + else: + counting_labels[i][k] += 1 + counting_labels = paddle.to_tensor(counting_labels, dtype='float32') + return counting_labels diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 20aea8b5995a49306d427bc427048c9df8d0923d..5e840a194adc2683e92c308f232dc869df34de8e 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -22,7 +22,7 @@ import copy __all__ = ["build_metric"] from .det_metric import DetMetric, DetFCEMetric -from .rec_metric import RecMetric, CNTMetric +from .rec_metric import RecMetric, CNTMetric, CANMetric from .cls_metric import ClsMetric from .e2e_metric import E2EMetric from .distillation_metric import DistillationMetric @@ -38,7 +38,7 @@ def build_metric(config): support_dict = [ "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', - 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric' + 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric', 'CANMetric' ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 4758e71d0930261044841a6a820308a04391fc0b..305b913c72da5842b6654f1fc9b27e6e2b46b436 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -13,6 +13,9 @@ # limitations under the License. from rapidfuzz.distance import Levenshtein +from difflib import SequenceMatcher + +import numpy as np import string @@ -106,3 +109,71 @@ class CNTMetric(object): def reset(self): self.correct_num = 0 self.all_num = 0 + + +class CANMetric(object): + def __init__(self, main_indicator='exp_rate', **kwargs): + self.main_indicator = main_indicator + self.word_right = [] + self.exp_right = [] + self.word_total_length = 0 + self.exp_total_num = 0 + self.word_rate = 0 + self.exp_rate = 0 + self.reset() + self.epoch_reset() + + def __call__(self, preds, batch, **kwargs): + for k, v in kwargs.items(): + epoch_reset = v + if epoch_reset: + self.epoch_reset() + word_probs = preds + word_label, word_label_mask = batch + line_right = 0 + if word_probs is not None: + word_pred = word_probs.argmax(2) + word_pred = word_pred.cpu().detach().numpy() + word_scores = [ + SequenceMatcher( + None, + s1[:int(np.sum(s3))], + s2[:int(np.sum(s3))], + autojunk=False).ratio() * ( + len(s1[:int(np.sum(s3))]) + len(s2[:int(np.sum(s3))])) / + len(s1[:int(np.sum(s3))]) / 2 + for s1, s2, s3 in zip(word_label, word_pred, word_label_mask) + ] + batch_size = len(word_scores) + for i in range(batch_size): + if word_scores[i] == 1: + line_right += 1 + self.word_rate = np.mean(word_scores) #float + self.exp_rate = line_right / batch_size #float + exp_length, word_length = word_label.shape[:2] + self.word_right.append(self.word_rate * word_length) + self.exp_right.append(self.exp_rate * exp_length) + self.word_total_length = self.word_total_length + word_length + self.exp_total_num = self.exp_total_num + exp_length + + def get_metric(self): + """ + return { + 'word_rate': 0, + "exp_rate": 0, + } + """ + cur_word_rate = sum(self.word_right) / self.word_total_length + cur_exp_rate = sum(self.exp_right) / self.exp_total_num + self.reset() + return {'word_rate': cur_word_rate, "exp_rate": cur_exp_rate} + + def reset(self): + self.word_rate = 0 + self.exp_rate = 0 + + def epoch_reset(self): + self.word_right = [] + self.exp_right = [] + self.word_total_length = 0 + self.exp_total_num = 0 diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 84892fa9c7fd61838e984690b17931f367ab0585..e2c2e9c4a4ed526b36d512d824ae8a8a701c17bc 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -43,10 +43,12 @@ def build_backbone(config, model_type): from .rec_svtrnet import SVTRNet from .rec_vitstr import ViTSTR from .rec_resnet_rfl import ResNetRFL + from .rec_densenet import DenseNet support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', - 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL' + 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL', + 'DenseNet' ] elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_densenet.py b/ppocr/modeling/backbones/rec_densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..d3391d4086857bc02bca20c1d0b9279172e6bc44 --- /dev/null +++ b/ppocr/modeling/backbones/rec_densenet.py @@ -0,0 +1,135 @@ +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class Bottleneck(nn.Layer): + ''' + ratio: 16 + growthRate: 24 + reduction: 0.5 + bottleneck: True + use_dropout: True + ''' + + def __init__(self, nChannels, growthRate, use_dropout): + super(Bottleneck, self).__init__() + interChannels = 4 * growthRate + self.bn1 = nn.BatchNorm2D(interChannels) + self.conv1 = nn.Conv2D( + nChannels, interChannels, kernel_size=1, + bias_attr=None) # Xavier initialization + self.bn2 = nn.BatchNorm2D(growthRate) + self.conv2 = nn.Conv2D( + interChannels, growthRate, kernel_size=3, padding=1, + bias_attr=None) # Xavier initialization + self.use_dropout = use_dropout + self.dropout = nn.Dropout(p=0.2) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + if self.use_dropout: + out = self.dropout(out) + out = F.relu(self.bn2(self.conv2(out))) + if self.use_dropout: + out = self.dropout(out) + out = paddle.concat([x, out], 1) + return out + + +class SingleLayer(nn.Layer): + def __init__(self, nChannels, growthRate, use_dropout): + super(SingleLayer, self).__init__() + self.bn1 = nn.BatchNorm2D(nChannels) + self.conv1 = nn.Conv2D( + nChannels, growthRate, kernel_size=3, padding=1, bias_attr=False) + + self.use_dropout = use_dropout + self.dropout = nn.Dropout(p=0.2) + + def forward(self, x): + out = self.conv1(F.relu(x)) + if self.use_dropout: + out = self.dropout(out) + + out = paddle.concat([x, out], 1) + return out + + +class Transition(nn.Layer): + def __init__(self, nChannels, out_channels, use_dropout): + super(Transition, self).__init__() + self.bn1 = nn.BatchNorm2D(out_channels) + self.conv1 = nn.Conv2D( + nChannels, out_channels, kernel_size=1, bias_attr=False) + self.use_dropout = use_dropout + self.dropout = nn.Dropout(p=0.2) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + if self.use_dropout: + out = self.dropout(out) + out = F.avg_pool2d(out, 2, ceil_mode=True, exclusive=False) + return out + + +class DenseNet(nn.Layer): + def __init__(self, growthRate, reduction, bottleneck, use_dropout, + input_channel, **kwargs): + super(DenseNet, self).__init__() + ''' + ratio: 16 + growthRate: 24 + reduction: 0.5 + ''' + nDenseBlocks = 16 + nChannels = 2 * growthRate + + self.conv1 = nn.Conv2D( + input_channel, + nChannels, + kernel_size=7, + padding=3, + stride=2, + bias_attr=False) + self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, + bottleneck, use_dropout) + nChannels += nDenseBlocks * growthRate + out_channels = int(math.floor(nChannels * reduction)) + self.trans1 = Transition(nChannels, out_channels, use_dropout) + + nChannels = out_channels + self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, + bottleneck, use_dropout) + nChannels += nDenseBlocks * growthRate + out_channels = int(math.floor(nChannels * reduction)) + self.trans2 = Transition(nChannels, out_channels, use_dropout) + + nChannels = out_channels + self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, + bottleneck, use_dropout) + self.out_channels = out_channels + + def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, + use_dropout): + layers = [] + for i in range(int(nDenseBlocks)): + if bottleneck: + layers.append(Bottleneck(nChannels, growthRate, use_dropout)) + else: + layers.append(SingleLayer(nChannels, growthRate, use_dropout)) + nChannels += growthRate + return nn.Sequential(*layers) + + def forward(self, inputs): + x, x_m, y = inputs + out = self.conv1(x) + out = F.relu(out) + out = F.max_pool2d(out, 2, ceil_mode=True) + out = self.dense1(out) + out = self.trans1(out) + out = self.dense2(out) + out = self.trans2(out) + out = self.dense3(out) + return out, x_m, y diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 63002140c5be4bd7e32b56995c6410ecc8a0fa36..fdf5a8a96d587c66ac9b26ff4e1264ab9c8173f3 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -40,6 +40,7 @@ def build_head(config): from .rec_robustscanner_head import RobustScannerHead from .rec_visionlan_head import VLHead from .rec_rfl_head import RFLHead + from .rec_can_head import CANHead # cls head from .cls_head import ClsHead @@ -56,7 +57,7 @@ def build_head(config): 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead', - 'DRRGHead' + 'DRRGHead', 'CANHead' ] #table head diff --git a/ppocr/modeling/heads/rec_can_head.py b/ppocr/modeling/heads/rec_can_head.py new file mode 100644 index 0000000000000000000000000000000000000000..afd78ee9d277dde7838b96938f4b2221b690c5d2 --- /dev/null +++ b/ppocr/modeling/heads/rec_can_head.py @@ -0,0 +1,294 @@ +from turtle import forward +import paddle.nn as nn +import paddle +import math +''' +Counting Module +''' + + +class ChannelAtt(nn.Layer): + def __init__(self, channel, reduction): + super(ChannelAtt, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.shape + y = paddle.reshape(self.avg_pool(x), [b, c]) + y = paddle.reshape(self.fc(y), [b, c, 1, 1]) + return x * y + + +class CountingDecoder(nn.Layer): + def __init__(self, in_channel, out_channel, kernel_size): + super(CountingDecoder, self).__init__() + self.in_channel = in_channel + self.out_channel = out_channel + + self.trans_layer = nn.Sequential( + nn.Conv2D( + self.in_channel, + 512, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias_attr=False), + nn.BatchNorm2D(512)) + + self.channel_att = ChannelAtt(512, 16) + + self.pred_layer = nn.Sequential( + nn.Conv2D( + 512, self.out_channel, kernel_size=1, bias_attr=False), + nn.Sigmoid()) + + def forward(self, x, mask): + b, _, h, w = x.shape + x = self.trans_layer(x) + x = self.channel_att(x) + x = self.pred_layer(x) + + if mask is not None: + x = x * mask + x = paddle.reshape(x, [b, self.out_channel, -1]) + x1 = paddle.sum(x, axis=-1) + + return x1, paddle.reshape(x, [b, self.out_channel, h, w]) + + +''' +Attention Decoder +''' + + +class PositionEmbeddingSine(nn.Layer): + def __init__(self, + num_pos_feats=64, + temperature=10000, + normalize=False, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask): + y_embed = paddle.cumsum(mask, 1, dtype='float32') + x_embed = paddle.cumsum(mask, 2, dtype='float32') + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + dim_t = paddle.arange(self.num_pos_feats, dtype='float32') + dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape) + dim_t = self.temperature**(2 * (dim_t / dim_d).astype('int64') / + self.num_pos_feats) + + pos_x = paddle.unsqueeze(x_embed, [3]) / dim_t + pos_y = paddle.unsqueeze(y_embed, [3]) / dim_t + + pos_x = paddle.flatten( + paddle.stack( + [ + paddle.sin(pos_x[:, :, :, 0::2]), + paddle.cos(pos_x[:, :, :, 1::2]) + ], + axis=4), + 3) + pos_y = paddle.flatten( + paddle.stack( + [ + paddle.sin(pos_y[:, :, :, 0::2]), + paddle.cos(pos_y[:, :, :, 1::2]) + ], + axis=4), + 3) + + pos = paddle.transpose( + paddle.concat( + [pos_y, pos_x], axis=3), [0, 3, 1, 2]) + + return pos + + +class AttDecoder(nn.Layer): + def __init__(self, ratio, is_train, input_size, hidden_size, + encoder_out_channel, dropout, dropout_ratio, word_num, + counting_decoder_out_channel, attention): + super(AttDecoder, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.out_channel = encoder_out_channel + self.attention_dim = attention['attention_dim'] + self.dropout_prob = dropout + self.ratio = ratio + self.word_num = word_num + + self.counting_num = counting_decoder_out_channel + self.is_train = is_train + + self.init_weight = nn.Linear(self.out_channel, self.hidden_size) + self.embedding = nn.Embedding(self.word_num, self.input_size) + self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size) + self.word_attention = Attention(hidden_size, attention['attention_dim']) + + self.encoder_feature_conv = nn.Conv2D( + self.out_channel, + self.attention_dim, + kernel_size=attention['word_conv_kernel'], + padding=attention['word_conv_kernel'] // 2) + + self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size) + self.word_embedding_weight = nn.Linear(self.input_size, + self.hidden_size) + self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size) + self.counting_context_weight = nn.Linear(self.counting_num, + self.hidden_size) + self.word_convert = nn.Linear(self.hidden_size, self.word_num) + + if dropout: + self.dropout = nn.Dropout(dropout_ratio) + + def forward(self, cnn_features, labels, counting_preds, images_mask): + if self.is_train: + _, num_steps = labels.shape + else: + num_steps = 36 + + batch_size, _, height, width = cnn_features.shape + images_mask = images_mask[:, :, ::self.ratio, ::self.ratio] + + word_probs = paddle.zeros((batch_size, num_steps, self.word_num)) + word_alpha_sum = paddle.zeros((batch_size, 1, height, width)) + + hidden = self.init_hidden(cnn_features, images_mask) + counting_context_weighted = self.counting_context_weight(counting_preds) + cnn_features_trans = self.encoder_feature_conv(cnn_features) + + position_embedding = PositionEmbeddingSine(256, normalize=True) + pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :]) + + cnn_features_trans = cnn_features_trans + pos + + word = paddle.ones([batch_size, 1], dtype='int64') # init word as sos + word = word.squeeze(axis=1) + for i in range(num_steps): + word_embedding = self.embedding(word) + _, hidden = self.word_input_gru(word_embedding, hidden) + word_context_vec, _, word_alpha_sum = self.word_attention( + cnn_features, cnn_features_trans, hidden, word_alpha_sum, + images_mask) + + current_state = self.word_state_weight(hidden) + word_weighted_embedding = self.word_embedding_weight(word_embedding) + word_context_weighted = self.word_context_weight(word_context_vec) + + if self.dropout_prob: + word_out_state = self.dropout( + current_state + word_weighted_embedding + + word_context_weighted + counting_context_weighted) + else: + word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted + + word_prob = self.word_convert(word_out_state) + word_probs[:, i] = word_prob + + if self.is_train: + word = labels[:, i] + else: + word = word_prob.argmax(1) + word = paddle.multiply( + word, labels[:, i] + ) # labels are oneslike tensor in infer/predict mode + + return word_probs + + def init_hidden(self, features, feature_mask): + average = paddle.sum(paddle.sum(features * feature_mask, axis=-1), + axis=-1) / paddle.sum( + (paddle.sum(feature_mask, axis=-1)), axis=-1) + average = self.init_weight(average) + return paddle.tanh(average) + + +''' +Attention Module +''' + + +class Attention(nn.Layer): + def __init__(self, hidden_size, attention_dim): + super(Attention, self).__init__() + self.hidden = hidden_size + self.attention_dim = attention_dim + self.hidden_weight = nn.Linear(self.hidden, self.attention_dim) + self.attention_conv = nn.Conv2D( + 1, 512, kernel_size=11, padding=5, bias_attr=False) + self.attention_weight = nn.Linear( + 512, self.attention_dim, bias_attr=False) + self.alpha_convert = nn.Linear(self.attention_dim, 1) + + def forward(self, + cnn_features, + cnn_features_trans, + hidden, + alpha_sum, + image_mask=None): + query = self.hidden_weight(hidden) + alpha_sum_trans = self.attention_conv(alpha_sum) + coverage_alpha = self.attention_weight( + paddle.transpose(alpha_sum_trans, [0, 2, 3, 1])) + alpha_score = paddle.tanh( + paddle.unsqueeze(query, [1, 2]) + coverage_alpha + paddle.transpose( + cnn_features_trans, [0, 2, 3, 1])) + energy = self.alpha_convert(alpha_score) + energy = energy - energy.max() + energy_exp = paddle.exp(paddle.squeeze(energy, -1)) + + if image_mask is not None: + energy_exp = energy_exp * paddle.squeeze(image_mask, 1) + alpha = energy_exp / (paddle.unsqueeze( + paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10) + alpha_sum = paddle.unsqueeze(alpha, 1) + alpha_sum + context_vector = paddle.sum( + paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1) + + return context_vector, alpha, alpha_sum + + +class CANHead(nn.Layer): + def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs): + super(CANHead, self).__init__() + + self.in_channel = in_channel + self.out_channel = out_channel + + self.counting_decoder1 = CountingDecoder(self.in_channel, + self.out_channel, 3) # mscm + self.counting_decoder2 = CountingDecoder(self.in_channel, + self.out_channel, 5) + + self.decoder = AttDecoder(ratio, **attdecoder) + + self.ratio = ratio + + def forward(self, inputs, targets=None): + cnn_features, images_mask, labels = inputs + + counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio] + counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask) + counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask) + counting_preds = (counting_preds1 + counting_preds2) / 2 + + word_probs = self.decoder(cnn_features, labels, counting_preds, + images_mask) + return word_probs, counting_preds, counting_preds1, counting_preds2 diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py index 7d45109b4857871f52764c64d6d32e5322fc7c57..be52a918458d64f0ae15b52ebf511e5068184f59 100644 --- a/ppocr/optimizer/learning_rate.py +++ b/ppocr/optimizer/learning_rate.py @@ -18,7 +18,7 @@ from __future__ import print_function from __future__ import unicode_literals from paddle.optimizer import lr -from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay +from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay, TwoStepCosineDecay class Linear(object): @@ -386,3 +386,44 @@ class MultiStepDecay(object): end_lr=self.learning_rate, last_epoch=self.last_epoch) return learning_rate + + +class TwoStepCosine(object): + """ + Cosine learning rate decay + lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1) + Args: + lr(float): initial learning rate + step_each_epoch(int): steps each epoch + epochs(int): total training epochs + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + """ + + def __init__(self, + learning_rate, + step_each_epoch, + epochs, + warmup_epoch=0, + last_epoch=-1, + **kwargs): + super(TwoStepCosine, self).__init__() + self.learning_rate = learning_rate + self.T_max1 = step_each_epoch * 200 + self.T_max2 = step_each_epoch * epochs + self.last_epoch = last_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) + + def __call__(self): + learning_rate = TwoStepCosineDecay( + learning_rate=self.learning_rate, + T_max1=self.T_max1, + T_max2=self.T_max2, + last_epoch=self.last_epoch) + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=0.0, + end_lr=self.learning_rate, + last_epoch=self.last_epoch) + return learning_rate diff --git a/ppocr/optimizer/lr_scheduler.py b/ppocr/optimizer/lr_scheduler.py index f62f1f3b0adbd8df0e03a66faa4565f2f7df28bc..cd09367e2ab8a649e3c375698f5b182eb5c3ff7a 100644 --- a/ppocr/optimizer/lr_scheduler.py +++ b/ppocr/optimizer/lr_scheduler.py @@ -160,3 +160,63 @@ class OneCycleDecay(LRScheduler): start_step = phase['end_step'] return computed_lr + + +class TwoStepCosineDecay(LRScheduler): + def __init__(self, + learning_rate, + T_max1, + T_max2, + eta_min=0, + last_epoch=-1, + verbose=False): + if not isinstance(T_max1, int): + raise TypeError( + "The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s." + % type(T_max1)) + if not isinstance(T_max2, int): + raise TypeError( + "The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s." + % type(T_max2)) + if not isinstance(eta_min, (float, int)): + raise TypeError( + "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s." + % type(eta_min)) + assert T_max1 > 0 and isinstance( + T_max1, int), " 'T_max1' must be a positive integer." + assert T_max2 > 0 and isinstance( + T_max2, int), " 'T_max1' must be a positive integer." + self.T_max1 = T_max1 + self.T_max2 = T_max2 + self.eta_min = float(eta_min) + super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch, + verbose) + + def get_lr(self): + + if self.last_epoch <= self.T_max1: + if self.last_epoch == 0: + return self.base_lr + elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0: + return self.last_lr + (self.base_lr - self.eta_min) * ( + 1 - math.cos(math.pi / self.T_max1)) / 2 + + return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / ( + 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)) * ( + self.last_lr - self.eta_min) + self.eta_min + else: + if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0: + return self.last_lr + (self.base_lr - self.eta_min) * ( + 1 - math.cos(math.pi / self.T_max2)) / 2 + + return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / ( + 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)) * ( + self.last_lr - self.eta_min) + self.eta_min + + def _get_closed_form_lr(self): + if self.last_epoch <= self.T_max1: + return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos( + math.pi * self.last_epoch / self.T_max1)) / 2 + else: + return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos( + math.pi * self.last_epoch / self.T_max2)) / 2 diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 3a09030b25461029d9160699dc591eaedab9e0db..e86a7ea70c5271482774d5c7f14a8d649184804f 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -37,6 +37,7 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode from .picodet_postprocess import PicoDetPostProcess from .ct_postprocess import CTPostProcess from .drrg_postprocess import DRRGPostprocess +from .rec_postprocess import SeqLabelDecode def build_post_process(config, global_config=None): @@ -51,7 +52,7 @@ def build_post_process(config, global_config=None): 'TableMasterLabelDecode', 'SPINLabelDecode', 'DistillationSerPostProcess', 'DistillationRePostProcess', 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', - 'RFLLabelDecode', 'DRRGPostprocess' + 'RFLLabelDecode', 'DRRGPostprocess', 'SeqLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 59b5254e480e7c52aca4ce648c379a280683db4f..4d88c278ec8ce02931478cc62492ab1a6f360594 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -897,3 +897,36 @@ class VLLabelDecode(BaseRecLabelDecode): return text label = self.decode(label) return text, label + + +class SeqLabelDecode(BaseRecLabelDecode): + """ Convert between latex-symbol and symbol-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SeqLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def decode(self, text_index, preds_prob=None): + result_list = [] + batch_size = len(text_index) + for batch_idx in range(batch_size): + seq_end = text_index[batch_idx].argmin(0) + idx_list = text_index[batch_idx][:seq_end].tolist() + symbol_list = [self.character[idx] for idx in idx_list] + probs = [] + if preds_prob is not None: + probs = preds_prob[batch_idx][:len(symbol_list)].tolist() + + result_list.append([' '.join(symbol_list), probs]) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + pred_prob, _, _, _ = preds + preds_idx = pred_prob.argmax(axis=2) + + text = self.decode(preds_idx) + if label is None: + return text + label = self.decode(label) + return text, label diff --git a/ppocr/utils/dict/latex_symbol_dict.txt b/ppocr/utils/dict/latex_symbol_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..b43f1fa8b904e3107eb450f6d7332aec6b5b81e2 --- /dev/null +++ b/ppocr/utils/dict/latex_symbol_dict.txt @@ -0,0 +1,111 @@ +eos +sos +! +' +( +) ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +< += +> +A +B +C +E +F +G +H +I +L +M +N +P +R +S +T +V +X +Y +[ +\Delta +\alpha +\beta +\cdot +\cdots +\cos +\div +\exists +\forall +\frac +\gamma +\geq +\in +\infty +\int +\lambda +\ldots +\leq +\lim +\log +\mu +\neq +\phi +\pi +\pm +\prime +\rightarrow +\sigma +\sin +\sqrt +\sum +\tan +\theta +\times +] +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +\{ +| +\} +{ +} +^ +_ \ No newline at end of file diff --git a/test_tipc/configs/rec_d28_can/rec_d28_can.yml b/test_tipc/configs/rec_d28_can/rec_d28_can.yml new file mode 100644 index 0000000000000000000000000000000000000000..aeaccb6b0d245cc4be479d17597d0029647db574 --- /dev/null +++ b/test_tipc/configs/rec_d28_can/rec_d28_can.yml @@ -0,0 +1,114 @@ +Global: + use_gpu: True + epoch_num: 240 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/can/ + save_epoch_step: 1 + # evaluation is run every 1105 iterations + eval_batch_step: [0, 1105] + cal_metric_during_train: True + pretrained_model: ./output/rec/can/CAN + checkpoints: ./output/rec/can/CAN + save_inference_dir: ./inference/rec_d28_can/ + use_visualdl: False + infer_img: doc/imgs_hme/hme_01.jpeg + # for data or label process + character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt + max_text_length: 36 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_can.txt + +Optimizer: + name: Momentum + momentum: 0.9 + clip_norm_global: 100.0 + lr: + name: TwoStepCosine + learning_rate: 0.01 + warmup_epoch: 1 + weight_decay: 0.0001 + +Architecture: + model_type: rec + algorithm: CAN + in_channels: 1 + Transform: + Backbone: + name: DenseNet + growthRate: 24 + reduction: 0.5 + bottleneck: True + use_dropout: True + input_channel: 1 + + Head: + name: CANHead + in_channel: 684 + out_channel: 111 + max_text_length: 36 + ratio: 16 + attdecoder: + is_train: True + input_size: 256 + hidden_size: 256 + encoder_out_channel: 684 + dropout: True + dropout_ratio: 0.5 + word_num: 111 + counting_decoder_out_channel: 111 + attention: + attention_dim: 512 + word_conv_kernel: 1 + +Loss: + name: CANLoss + +PostProcess: + name: SeqLabelDecode + character: 111 + +Metric: + name: CANMetric + main_indicator: exp_rate + +Train: + dataset: + name: HMERDataSet + data_dir: ./train_data/CROHME/training/images/ + transforms: + - DecodeImage: + channel_first: False + - GrayImageChannelFormat: + normalize: True + inverse: True + - KeepKeys: + keep_keys: ['image', 'label'] + label_file_list: ["./train_data/CROHME/training/labels.json"] + loader: + shuffle: True + batch_size_per_card: 2 + drop_last: True + num_workers: 1 + collate_fn: DyMaskCollator + +Eval: + dataset: + name: HMERDataSet + data_dir: ./train_data/CROHME/evaluation/images/ + transforms: + - DecodeImage: + channel_first: False + - GrayImageChannelFormat: + normalize: True + inverse: True + - KeepKeys: + keep_keys: ['image', 'label'] + label_file_list: ["./train_data/CROHME/evaluation/labels.json"] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 4 + collate_fn: DyMaskCollator diff --git a/test_tipc/configs/rec_d28_can/train_infer_python.txt b/test_tipc/configs/rec_d28_can/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..be50c59805b2f1b8a0baa5a61b9d5cd4a21b68df --- /dev/null +++ b/test_tipc/configs/rec_d28_can/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:rec_d28_can +python:python +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=240 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=8 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./doc/imgs_hme +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/rec_d28_can_train/best_accuracy +infer_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o +infer_quant:False +inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/latex_symbol_dict.txt --rec_image_shape="1,100,100" --rec_algorithm="CAN" +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--rec_model_dir:./output/ +--image_dir:./doc/imgs_hme +--save_log_path:./test/output/ +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[1,100,100]}] diff --git a/tools/eval.py b/tools/eval.py index 3d1d3813d33e251ec83a9729383fe772bc4cc225..21f4d94d5e4ed560b8775c8827ffdbbd00355218 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -74,7 +74,9 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"] + extra_input_models = [ + "SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner" + ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: @@ -83,7 +85,10 @@ def main(): else: extra_input = config['Architecture']['algorithm'] in extra_input_models if "model_type" in config['Architecture'].keys(): - model_type = config['Architecture']['model_type'] + if config['Architecture']['algorithm'] == 'CAN': + model_type = 'can' + else: + model_type = config['Architecture']['model_type'] else: model_type = None @@ -92,7 +97,7 @@ def main(): # amp use_amp = config["Global"].get("use_amp", False) amp_level = config["Global"].get("amp_level", 'O2') - amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) + amp_custom_black_list = config['Global'].get('amp_custom_black_list', []) if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, @@ -120,7 +125,8 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, extra_input, scaler, amp_level, amp_custom_black_list) + eval_class, model_type, extra_input, scaler, + amp_level, amp_custom_black_list) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/export_model.py b/tools/export_model.py index 52f05bfcba0487d1c5abd0f7d7221c2feca40ae9..4b90fcae435619a53a3def8cc4dc46b4e2963bff 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -123,6 +123,17 @@ def export_single_model(model, ] ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "CAN": + other_shape = [[ + paddle.static.InputSpec( + shape=[None, 1, None, None], + dtype="float32"), paddle.static.InputSpec( + shape=[None, 1, None, None], dtype="float32"), + paddle.static.InputSpec( + shape=[None, arch_config['Head']['max_text_length']], + dtype="int64") + ]] + model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: input_spec = [ paddle.static.InputSpec( diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index bffeb25534068691fee21bbf946cc7cda7326d27..c1604798e7506561b4f03bd401f4588c28b6cb1a 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -108,6 +108,13 @@ class TextRecognizer(object): } elif self.rec_algorithm == "PREN": postprocess_params = {'name': 'PRENLabelDecode'} + elif self.rec_algorithm == "CAN": + self.inverse = args.rec_image_inverse + postprocess_params = { + 'name': 'SeqLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -351,6 +358,30 @@ class TextRecognizer(object): return resized_image + def norm_img_can(self, img, image_shape): + + img = cv2.cvtColor( + img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image + + if self.inverse: + img = 255 - img + + if self.rec_image_shape[0] == 1: + h, w = img.shape + _, imgH, imgW = self.rec_image_shape + if h < imgH or w < imgW: + padding_h = max(imgH - h, 0) + padding_w = max(imgW - w, 0) + img_padded = np.pad(img, ((0, padding_h), (0, padding_w)), + 'constant', + constant_values=(255)) + img = img_padded + + img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w + img = img.astype('float32') + + return img + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -430,6 +461,17 @@ class TextRecognizer(object): word_positions = np.array(range(0, 40)).astype('int64') word_positions = np.expand_dims(word_positions, axis=0) word_positions_list.append(word_positions) + elif self.rec_algorithm == "CAN": + norm_img = self.norm_img_can(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_image_mask = np.ones(norm_img.shape, dtype='float32') + word_label = np.ones([1, 36], dtype='int64') + norm_img_mask_batch = [] + word_label_list = [] + norm_img_mask_batch.append(norm_image_mask) + word_label_list.append(word_label) else: norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) @@ -527,6 +569,33 @@ class TextRecognizer(object): if self.benchmark: self.autolog.times.stamp() preds = outputs[0] + elif self.rec_algorithm == "CAN": + norm_img_mask_batch = np.concatenate(norm_img_mask_batch) + word_label_list = np.concatenate(word_label_list) + inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs + else: + input_names = self.predictor.get_input_names() + input_tensor = [] + for i in range(len(input_names)): + input_tensor_i = self.predictor.get_input_handle( + input_names[i]) + input_tensor_i.copy_from_cpu(inputs[i]) + input_tensor.append(input_tensor_i) + self.input_tensor = input_tensor + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs else: if self.use_onnx: input_dict = {} diff --git a/tools/infer/utility.py b/tools/infer/utility.py index f6a44e35a5b303d6ed30bf8057a62409aa690fef..34cad2590f2904f79709530acf841033c89088e0 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -84,6 +84,7 @@ def init_args(): # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet') parser.add_argument("--rec_model_dir", type=str) + parser.add_argument("--rec_image_inverse", type=str2bool, default=True) parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320") parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--max_text_length", type=int, default=25) diff --git a/tools/infer_rec.py b/tools/infer_rec.py index cb8a6ec3050c878669f539b8b11d97214f5eec20..29aab9b57853b16bf615c893c30351a403270b57 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -141,6 +141,11 @@ def main(): paddle.to_tensor(valid_ratio), paddle.to_tensor(word_positons), ] + if config['Architecture']['algorithm'] == "CAN": + image_mask = paddle.ones( + (np.expand_dims( + batch[0], axis=0).shape), dtype='float32') + label = paddle.ones((1, 36), dtype='int64') images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) if config['Architecture']['algorithm'] == "SRN": @@ -149,6 +154,8 @@ def main(): preds = model(images, img_metas) elif config['Architecture']['algorithm'] == "RobustScanner": preds = model(images, img_metas) + elif config['Architecture']['algorithm'] == "CAN": + preds = model([images, image_mask, label]) else: preds = model(images) post_result = post_process_class(preds) diff --git a/tools/program.py b/tools/program.py index 5d2bd5bfb034940e3bec802b5e7041c8e82a9271..c491247a64697774cca73e209b714f7560c5fcdd 100755 --- a/tools/program.py +++ b/tools/program.py @@ -273,6 +273,8 @@ def train(config, preds = model(images, data=batch[1:]) elif model_type in ["kie"]: preds = model(batch) + elif algorithm in ['CAN']: + preds = model(batch[:3]) else: preds = model(images) preds = to_float32(preds) @@ -286,6 +288,8 @@ def train(config, preds = model(images, data=batch[1:]) elif model_type in ["kie", 'sr']: preds = model(batch) + elif algorithm in ['CAN']: + preds = model(batch[:3]) else: preds = model(images) loss = loss_class(preds, batch) @@ -302,6 +306,9 @@ def train(config, elif model_type in ['table']: post_result = post_process_class(preds, batch) eval_class(post_result, batch) + elif algorithm in ['CAN']: + model_type = 'can' + eval_class(preds[0], batch[2:], epoch_reset=(idx == 0)) else: if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2' ]: # for multi head loss @@ -496,6 +503,8 @@ def eval(model, preds = model(images, data=batch[1:]) elif model_type in ["kie"]: preds = model(batch) + elif model_type in ['can']: + preds = model(batch[:3]) elif model_type in ['sr']: preds = model(batch) sr_img = preds["sr_img"] @@ -508,6 +517,8 @@ def eval(model, preds = model(images, data=batch[1:]) elif model_type in ["kie"]: preds = model(batch) + elif model_type in ['can']: + preds = model(batch[:3]) elif model_type in ['sr']: preds = model(batch) sr_img = preds["sr_img"] @@ -532,6 +543,8 @@ def eval(model, eval_class(post_result, batch_numpy) elif model_type in ['sr']: eval_class(preds, batch_numpy) + elif model_type in ['can']: + eval_class(preds[0], batch_numpy[2:], epoch_reset=False) else: post_result = post_process_class(preds, batch_numpy[1]) eval_class(post_result, batch_numpy) @@ -629,7 +642,7 @@ def preprocess(is_train=False): 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG' + 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN' ] if use_xpu: