未验证 提交 5c664bf4 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #3721 from Topdu/dygraph

add rec_nrtr
Global:
use_gpu: True
epoch_num: 21
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/nrtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
character_type: EN_symbol
max_text_length: 25
infer_mode: False
use_space_char: True
save_res_path: ./output/rec/predicts_nrtr.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.99
clip_norm: 5.0
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0.
Architecture:
model_type: rec
algorithm: NRTR
in_channels: 1
Transform:
Backbone:
name: MTB
cnn_num: 2
Head:
name: Transformer
d_model: 512
num_encoder_layers: 6
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
Loss:
name: NRTRLoss
smoothing: True
PostProcess:
name: NRTRLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- NRTRDecodeImage: # load image
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- NRTRRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 512
drop_last: True
num_workers: 8
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- NRTRDecodeImage: # load image
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- NRTRRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 1
use_shared_memory: False
......@@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
......@@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)
......@@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
......
......@@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list:
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
......@@ -60,5 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
......@@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
......
......@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import *
......
......@@ -161,6 +161,34 @@ class BaseRecLabelEncode(object):
return text_list
class NRTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN_symbol',
use_space_char=False,
**kwargs):
super(NRTRLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
text.insert(0, 2)
text.append(3)
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
return dict_character
class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
......
......@@ -57,6 +57,38 @@ class DecodeImage(object):
return data
class NRTRDecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
......
......@@ -16,7 +16,7 @@ import math
import cv2
import numpy as np
import random
from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
......@@ -43,6 +43,25 @@ class ClsResizeImg(object):
return data
class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type
def __call__(self, data):
img = data['image']
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
class RecResizeImg(object):
def __init__(self,
image_shape,
......
......@@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss
# cls loss
from .cls_loss import ClsLoss
......@@ -44,8 +44,9 @@ from .table_att_loss import TableAttentionLoss
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
......
import paddle
from paddle import nn
import paddle.nn.functional as F
class NRTRLoss(nn.Layer):
def __init__(self, smoothing=True, **kwargs):
super(NRTRLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.smoothing = smoothing
def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
max_len = batch[2].max()
tgt = batch[1][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int64'))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
return {'loss': loss}
......@@ -57,3 +57,4 @@ class RecMetric(object):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
......@@ -14,7 +14,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
......
......@@ -26,8 +26,9 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
support_dict = [
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB'
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
class MTB(nn.Layer):
def __init__(self, cnn_num, in_channels):
super(MTB, self).__init__()
self.block = nn.Sequential()
self.out_channels = in_channels
self.cnn_num = cnn_num
if self.cnn_num == 2:
for i in range(self.cnn_num):
self.block.add_sublayer(
'conv_{}'.format(i),
nn.Conv2D(
in_channels=in_channels
if i == 0 else 32 * (2**(i - 1)),
out_channels=32 * (2**i),
kernel_size=3,
stride=2,
padding=1))
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
self.block.add_sublayer('bn_{}'.format(i),
nn.BatchNorm2D(32 * (2**i)))
def forward(self, images):
x = self.block(images)
if self.cnn_num == 2:
# (b, w, h, c)
x = x.transpose([0, 3, 2, 1])
x_shape = x.shape
x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
return x
......@@ -26,12 +26,14 @@ def build_head(config):
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
from .rec_nrtr_head import Transformer
# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead', 'TableAttentionHead']
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
]
#table head
from .table_att_head import TableAttentionHead
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import Linear
from paddle.nn.initializer import XavierUniform as xavier_uniform_
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_
zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)
class MultiheadAttention(nn.Layer):
"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
"""
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
bias=True,
add_bias_kv=False,
add_zero_attn=False):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
self._reset_parameters()
self.conv1 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv2 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv3 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
def _reset_parameters(self):
xavier_uniform_(self.out_proj.weight)
def forward(self,
query,
key,
value,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]
key_padding_mask: if True, mask padding based on batch size
incremental_state: if provided, previous time steps are cashed
need_weights: output attn_output_weights
static_kv: key and value are static
Outputs of forward function
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert key.shape == value.shape
q = self._in_proj_q(query)
k = self._in_proj_k(key)
v = self._in_proj_v(value)
q *= self.scaling
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])
src_len = k.shape[1]
if key_padding_mask is not None:
assert key_padding_mask.shape[0] == bsz
assert key_padding_mask.shape[1] == src_len
attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1]))
assert list(attn_output_weights.
shape) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
y = paddle.where(key == 0., key, y)
attn_output_weights += y
attn_output_weights = attn_output_weights.reshape(
[bsz * self.num_heads, tgt_len, src_len])
attn_output_weights = F.softmax(
attn_output_weights.astype('float32'),
axis=-1,
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
else attn_output_weights.dtype)
attn_output_weights = F.dropout(
attn_output_weights, p=self.dropout, training=self.training)
attn_output = paddle.bmm(attn_output_weights, v)
assert list(attn_output.
shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn_output = attn_output.transpose([1, 0, 2]).reshape(
[tgt_len, bsz, embed_dim])
attn_output = self.out_proj(attn_output)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.sum(
axis=1) / self.num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights
def _in_proj_q(self, query):
query = query.transpose([1, 2, 0])
query = paddle.unsqueeze(query, axis=2)
res = self.conv1(query)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
def _in_proj_k(self, key):
key = key.transpose([1, 2, 0])
key = paddle.unsqueeze(key, axis=2)
res = self.conv2(key)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
def _in_proj_v(self, value):
value = value.transpose([1, 2, 0]) #(1, 2, 0)
value = paddle.unsqueeze(value, axis=2)
res = self.conv3(value)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
此差异已折叠。
......@@ -24,18 +24,16 @@ __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \
TableLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess'
'DistillationCTCLabelDecode', 'NRTRLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
]
config = copy.deepcopy(config)
......
......@@ -156,6 +156,69 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
return output
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='EN_symbol',
use_space_char=True,
**kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if preds.dtype == paddle.int64:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if preds[0][0]==2:
preds_idx = preds[:,1:]
else:
preds_idx = preds
text = self.decode(preds_idx)
if label is None:
return text
label = self.decode(label[:,1:])
else:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:,1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == 3: # end
break
try:
char_list.append(self.character[int(text_index[batch_idx][idx])])
except:
continue
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list)))
return result_list
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......@@ -193,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
......
......@@ -186,9 +186,11 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
try:
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
try:
model_type = config['Architecture']['model_type']
except:
except:
model_type = None
if 'start_epoch' in best_model_dict:
......@@ -213,7 +215,7 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
if use_srn or model_type == 'table':
if use_srn or model_type == 'table' or use_nrtr:
preds = model(images, data=batch[1:])
else:
preds = model(images)
......@@ -398,7 +400,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'TableAttn'
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册