未验证 提交 9a44e279 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #3798 from andyjpaddle/add_rec_sar

Add rec_sar
Global:
use_gpu: true
epoch_num: 5
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./sar_rec
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
character_type: EN_symbol
max_text_length: 30
infer_mode: False
use_space_char: False
rm_symbol: True
save_res_path: ./output/rec/predicts_sar.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4]
values: [0.001, 0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: SAR
Transform:
Backbone:
name: ResNet31
Head:
name: SARHead
Loss:
name: SARLoss
PostProcess:
name: SARLabelDecode
Metric:
name: RecMetric
Train:
dataset:
name: SimpleDataSet
label_file_list: ['./train_data/train_list.txt']
data_dir: ./train_data/
ratio_list: 1.0
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- SARRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 64
drop_last: True
num_workers: 8
use_shared_memory: False
Eval:
dataset:
name: LMDBDataSet
data_dir: ./eval_data/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- SARRecResizeImg:
image_shape: [3, 48, 48, 160]
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 64
num_workers: 4
use_shared_memory: False
......@@ -45,6 +45,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
......@@ -60,6 +61,6 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|SAR|Resnet31| 87.2% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)
......@@ -86,7 +86,10 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单
若您本地没有数据集,可以在官网下载 [ICDAR2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,下载 benchmark 所需的lmdb格式数据集。
如果希望复现SAR的论文指标,需要下载[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg), 提取码:627x。此外,真实数据集icdar2013, icdar2015, cocotext, IIIT5也作为训练数据的一部分。具体数据细节可以参考论文SAR。
如果你使用的是icdar2015的公开数据集,PaddleOCR 提供了一份用于训练 ICDAR2015 数据集的标签文件,通过以下方式下载:
```
# 训练集标签
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt
......@@ -230,6 +233,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
......
......@@ -47,6 +47,7 @@ PaddleOCR open-source text recognition algorithms list:
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
......@@ -62,5 +63,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|SAR|Resnet31| 87.2% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
......@@ -92,6 +92,8 @@ Similar to the training set, the test set also needs to be provided a folder con
If you do not have a dataset locally, you can download it on the official website [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads).
Also refer to [DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,download the lmdb format dataset required for benchmark
If you want to reproduce the paper SAR, you need to download extra dataset [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg), extraction code: 627x. Besides, icdar2013, icdar2015, cocotext, IIIT5k datasets are also used to train. For specific details, please refer to the paper SAR.
PaddleOCR provides label files for training the icdar2015 dataset, which can be downloaded in the following ways:
```
......@@ -236,6 +238,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
......
......@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg
from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import *
......
......@@ -549,3 +549,49 @@ class TableLabelEncode(object):
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SARLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
def get_ignored_tokens(self):
return [self.padding_idx]
......@@ -102,6 +102,56 @@ class SRNRecResizeImg(object):
return data
class SARRecResizeImg(object):
def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio)
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
data['valid_ratio'] = valid_ratio
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape
h = img.shape[0]
......
......@@ -26,6 +26,7 @@ from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss
from .rec_sar_loss import SARLoss
# cls loss
from .cls_loss import ClsLoss
......@@ -44,7 +45,7 @@ from .table_att_loss import TableAttentionLoss
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss'
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss'
]
config = copy.deepcopy(config)
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class SARLoss(nn.Layer):
def __init__(self, **kwargs):
super(SARLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
def forward(self, predicts, batch):
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
1], predict.shape[2]
assert len(label.shape) == len(list(predict.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predict, [-1, num_classes])
targets = paddle.reshape(label, [-1])
loss = self.loss_func(inputs, targets)
return {'loss': loss}
......@@ -27,8 +27,9 @@ def build_backbone(config, model_type):
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB'
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31"
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
__all__ = ["ResNet31"]
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2D(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False
)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, in_channels, channels, stride=1, downsample=False):
super().__init__()
self.conv1 = conv3x3(in_channels, channels, stride)
self.bn1 = nn.BatchNorm2D(channels)
self.relu = nn.ReLU()
self.conv2 = conv3x3(channels, channels)
self.bn2 = nn.BatchNorm2D(channels)
self.downsample = downsample
if downsample:
self.downsample = nn.Sequential(
nn.Conv2D(in_channels, channels * self.expansion, 1, stride, bias_attr=False),
nn.BatchNorm2D(channels * self.expansion),
)
else:
self.downsample = nn.Sequential()
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet31(nn.Layer):
'''
Args:
in_channels (int): Number of channels of input image tensor.
layers (list[int]): List of BasicBlock number for each stage.
channels (list[int]): List of out_channels of Conv2d layer.
out_indices (None | Sequence[int]): Indices of output stages.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
'''
def __init__(self,
in_channels=3,
layers=[1, 2, 5, 3],
channels=[64, 128, 256, 256, 512, 512, 512],
out_indices=None,
last_stage_pool=False):
super(ResNet31, self).__init__()
assert isinstance(in_channels, int)
assert isinstance(last_stage_pool, bool)
self.out_indices = out_indices
self.last_stage_pool = last_stage_pool
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(in_channels, channels[0], kernel_size=3, stride=1, padding=1)
self.bn1_1 = nn.BatchNorm2D(channels[0])
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
self.bn1_2 = nn.BatchNorm2D(channels[1])
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
self.pool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.conv2 = nn.Conv2D(channels[2], channels[2], kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(channels[2])
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
self.pool3 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.conv3 = nn.Conv2D(channels[3], channels[3], kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2D(channels[3])
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
self.pool4 = nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.conv4 = nn.Conv2D(channels[4], channels[4], kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2D(channels[4])
self.relu4 = nn.ReLU()
# conv 5 ((Max-pooling), Residual block, Conv)
self.pool5 = None
if self.last_stage_pool:
self.pool5 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.conv5 = nn.Conv2D(channels[5], channels[5], kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2D(channels[5])
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
def _make_layer(self, input_channels, output_channels, blocks):
layers = []
for _ in range(blocks):
downsample = None
if input_channels != output_channels:
downsample = nn.Sequential(
nn.Conv2D(
input_channels,
output_channels,
kernel_size=1,
stride=1,
bias_attr=False),
nn.BatchNorm2D(output_channels),
)
layers.append(BasicBlock(input_channels, output_channels, downsample=downsample))
input_channels = output_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu1_1(x)
x = self.conv1_2(x)
x = self.bn1_2(x)
x = self.relu1_2(x)
outs = []
for i in range(4):
layer_index = i + 2
pool_layer = getattr(self, f'pool{layer_index}')
block_layer = getattr(self, f'block{layer_index}')
conv_layer = getattr(self, f'conv{layer_index}')
bn_layer = getattr(self, f'bn{layer_index}')
relu_layer = getattr(self, f'relu{layer_index}')
if pool_layer is not None:
x = pool_layer(x)
x = block_layer(x)
x = conv_layer(x)
x = bn_layer(x)
x= relu_layer(x)
outs.append(x)
if self.out_indices is not None:
return tuple([outs[i] for i in self.out_indices])
return x
......@@ -27,12 +27,13 @@ def build_head(config):
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
from .rec_nrtr_head import Transformer
from .rec_sar_head import SARHead
# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead'
]
#table head
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
class SAREncoder(nn.Layer):
"""
Args:
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
enc_gru (bool): If True, use GRU, else LSTM in encoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
"""
def __init__(self,
enc_bi_rnn=False,
enc_drop_rnn=0.1,
enc_gru=False,
d_model=512,
d_enc=512,
mask=True,
**kwargs):
super().__init__()
assert isinstance(enc_bi_rnn, bool)
assert isinstance(enc_drop_rnn, (int, float))
assert 0 <= enc_drop_rnn < 1.0
assert isinstance(enc_gru, bool)
assert isinstance(d_model, int)
assert isinstance(d_enc, int)
assert isinstance(mask, bool)
self.enc_bi_rnn = enc_bi_rnn
self.enc_drop_rnn = enc_drop_rnn
self.mask = mask
# LSTM Encoder
if enc_bi_rnn:
direction = 'bidirectional'
else:
direction = 'forward'
kwargs = dict(
input_size=d_model,
hidden_size=d_enc,
num_layers=2,
time_major=False,
dropout=enc_drop_rnn,
direction=direction)
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
else:
self.rnn_encoder = nn.LSTM(**kwargs)
# global feature transformation
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
def forward(self, feat, img_metas=None):
if img_metas is not None:
assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
h_feat = feat.shape[2] # bsz c h w
feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
if valid_ratios is not None:
valid_hf = []
T = holistic_feat.shape[1]
for i, valid_ratio in enumerate(valid_ratios):
valid_step = min(T, math.ceil(T * valid_ratio)) - 1
valid_hf.append(holistic_feat[i, valid_step, :])
valid_hf = paddle.stack(valid_hf, axis=0)
else:
valid_hf = holistic_feat[:, -1, :] # bsz * C
holistic_feat = self.linear(valid_hf) # bsz * C
return holistic_feat
class BaseDecoder(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
def forward_train(self, feat, out_enc, targets, img_metas):
raise NotImplementedError
def forward_test(self, feat, out_enc, img_metas):
raise NotImplementedError
def forward(self,
feat,
out_enc,
label=None,
img_metas=None,
train_mode=True):
self.train_mode = train_mode
if train_mode:
return self.forward_train(feat, out_enc, label, img_metas)
return self.forward_test(feat, out_enc, img_metas)
class ParallelSARDecoder(BaseDecoder):
"""
Args:
out_channels (int): Output class number.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_drop_rnn (float): Dropout of RNN layer in decoder.
dec_gru (bool): If True, use GRU, else LSTM in decoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
d_k (int): Dim of channels of attention module.
pred_dropout (float): Dropout probability of prediction layer.
max_seq_len (int): Maximum sequence length for decoding.
mask (bool): If True, mask padding in feature map.
start_idx (int): Index of start token.
padding_idx (int): Index of padding token.
pred_concat (bool): If True, concat glimpse feature from
attention with holistic feature and hidden state.
"""
def __init__(
self,
out_channels, # 90 + unknown + start + padding
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_drop_rnn=0.0,
dec_gru=False,
d_model=512,
d_enc=512,
d_k=64,
pred_dropout=0.1,
max_text_length=30,
mask=True,
pred_concat=True,
**kwargs):
super().__init__()
self.num_classes = out_channels
self.enc_bi_rnn = enc_bi_rnn
self.d_k = d_k
self.start_idx = out_channels - 2
self.padding_idx = out_channels - 1
self.max_seq_len = max_text_length
self.mask = mask
self.pred_concat = pred_concat
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
# 2D attention layer
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
self.conv3x3_1 = nn.Conv2D(
d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Linear(d_k, 1)
# Decoder RNN layer
if dec_bi_rnn:
direction = 'bidirectional'
else:
direction = 'forward'
kwargs = dict(
input_size=encoder_rnn_out_size,
hidden_size=encoder_rnn_out_size,
num_layers=2,
time_major=False,
dropout=dec_drop_rnn,
direction=direction)
if dec_gru:
self.rnn_decoder = nn.GRU(**kwargs)
else:
self.rnn_decoder = nn.LSTM(**kwargs)
# Decoder input embedding
self.embedding = nn.Embedding(
self.num_classes,
encoder_rnn_out_size,
padding_idx=self.padding_idx)
# Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_classes = self.num_classes - 1
if pred_concat:
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
else:
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
def _2d_attention(self,
decoder_input,
feat,
holistic_feat,
valid_ratios=None):
y = self.rnn_decoder(decoder_input)[0]
# y: bsz * (seq_len + 1) * hidden_size
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
bsz, seq_len, attn_size = attn_query.shape
attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
# (bsz, seq_len + 1, attn_size, 1, 1)
attn_key = self.conv3x3_1(feat)
# bsz * attn_size * h * w
attn_key = attn_key.unsqueeze(1)
# bsz * 1 * attn_size * h * w
attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
# bsz * (seq_len + 1) * attn_size * h * w
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
# bsz * (seq_len + 1) * h * w * attn_size
attn_weight = self.conv1x1_2(attn_weight)
# bsz * (seq_len + 1) * h * w * 1
bsz, T, h, w, c = attn_weight.shape
assert c == 1
if valid_ratios is not None:
# cal mask of attention weight
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
attn_weight[i, :, :, valid_width:, :] = float('-inf')
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
attn_weight = F.softmax(attn_weight, axis=-1)
attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
# attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
(3, 4),
keepdim=False)
# bsz * (seq_len + 1) * C
# Linear transformation
if self.pred_concat:
hf_c = holistic_feat.shape[-1]
holistic_feat = paddle.expand(
holistic_feat, shape=[bsz, seq_len, hf_c])
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
else:
y = self.prediction(attn_feat)
# bsz * (seq_len + 1) * num_classes
if self.train_mode:
y = self.pred_dropout(y)
return y
def forward_train(self, feat, out_enc, label, img_metas):
'''
img_metas: [label, valid_ratio]
'''
if img_metas is not None:
assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
label = label.cuda()
lab_embedding = self.embedding(label)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
# bsz * (seq_len + 1) * C
out_dec = self._2d_attention(
in_dec, feat, out_enc, valid_ratios=valid_ratios)
# bsz * (seq_len + 1) * num_classes
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
def forward_test(self, feat, out_enc, img_metas):
if img_metas is not None:
assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]
seq_len = self.max_seq_len
bsz = feat.shape[0]
start_token = paddle.full(
(bsz, ), fill_value=self.start_idx, dtype='int64')
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
emb_dim = start_token.shape[1]
start_token = start_token.unsqueeze(1)
start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
decoder_input = paddle.concat((out_enc, start_token), axis=1)
# bsz * (seq_len + 1) * emb_dim
outputs = []
for i in range(1, seq_len + 1):
decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1)
outputs.append(char_output)
max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
char_embedding = self.embedding(max_idx) # bsz * emb_dim
if i < seq_len:
decoder_input[:, i + 1, :] = char_embedding
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
return outputs
class SARHead(nn.Layer):
def __init__(self,
out_channels,
enc_bi_rnn=False,
enc_drop_rnn=0.1,
enc_gru=False,
dec_bi_rnn=False,
dec_drop_rnn=0.0,
dec_gru=False,
d_k=512,
pred_dropout=0.1,
max_text_length=30,
pred_concat=True,
**kwargs):
super(SARHead, self).__init__()
# encoder module
self.encoder = SAREncoder(
enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
# decoder module
self.decoder = ParallelSARDecoder(
out_channels=out_channels,
enc_bi_rnn=enc_bi_rnn,
dec_bi_rnn=dec_bi_rnn,
dec_drop_rnn=dec_drop_rnn,
dec_gru=dec_gru,
d_k=d_k,
pred_dropout=pred_dropout,
max_text_length=max_text_length,
pred_concat=pred_concat)
def forward(self, feat, targets=None):
'''
img_metas: [label, valid_ratio]
'''
holistic_feat = self.encoder(feat, targets) # bsz c
if self.training:
label = targets[0] # label
label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets)
if not self.training:
final_out = self.decoder(
feat,
holistic_feat,
label=None,
img_metas=targets,
train_mode=False)
# (bsz, seq_len, num_classes)
return final_out
......@@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \
TableLabelDecode
TableLabelDecode, SARLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
......@@ -33,7 +33,8 @@ def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'NRTRLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode'
]
config = copy.deepcopy(config)
......
......@@ -15,6 +15,7 @@ import numpy as np
import string
import paddle
from paddle.nn import functional as F
import re
class BaseRecLabelDecode(object):
......@@ -165,21 +166,21 @@ class NRTRLabelDecode(BaseRecLabelDecode):
use_space_char=True,
**kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if preds.dtype == paddle.int64:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if preds[0][0]==2:
preds_idx = preds[:,1:]
if preds[0][0] == 2:
preds_idx = preds[:, 1:]
else:
preds_idx = preds
text = self.decode(preds_idx)
if label is None:
return text
label = self.decode(label[:,1:])
label = self.decode(label[:, 1:])
else:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
......@@ -188,13 +189,13 @@ class NRTRLabelDecode(BaseRecLabelDecode):
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:,1:])
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
......@@ -203,10 +204,11 @@ class NRTRLabelDecode(BaseRecLabelDecode):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == 3: # end
if text_index[batch_idx][idx] == 3: # end
break
try:
char_list.append(self.character[int(text_index[batch_idx][idx])])
char_list.append(self.character[int(text_index[batch_idx][
idx])])
except:
continue
if text_prob is not None:
......@@ -218,7 +220,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
return result_list
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......@@ -256,7 +257,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][idx])])
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
......@@ -386,10 +388,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
class TableLabelDecode(object):
""" """
def __init__(self,
character_dict_path,
**kwargs):
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
def __init__(self, character_dict_path, **kwargs):
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
......@@ -408,7 +409,8 @@ class TableLabelDecode(object):
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
......@@ -428,14 +430,14 @@ class TableLabelDecode(object):
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs,paddle.Tensor):
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds,paddle.Tensor):
if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
structure_probs, 'elem')
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
......@@ -450,8 +452,13 @@ class TableLabelDecode(object):
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
return {
'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
......@@ -516,3 +523,82 @@ class TableLabelDecode(object):
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SARLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', False)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(self.end_idx):
if text_prob is None and idx == 0:
continue
else:
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
if self.rm_symbol:
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
text = text.lower()
text = comp.sub('', text)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
def get_ignored_tokens(self):
return [self.padding_idx]
......@@ -55,6 +55,7 @@ def main():
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
use_sar = config['Architecture']['algorithm'] == "SAR"
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
......@@ -71,7 +72,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn)
eval_class, model_type, use_srn, use_sar)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
......
......@@ -74,6 +74,10 @@ def main():
'image', 'encoder_word_pos', 'gsrm_word_pos',
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = [
'image', 'valid_ratio'
]
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
......@@ -106,11 +110,16 @@ def main():
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
if config['Architecture']['algorithm'] == "SAR":
valid_ratio = np.expand_dims(batch[-1], axis=0)
img_metas = [paddle.to_tensor(valid_ratio)]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
elif config['Architecture']['algorithm'] == "SAR":
preds = model(images, img_metas)
else:
preds = model(images)
post_result = post_process_class(preds)
......
......@@ -187,7 +187,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
use_sar = config['Architecture']['algorithm'] == 'SAR'
try:
model_type = config['Architecture']['model_type']
except:
......@@ -215,7 +215,7 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
if use_srn or model_type == 'table' or use_nrtr:
if use_srn or model_type == 'table' or use_nrtr or use_sar:
preds = model(images, data=batch[1:])
else:
preds = model(images)
......@@ -279,7 +279,8 @@ def train(config,
post_process_class,
eval_class,
model_type,
use_srn=use_srn)
use_srn=use_srn,
use_sar=use_sar)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
......@@ -351,7 +352,8 @@ def eval(model,
post_process_class,
eval_class,
model_type,
use_srn=False):
use_srn=False,
use_sar=False):
model.eval()
with paddle.no_grad():
total_frame = 0.0
......@@ -364,7 +366,7 @@ def eval(model,
break
images = batch[0]
start = time.time()
if use_srn or model_type == 'table':
if use_srn or model_type == 'table' or use_sar:
preds = model(images, data=batch[1:])
else:
preds = model(images)
......@@ -400,7 +402,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册