diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml new file mode 100644 index 0000000000000000000000000000000000000000..7b5a9c7117b1220c94b4ee3cf6036f67ffbc13a0 --- /dev/null +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -0,0 +1,106 @@ +Global: + use_gpu: True + epoch_num: 400 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/seed + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: EN_symbol + max_text_length: 100 + infer_mode: False + use_space_char: False + eval_filter: True + save_res_path: ./output/rec/predicts_seed.txt + + +Optimizer: + name: Adadelta + weight_deacy: 0.0 + momentum: 0.9 + lr: + name: Piecewise + decay_epochs: [4,5,8] + values: [1.0, 0.1, 0.01] + regularizer: + name: 'L2' + factor: 2.0e-05 + + +Architecture: + model_type: seed + algorithm: ASTER + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 100] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: ResNet_ASTER + Head: + name: AsterHead # AttentionHead + sDim: 512 + attDim: 512 + max_len_labels: 100 + +Loss: + name: AsterLoss + +PostProcess: + name: SEEDLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - Fasttext: + path: "./cc.en.300.bin" + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SEEDLabelEncode: # Class handling label + - SEEDResize: + image_shape: [3, 64, 256] + - KeepKeys: + keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 6 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SEEDLabelEncode: # Class handling label + - SEEDResize: + image_shape: [3, 64, 256] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: True + batch_size_per_card: 256 + num_workers: 4 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 8bfc17508398224b86d7e9d9a03a250c266a54a8..31deec7b43b5bad1d932abde74787d471d7d35d0 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, SEEDResize from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 643ec70503fb0015e9f7a448d3a6cf9f99171493..07725cf4af298a0538a79355ade247b6db454626 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -106,6 +106,7 @@ class BaseRecLabelEncode(object): self.max_text_len = max_text_length self.beg_str = "sos" self.end_str = "eos" + self.unknown = "UNKNOWN" if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -174,6 +175,7 @@ class NRTRLabelEncode(BaseRecLabelEncode): super(NRTRLabelEncode, self).__init__(max_text_length, character_dict_path, character_type, use_space_char) + def __call__(self, data): text = data['label'] text = self.encode(text) @@ -185,10 +187,12 @@ class NRTRLabelEncode(BaseRecLabelEncode): text = text + [0] * (self.max_text_len - len(text)) data['label'] = np.array(text) return data + def add_special_char(self, dict_character): - dict_character = ['blank','','',''] + dict_character + dict_character = ['blank', '', '', ''] + dict_character return dict_character + class CTCLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ @@ -337,6 +341,39 @@ class AttnLabelEncode(BaseRecLabelEncode): return idx +class SEEDLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SEEDLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + [self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data['length'] = np.array(len(text)) + 1 # conclue eos + text = text + [len(self.character) - 1] * (self.max_text_len - len(text) + ) + data['label'] = np.array(text) + return data + + class SRNLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ @@ -416,7 +453,6 @@ class TableLabelEncode(object): substr = lines[0].decode('utf-8').strip("\r\n").split("\t") character_num = int(substr[0]) elem_num = int(substr[1]) - for cno in range(1, 1 + character_num): character = lines[cno].decode('utf-8').strip("\r\n") list_character.append(character) @@ -588,7 +624,7 @@ class SARLabelEncode(BaseRecLabelEncode): data['length'] = np.array(len(text)) target = [self.start_idx] + text + [self.end_idx] padded_text = [self.padding_idx for _ in range(self.max_text_len)] - + padded_text[:len(target)] = target data['label'] = np.array(padded_text) return data diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index bbdf8f9647999304ea586b4089cd99589ac1fb94..87e3088d07a8c5a2eea5d4deff87c69a753e215b 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -23,6 +23,7 @@ import sys import six import cv2 import numpy as np +import fasttext class DecodeImage(object): @@ -83,12 +84,13 @@ class NRTRDecodeImage(object): elif self.img_mode == 'RGB': assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) img = img[:, :, ::-1] - img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if self.channel_first: img = img.transpose((2, 0, 1)) data['image'] = img return data + class NormalizeImage(object): """ normalize image such as substract mean, divide std """ @@ -133,6 +135,17 @@ class ToCHWImage(object): return data +class Fasttext(object): + def __init__(self, path="None", **kwargs): + self.fast_model = fasttext.load_model(path) + + def __call__(self, data): + label = data['label'] + fast_label = self.fast_model[label] + data['fast_label'] = fast_label + return data + + class KeepKeys(object): def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 51f5855ac36c40f1808fdec4dc00c540792b15e7..b4bca2fc149f753875db56c6ab566c334c9c890e 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -82,6 +82,18 @@ class RecResizeImg(object): return data +class SEEDResize(object): + def __init__(self, image_shape, infer_mode=False, **kwargs): + self.image_shape = image_shape + self.infer_mode = infer_mode + + def __call__(self, data): + img = data['image'] + norm_img = resize_no_padding_img(img, self.image_shape) + data['image'] = norm_img + return data + + class SRNRecResizeImg(object): def __init__(self, image_shape, num_heads, max_text_length, **kwargs): self.image_shape = image_shape @@ -109,7 +121,8 @@ class SARRecResizeImg(object): def __call__(self, data): img = data['image'] - norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio) + norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar( + img, self.image_shape, self.width_downsample_ratio) data['image'] = norm_img data['resized_shape'] = resize_shape data['pad_shape'] = pad_shape @@ -175,6 +188,17 @@ def resize_norm_img(img, image_shape): return padding_im +def resize_no_padding_img(img, image_shape): + imgC, imgH, imgW = image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image + + def resize_norm_img_chinese(img, image_shape): imgC, imgH, imgW = image_shape # todo: change to 0 and modified image shape diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 0484542f0b8c3cc46ce643d9f4577fe87c6346b7..1c637c4276ec1e743b6bfd2d8f363cbe087471a4 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -42,10 +42,14 @@ from .combined_loss import CombinedLoss # table loss from .table_att_loss import TableAttentionLoss +from .rec_aster_loss import AsterLoss + + def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss' + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', + 'SARLoss', 'AsterLoss' ] config = copy.deepcopy(config) diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb99d29a638540b02649a8912051339c08b22dd --- /dev/null +++ b/ppocr/losses/rec_aster_loss.py @@ -0,0 +1,99 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class CosineEmbeddingLoss(nn.Layer): + def __init__(self, margin=0.): + super(CosineEmbeddingLoss, self).__init__() + self.margin = margin + self.epsilon = 1e-12 + + def forward(self, x1, x2, target): + similarity = paddle.fluid.layers.reduce_sum( + x1 * x2, dim=-1) / (paddle.norm( + x1, axis=-1) * paddle.norm( + x2, axis=-1) + self.epsilon) + one_list = paddle.full_like(target, fill_value=1) + out = paddle.fluid.layers.reduce_mean( + paddle.where( + paddle.equal(target, one_list), 1. - similarity, + paddle.maximum( + paddle.zeros_like(similarity), similarity - self.margin))) + + return out + + +class AsterLoss(nn.Layer): + def __init__(self, + weight=None, + size_average=True, + ignore_index=-100, + sequence_normalize=False, + sample_normalize=True, + **kwargs): + super(AsterLoss, self).__init__() + self.weight = weight + self.size_average = size_average + self.ignore_index = ignore_index + self.sequence_normalize = sequence_normalize + self.sample_normalize = sample_normalize + self.loss_sem = CosineEmbeddingLoss() + self.is_cosin_loss = True + self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none') + + def forward(self, predicts, batch): + targets = batch[1].astype("int64") + label_lengths = batch[2].astype('int64') + sem_target = batch[3].astype('float32') + embedding_vectors = predicts['embedding_vectors'] + rec_pred = predicts['rec_pred'] + + if not self.is_cosin_loss: + sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target)) + else: + label_target = paddle.ones([embedding_vectors.shape[0]]) + sem_loss = paddle.sum( + self.loss_sem(embedding_vectors, sem_target, label_target)) + + # rec loss + batch_size, def_max_length = targets.shape[0], targets.shape[1] + + mask = paddle.zeros([batch_size, def_max_length]) + for i in range(batch_size): + mask[i, :label_lengths[i]] = 1 + mask = paddle.cast(mask, "float32") + max_length = max(label_lengths) + assert max_length == rec_pred.shape[1] + targets = targets[:, :max_length] + mask = mask[:, :max_length] + rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]]) + input = nn.functional.log_softmax(rec_pred, axis=1) + targets = paddle.reshape(targets, [-1, 1]) + mask = paddle.reshape(mask, [-1, 1]) + output = -paddle.index_sample(input, index=targets) * mask + output = paddle.sum(output) + if self.sequence_normalize: + output = output / paddle.sum(mask) + if self.sample_normalize: + output = output / batch_size + + loss = output + sem_loss * 0.1 + return {'loss': loss} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 3e82fe756b1e42d6a77192af41e27e0fe30a86ab..db2f41c3a140ecebc42b71ee03f0ecb5cf50ca80 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -13,13 +13,20 @@ # limitations under the License. import Levenshtein +import string class RecMetric(object): - def __init__(self, main_indicator='acc', **kwargs): + def __init__(self, main_indicator='acc', is_filter=False, **kwargs): self.main_indicator = main_indicator + self.is_filter = is_filter self.reset() + def _normalize_text(self, text): + text = ''.join( + filter(lambda x: x in (string.digits + string.ascii_letters), text)) + return text.lower() + def __call__(self, pred_label, *args, **kwargs): preds, labels = pred_label correct_num = 0 @@ -28,6 +35,9 @@ class RecMetric(object): for (pred, pred_conf), (target, _) in zip(preds, labels): pred = pred.replace(" ", "") target = target.replace(" ", "") + if self.is_filter: + pred = self._normalize_text(pred) + target = self._normalize_text(target) norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target), 1) if pred == target: @@ -57,4 +67,3 @@ class RecMetric(object): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 - diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 50438893c7b8511b699f0ebeaf6f4ff46e7d3106..d9815021c9a9a1ee40106c5e323bf346c3b4376d 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -29,7 +29,8 @@ def build_backbone(config, model_type): from .rec_nrtr_mtb import MTB from .rec_resnet_31 import ResNet31 support_dict = [ - 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31" + 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', + "ResNet31" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet @@ -38,6 +39,9 @@ def build_backbone(config, model_type): from .table_resnet_vd import ResNet from .table_mobilenet_v3 import MobileNetV3 support_dict = ["ResNet", "MobileNetV3"] + elif model_type == "seed": + from .rec_resnet_aster import ResNet_ASTER + support_dict = ["ResNet_ASTER"] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb58035753e48f1bcf09a9028049e24775df210 --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_aster.py @@ -0,0 +1,147 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +import sys +import math + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2D( + in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False) + + +def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): + # [n_position] + positions = paddle.arange(0, n_position) + # [feat_dim] + dim_range = paddle.arange(0, feat_dim) + dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim) + # [n_position, feat_dim] + angles = paddle.unsqueeze( + positions, axis=1) / paddle.unsqueeze( + dim_range, axis=0) + angles = paddle.cast(angles, "float32") + angles[:, 0::2] = paddle.sin(angles[:, 0::2]) + angles[:, 1::2] = paddle.cos(angles[:, 1::2]) + return angles + + +class AsterBlock(nn.Layer): + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(AsterBlock, self).__init__() + self.conv1 = conv1x1(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2D(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2D(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet_ASTER(nn.Layer): + """For aster or crnn""" + + def __init__(self, with_lstm=True, n_group=1, in_channels=3): + super(ResNet_ASTER, self).__init__() + self.with_lstm = with_lstm + self.n_group = n_group + + self.layer0 = nn.Sequential( + nn.Conv2D( + in_channels, + 32, + kernel_size=(3, 3), + stride=1, + padding=1, + bias_attr=False), + nn.BatchNorm2D(32), + nn.ReLU()) + + self.inplanes = 32 + self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] + self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] + self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25] + self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25] + self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25] + + if with_lstm: + self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2) + self.out_channels = 2 * 256 + else: + self.out_channels = 512 + + def _make_layer(self, planes, blocks, stride): + downsample = None + if stride != [1, 1] or self.inplanes != planes: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes)) + + layers = [] + layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) + self.inplanes = planes + for _ in range(1, blocks): + layers.append(AsterBlock(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + x0 = self.layer0(x) + x1 = self.layer1(x0) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + x5 = self.layer5(x4) + + cnn_feat = x5.squeeze(2) # [N, c, w] + cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1]) + if self.with_lstm: + rnn_feat, _ = self.rnn(cnn_feat) + return rnn_feat + else: + return cnn_feat + + +if __name__ == "__main__": + x = paddle.randn([3, 3, 32, 100]) + net = ResNet_ASTER() + encoder_feat = net(x) + print(encoder_feat.shape) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 80311f92a6c0c7a67673b90ee19ff5d8778ac0e8..4068515b66f0528c54236bd5e62aa023c76d2e4a 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -28,12 +28,14 @@ def build_head(config): from .rec_srn_head import SRNHead from .rec_nrtr_head import Transformer from .rec_sar_head import SARHead + from .rec_aster_head import AsterHead # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead' + 'SRNHead', 'PGHead', 'TableAttentionHead', 'SARHead', 'Transformer', + 'AsterHead', 'SARHead' ] #table head diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ed520669e36593906e240c69c9f5de28ca2d9ee8 --- /dev/null +++ b/ppocr/modeling/heads/rec_aster_head.py @@ -0,0 +1,390 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import paddle +from paddle import nn +from paddle.nn import functional as F + + +class AsterHead(nn.Layer): + def __init__(self, + in_channels, + out_channels, + sDim, + attDim, + max_len_labels, + time_step=25, + beam_width=5, + **kwargs): + super(AsterHead, self).__init__() + self.num_classes = out_channels + self.in_planes = in_channels + self.sDim = sDim + self.attDim = attDim + self.max_len_labels = max_len_labels + self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim, + attDim, max_len_labels) + self.time_step = time_step + self.embeder = Embedding(self.time_step, in_channels) + self.beam_width = beam_width + self.eos = self.num_classes - 1 + + def forward(self, x, targets=None, embed=None): + return_dict = {} + embedding_vectors = self.embeder(x) + + if self.training: + rec_targets, rec_lengths, _ = targets + rec_pred = self.decoder([x, rec_targets, rec_lengths], + embedding_vectors) + return_dict['rec_pred'] = rec_pred + return_dict['embedding_vectors'] = embedding_vectors + else: + rec_pred, rec_pred_scores = self.decoder.beam_search( + x, self.beam_width, self.eos, embedding_vectors) + return_dict['rec_pred'] = rec_pred + return_dict['rec_pred_scores'] = rec_pred_scores + return_dict['embedding_vectors'] = embedding_vectors + + return return_dict + + +class Embedding(nn.Layer): + def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300): + super(Embedding, self).__init__() + self.in_timestep = in_timestep + self.in_planes = in_planes + self.embed_dim = embed_dim + self.mid_dim = mid_dim + self.eEmbed = nn.Linear( + in_timestep * in_planes, + self.embed_dim) # Embed encoder output to a word-embedding like + + def forward(self, x): + x = paddle.reshape(x, [paddle.shape(x)[0], -1]) + x = self.eEmbed(x) + return x + + +class AttentionRecognitionHead(nn.Layer): + """ + input: [b x 16 x 64 x in_planes] + output: probability sequence: [b x T x num_classes] + """ + + def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels): + super(AttentionRecognitionHead, self).__init__() + self.num_classes = out_channels # this is the output classes. So it includes the . + self.in_planes = in_channels + self.sDim = sDim + self.attDim = attDim + self.max_len_labels = max_len_labels + + self.decoder = DecoderUnit( + sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim) + + def forward(self, x, embed): + x, targets, lengths = x + batch_size = paddle.shape(x)[0] + # Decoder + state = self.decoder.get_initial_state(embed) + outputs = [] + for i in range(max(lengths)): + if i == 0: + y_prev = paddle.full( + shape=[batch_size], fill_value=self.num_classes) + else: + y_prev = targets[:, i - 1] + output, state = self.decoder(x, state, y_prev) + outputs.append(output) + outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1) + return outputs + + # inference stage. + def sample(self, x): + x, _, _ = x + batch_size = x.size(0) + # Decoder + state = paddle.zeros([1, batch_size, self.sDim]) + + predicted_ids, predicted_scores = [], [] + for i in range(self.max_len_labels): + if i == 0: + y_prev = paddle.full( + shape=[batch_size], fill_value=self.num_classes) + else: + y_prev = predicted + + output, state = self.decoder(x, state, y_prev) + output = F.softmax(output, axis=1) + score, predicted = output.max(1) + predicted_ids.append(predicted.unsqueeze(1)) + predicted_scores.append(score.unsqueeze(1)) + predicted_ids = paddle.concat([predicted_ids, 1]) + predicted_scores = paddle.concat([predicted_scores, 1]) + # return predicted_ids.squeeze(), predicted_scores.squeeze() + return predicted_ids, predicted_scores + + def beam_search(self, x, beam_width, eos, embed): + def _inflate(tensor, times, dim): + repeat_dims = [1] * tensor.dim() + repeat_dims[dim] = times + output = paddle.tile(tensor, repeat_dims) + return output + + # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py + batch_size, l, d = x.shape + # inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC + x = paddle.tile( + paddle.transpose( + x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1]) + inflated_encoder_feats = paddle.reshape( + paddle.transpose( + x, perm=[1, 0, 2, 3]), [-1, l, d]) + + # Initialize the decoder + state = self.decoder.get_initial_state(embed, tile_times=beam_width) + + pos_index = paddle.reshape( + paddle.arange(batch_size) * beam_width, shape=[-1, 1]) + + # Initialize the scores + sequence_scores = paddle.full( + shape=[batch_size * beam_width, 1], fill_value=-float('Inf')) + index = [i * beam_width for i in range(0, batch_size)] + sequence_scores[index] = 0.0 + + # Initialize the input vector + y_prev = paddle.full( + shape=[batch_size * beam_width], fill_value=self.num_classes) + + # Store decisions for backtracking + stored_scores = list() + stored_predecessors = list() + stored_emitted_symbols = list() + + for i in range(self.max_len_labels): + output, state = self.decoder(inflated_encoder_feats, state, y_prev) + state = paddle.unsqueeze(state, axis=0) + log_softmax_output = paddle.nn.functional.log_softmax( + output, axis=1) + + sequence_scores = _inflate(sequence_scores, self.num_classes, 1) + sequence_scores += log_softmax_output + scores, candidates = paddle.topk( + paddle.reshape(sequence_scores, [batch_size, -1]), + beam_width, + axis=1) + + # Reshape input = (bk, 1) and sequence_scores = (bk, 1) + y_prev = paddle.reshape( + candidates % self.num_classes, shape=[batch_size * beam_width]) + sequence_scores = paddle.reshape( + scores, shape=[batch_size * beam_width, 1]) + + # Update fields for next timestep + pos_index = paddle.expand_as(pos_index, candidates) + predecessors = paddle.cast( + candidates / self.num_classes + pos_index, dtype='int64') + predecessors = paddle.reshape( + predecessors, shape=[batch_size * beam_width, 1]) + state = paddle.index_select( + state, index=predecessors.squeeze(), axis=1) + + # Update sequence socres and erase scores for symbol so that they aren't expanded + stored_scores.append(sequence_scores.clone()) + y_prev = paddle.reshape(y_prev, shape=[-1, 1]) + eos_prev = paddle.full_like(y_prev, fill_value=eos) + mask = eos_prev == y_prev + mask = paddle.nonzero(mask) + if mask.dim() > 0: + sequence_scores = sequence_scores.numpy() + mask = mask.numpy() + sequence_scores[mask] = -float('inf') + sequence_scores = paddle.to_tensor(sequence_scores) + + # Cache results for backtracking + stored_predecessors.append(predecessors) + y_prev = paddle.squeeze(y_prev) + stored_emitted_symbols.append(y_prev) + + # Do backtracking to return the optimal values + #====== backtrak ======# + # Initialize return variables given different types + p = list() + l = [[self.max_len_labels] * beam_width for _ in range(batch_size) + ] # Placeholder for lengths of top-k sequences + + # the last step output of the beams are not sorted + # thus they are sorted here + sorted_score, sorted_idx = paddle.topk( + paddle.reshape( + stored_scores[-1], shape=[batch_size, beam_width]), + beam_width) + + # initialize the sequence scores with the sorted last step beam scores + s = sorted_score.clone() + + batch_eos_found = [0] * batch_size # the number of EOS found + # in the backward loop below for each batch + t = self.max_len_labels - 1 + # initialize the back pointer with the sorted order of the last step beams. + # add pos_index for indexing variable with b*k as the first dimension. + t_predecessors = paddle.reshape( + sorted_idx + pos_index.expand_as(sorted_idx), + shape=[batch_size * beam_width]) + while t >= 0: + # Re-order the variables with the back pointer + current_symbol = paddle.index_select( + stored_emitted_symbols[t], index=t_predecessors, axis=0) + t_predecessors = paddle.index_select( + stored_predecessors[t].squeeze(), index=t_predecessors, axis=0) + eos_indices = stored_emitted_symbols[t] == eos + eos_indices = paddle.nonzero(eos_indices) + + if eos_indices.dim() > 0: + for i in range(eos_indices.shape[0] - 1, -1, -1): + # Indices of the EOS symbol for both variables + # with b*k as the first dimension, and b, k for + # the first two dimensions + idx = eos_indices[i] + b_idx = int(idx[0] / beam_width) + # The indices of the replacing position + # according to the replacement strategy noted above + res_k_idx = beam_width - (batch_eos_found[b_idx] % + beam_width) - 1 + batch_eos_found[b_idx] += 1 + res_idx = b_idx * beam_width + res_k_idx + + # Replace the old information in return variables + # with the new ended sequence information + t_predecessors[res_idx] = stored_predecessors[t][idx[0]] + current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]] + s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0] + l[b_idx][res_k_idx] = t + 1 + + # record the back tracked results + p.append(current_symbol) + t -= 1 + + # Sort and re-order again as the added ended sequences may change + # the order (very unlikely) + s, re_sorted_idx = s.topk(beam_width) + for b_idx in range(batch_size): + l[b_idx] = [ + l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :] + ] + + re_sorted_idx = paddle.reshape( + re_sorted_idx + pos_index.expand_as(re_sorted_idx), + [batch_size * beam_width]) + + # Reverse the sequences and re-order at the same time + # It is reversed because the backtracking happens in reverse time order + p = [ + paddle.reshape( + paddle.index_select(step, re_sorted_idx, 0), + shape=[batch_size, beam_width, -1]) for step in reversed(p) + ] + p = paddle.concat(p, -1)[:, 0, :] + return p, paddle.ones_like(p) + + +class AttentionUnit(nn.Layer): + def __init__(self, sDim, xDim, attDim): + super(AttentionUnit, self).__init__() + + self.sDim = sDim + self.xDim = xDim + self.attDim = attDim + + self.sEmbed = nn.Linear(sDim, attDim) + self.xEmbed = nn.Linear(xDim, attDim) + self.wEmbed = nn.Linear(attDim, 1) + + def forward(self, x, sPrev): + batch_size, T, _ = x.shape # [b x T x xDim] + x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim] + xProj = self.xEmbed(x) # [(b x T) x attDim] + xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim] + + sPrev = sPrev.squeeze(0) + sProj = self.sEmbed(sPrev) # [b x attDim] + sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim] + sProj = paddle.expand(sProj, + [batch_size, T, self.attDim]) # [b x T x attDim] + + sumTanh = paddle.tanh(sProj + xProj) + sumTanh = paddle.reshape(sumTanh, [-1, self.attDim]) + + vProj = self.wEmbed(sumTanh) # [(b x T) x 1] + vProj = paddle.reshape(vProj, [batch_size, T]) + alpha = F.softmax( + vProj, axis=1) # attention weights for each sample in the minibatch + return alpha + + +class DecoderUnit(nn.Layer): + def __init__(self, sDim, xDim, yDim, attDim): + super(DecoderUnit, self).__init__() + self.sDim = sDim + self.xDim = xDim + self.yDim = yDim + self.attDim = attDim + self.emdDim = attDim + + self.attention_unit = AttentionUnit(sDim, xDim, attDim) + self.tgt_embedding = nn.Embedding( + yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal( + std=0.01)) # the last is used for + self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim) + self.fc = nn.Linear( + sDim, + yDim, + weight_attr=nn.initializer.Normal(std=0.01), + bias_attr=nn.initializer.Constant(value=0)) + self.embed_fc = nn.Linear(300, self.sDim) + + def get_initial_state(self, embed, tile_times=1): + assert embed.shape[1] == 300 + state = self.embed_fc(embed) # N * sDim + if tile_times != 1: + state = state.unsqueeze(1) + trans_state = paddle.transpose(state, perm=[1, 0, 2]) + state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1]) + trans_state = paddle.transpose(state, perm=[1, 0, 2]) + state = paddle.reshape(trans_state, shape=[-1, self.sDim]) + state = state.unsqueeze(0) # 1 * N * sDim + return state + + def forward(self, x, sPrev, yPrev): + # x: feature sequence from the image decoder. + batch_size, T, _ = x.shape + alpha = self.attention_unit(x, sPrev) + context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1) + yPrev = paddle.cast(yPrev, dtype="int64") + yProj = self.tgt_embedding(yPrev) + + concat_context = paddle.concat([yProj, context], 1) + concat_context = paddle.squeeze(concat_context, 1) + sPrev = paddle.squeeze(sPrev, 0) + output, state = self.gru(concat_context, sPrev) + output = paddle.squeeze(output, axis=1) + output = self.fc(output) + return output, state \ No newline at end of file diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py index 78eaecccc55f77d6624aa0c5bdb839acc3462129..0e02a1c0cf88eefa3f7dde4f606b7db3016e4e1c 100755 --- a/ppocr/modeling/transforms/__init__.py +++ b/ppocr/modeling/transforms/__init__.py @@ -17,8 +17,9 @@ __all__ = ['build_transform'] def build_transform(config): from .tps import TPS + from .tps import STN_ON - support_dict = ['TPS'] + support_dict = ['TPS', 'STN_ON'] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py new file mode 100644 index 0000000000000000000000000000000000000000..23bd21891f3eb5ea794e4d46f8236010acc65b45 --- /dev/null +++ b/ppocr/modeling/transforms/stn.py @@ -0,0 +1,108 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np + + +def conv3x3_block(in_channels, out_channels, stride=1): + n = 3 * 3 * out_channels + w = math.sqrt(2. / n) + conv_layer = nn.Conv2D( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + weight_attr=nn.initializer.Normal( + mean=0.0, std=w), + bias_attr=nn.initializer.Constant(0)) + block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU()) + return block + + +class STN(nn.Layer): + def __init__(self, in_channels, num_ctrlpoints, activation='none'): + super(STN, self).__init__() + self.in_channels = in_channels + self.num_ctrlpoints = num_ctrlpoints + self.activation = activation + self.stn_convnet = nn.Sequential( + conv3x3_block(in_channels, 32), #32x64 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(32, 64), #16x32 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(64, 128), # 8*16 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(128, 256), # 4*8 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(256, 256), # 2*4, + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(256, 256)) # 1*2 + self.stn_fc1 = nn.Sequential( + nn.Linear( + 2 * 256, + 512, + weight_attr=nn.initializer.Normal(0, 0.001), + bias_attr=nn.initializer.Constant(0)), + nn.BatchNorm1D(512), + nn.ReLU()) + fc2_bias = self.init_stn() + self.stn_fc2 = nn.Linear( + 512, + num_ctrlpoints * 2, + weight_attr=nn.initializer.Constant(0.0), + bias_attr=nn.initializer.Assign(fc2_bias)) + + def init_stn(self): + margin = 0.01 + sampling_num_per_side = int(self.num_ctrlpoints / 2) + ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side) + ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin + ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + ctrl_points = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) + if self.activation == 'none': + pass + elif self.activation == 'sigmoid': + ctrl_points = -np.log(1. / ctrl_points - 1.) + ctrl_points = paddle.to_tensor(ctrl_points) + fc2_bias = paddle.reshape( + ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]]) + return fc2_bias + + def forward(self, x): + x = self.stn_convnet(x) + batch_size, _, h, w = x.shape + x = paddle.reshape(x, shape=(batch_size, -1)) + img_feat = self.stn_fc1(x) + x = self.stn_fc2(0.1 * img_feat) + if self.activation == 'sigmoid': + x = F.sigmoid(x) + x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) + return img_feat, x diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index dcce6246ac64b4b84229cbd69a4dc53c658b4c7b..81221b0351944ef8a054d44d3857a8c0ad61e052 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -22,6 +22,9 @@ from paddle import nn, ParamAttr from paddle.nn import functional as F import numpy as np +from .tps_spatial_transformer import TPSSpatialTransformer +from .stn import STN + class ConvBNLayer(nn.Layer): def __init__(self, @@ -231,7 +234,8 @@ class GridGenerator(nn.Layer): """ Return inv_delta_C which is needed to calculate T """ F = self.F hat_eye = paddle.eye(F, dtype='float64') # F x F - hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye + hat_C = paddle.norm( + C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye hat_C = (hat_C**2) * paddle.log(hat_C) delta_C = paddle.concat( # F+3 x F+3 [ @@ -301,3 +305,25 @@ class TPS(nn.Layer): [-1, image.shape[2], image.shape[3], 2]) batch_I_r = F.grid_sample(x=image, grid=batch_P_prime) return batch_I_r + + +class STN_ON(nn.Layer): + def __init__(self, in_channels, tps_inputsize, tps_outputsize, + num_control_points, tps_margins, stn_activation): + super(STN_ON, self).__init__() + self.tps = TPSSpatialTransformer( + output_image_size=tuple(tps_outputsize), + num_control_points=num_control_points, + margins=tuple(tps_margins)) + self.stn_head = STN(in_channels=in_channels, + num_ctrlpoints=num_control_points, + activation=stn_activation) + self.tps_inputsize = tps_inputsize + self.out_channels = in_channels + + def forward(self, image): + stn_input = paddle.nn.functional.interpolate( + image, self.tps_inputsize, mode="bilinear", align_corners=True) + stn_img_feat, ctrl_points = self.stn_head(stn_input) + x, _ = self.tps(image, ctrl_points) + return x diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..731e3ee9f0e0090f8fe4c6642325667633e560e0 --- /dev/null +++ b/ppocr/modeling/transforms/tps_spatial_transformer.py @@ -0,0 +1,157 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np +import itertools + + +def grid_sample(input, grid, canvas=None): + input.stop_gradient = False + output = F.grid_sample(input, grid) + if canvas is None: + return output + else: + input_mask = paddle.ones(shape=input.shape) + output_mask = F.grid_sample(input_mask, grid) + padded_output = output * output_mask + canvas * (1 - output_mask) + return padded_output + + +# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 +def compute_partial_repr(input_points, control_points): + N = input_points.shape[0] + M = control_points.shape[0] + pairwise_diff = paddle.reshape( + input_points, shape=[N, 1, 2]) - paddle.reshape( + control_points, shape=[1, M, 2]) + # original implementation, very slow + # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance + pairwise_diff_square = pairwise_diff * pairwise_diff + pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, + 1] + repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist) + # fix numerical error for 0 * log(0), substitute all nan with 0 + mask = repr_matrix != repr_matrix + repr_matrix[mask] = 0 + return repr_matrix + + +# output_ctrl_pts are specified, according to our task. +def build_output_control_points(num_control_points, margins): + margin_x, margin_y = margins + num_ctrl_pts_per_side = num_control_points // 2 + ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) + ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y + ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + # ctrl_pts_top = ctrl_pts_top[1:-1,:] + # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] + output_ctrl_pts_arr = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0) + output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr) + return output_ctrl_pts + + +class TPSSpatialTransformer(nn.Layer): + def __init__(self, + output_image_size=None, + num_control_points=None, + margins=None): + super(TPSSpatialTransformer, self).__init__() + self.output_image_size = output_image_size + self.num_control_points = num_control_points + self.margins = margins + + self.target_height, self.target_width = output_image_size + target_control_points = build_output_control_points(num_control_points, + margins) + N = num_control_points + # N = N - 4 + + # create padded kernel matrix + forward_kernel = paddle.zeros(shape=[N + 3, N + 3]) + target_control_partial_repr = compute_partial_repr( + target_control_points, target_control_points) + target_control_partial_repr = paddle.cast(target_control_partial_repr, + forward_kernel.dtype) + forward_kernel[:N, :N] = target_control_partial_repr + forward_kernel[:N, -3] = 1 + forward_kernel[-3, :N] = 1 + target_control_points = paddle.cast(target_control_points, + forward_kernel.dtype) + forward_kernel[:N, -2:] = target_control_points + forward_kernel[-2:, :N] = paddle.transpose( + target_control_points, perm=[1, 0]) + # compute inverse matrix + inverse_kernel = paddle.inverse(forward_kernel) + + # create target cordinate matrix + HW = self.target_height * self.target_width + target_coordinate = list( + itertools.product( + range(self.target_height), range(self.target_width))) + target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2 + Y, X = paddle.split( + target_coordinate, target_coordinate.shape[1], axis=1) + #Y, X = target_coordinate.split(1, dim = 1) + Y = Y / (self.target_height - 1) + X = X / (self.target_width - 1) + target_coordinate = paddle.concat( + [X, Y], axis=1) # convert from (y, x) to (x, y) + target_coordinate_partial_repr = compute_partial_repr( + target_coordinate, target_control_points) + target_coordinate_repr = paddle.concat( + [ + target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]), + target_coordinate + ], + axis=1) + + # register precomputed matrices + self.inverse_kernel = inverse_kernel + self.padding_matrix = paddle.zeros(shape=[3, 2]) + self.target_coordinate_repr = target_coordinate_repr + self.target_control_points = target_control_points + + def forward(self, input, source_control_points): + assert source_control_points.ndimension() == 3 + assert source_control_points.shape[1] == self.num_control_points + assert source_control_points.shape[2] == 2 + #batch_size = source_control_points.shape[0] + batch_size = paddle.shape(source_control_points)[0] + + self.padding_matrix = paddle.expand( + self.padding_matrix, shape=[batch_size, 3, 2]) + Y = paddle.concat([source_control_points, self.padding_matrix], 1) + mapping_matrix = paddle.matmul(self.inverse_kernel, Y) + source_coordinate = paddle.matmul(self.target_coordinate_repr, + mapping_matrix) + + grid = paddle.reshape( + source_coordinate, + shape=[-1, self.target_height, self.target_width, 2]) + grid = paddle.clip(grid, 0, + 1) # the source_control_points may be out of [0, 1]. + # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] + grid = 2.0 * grid - 1.0 + output_maps = grid_sample(input, grid, canvas=None) + return output_maps, source_coordinate diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py index 8215b92d8c8d05c2b3c2e95ac989bf4ea011310b..34098c0fad553f7d39f6b5341e4da70a263eeaea 100644 --- a/ppocr/optimizer/optimizer.py +++ b/ppocr/optimizer/optimizer.py @@ -127,3 +127,34 @@ class RMSProp(object): grad_clip=self.grad_clip, parameters=parameters) return opt + + +class Adadelta(object): + def __init__(self, + learning_rate=0.001, + epsilon=1e-08, + rho=0.95, + parameter_list=None, + weight_decay=None, + grad_clip=None, + name=None, + **kwargs): + self.learning_rate = learning_rate + self.epsilon = epsilon + self.rho = rho + self.parameter_list = parameter_list + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.grad_clip = grad_clip + self.name = name + + def __call__(self, parameters): + opt = optim.Adadelta( + learning_rate=self.learning_rate, + epsilon=self.epsilon, + rho=self.rho, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + name=self.name, + parameters=parameters) + return opt diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 77081abeb691f01d0d3fcc8285d78cf5878411a7..cc03984756dd594202ba3cb4cd989371ed7a808e 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -24,17 +24,19 @@ __all__ = ['build_post_process'] from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess -from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \ - TableLabelDecode, SARLabelDecode +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ + TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess + def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', - 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode' + 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', + 'SEEDLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 6ff375eb43e773f68f89f66663835ffe45da09b5..5d874eaefaac57d4645f063a775983101b3f9039 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -303,6 +303,88 @@ class AttnLabelDecode(BaseRecLabelDecode): return idx +class SEEDLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SEEDLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = dict_character + [self.end_str] + return dict_character + + def get_ignored_tokens(self): + end_idx = self.get_beg_end_flag_idx("eos") + return [end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "sos": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "eos": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + [end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + preds_idx = preds["rec_pred"] + if isinstance(preds_idx, paddle.Tensor): + preds_idx = preds_idx.numpy() + if "rec_pred_scores" in preds: + preds_idx = preds["rec_pred"] + preds_prob = preds["rec_pred_scores"] + else: + preds_idx = preds["rec_pred"].argmax(axis=2) + preds_prob = preds["rec_pred"].max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + class SRNLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/tools/program.py b/tools/program.py index d6d47d047b5bc65024b5db866b67ee260f8f8bb3..6de327c7b784472a02c1cfc78926a84c7de9272a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -188,10 +188,12 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" use_nrtr = config['Architecture']['algorithm'] == "NRTR" use_sar = config['Architecture']['algorithm'] == 'SAR' + use_seed = config['Architecture']['algorithm'] == 'SEED' try: model_type = config['Architecture']['model_type'] except: model_type = None + algorithm = config['Architecture']['algorithm'] if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] @@ -215,7 +217,7 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table' or use_nrtr or use_sar: + if use_srn or model_type == 'table' or use_nrtr or use_sar or use_seed: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -402,7 +404,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR' + 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'ASTER' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'