提交 1623c17c 编写于 作者: T Topdu

add rec_nrtr

上级 b6f0a903
...@@ -3,22 +3,38 @@ Global: ...@@ -3,22 +3,38 @@ Global:
epoch_num: 21 epoch_num: 21
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
<<<<<<< HEAD
save_model_dir: ./output/rec/nrtr_final/ save_model_dir: ./output/rec/nrtr_final/
save_epoch_step: 1 save_epoch_step: 1
# evaluation is run every 2000 iterations # evaluation is run every 2000 iterations
eval_batch_step: [0, 2000] eval_batch_step: [0, 2000]
cal_metric_during_train: True cal_metric_during_train: True
=======
save_model_dir: ./output/rec/piloptimnrtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: False
>>>>>>> 9c67a7f... add rec_nrtr
pretrained_model: pretrained_model:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png infer_img: doc/imgs_words_en/word_10.png
# for data or label process # for data or label process
<<<<<<< HEAD
character_dict_path: character_dict_path:
character_type: EN_symbol character_type: EN_symbol
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: True
=======
character_dict_path: ppocr/utils/dict_99.txt
character_type: dict_99
max_text_length: 25
infer_mode: False
use_space_char: False
>>>>>>> 9c67a7f... add rec_nrtr
save_res_path: ./output/rec/predicts_nrtr.txt save_res_path: ./output/rec/predicts_nrtr.txt
Optimizer: Optimizer:
......
...@@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: ...@@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
...@@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: ...@@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md) PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)
...@@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t ...@@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
......
...@@ -60,5 +60,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r ...@@ -60,5 +60,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
...@@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend ...@@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
For training Chinese data, it is recommended to use For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file: [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
......
...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap ...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *
......
...@@ -96,7 +96,7 @@ class BaseRecLabelEncode(object): ...@@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari' 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
......
...@@ -57,6 +57,38 @@ class DecodeImage(object): ...@@ -57,6 +57,38 @@ class DecodeImage(object):
return data return data
class NRTRDecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object): class NormalizeImage(object):
""" normalize image such as substract mean, divide std """ normalize image such as substract mean, divide std
""" """
......
...@@ -16,7 +16,7 @@ import math ...@@ -16,7 +16,7 @@ import math
import cv2 import cv2
import numpy as np import numpy as np
import random import random
from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort from .text_image_aug import tia_perspective, tia_stretch, tia_distort
...@@ -42,6 +42,34 @@ class ClsResizeImg(object): ...@@ -42,6 +42,34 @@ class ClsResizeImg(object):
data['image'] = norm_img data['image'] = norm_img
return data return data
class PILResize(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
def __call__(self, data):
img = data['image']
image_pil = Image.fromarray(np.uint8(img))
norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
norm_img = np.array(norm_img)
norm_img = np.expand_dims(norm_img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
class CVResize(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
def __call__(self, data):
img = data['image']
#print(img)
norm_img = cv2.resize(img,self.image_shape)
norm_img = np.expand_dims(norm_img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
class RecResizeImg(object): class RecResizeImg(object):
def __init__(self, def __init__(self,
......
...@@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss ...@@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
...@@ -42,8 +42,8 @@ from .combined_loss import CombinedLoss ...@@ -42,8 +42,8 @@ from .combined_loss import CombinedLoss
def build_loss(config): def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss' 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss']
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format( assert module_name in support_dict, Exception('loss only support {}'.format(
......
import paddle
from paddle import nn
import paddle.nn.functional as F
def cal_performance(pred, tgt):
pred = pred.max(1)[1]
tgt = tgt.contiguous().view(-1)
non_pad_mask = tgt.ne(0)
n_correct = pred.eq(tgt)
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
return n_correct
class NRTRLoss(nn.Layer):
def __init__(self,smoothing=True, **kwargs):
super(NRTRLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0)
self.smoothing = smoothing
def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
max_len = batch[2].max()
tgt = batch[1][:,1:2+max_len]
tgt = tgt.reshape([-1] )
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64'))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
return {'loss': loss}
...@@ -30,7 +30,7 @@ class RecMetric(object): ...@@ -30,7 +30,7 @@ class RecMetric(object):
target = target.replace(" ", "") target = target.replace(" ", "")
norm_edit_dis += Levenshtein.distance(pred, target) / max( norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target), 1) len(pred), len(target), 1)
if pred == target: if pred.lower() == target.lower():
correct_num += 1 correct_num += 1
all_num += 1 all_num += 1
self.correct_num += correct_num self.correct_num += correct_num
...@@ -48,8 +48,8 @@ class RecMetric(object): ...@@ -48,8 +48,8 @@ class RecMetric(object):
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
""" """
acc = 1.0 * self.correct_num / self.all_num acc = 1.0 * self.correct_num / (self.all_num)
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num)
self.reset() self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis} return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
...@@ -57,3 +57,4 @@ class RecMetric(object): ...@@ -57,3 +57,4 @@ class RecMetric(object):
self.correct_num = 0 self.correct_num = 0
self.all_num = 0 self.all_num = 0
self.norm_edit_dis = 0 self.norm_edit_dis = 0
\ No newline at end of file
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle
from paddle import nn from paddle import nn
from ppocr.modeling.transforms import build_transform from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone from ppocr.modeling.backbones import build_backbone
......
...@@ -25,7 +25,10 @@ def build_backbone(config, model_type): ...@@ -25,7 +25,10 @@ def build_backbone(config, model_type):
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] from .rec_nrtr_mtb import MTB
from .rec_swin import SwinTransformer
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer']
elif model_type == 'e2e': elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet'] support_dict = ['ResNet']
......
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import Linear
from paddle.nn.initializer import XavierUniform as xavier_uniform_
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_
zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)
class MultiheadAttention(nn.Layer):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
Examples::
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
if add_bias_kv:
self.bias_k = self.create_parameter(
shape=(1, 1, embed_dim), default_initializer=zeros_)
self.add_parameter("bias_k", self.bias_k)
self.bias_v = self.create_parameter(
shape=(1, 1, embed_dim), default_initializer=zeros_)
self.add_parameter("bias_v", self.bias_v)
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1))
self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1))
def _reset_parameters(self):
xavier_uniform_(self.out_proj.weight)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False, attn_mask=None, qkv_ = [False,False,False]):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]
key_padding_mask: if True, mask padding based on batch size
incremental_state: if provided, previous time steps are cashed
need_weights: output attn_output_weights
static_kv: key and value are static
Outputs of forward function
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
qkv_same = qkv_[0]
kv_same = qkv_[1]
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert key.shape == value.shape
if qkv_same:
# self-attention
q, k, v = self._in_proj_qkv(query)
elif kv_same:
# encoder-decoder attention
q = self._in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k, v = self._in_proj_kv(key)
else:
q = self._in_proj_q(query)
k = self._in_proj_k(key)
v = self._in_proj_v(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
self.bias_k = paddle.concat([self.bias_k for i in range(bsz)],axis=1)
self.bias_v = paddle.concat([self.bias_v for i in range(bsz)],axis=1)
k = paddle.concat([k, self.bias_k])
v = paddle.concat([v, self.bias_v])
if attn_mask is not None:
attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1)
if key_padding_mask is not None:
key_padding_mask = paddle.concat(
[key_padding_mask,paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1)
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
if k is not None:
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
if v is not None:
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
src_len = k.shape[1]
if key_padding_mask is not None:
assert key_padding_mask.shape[0] == bsz
assert key_padding_mask.shape[1] == src_len
if self.add_zero_attn:
src_len += 1
k = paddle.concat([k, paddle.zeros((k.shape[0], 1) + k.shape[2:],dtype=k.dtype)], axis=1)
v = paddle.concat([v, paddle.zeros((v.shape[0], 1) + v.shape[2:],dtype=v.dtype)], axis=1)
if attn_mask is not None:
attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1)
if key_padding_mask is not None:
key_padding_mask = paddle.concat(
[key_padding_mask, paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1)
attn_output_weights = paddle.bmm(q, k.transpose([0,2,1]))
assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
y = paddle.where(key==0.,key, y)
attn_output_weights += y
attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len])
attn_output_weights = F.softmax(
attn_output_weights.astype('float32'), axis=-1,
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype)
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = paddle.bmm(attn_output_weights, v)
assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim])
attn_output = self.out_proj(attn_output)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights
def _in_proj_qkv(self, query):
query = query.transpose([1, 2, 0])
query = paddle.unsqueeze(query, axis=2)
res = self.conv3(query)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res.chunk(3, axis=-1)
def _in_proj_kv(self, key):
key = key.transpose([1, 2, 0])
key = paddle.unsqueeze(key, axis=2)
res = self.conv2(key)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res.chunk(2, axis=-1)
def _in_proj_q(self, query):
query = query.transpose([1, 2, 0])
query = paddle.unsqueeze(query, axis=2)
res = self.conv1(query)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
def _in_proj_k(self, key):
key = key.transpose([1, 2, 0])
key = paddle.unsqueeze(key, axis=2)
res = self.conv1(key)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
def _in_proj_v(self, value):
value = value.transpose([1,2,0])#(1, 2, 0)
value = paddle.unsqueeze(value, axis=2)
res = self.conv1(value)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
class MultiheadAttentionOptim(nn.Layer):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
Examples::
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
super(MultiheadAttentionOptim, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
self._reset_parameters()
self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
def _reset_parameters(self):
xavier_uniform_(self.out_proj.weight)
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False, attn_mask=None):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]
key_padding_mask: if True, mask padding based on batch size
incremental_state: if provided, previous time steps are cashed
need_weights: output attn_output_weights
static_kv: key and value are static
Outputs of forward function
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert key.shape == value.shape
q = self._in_proj_q(query)
k = self._in_proj_k(key)
v = self._in_proj_v(value)
q *= self.scaling
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
src_len = k.shape[1]
if key_padding_mask is not None:
assert key_padding_mask.shape[0] == bsz
assert key_padding_mask.shape[1] == src_len
attn_output_weights = paddle.bmm(q, k.transpose([0,2,1]))
assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
y = paddle.where(key==0.,key, y)
attn_output_weights += y
attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len])
attn_output_weights = F.softmax(
attn_output_weights.astype('float32'), axis=-1,
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype)
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = paddle.bmm(attn_output_weights, v)
assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim])
attn_output = self.out_proj(attn_output)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights
def _in_proj_q(self, query):
query = query.transpose([1, 2, 0])
query = paddle.unsqueeze(query, axis=2)
res = self.conv1(query)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
def _in_proj_k(self, key):
key = key.transpose([1, 2, 0])
key = paddle.unsqueeze(key, axis=2)
res = self.conv2(key)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
def _in_proj_v(self, value):
value = value.transpose([1,2,0])#(1, 2, 0)
value = paddle.unsqueeze(value, axis=2)
res = self.conv3(value)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
return res
\ No newline at end of file
from paddle import nn
class MTB(nn.Layer):
def __init__(self, cnn_num, in_channels):
super(MTB, self).__init__()
self.block = nn.Sequential()
self.out_channels = in_channels
self.cnn_num = cnn_num
if self.cnn_num == 2:
for i in range(self.cnn_num):
self.block.add_sublayer('conv_{}'.format(i), nn.Conv2D(
in_channels = in_channels if i == 0 else 32*(2**(i-1)),
out_channels = 32*(2**i),
kernel_size = 3,
stride = 2,
padding=1))
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
self.block.add_sublayer('bn_{}'.format(i), nn.BatchNorm2D(32*(2**i)))
def forward(self, images):
x = self.block(images)
if self.cnn_num == 2:
# (b, w, h, c)
x = x.transpose([0, 3, 2, 1])
x_shape = x.shape
x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
return x
...@@ -7,7 +7,11 @@ from paddle.nn import LayerList ...@@ -7,7 +7,11 @@ from paddle.nn import LayerList
from paddle.nn.initializer import XavierNormal as xavier_uniform_ from paddle.nn.initializer import XavierNormal as xavier_uniform_
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
import numpy as np import numpy as np
<<<<<<< HEAD
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
=======
from ppocr.modeling.backbones.multiheadAttention import MultiheadAttentionOptim
>>>>>>> 9c67a7f... add rec_nrtr
from paddle.nn.initializer import Constant as constant_ from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_ from paddle.nn.initializer import XavierNormal as xavier_normal_
......
...@@ -21,7 +21,7 @@ def build_neck(config): ...@@ -21,7 +21,7 @@ def build_neck(config):
from .sast_fpn import SASTFPN from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder from .rnn import SequenceEncoder
from .pg_fpn import PGFPN from .pg_fpn import PGFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN','TFEncoder']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
......
...@@ -24,16 +24,15 @@ __all__ = ['build_post_process'] ...@@ -24,16 +24,15 @@ __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode' 'DistillationCTCLabelDecode', 'NRTRLabelDecode'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -28,7 +28,7 @@ class BaseRecLabelDecode(object): ...@@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari' 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
...@@ -256,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -256,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]: batch_idx][idx]:
continue continue
char_list.append(self.character[int(text_index[batch_idx][ char_list.append(self.character[int(text_index[batch_idx][idx])])
idx])])
if text_prob is not None: if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx]) conf_list.append(text_prob[batch_idx][idx])
else: else:
......
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
\ No newline at end of file
...@@ -22,6 +22,7 @@ import sys ...@@ -22,6 +22,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
...@@ -30,6 +31,7 @@ from ppocr.utils.save_load import init_model ...@@ -30,6 +31,7 @@ from ppocr.utils.save_load import init_model
from ppocr.utils.utility import print_dict from ppocr.utils.utility import print_dict
import tools.program as program import tools.program as program
def main(): def main():
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
......
...@@ -186,7 +186,7 @@ def train(config, ...@@ -186,7 +186,7 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
else: else:
...@@ -211,6 +211,9 @@ def train(config, ...@@ -211,6 +211,9 @@ def train(config,
others = batch[-4:] others = batch[-4:]
preds = model(images, others) preds = model(images, others)
model_average = True model_average = True
elif use_nrtr:
max_len = batch[2].max()
preds = model(images, batch[1][:,:2+max_len])
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -350,13 +353,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, ...@@ -350,13 +353,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if use_srn: if use_srn:
others = batch[-4:] others = batch[-4:]
preds = model(images, others) preds = model(images, others)
else: else:
preds = model(images) preds = model(images)
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
...@@ -386,7 +387,7 @@ def preprocess(is_train=False): ...@@ -386,7 +387,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation' 'CLS', 'PGNet', 'Distillation','NRTR'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册