未验证 提交 6e607a0f 编写于 作者: O OneYearIsEnough 提交者: GitHub

[Feature] Add PREN Scene Text Recognition Model(Accepted in CVPR2021) (#5563)

* [Feature] add PREN scene text recognition model

* [Patch] Optimize yml File

* [Patch] Save Label/Pred Preprocess Time Cost

* [BugFix] Modify Shape Conversion to Fit for Inference Model Exportion

* [Patch] ?

* [Patch] ?

* 啥情况...
上级 3df64d50
Global:
use_gpu: True
epoch_num: 8
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/rec/pren_new
save_epoch_step: 3
# evaluation is run every 2000 iterations after the 4000th iteration
eval_batch_step: [4000, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path:
max_text_length: &max_text_length 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_pren.txt
Optimizer:
name: Adadelta
lr:
name: Piecewise
decay_epochs: [2, 5, 7]
values: [0.5, 0.1, 0.01, 0.001]
Architecture:
model_type: rec
algorithm: PREN
in_channels: 3
Backbone:
name: EfficientNetb3_PREN
Neck:
name: PRENFPN
n_r: 5
d_model: 384
max_len: *max_text_length
dropout: 0.1
Head:
name: PRENHead
Loss:
name: PRENLoss
PostProcess:
name: PRENLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- PRENLabelEncode:
- RecAug:
- PRENResizeImg:
image_shape: [64, 256] # h,w
- KeepKeys:
keep_keys: ['image', 'label']
loader:
shuffle: True
batch_size_per_card: 128
drop_last: True
num_workers: 8
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- PRENLabelEncode:
- PRENResizeImg:
image_shape: [64, 256] # h,w
- KeepKeys:
keep_keys: ['image', 'label']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 64
num_workers: 8
...@@ -22,7 +22,8 @@ from .make_shrink_map import MakeShrinkMap ...@@ -22,7 +22,8 @@ from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, RandomCropImgMask from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
from .ColorJitter import ColorJitter from .ColorJitter import ColorJitter
......
...@@ -785,6 +785,53 @@ class SARLabelEncode(BaseRecLabelEncode): ...@@ -785,6 +785,53 @@ class SARLabelEncode(BaseRecLabelEncode):
return [self.padding_idx] return [self.padding_idx]
class PRENLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path,
use_space_char=False,
**kwargs):
super(PRENLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
padding_str = '<PAD>' # 0
end_str = '<EOS>' # 1
unknown_str = '<UNK>' # 2
dict_character = [padding_str, end_str, unknown_str] + dict_character
self.padding_idx = 0
self.end_idx = 1
self.unknown_idx = 2
return dict_character
def encode(self, text):
if len(text) == 0 or len(text) >= self.max_text_len:
return None
if self.lower:
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
text_list.append(self.unknown_idx)
else:
text_list.append(self.dict[char])
text_list.append(self.end_idx)
if len(text_list) < self.max_text_len:
text_list += [self.padding_idx] * (
self.max_text_len - len(text_list))
return text_list
def __call__(self, data):
text = data['label']
encoded_text = self.encode(text)
if encoded_text is None:
return None
data['label'] = np.array(encoded_text)
return data
class VQATokenLabelEncode(object): class VQATokenLabelEncode(object):
""" """
Label encode for NLP VQA methods Label encode for NLP VQA methods
......
...@@ -141,6 +141,25 @@ class SARRecResizeImg(object): ...@@ -141,6 +141,25 @@ class SARRecResizeImg(object):
return data return data
class PRENResizeImg(object):
def __init__(self, image_shape, **kwargs):
"""
Accroding to original paper's realization, it's a hard resize method here.
So maybe you should optimize it to fit for your task better.
"""
self.dst_h, self.dst_w = image_shape
def __call__(self, data):
img = data['image']
resized_img = cv2.resize(
img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
resized_img = resized_img.transpose((2, 0, 1)) / 255
resized_img -= 0.5
resized_img /= 0.5
data['image'] = resized_img.astype(np.float32)
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0] h = img.shape[0]
......
...@@ -32,6 +32,7 @@ from .rec_srn_loss import SRNLoss ...@@ -32,6 +32,7 @@ from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss from .rec_nrtr_loss import NRTRLoss
from .rec_sar_loss import SARLoss from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
...@@ -58,7 +59,7 @@ def build_loss(config): ...@@ -58,7 +59,7 @@ def build_loss(config):
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput' 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
class PRENLoss(nn.Layer):
def __init__(self, **kwargs):
super(PRENLoss, self).__init__()
# note: 0 is padding idx
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
def forward(self, predicts, batch):
loss = self.loss_func(predicts, batch[1].astype('int64'))
return {'loss': loss}
...@@ -30,9 +30,10 @@ def build_backbone(config, model_type): ...@@ -30,9 +30,10 @@ def build_backbone(config, model_type):
from .rec_resnet_31 import ResNet31 from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER", 'MicroNet' "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN'
] ]
elif model_type == "e2e": elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code is refer from:
https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from collections import namedtuple
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ['EfficientNetb3']
class EffB3Params:
@staticmethod
def get_global_params():
"""
The fllowing are efficientnetb3's arch superparams, but to fit for scene
text recognition task, the resolution(image_size) here is changed
from 300 to 64.
"""
GlobalParams = namedtuple('GlobalParams', [
'drop_connect_rate', 'width_coefficient', 'depth_coefficient',
'depth_divisor', 'image_size'
])
global_params = GlobalParams(
drop_connect_rate=0.3,
width_coefficient=1.2,
depth_coefficient=1.4,
depth_divisor=8,
image_size=64)
return global_params
@staticmethod
def get_block_params():
BlockParams = namedtuple('BlockParams', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'se_ratio', 'stride'
])
block_params = [
BlockParams(3, 1, 32, 16, 1, True, 0.25, 1),
BlockParams(3, 2, 16, 24, 6, True, 0.25, 2),
BlockParams(5, 2, 24, 40, 6, True, 0.25, 2),
BlockParams(3, 3, 40, 80, 6, True, 0.25, 2),
BlockParams(5, 3, 80, 112, 6, True, 0.25, 1),
BlockParams(5, 4, 112, 192, 6, True, 0.25, 2),
BlockParams(3, 1, 192, 320, 6, True, 0.25, 1)
]
return block_params
class EffUtils:
@staticmethod
def round_filters(filters, global_params):
"""Calculate and round number of filters based on depth multiplier."""
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
filters *= multiplier
new_filters = int(filters + divisor / 2) // divisor * divisor
if new_filters < 0.9 * filters:
new_filters += divisor
return int(new_filters)
@staticmethod
def round_repeats(repeats, global_params):
"""Round number of filters based on depth multiplier."""
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
class ConvBlock(nn.Layer):
def __init__(self, block_params):
super(ConvBlock, self).__init__()
self.block_args = block_params
self.has_se = (self.block_args.se_ratio is not None) and \
(0 < self.block_args.se_ratio <= 1)
self.id_skip = block_params.id_skip
# expansion phase
self.input_filters = self.block_args.input_filters
output_filters = \
self.block_args.input_filters * self.block_args.expand_ratio
if self.block_args.expand_ratio != 1:
self.expand_conv = nn.Conv2D(
self.input_filters, output_filters, 1, bias_attr=False)
self.bn0 = nn.BatchNorm(output_filters)
# depthwise conv phase
k = self.block_args.kernel_size
s = self.block_args.stride
self.depthwise_conv = nn.Conv2D(
output_filters,
output_filters,
groups=output_filters,
kernel_size=k,
stride=s,
padding='same',
bias_attr=False)
self.bn1 = nn.BatchNorm(output_filters)
# squeeze and excitation layer, if desired
if self.has_se:
num_squeezed_channels = max(1,
int(self.block_args.input_filters *
self.block_args.se_ratio))
self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1)
self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1)
# output phase
self.final_oup = self.block_args.output_filters
self.project_conv = nn.Conv2D(
output_filters, self.final_oup, 1, bias_attr=False)
self.bn2 = nn.BatchNorm(self.final_oup)
self.swish = nn.Swish()
def drop_connect(self, inputs, p, training):
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype)
random_tensor = paddle.to_tensor(random_tensor, place=inputs.place)
binary_tensor = paddle.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def forward(self, inputs, drop_connect_rate=None):
# expansion and depthwise conv
x = inputs
if self.block_args.expand_ratio != 1:
x = self.swish(self.bn0(self.expand_conv(inputs)))
x = self.swish(self.bn1(self.depthwise_conv(x)))
# squeeze and excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed)))
x = F.sigmoid(x_squeezed) * x
x = self.bn2(self.project_conv(x))
# skip conntection and drop connect
if self.id_skip and self.block_args.stride == 1 and \
self.input_filters == self.final_oup:
if drop_connect_rate:
x = self.drop_connect(
x, p=drop_connect_rate, training=self.training)
x = x + inputs
return x
class EfficientNetb3_PREN(nn.Layer):
def __init__(self, in_channels):
super(EfficientNetb3_PREN, self).__init__()
self.blocks_params = EffB3Params.get_block_params()
self.global_params = EffB3Params.get_global_params()
self.out_channels = []
# stem
stem_channels = EffUtils.round_filters(32, self.global_params)
self.conv_stem = nn.Conv2D(
in_channels, stem_channels, 3, 2, padding='same', bias_attr=False)
self.bn0 = nn.BatchNorm(stem_channels)
self.blocks = []
# to extract three feature maps for fpn based on efficientnetb3 backbone
self.concerned_block_idxes = [7, 17, 25]
concerned_idx = 0
for i, block_params in enumerate(self.blocks_params):
block_params = block_params._replace(
input_filters=EffUtils.round_filters(block_params.input_filters,
self.global_params),
output_filters=EffUtils.round_filters(
block_params.output_filters, self.global_params),
num_repeat=EffUtils.round_repeats(block_params.num_repeat,
self.global_params))
self.blocks.append(
self.add_sublayer("{}-0".format(i), ConvBlock(block_params)))
concerned_idx += 1
if concerned_idx in self.concerned_block_idxes:
self.out_channels.append(block_params.output_filters)
if block_params.num_repeat > 1:
block_params = block_params._replace(
input_filters=block_params.output_filters, stride=1)
for j in range(block_params.num_repeat - 1):
self.blocks.append(
self.add_sublayer('{}-{}'.format(i, j + 1),
ConvBlock(block_params)))
concerned_idx += 1
if concerned_idx in self.concerned_block_idxes:
self.out_channels.append(block_params.output_filters)
self.swish = nn.Swish()
def forward(self, inputs):
outs = []
x = self.swish(self.bn0(self.conv_stem(inputs)))
for idx, block in enumerate(self.blocks):
drop_connect_rate = self.global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
if idx in self.concerned_block_idxes:
outs.append(x)
return outs
...@@ -30,6 +30,7 @@ def build_head(config): ...@@ -30,6 +30,7 @@ def build_head(config):
from .rec_nrtr_head import Transformer from .rec_nrtr_head import Transformer
from .rec_sar_head import SARHead from .rec_sar_head import SARHead
from .rec_aster_head import AsterHead from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead
# cls head # cls head
from .cls_head import ClsHead from .cls_head import ClsHead
...@@ -42,7 +43,7 @@ def build_head(config): ...@@ -42,7 +43,7 @@ def build_head(config):
support_dict = [ support_dict = [
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead' 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead'
] ]
#table head #table head
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
from paddle.nn import functional as F
class PRENHead(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs):
super(PRENHead, self).__init__()
self.linear = nn.Linear(in_channels, out_channels)
def forward(self, x, targets=None):
predicts = self.linear(x)
if not self.training:
predicts = F.softmax(predicts, axis=2)
return predicts
...@@ -23,7 +23,11 @@ def build_neck(config): ...@@ -23,7 +23,11 @@ def build_neck(config):
from .pg_fpn import PGFPN from .pg_fpn import PGFPN
from .table_fpn import TableFPN from .table_fpn import TableFPN
from .fpn import FPN from .fpn import FPN
support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN'] from .pren_fpn import PRENFPN
support_dict = [
'FPN', 'DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN',
'TableFPN', 'PRENFPN'
]
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code is refer from:
https://github.com/RuijieJ/pren/blob/main/Nets/Aggregation.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
class PoolAggregate(nn.Layer):
def __init__(self, n_r, d_in, d_middle=None, d_out=None):
super(PoolAggregate, self).__init__()
if not d_middle:
d_middle = d_in
if not d_out:
d_out = d_in
self.d_in = d_in
self.d_middle = d_middle
self.d_out = d_out
self.act = nn.Swish()
self.n_r = n_r
self.aggs = self._build_aggs()
def _build_aggs(self):
aggs = []
for i in range(self.n_r):
aggs.append(
self.add_sublayer(
'{}'.format(i),
nn.Sequential(
('conv1', nn.Conv2D(
self.d_in, self.d_middle, 3, 2, 1, bias_attr=False)
), ('bn1', nn.BatchNorm(self.d_middle)),
('act', self.act), ('conv2', nn.Conv2D(
self.d_middle, self.d_out, 3, 2, 1, bias_attr=False
)), ('bn2', nn.BatchNorm(self.d_out)))))
return aggs
def forward(self, x):
b = x.shape[0]
outs = []
for agg in self.aggs:
y = agg(x)
p = F.adaptive_avg_pool2d(y, 1)
outs.append(p.reshape((b, 1, self.d_out)))
out = paddle.concat(outs, 1)
return out
class WeightAggregate(nn.Layer):
def __init__(self, n_r, d_in, d_middle=None, d_out=None):
super(WeightAggregate, self).__init__()
if not d_middle:
d_middle = d_in
if not d_out:
d_out = d_in
self.n_r = n_r
self.d_out = d_out
self.act = nn.Swish()
self.conv_n = nn.Sequential(
('conv1', nn.Conv2D(
d_in, d_in, 3, 1, 1,
bias_attr=False)), ('bn1', nn.BatchNorm(d_in)),
('act1', self.act), ('conv2', nn.Conv2D(
d_in, n_r, 1, bias_attr=False)), ('bn2', nn.BatchNorm(n_r)),
('act2', nn.Sigmoid()))
self.conv_d = nn.Sequential(
('conv1', nn.Conv2D(
d_in, d_middle, 3, 1, 1,
bias_attr=False)), ('bn1', nn.BatchNorm(d_middle)),
('act1', self.act), ('conv2', nn.Conv2D(
d_middle, d_out, 1,
bias_attr=False)), ('bn2', nn.BatchNorm(d_out)))
def forward(self, x):
b, _, h, w = x.shape
hmaps = self.conv_n(x)
fmaps = self.conv_d(x)
r = paddle.bmm(
hmaps.reshape((b, self.n_r, h * w)),
fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1)))
return r
class GCN(nn.Layer):
def __init__(self, d_in, n_in, d_out=None, n_out=None, dropout=0.1):
super(GCN, self).__init__()
if not d_out:
d_out = d_in
if not n_out:
n_out = d_in
self.conv_n = nn.Conv1D(n_in, n_out, 1)
self.linear = nn.Linear(d_in, d_out)
self.dropout = nn.Dropout(dropout)
self.act = nn.Swish()
def forward(self, x):
x = self.conv_n(x)
x = self.dropout(self.linear(x))
return self.act(x)
class PRENFPN(nn.Layer):
def __init__(self, in_channels, n_r, d_model, max_len, dropout):
super(PRENFPN, self).__init__()
assert len(in_channels) == 3, "in_channels' length must be 3."
c1, c2, c3 = in_channels # the depths are from big to small
# build fpn
assert d_model % 3 == 0, "{} can't be divided by 3.".format(d_model)
self.agg_p1 = PoolAggregate(n_r, c1, d_out=d_model // 3)
self.agg_p2 = PoolAggregate(n_r, c2, d_out=d_model // 3)
self.agg_p3 = PoolAggregate(n_r, c3, d_out=d_model // 3)
self.agg_w1 = WeightAggregate(n_r, c1, 4 * c1, d_model // 3)
self.agg_w2 = WeightAggregate(n_r, c2, 4 * c2, d_model // 3)
self.agg_w3 = WeightAggregate(n_r, c3, 4 * c3, d_model // 3)
self.gcn_pool = GCN(d_model, n_r, d_model, max_len, dropout)
self.gcn_weight = GCN(d_model, n_r, d_model, max_len, dropout)
self.out_channels = d_model
def forward(self, inputs):
f3, f5, f7 = inputs
rp1 = self.agg_p1(f3)
rp2 = self.agg_p2(f5)
rp3 = self.agg_p3(f7)
rp = paddle.concat([rp1, rp2, rp3], 2) # [b,nr,d]
rw1 = self.agg_w1(f3)
rw2 = self.agg_w2(f5)
rw3 = self.agg_w3(f7)
rw = paddle.concat([rw1, rw2, rw3], 2) # [b,nr,d]
y1 = self.gcn_pool(rp)
y2 = self.gcn_weight(rw)
y = 0.5 * (y1 + y2)
return y # [b,max_len,d]
...@@ -24,8 +24,9 @@ __all__ = ['build_post_process'] ...@@ -24,8 +24,9 @@ __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
...@@ -39,7 +40,7 @@ def build_post_process(config, global_config=None): ...@@ -39,7 +40,7 @@ def build_post_process(config, global_config=None):
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess' 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import string
import paddle import paddle
from paddle.nn import functional as F from paddle.nn import functional as F
import re import re
...@@ -652,3 +652,63 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -652,3 +652,63 @@ class SARLabelDecode(BaseRecLabelDecode):
def get_ignored_tokens(self): def get_ignored_tokens(self):
return [self.padding_idx] return [self.padding_idx]
class PRENLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(PRENLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
padding_str = '<PAD>' # 0
end_str = '<EOS>' # 1
unknown_str = '<UNK>' # 2
dict_character = [padding_str, end_str, unknown_str] + dict_character
self.padding_idx = 0
self.end_idx = 1
self.unknown_idx = 2
return dict_character
def decode(self, text_index, text_prob=None):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == self.end_idx:
break
if text_index[batch_idx][idx] in \
[self.padding_idx, self.unknown_idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
if len(text) > 0:
result_list.append((text, np.mean(conf_list)))
else:
# here confidence of empty recog result is 1
result_list.append(('', 1))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob)
if label is None:
return text
label = self.decode(label)
return text, label
...@@ -28,7 +28,6 @@ from ppocr.modeling.architectures import build_model ...@@ -28,7 +28,6 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import print_dict
import tools.program as program import tools.program as program
......
...@@ -55,6 +55,12 @@ def export_single_model(model, arch_config, save_path, logger): ...@@ -55,6 +55,12 @@ def export_single_model(model, arch_config, save_path, logger):
shape=[None, 3, 48, 160], dtype="float32"), shape=[None, 3, 48, 160], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN":
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 64, 512], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
else: else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec": if arch_config["model_type"] == "rec":
......
...@@ -541,7 +541,7 @@ def preprocess(is_train=False): ...@@ -541,7 +541,7 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM' 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN'
] ]
device = 'cpu' device = 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册