提交 f3f473d3 编写于 作者: D dorren

update CAN model

上级 8babfc86
repos:
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git - repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks: hooks:
- id: yapf - id: yapf
files: \.py$ files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464 rev: a11d9314b22d8f8c7556443875b731ef05965464
hooks: hooks:
- id: check-merge-conflict - id: check-merge-conflict
- id: check-symlinks - id: check-symlinks
...@@ -15,7 +16,7 @@ ...@@ -15,7 +16,7 @@
- id: trailing-whitespace - id: trailing-whitespace
files: \.md$ files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks - repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1 rev: v1.0.1
hooks: hooks:
- id: forbid-crlf - id: forbid-crlf
files: \.md$ files: \.md$
......
Global:
use_gpu: True
epoch_num: 240
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/can/
save_epoch_step: 1
# evaluation is run every 1105 iterations
eval_batch_step: [0, 1105]
cal_metric_during_train: True
pretrained_model: ./output/rec/can/CAN
checkpoints: ./output/rec/can/CAN
save_inference_dir: ./inference/rec_d28_can/
use_visualdl: False
infer_img: doc/imgs_hme/hme_01.jpeg
# for data or label process
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
max_text_length: 36
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_can.txt
Optimizer:
name: Momentum
momentum: 0.9
clip_norm_global: 100.0
lr:
name: TwoStepCosine
learning_rate: 0.01
warmup_epoch: 1
weight_decay: 0.0001
Architecture:
model_type: rec
algorithm: CAN
in_channels: 1
Transform:
Backbone:
name: DenseNet
growthRate: 24
reduction: 0.5
bottleneck: True
use_dropout: True
input_channel: 1
Head:
name: CANHead
in_channel: 684
out_channel: 111
max_text_length: 36
ratio: 16
attdecoder:
is_train: True
input_size: 256
hidden_size: 256
encoder_out_channel: 684
dropout: True
dropout_ratio: 0.5
word_num: 111
counting_decoder_out_channel: 111
attention:
attention_dim: 512
word_conv_kernel: 1
Loss:
name: CANLoss
PostProcess:
name: SeqLabelDecode
character: 111
Metric:
name: CANMetric
main_indicator: exp_rate
Train:
dataset:
name: HMERDataSet
data_dir: ./train_data/CROHME/training/images/
transforms:
- DecodeImage:
channel_first: False
- GrayImageChannelFormat:
normalize: True
inverse: True
- KeepKeys:
keep_keys: ['image', 'label']
label_file_list: ["./train_data/CROHME/training/labels.json"]
loader:
shuffle: True
batch_size_per_card: 2
drop_last: True
num_workers: 1
collate_fn: DyMaskCollator
Eval:
dataset:
name: HMERDataSet
data_dir: ./train_data/CROHME/evaluation/images/
transforms:
- DecodeImage:
channel_first: False
- GrayImageChannelFormat:
normalize: True
inverse: True
- KeepKeys:
keep_keys: ['image', 'label']
label_file_list: ["./train_data/CROHME/evaluation/labels.json"]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1
num_workers: 4
collate_fn: DyMaskCollator
# 手写数学公式识别算法-ABINet
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463)
> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai
> ECCV, 2022
<a name="model"></a>
`CAN`使用CROHME手写公式数据集进行训练,在对应测试集上的精度如下:
|模型 |骨干网络|配置文件|ExpRate|下载链接|
| ----- | ----- | ----- | ----- | ----- |
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
<a name="3-1"></a>
### 3.1 模型训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`CAN`识别模型时需要**更换配置文件**`CAN`[配置文件](../../configs/rec/rec_d28_can.yml)
#### 启动训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```shell
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_d28_can.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml
```
**注意:**
- 我们提供的数据集,即`CROHME数据集`将手写公式存储为黑底白字的格式,若您自行准备的数据集与之相反,即以白底黑字模式存储,请在训练时做出如下修改
```
python3 tools/train.py -c configs/rec/rec_d28_can.yml
-o Train.dataset.transforms.GrayImageChannelFormat.inverse=False
```
#
<a name="3-2"></a>
### 3.2 评估
可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/best_accuracy
```
<a name="3-3"></a>
### 3.3 预测
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/imgs_hme/hme_01.jpg' Global.pretrained_model=./rec_d28_can_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_hme/'。
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_d28_can_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
# 目前的静态图模型默认的输出长度最大为36,如果您需要预测更长的序列,请在导出模型时指定其输出序列为合适的值,例如 Architecture.Head.max_text_length=72
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应ABINet的`infer_shape`
转换成功后,在目录下有三个文件:
```
/inference/rec_d28_can/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_hme/hme_01.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_hme/'。
# 如果您需要在白底黑字的图片上进行预测,请设置 --rec_image_inverse=False
```
![测试图片样例](../imgs_hme/hme_00.jpg)
执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下:
```shell
Predicts of ./doc/imgs_hme/hme_03.jpg:['x _ { k } x x _ { k } + y _ { k } y x _ { k }', []]
```
**注意**
- 需要注意预测图像为**黑底白字**,即手写公式部分为白色,背景为黑色的图片。
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中CAN的预处理为您的预处理方法。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持ABINet,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
1. CROHME数据集来自于[CAN源repo](https://github.com/LBH1024/CAN)
## 引用
```bibtex
@misc{https://doi.org/10.48550/arxiv.2207.11463,
doi = {10.48550/ARXIV.2207.11463},
url = {https://arxiv.org/abs/2207.11463},
author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang},
keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition},
publisher = {arXiv},
year = {2022},
copyright = {arXiv.org perpetual, non-exclusive license}
}
```
# RobustScanner
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/abs/2207.11463)
> Bohan Li, Ye Yuan, Dingkang Liang, Xiao Liu, Zhilong Ji, Jinfeng Bai, Wenyu Liu, Xiang Bai
> ECCV, 2022
Using CROHME handwrittem mathematical expression recognition datasets for training, and evaluating on its test sets, the algorithm reproduction effect is as follows:
|Model|Backbone|config|exprate|Download link|
| --- | --- | --- | --- | --- |
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|coming soon|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/rec/rec_d28_can.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_d28_can.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/imgs_hme/hme_01.jpg' Global.pretrained_model=./rec_d28_can_train/best_accuracy
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
```
For RobustScanner text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_hme/hme_01.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_image_shape="1, 132, 519" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@misc{https://doi.org/10.48550/arxiv.2207.11463,
doi = {10.48550/ARXIV.2207.11463},
url = {https://arxiv.org/abs/2207.11463},
author = {Li, Bohan and Yuan, Ye and Liang, Dingkang and Liu, Xiao and Ji, Zhilong and Bai, Jinfeng and Liu, Wenyu and Bai, Xiang},
keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition},
publisher = {arXiv},
year = {2022},
copyright = {arXiv.org perpetual, non-exclusive license}
}
```
...@@ -37,6 +37,7 @@ from ppocr.data.simple_dataset import SimpleDataSet ...@@ -37,6 +37,7 @@ from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet from ppocr.data.pubtab_dataset import PubTabDataSet
from ppocr.data.hmer_dataset import HMERDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -55,7 +56,7 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -55,7 +56,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
support_dict = [ support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet', 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
'LMDBDataSetSR' 'LMDBDataSetSR', 'HMERDataSet'
] ]
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
......
...@@ -70,3 +70,49 @@ class SSLRotateCollate(object): ...@@ -70,3 +70,49 @@ class SSLRotateCollate(object):
def __call__(self, batch): def __call__(self, batch):
output = [np.concatenate(d, axis=0) for d in zip(*batch)] output = [np.concatenate(d, axis=0) for d in zip(*batch)]
return output return output
class DyMaskCollator(object):
"""
batch: [
image [batch_size, channel, maxHinbatch, maxWinbatch]
image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
label [batch_size, maxLabelLen]
label_mask [batch_size, maxLabelLen]
...
]
"""
def __call__(self, batch):
max_width, max_height, max_length = 0, 0, 0
bs, channel = len(batch), batch[0][0].shape[0]
proper_items = []
for item in batch:
if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
2] * max_height > 1600 * 320:
continue
max_height = item[0].shape[1] if item[0].shape[
1] > max_height else max_height
max_width = item[0].shape[2] if item[0].shape[
2] > max_width else max_width
max_length = item[1].shape[0] if item[1].shape[
0] > max_length else max_length
proper_items.append(item)
images, image_masks = np.zeros(
(len(proper_items), channel, max_height, max_width),
dtype='float32'), np.zeros(
(len(proper_items), 1, max_height, max_width), dtype='float32')
labels, label_masks = np.zeros(
(len(proper_items), max_length), dtype='int64'), np.zeros(
(len(proper_items), max_length), dtype='int64')
for i in range(len(proper_items)):
_, h, w = proper_items[i][0].shape
images[i][:, :h, :w] = proper_items[i][0]
image_masks[i][:, :h, :w] = 1
l = proper_items[i][1].shape[0]
labels[i][:l] = proper_items[i][1]
label_masks[i][:l] = 1
return images, image_masks, labels, label_masks
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, json, random, traceback
import numpy as np
from PIL import Image
from paddle.io import Dataset
from .imaug import transform, create_operators
class HMERDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(HMERDataSet, self).__init__()
self.logger = logger
self.seed = seed
self.mode = mode
global_config = config['Global']
dataset_config = config[mode]['dataset']
self.data_dir = config[mode]['dataset']['data_dir']
label_file_list = dataset_config['label_file_list']
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.data_lines, self.labels = self.get_image_info_list(label_file_list,
ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
labels = {}
for idx, file in enumerate(file_list):
with open(file, "r") as f:
lines = json.load(f)
labels.update(lines)
data_lines = [name for name in labels.keys()]
return data_lines, labels
def shuffle_data_random(self):
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def __len__(self):
return len(self.data_idx_order_list)
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_name = self.data_lines[file_idx]
try:
file_name = data_name + '.jpg'
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(img_path, 'rb') as f:
img = f.read()
label = self.labels.get(data_name).split()
label = np.array([int(item) for item in label])
data = {'image': img, 'label': label}
outs = transform(data, self.ops)
except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
file_name, traceback.format_exc()))
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
return outs
...@@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt ...@@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
RFLRecResizeImg RFLRecResizeImg, GrayImageChannelFormat
from .ssl_img_aug import SSLRotateResize from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
......
...@@ -465,6 +465,36 @@ class RobustScannerRecResizeImg(object): ...@@ -465,6 +465,36 @@ class RobustScannerRecResizeImg(object):
return data return data
class GrayImageChannelFormat(object):
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
normalize: True/False
when True convert image dynamic range [0,255]->[0,1]
inverse: inverse gray image
"""
def __init__(self, normalize=True, inverse=False, **kwargs):
self.normalize = normalize
self.inverse = inverse
def __call__(self, data):
img = data['image']
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_single_channel = np.expand_dims(img_single_channel, 0)
if self.normalize:
img_single_channel = img_single_channel / 255.0
if self.inverse:
data['image'] = np.abs(img_single_channel - 1).astype('float32')
else:
data['image'] = img_single_channel.astype('float32')
data['src_image'] = img
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0] h = img.shape[0]
......
...@@ -40,6 +40,7 @@ from .rec_multi_loss import MultiLoss ...@@ -40,6 +40,7 @@ from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss from .rec_rfl_loss import RFLLoss
from .rec_can_loss import CANLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
...@@ -71,7 +72,7 @@ def build_loss(config): ...@@ -71,7 +72,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss' 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
import paddle
import paddle.nn as nn
import numpy as np
class CANLoss(nn.Layer):
'''
CANLoss is consist of two part:
word_average_loss: average accuracy of the symbol
counting_loss: counting loss of every symbol
'''
def __init__(self):
super(CANLoss, self).__init__()
self.use_label_mask = False
self.out_channel = 111
self.cross = nn.CrossEntropyLoss(
reduction='none') if self.use_label_mask else nn.CrossEntropyLoss()
self.counting_loss = nn.SmoothL1Loss(reduction='mean')
self.ratio = 16
def forward(self, preds, batch):
word_probs = preds[0]
counting_preds = preds[1]
counting_preds1 = preds[2]
counting_preds2 = preds[3]
labels = batch[2]
labels_mask = batch[3]
counting_labels = gen_counting_label(labels, self.out_channel, True)
counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \
+ self.counting_loss(counting_preds, counting_labels)
word_loss = self.cross(
paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
paddle.reshape(labels, [-1]))
word_average_loss = paddle.sum(
paddle.reshape(word_loss * labels_mask, [-1])) / (
paddle.sum(labels_mask) + 1e-10
) if self.use_label_mask else word_loss
loss = word_average_loss + counting_loss
return {'loss': loss}
def gen_counting_label(labels, channel, tag):
b, t = labels.shape
counting_labels = np.zeros([b, channel])
if tag:
ignore = [0, 1, 107, 108, 109, 110]
else:
ignore = []
for i in range(b):
for j in range(t):
k = labels[i][j]
if k in ignore:
continue
else:
counting_labels[i][k] += 1
counting_labels = paddle.to_tensor(counting_labels, dtype='float32')
return counting_labels
...@@ -22,7 +22,7 @@ import copy ...@@ -22,7 +22,7 @@ import copy
__all__ = ["build_metric"] __all__ = ["build_metric"]
from .det_metric import DetMetric, DetFCEMetric from .det_metric import DetMetric, DetFCEMetric
from .rec_metric import RecMetric, CNTMetric from .rec_metric import RecMetric, CNTMetric, CANMetric
from .cls_metric import ClsMetric from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric from .distillation_metric import DistillationMetric
...@@ -38,7 +38,7 @@ def build_metric(config): ...@@ -38,7 +38,7 @@ def build_metric(config):
support_dict = [ support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric' 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric', 'CANMetric'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
from rapidfuzz.distance import Levenshtein from rapidfuzz.distance import Levenshtein
from difflib import SequenceMatcher
import numpy as np
import string import string
...@@ -106,3 +109,71 @@ class CNTMetric(object): ...@@ -106,3 +109,71 @@ class CNTMetric(object):
def reset(self): def reset(self):
self.correct_num = 0 self.correct_num = 0
self.all_num = 0 self.all_num = 0
class CANMetric(object):
def __init__(self, main_indicator='exp_rate', **kwargs):
self.main_indicator = main_indicator
self.word_right = []
self.exp_right = []
self.word_total_length = 0
self.exp_total_num = 0
self.word_rate = 0
self.exp_rate = 0
self.reset()
self.epoch_reset()
def __call__(self, preds, batch, **kwargs):
for k, v in kwargs.items():
epoch_reset = v
if epoch_reset:
self.epoch_reset()
word_probs = preds
word_label, word_label_mask = batch
line_right = 0
if word_probs is not None:
word_pred = word_probs.argmax(2)
word_pred = word_pred.cpu().detach().numpy()
word_scores = [
SequenceMatcher(
None,
s1[:int(np.sum(s3))],
s2[:int(np.sum(s3))],
autojunk=False).ratio() * (
len(s1[:int(np.sum(s3))]) + len(s2[:int(np.sum(s3))])) /
len(s1[:int(np.sum(s3))]) / 2
for s1, s2, s3 in zip(word_label, word_pred, word_label_mask)
]
batch_size = len(word_scores)
for i in range(batch_size):
if word_scores[i] == 1:
line_right += 1
self.word_rate = np.mean(word_scores) #float
self.exp_rate = line_right / batch_size #float
exp_length, word_length = word_label.shape[:2]
self.word_right.append(self.word_rate * word_length)
self.exp_right.append(self.exp_rate * exp_length)
self.word_total_length = self.word_total_length + word_length
self.exp_total_num = self.exp_total_num + exp_length
def get_metric(self):
"""
return {
'word_rate': 0,
"exp_rate": 0,
}
"""
cur_word_rate = sum(self.word_right) / self.word_total_length
cur_exp_rate = sum(self.exp_right) / self.exp_total_num
self.reset()
return {'word_rate': cur_word_rate, "exp_rate": cur_exp_rate}
def reset(self):
self.word_rate = 0
self.exp_rate = 0
def epoch_reset(self):
self.word_right = []
self.exp_right = []
self.word_total_length = 0
self.exp_total_num = 0
...@@ -43,10 +43,12 @@ def build_backbone(config, model_type): ...@@ -43,10 +43,12 @@ def build_backbone(config, model_type):
from .rec_svtrnet import SVTRNet from .rec_svtrnet import SVTRNet
from .rec_vitstr import ViTSTR from .rec_vitstr import ViTSTR
from .rec_resnet_rfl import ResNetRFL from .rec_resnet_rfl import ResNetRFL
from .rec_densenet import DenseNet
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL' 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet'
] ]
elif model_type == 'e2e': elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
......
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class Bottleneck(nn.Layer):
'''
ratio: 16
growthRate: 24
reduction: 0.5
bottleneck: True
use_dropout: True
'''
def __init__(self, nChannels, growthRate, use_dropout):
super(Bottleneck, self).__init__()
interChannels = 4 * growthRate
self.bn1 = nn.BatchNorm2D(interChannels)
self.conv1 = nn.Conv2D(
nChannels, interChannels, kernel_size=1,
bias_attr=None) # Xavier initialization
self.bn2 = nn.BatchNorm2D(growthRate)
self.conv2 = nn.Conv2D(
interChannels, growthRate, kernel_size=3, padding=1,
bias_attr=None) # Xavier initialization
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
if self.use_dropout:
out = self.dropout(out)
out = F.relu(self.bn2(self.conv2(out)))
if self.use_dropout:
out = self.dropout(out)
out = paddle.concat([x, out], 1)
return out
class SingleLayer(nn.Layer):
def __init__(self, nChannels, growthRate, use_dropout):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2D(nChannels)
self.conv1 = nn.Conv2D(
nChannels, growthRate, kernel_size=3, padding=1, bias_attr=False)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
def forward(self, x):
out = self.conv1(F.relu(x))
if self.use_dropout:
out = self.dropout(out)
out = paddle.concat([x, out], 1)
return out
class Transition(nn.Layer):
def __init__(self, nChannels, out_channels, use_dropout):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2D(out_channels)
self.conv1 = nn.Conv2D(
nChannels, out_channels, kernel_size=1, bias_attr=False)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
if self.use_dropout:
out = self.dropout(out)
out = F.avg_pool2d(out, 2, ceil_mode=True, exclusive=False)
return out
class DenseNet(nn.Layer):
def __init__(self, growthRate, reduction, bottleneck, use_dropout,
input_channel, **kwargs):
super(DenseNet, self).__init__()
'''
ratio: 16
growthRate: 24
reduction: 0.5
'''
nDenseBlocks = 16
nChannels = 2 * growthRate
self.conv1 = nn.Conv2D(
input_channel,
nChannels,
kernel_size=7,
padding=3,
stride=2,
bias_attr=False)
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks,
bottleneck, use_dropout)
nChannels += nDenseBlocks * growthRate
out_channels = int(math.floor(nChannels * reduction))
self.trans1 = Transition(nChannels, out_channels, use_dropout)
nChannels = out_channels
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks,
bottleneck, use_dropout)
nChannels += nDenseBlocks * growthRate
out_channels = int(math.floor(nChannels * reduction))
self.trans2 = Transition(nChannels, out_channels, use_dropout)
nChannels = out_channels
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks,
bottleneck, use_dropout)
self.out_channels = out_channels
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,
use_dropout):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels, growthRate, use_dropout))
else:
layers.append(SingleLayer(nChannels, growthRate, use_dropout))
nChannels += growthRate
return nn.Sequential(*layers)
def forward(self, inputs):
x, x_m, y = inputs
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2, ceil_mode=True)
out = self.dense1(out)
out = self.trans1(out)
out = self.dense2(out)
out = self.trans2(out)
out = self.dense3(out)
return out, x_m, y
...@@ -40,6 +40,7 @@ def build_head(config): ...@@ -40,6 +40,7 @@ def build_head(config):
from .rec_robustscanner_head import RobustScannerHead from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VLHead from .rec_visionlan_head import VLHead
from .rec_rfl_head import RFLHead from .rec_rfl_head import RFLHead
from .rec_can_head import CANHead
# cls head # cls head
from .cls_head import ClsHead from .cls_head import ClsHead
...@@ -56,7 +57,7 @@ def build_head(config): ...@@ -56,7 +57,7 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead', 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
'DRRGHead' 'DRRGHead', 'CANHead'
] ]
#table head #table head
......
from turtle import forward
import paddle.nn as nn
import paddle
import math
'''
Counting Module
'''
class ChannelAtt(nn.Layer):
def __init__(self, channel, reduction):
super(ChannelAtt, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid())
def forward(self, x):
b, c, _, _ = x.shape
y = paddle.reshape(self.avg_pool(x), [b, c])
y = paddle.reshape(self.fc(y), [b, c, 1, 1])
return x * y
class CountingDecoder(nn.Layer):
def __init__(self, in_channel, out_channel, kernel_size):
super(CountingDecoder, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.trans_layer = nn.Sequential(
nn.Conv2D(
self.in_channel,
512,
kernel_size=kernel_size,
padding=kernel_size // 2,
bias_attr=False),
nn.BatchNorm2D(512))
self.channel_att = ChannelAtt(512, 16)
self.pred_layer = nn.Sequential(
nn.Conv2D(
512, self.out_channel, kernel_size=1, bias_attr=False),
nn.Sigmoid())
def forward(self, x, mask):
b, _, h, w = x.shape
x = self.trans_layer(x)
x = self.channel_att(x)
x = self.pred_layer(x)
if mask is not None:
x = x * mask
x = paddle.reshape(x, [b, self.out_channel, -1])
x1 = paddle.sum(x, axis=-1)
return x1, paddle.reshape(x, [b, self.out_channel, h, w])
'''
Attention Decoder
'''
class PositionEmbeddingSine(nn.Layer):
def __init__(self,
num_pos_feats=64,
temperature=10000,
normalize=False,
scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask):
y_embed = paddle.cumsum(mask, 1, dtype='float32')
x_embed = paddle.cumsum(mask, 2, dtype='float32')
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = paddle.arange(self.num_pos_feats, dtype='float32')
dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape)
dim_t = self.temperature**(2 * (dim_t / dim_d).astype('int64') /
self.num_pos_feats)
pos_x = paddle.unsqueeze(x_embed, [3]) / dim_t
pos_y = paddle.unsqueeze(y_embed, [3]) / dim_t
pos_x = paddle.flatten(
paddle.stack(
[
paddle.sin(pos_x[:, :, :, 0::2]),
paddle.cos(pos_x[:, :, :, 1::2])
],
axis=4),
3)
pos_y = paddle.flatten(
paddle.stack(
[
paddle.sin(pos_y[:, :, :, 0::2]),
paddle.cos(pos_y[:, :, :, 1::2])
],
axis=4),
3)
pos = paddle.transpose(
paddle.concat(
[pos_y, pos_x], axis=3), [0, 3, 1, 2])
return pos
class AttDecoder(nn.Layer):
def __init__(self, ratio, is_train, input_size, hidden_size,
encoder_out_channel, dropout, dropout_ratio, word_num,
counting_decoder_out_channel, attention):
super(AttDecoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.out_channel = encoder_out_channel
self.attention_dim = attention['attention_dim']
self.dropout_prob = dropout
self.ratio = ratio
self.word_num = word_num
self.counting_num = counting_decoder_out_channel
self.is_train = is_train
self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
self.embedding = nn.Embedding(self.word_num, self.input_size)
self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
self.word_attention = Attention(hidden_size, attention['attention_dim'])
self.encoder_feature_conv = nn.Conv2D(
self.out_channel,
self.attention_dim,
kernel_size=attention['word_conv_kernel'],
padding=attention['word_conv_kernel'] // 2)
self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
self.word_embedding_weight = nn.Linear(self.input_size,
self.hidden_size)
self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
self.counting_context_weight = nn.Linear(self.counting_num,
self.hidden_size)
self.word_convert = nn.Linear(self.hidden_size, self.word_num)
if dropout:
self.dropout = nn.Dropout(dropout_ratio)
def forward(self, cnn_features, labels, counting_preds, images_mask):
if self.is_train:
_, num_steps = labels.shape
else:
num_steps = 36
batch_size, _, height, width = cnn_features.shape
images_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
word_probs = paddle.zeros((batch_size, num_steps, self.word_num))
word_alpha_sum = paddle.zeros((batch_size, 1, height, width))
hidden = self.init_hidden(cnn_features, images_mask)
counting_context_weighted = self.counting_context_weight(counting_preds)
cnn_features_trans = self.encoder_feature_conv(cnn_features)
position_embedding = PositionEmbeddingSine(256, normalize=True)
pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])
cnn_features_trans = cnn_features_trans + pos
word = paddle.ones([batch_size, 1], dtype='int64') # init word as sos
word = word.squeeze(axis=1)
for i in range(num_steps):
word_embedding = self.embedding(word)
_, hidden = self.word_input_gru(word_embedding, hidden)
word_context_vec, _, word_alpha_sum = self.word_attention(
cnn_features, cnn_features_trans, hidden, word_alpha_sum,
images_mask)
current_state = self.word_state_weight(hidden)
word_weighted_embedding = self.word_embedding_weight(word_embedding)
word_context_weighted = self.word_context_weight(word_context_vec)
if self.dropout_prob:
word_out_state = self.dropout(
current_state + word_weighted_embedding +
word_context_weighted + counting_context_weighted)
else:
word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
word_prob = self.word_convert(word_out_state)
word_probs[:, i] = word_prob
if self.is_train:
word = labels[:, i]
else:
word = word_prob.argmax(1)
word = paddle.multiply(
word, labels[:, i]
) # labels are oneslike tensor in infer/predict mode
return word_probs
def init_hidden(self, features, feature_mask):
average = paddle.sum(paddle.sum(features * feature_mask, axis=-1),
axis=-1) / paddle.sum(
(paddle.sum(feature_mask, axis=-1)), axis=-1)
average = self.init_weight(average)
return paddle.tanh(average)
'''
Attention Module
'''
class Attention(nn.Layer):
def __init__(self, hidden_size, attention_dim):
super(Attention, self).__init__()
self.hidden = hidden_size
self.attention_dim = attention_dim
self.hidden_weight = nn.Linear(self.hidden, self.attention_dim)
self.attention_conv = nn.Conv2D(
1, 512, kernel_size=11, padding=5, bias_attr=False)
self.attention_weight = nn.Linear(
512, self.attention_dim, bias_attr=False)
self.alpha_convert = nn.Linear(self.attention_dim, 1)
def forward(self,
cnn_features,
cnn_features_trans,
hidden,
alpha_sum,
image_mask=None):
query = self.hidden_weight(hidden)
alpha_sum_trans = self.attention_conv(alpha_sum)
coverage_alpha = self.attention_weight(
paddle.transpose(alpha_sum_trans, [0, 2, 3, 1]))
alpha_score = paddle.tanh(
paddle.unsqueeze(query, [1, 2]) + coverage_alpha + paddle.transpose(
cnn_features_trans, [0, 2, 3, 1]))
energy = self.alpha_convert(alpha_score)
energy = energy - energy.max()
energy_exp = paddle.exp(paddle.squeeze(energy, -1))
if image_mask is not None:
energy_exp = energy_exp * paddle.squeeze(image_mask, 1)
alpha = energy_exp / (paddle.unsqueeze(
paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10)
alpha_sum = paddle.unsqueeze(alpha, 1) + alpha_sum
context_vector = paddle.sum(
paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1)
return context_vector, alpha, alpha_sum
class CANHead(nn.Layer):
def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
super(CANHead, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.counting_decoder1 = CountingDecoder(self.in_channel,
self.out_channel, 3) # mscm
self.counting_decoder2 = CountingDecoder(self.in_channel,
self.out_channel, 5)
self.decoder = AttDecoder(ratio, **attdecoder)
self.ratio = ratio
def forward(self, inputs, targets=None):
cnn_features, images_mask, labels = inputs
counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio]
counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
counting_preds = (counting_preds1 + counting_preds2) / 2
word_probs = self.decoder(cnn_features, labels, counting_preds,
images_mask)
return word_probs, counting_preds, counting_preds1, counting_preds2
...@@ -18,7 +18,7 @@ from __future__ import print_function ...@@ -18,7 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from paddle.optimizer import lr from paddle.optimizer import lr
from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay, TwoStepCosineDecay
class Linear(object): class Linear(object):
...@@ -386,3 +386,44 @@ class MultiStepDecay(object): ...@@ -386,3 +386,44 @@ class MultiStepDecay(object):
end_lr=self.learning_rate, end_lr=self.learning_rate,
last_epoch=self.last_epoch) last_epoch=self.last_epoch)
return learning_rate return learning_rate
class TwoStepCosine(object):
"""
Cosine learning rate decay
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
learning_rate,
step_each_epoch,
epochs,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(TwoStepCosine, self).__init__()
self.learning_rate = learning_rate
self.T_max1 = step_each_epoch * 200
self.T_max2 = step_each_epoch * epochs
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = TwoStepCosineDecay(
learning_rate=self.learning_rate,
T_max1=self.T_max1,
T_max2=self.T_max2,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
...@@ -160,3 +160,63 @@ class OneCycleDecay(LRScheduler): ...@@ -160,3 +160,63 @@ class OneCycleDecay(LRScheduler):
start_step = phase['end_step'] start_step = phase['end_step']
return computed_lr return computed_lr
class TwoStepCosineDecay(LRScheduler):
def __init__(self,
learning_rate,
T_max1,
T_max2,
eta_min=0,
last_epoch=-1,
verbose=False):
if not isinstance(T_max1, int):
raise TypeError(
"The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s."
% type(T_max1))
if not isinstance(T_max2, int):
raise TypeError(
"The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s."
% type(T_max2))
if not isinstance(eta_min, (float, int)):
raise TypeError(
"The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
% type(eta_min))
assert T_max1 > 0 and isinstance(
T_max1, int), " 'T_max1' must be a positive integer."
assert T_max2 > 0 and isinstance(
T_max2, int), " 'T_max1' must be a positive integer."
self.T_max1 = T_max1
self.T_max2 = T_max2
self.eta_min = float(eta_min)
super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch,
verbose)
def get_lr(self):
if self.last_epoch <= self.T_max1:
if self.last_epoch == 0:
return self.base_lr
elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0:
return self.last_lr + (self.base_lr - self.eta_min) * (
1 - math.cos(math.pi / self.T_max1)) / 2
return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / (
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)) * (
self.last_lr - self.eta_min) + self.eta_min
else:
if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0:
return self.last_lr + (self.base_lr - self.eta_min) * (
1 - math.cos(math.pi / self.T_max2)) / 2
return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / (
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)) * (
self.last_lr - self.eta_min) + self.eta_min
def _get_closed_form_lr(self):
if self.last_epoch <= self.T_max1:
return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
math.pi * self.last_epoch / self.T_max1)) / 2
else:
return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
math.pi * self.last_epoch / self.T_max2)) / 2
...@@ -37,6 +37,7 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode ...@@ -37,6 +37,7 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess from .picodet_postprocess import PicoDetPostProcess
from .ct_postprocess import CTPostProcess from .ct_postprocess import CTPostProcess
from .drrg_postprocess import DRRGPostprocess from .drrg_postprocess import DRRGPostprocess
from .rec_postprocess import SeqLabelDecode
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
...@@ -51,7 +52,7 @@ def build_post_process(config, global_config=None): ...@@ -51,7 +52,7 @@ def build_post_process(config, global_config=None):
'TableMasterLabelDecode', 'SPINLabelDecode', 'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess', 'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
'RFLLabelDecode', 'DRRGPostprocess' 'RFLLabelDecode', 'DRRGPostprocess', 'SeqLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -897,3 +897,36 @@ class VLLabelDecode(BaseRecLabelDecode): ...@@ -897,3 +897,36 @@ class VLLabelDecode(BaseRecLabelDecode):
return text return text
label = self.decode(label) label = self.decode(label)
return text, label return text, label
class SeqLabelDecode(BaseRecLabelDecode):
""" Convert between latex-symbol and symbol-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SeqLabelDecode, self).__init__(character_dict_path,
use_space_char)
def decode(self, text_index, preds_prob=None):
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
seq_end = text_index[batch_idx].argmin(0)
idx_list = text_index[batch_idx][:seq_end].tolist()
symbol_list = [self.character[idx] for idx in idx_list]
probs = []
if preds_prob is not None:
probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
result_list.append([' '.join(symbol_list), probs])
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
pred_prob, _, _, _ = preds
preds_idx = pred_prob.argmax(axis=2)
text = self.decode(preds_idx)
if label is None:
return text
label = self.decode(label)
return text, label
eos
sos
!
'
(
)
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
<
=
>
A
B
C
E
F
G
H
I
L
M
N
P
R
S
T
V
X
Y
[
\Delta
\alpha
\beta
\cdot
\cdots
\cos
\div
\exists
\forall
\frac
\gamma
\geq
\in
\infty
\int
\lambda
\ldots
\leq
\lim
\log
\mu
\neq
\phi
\pi
\pm
\prime
\rightarrow
\sigma
\sin
\sqrt
\sum
\tan
\theta
\times
]
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
\{
|
\}
{
}
^
_
\ No newline at end of file
Global:
use_gpu: True
epoch_num: 240
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/can/
save_epoch_step: 1
# evaluation is run every 1105 iterations
eval_batch_step: [0, 1105]
cal_metric_during_train: True
pretrained_model: ./output/rec/can/CAN
checkpoints: ./output/rec/can/CAN
save_inference_dir: ./inference/rec_d28_can/
use_visualdl: False
infer_img: doc/imgs_hme/hme_01.jpeg
# for data or label process
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
max_text_length: 36
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_can.txt
Optimizer:
name: Momentum
momentum: 0.9
clip_norm_global: 100.0
lr:
name: TwoStepCosine
learning_rate: 0.01
warmup_epoch: 1
weight_decay: 0.0001
Architecture:
model_type: rec
algorithm: CAN
in_channels: 1
Transform:
Backbone:
name: DenseNet
growthRate: 24
reduction: 0.5
bottleneck: True
use_dropout: True
input_channel: 1
Head:
name: CANHead
in_channel: 684
out_channel: 111
max_text_length: 36
ratio: 16
attdecoder:
is_train: True
input_size: 256
hidden_size: 256
encoder_out_channel: 684
dropout: True
dropout_ratio: 0.5
word_num: 111
counting_decoder_out_channel: 111
attention:
attention_dim: 512
word_conv_kernel: 1
Loss:
name: CANLoss
PostProcess:
name: SeqLabelDecode
character: 111
Metric:
name: CANMetric
main_indicator: exp_rate
Train:
dataset:
name: HMERDataSet
data_dir: ./train_data/CROHME/training/images/
transforms:
- DecodeImage:
channel_first: False
- GrayImageChannelFormat:
normalize: True
inverse: True
- KeepKeys:
keep_keys: ['image', 'label']
label_file_list: ["./train_data/CROHME/training/labels.json"]
loader:
shuffle: True
batch_size_per_card: 2
drop_last: True
num_workers: 1
collate_fn: DyMaskCollator
Eval:
dataset:
name: HMERDataSet
data_dir: ./train_data/CROHME/evaluation/images/
transforms:
- DecodeImage:
channel_first: False
- GrayImageChannelFormat:
normalize: True
inverse: True
- KeepKeys:
keep_keys: ['image', 'label']
label_file_list: ["./train_data/CROHME/evaluation/labels.json"]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1
num_workers: 4
collate_fn: DyMaskCollator
===========================train_params===========================
model_name:rec_d28_can
python:python
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=240
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=8
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./doc/imgs_hme
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_d28_can_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/latex_symbol_dict.txt --rec_image_shape="1,100,100" --rec_algorithm="CAN"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--rec_model_dir:./output/
--image_dir:./doc/imgs_hme
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[1,100,100]}]
...@@ -74,7 +74,9 @@ def main(): ...@@ -74,7 +74,9 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"] extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
]
extra_input = False extra_input = False
if config['Architecture']['algorithm'] == 'Distillation': if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
...@@ -83,7 +85,10 @@ def main(): ...@@ -83,7 +85,10 @@ def main():
else: else:
extra_input = config['Architecture']['algorithm'] in extra_input_models extra_input = config['Architecture']['algorithm'] in extra_input_models
if "model_type" in config['Architecture'].keys(): if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type'] if config['Architecture']['algorithm'] == 'CAN':
model_type = 'can'
else:
model_type = config['Architecture']['model_type']
else: else:
model_type = None model_type = None
...@@ -92,7 +97,7 @@ def main(): ...@@ -92,7 +97,7 @@ def main():
# amp # amp
use_amp = config["Global"].get("use_amp", False) use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2') amp_level = config["Global"].get("amp_level", 'O2')
amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
if use_amp: if use_amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
...@@ -120,7 +125,8 @@ def main(): ...@@ -120,7 +125,8 @@ def main():
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, extra_input, scaler, amp_level, amp_custom_black_list) eval_class, model_type, extra_input, scaler,
amp_level, amp_custom_black_list)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
......
...@@ -123,6 +123,17 @@ def export_single_model(model, ...@@ -123,6 +123,17 @@ def export_single_model(model,
] ]
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "CAN":
other_shape = [[
paddle.static.InputSpec(
shape=[None, 1, None, None],
dtype="float32"), paddle.static.InputSpec(
shape=[None, 1, None, None], dtype="float32"),
paddle.static.InputSpec(
shape=[None, arch_config['Head']['max_text_length']],
dtype="int64")
]]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(
......
...@@ -108,6 +108,13 @@ class TextRecognizer(object): ...@@ -108,6 +108,13 @@ class TextRecognizer(object):
} }
elif self.rec_algorithm == "PREN": elif self.rec_algorithm == "PREN":
postprocess_params = {'name': 'PRENLabelDecode'} postprocess_params = {'name': 'PRENLabelDecode'}
elif self.rec_algorithm == "CAN":
self.inverse = args.rec_image_inverse
postprocess_params = {
'name': 'SeqLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
...@@ -351,6 +358,30 @@ class TextRecognizer(object): ...@@ -351,6 +358,30 @@ class TextRecognizer(object):
return resized_image return resized_image
def norm_img_can(self, img, image_shape):
img = cv2.cvtColor(
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
if self.inverse:
img = 255 - img
if self.rec_image_shape[0] == 1:
h, w = img.shape
_, imgH, imgW = self.rec_image_shape
if h < imgH or w < imgW:
padding_h = max(imgH - h, 0)
padding_w = max(imgW - w, 0)
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
'constant',
constant_values=(255))
img = img_padded
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
img = img.astype('float32')
return img
def __call__(self, img_list): def __call__(self, img_list):
img_num = len(img_list) img_num = len(img_list)
# Calculate the aspect ratio of all text bars # Calculate the aspect ratio of all text bars
...@@ -430,6 +461,17 @@ class TextRecognizer(object): ...@@ -430,6 +461,17 @@ class TextRecognizer(object):
word_positions = np.array(range(0, 40)).astype('int64') word_positions = np.array(range(0, 40)).astype('int64')
word_positions = np.expand_dims(word_positions, axis=0) word_positions = np.expand_dims(word_positions, axis=0)
word_positions_list.append(word_positions) word_positions_list.append(word_positions)
elif self.rec_algorithm == "CAN":
norm_img = self.norm_img_can(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_image_mask = np.ones(norm_img.shape, dtype='float32')
word_label = np.ones([1, 36], dtype='int64')
norm_img_mask_batch = []
word_label_list = []
norm_img_mask_batch.append(norm_image_mask)
word_label_list.append(word_label)
else: else:
norm_img = self.resize_norm_img(img_list[indices[ino]], norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio) max_wh_ratio)
...@@ -527,6 +569,33 @@ class TextRecognizer(object): ...@@ -527,6 +569,33 @@ class TextRecognizer(object):
if self.benchmark: if self.benchmark:
self.autolog.times.stamp() self.autolog.times.stamp()
preds = outputs[0] preds = outputs[0]
elif self.rec_algorithm == "CAN":
norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
word_label_list = np.concatenate(word_label_list)
inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors,
input_dict)
preds = outputs
else:
input_names = self.predictor.get_input_names()
input_tensor = []
for i in range(len(input_names)):
input_tensor_i = self.predictor.get_input_handle(
input_names[i])
input_tensor_i.copy_from_cpu(inputs[i])
input_tensor.append(input_tensor_i)
self.input_tensor = input_tensor
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs
else: else:
if self.use_onnx: if self.use_onnx:
input_dict = {} input_dict = {}
......
...@@ -84,6 +84,7 @@ def init_args(): ...@@ -84,6 +84,7 @@ def init_args():
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet') parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
......
...@@ -141,6 +141,11 @@ def main(): ...@@ -141,6 +141,11 @@ def main():
paddle.to_tensor(valid_ratio), paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons), paddle.to_tensor(word_positons),
] ]
if config['Architecture']['algorithm'] == "CAN":
image_mask = paddle.ones(
(np.expand_dims(
batch[0], axis=0).shape), dtype='float32')
label = paddle.ones((1, 36), dtype='int64')
images = np.expand_dims(batch[0], axis=0) images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN": if config['Architecture']['algorithm'] == "SRN":
...@@ -149,6 +154,8 @@ def main(): ...@@ -149,6 +154,8 @@ def main():
preds = model(images, img_metas) preds = model(images, img_metas)
elif config['Architecture']['algorithm'] == "RobustScanner": elif config['Architecture']['algorithm'] == "RobustScanner":
preds = model(images, img_metas) preds = model(images, img_metas)
elif config['Architecture']['algorithm'] == "CAN":
preds = model([images, image_mask, label])
else: else:
preds = model(images) preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds)
......
...@@ -273,6 +273,8 @@ def train(config, ...@@ -273,6 +273,8 @@ def train(config,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie"]: elif model_type in ["kie"]:
preds = model(batch) preds = model(batch)
elif algorithm in ['CAN']:
preds = model(batch[:3])
else: else:
preds = model(images) preds = model(images)
preds = to_float32(preds) preds = to_float32(preds)
...@@ -286,6 +288,8 @@ def train(config, ...@@ -286,6 +288,8 @@ def train(config,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'sr']: elif model_type in ["kie", 'sr']:
preds = model(batch) preds = model(batch)
elif algorithm in ['CAN']:
preds = model(batch[:3])
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -302,6 +306,9 @@ def train(config, ...@@ -302,6 +306,9 @@ def train(config,
elif model_type in ['table']: elif model_type in ['table']:
post_result = post_process_class(preds, batch) post_result = post_process_class(preds, batch)
eval_class(post_result, batch) eval_class(post_result, batch)
elif algorithm in ['CAN']:
model_type = 'can'
eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
else: else:
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2' if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss ]: # for multi head loss
...@@ -496,6 +503,8 @@ def eval(model, ...@@ -496,6 +503,8 @@ def eval(model,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie"]: elif model_type in ["kie"]:
preds = model(batch) preds = model(batch)
elif model_type in ['can']:
preds = model(batch[:3])
elif model_type in ['sr']: elif model_type in ['sr']:
preds = model(batch) preds = model(batch)
sr_img = preds["sr_img"] sr_img = preds["sr_img"]
...@@ -508,6 +517,8 @@ def eval(model, ...@@ -508,6 +517,8 @@ def eval(model,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie"]: elif model_type in ["kie"]:
preds = model(batch) preds = model(batch)
elif model_type in ['can']:
preds = model(batch[:3])
elif model_type in ['sr']: elif model_type in ['sr']:
preds = model(batch) preds = model(batch)
sr_img = preds["sr_img"] sr_img = preds["sr_img"]
...@@ -532,6 +543,8 @@ def eval(model, ...@@ -532,6 +543,8 @@ def eval(model,
eval_class(post_result, batch_numpy) eval_class(post_result, batch_numpy)
elif model_type in ['sr']: elif model_type in ['sr']:
eval_class(preds, batch_numpy) eval_class(preds, batch_numpy)
elif model_type in ['can']:
eval_class(preds[0], batch_numpy[2:], epoch_reset=False)
else: else:
post_result = post_process_class(preds, batch_numpy[1]) post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy) eval_class(post_result, batch_numpy)
...@@ -629,7 +642,7 @@ def preprocess(is_train=False): ...@@ -629,7 +642,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG' 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN'
] ]
if use_xpu: if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册