diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f705f1e23db61799fa1974804daa368e3cc82661
--- /dev/null
+++ b/configs/rec/rec_resnet_stn_bilstm_att.yml
@@ -0,0 +1,101 @@
+Global:
+ use_gpu: False
+ epoch_num: 400
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/b3_rare_r34_none_gru/
+ save_epoch_step: 3
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path:
+ character_type: EN_symbol
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
+
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ learning_rate: 0.0005
+ regularizer:
+ name: 'L2'
+ factor: 0.00000
+
+Architecture:
+ model_type: rec
+ algorithm: ASTER
+ Transform:
+ name: STN_ON
+ tps_inputsize: [32, 64]
+ tps_outputsize: [32, 100]
+ num_control_points: 20
+ tps_margins: [0.05,0.05]
+ stn_activation: none
+ Backbone:
+ name: ResNet_ASTER
+ Head:
+ name: AsterHead # AttentionHead
+ sDim: 512
+ attDim: 512
+ max_len_labels: 100
+
+Loss:
+ name: AsterLoss
+
+PostProcess:
+ name: AttnLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/1.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - AttnLabelEncode: # Class handling label
+ - RecResizeImg:
+ image_shape: [3, 32, 100]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 2
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/1.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - AttnLabelEncode: # Class handling label
+ - RecResizeImg:
+ image_shape: [3, 32, 100]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 2
+ num_workers: 8
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index e25cce79b553f127afc0167f18b6f663ceb617d7..0e1d4939d607df8dba27ab446134cd68b8770393 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -104,6 +104,7 @@ class BaseRecLabelEncode(object):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
+ self.unknown = "UNKNOWN"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
@@ -275,7 +276,9 @@ class AttnLabelEncode(BaseRecLabelEncode):
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
- dict_character = [self.beg_str] + dict_character + [self.end_str]
+ self.unknown = "UNKNOWN"
+ dict_character = [self.beg_str] + dict_character + [self.end_str
+ ] + [self.unknown]
return dict_character
def __call__(self, data):
@@ -288,6 +291,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
data['length'] = np.array(len(text))
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
- len(text) - 2)
+
data['label'] = np.array(text)
return data
@@ -352,19 +356,22 @@ class SRNLabelEncode(BaseRecLabelEncode):
% beg_or_end
return idx
+
class TableLabelEncode(object):
""" Convert between text-label and text-index """
- def __init__(self,
- max_text_length,
- max_elem_length,
- max_cell_num,
- character_dict_path,
- span_weight = 1.0,
- **kwargs):
+
+ def __init__(self,
+ max_text_length,
+ max_elem_length,
+ max_cell_num,
+ character_dict_path,
+ span_weight=1.0,
+ **kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
- list_character, list_elem = self.load_char_elem_dict(character_dict_path)
+ list_character, list_elem = self.load_char_elem_dict(
+ character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
@@ -374,7 +381,7 @@ class TableLabelEncode(object):
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
-
+
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
@@ -383,27 +390,27 @@ class TableLabelEncode(object):
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
- for cno in range(1, 1+character_num):
+ for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
- for eno in range(1+character_num, 1+character_num+elem_num):
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
-
+
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
-
+
def get_span_idx_list(self):
span_idx_list = []
for elem in self.dict_elem:
if 'span' in elem:
span_idx_list.append(self.dict_elem[elem])
return span_idx_list
-
+
def __call__(self, data):
cells = data['cells']
structure = data['structure']['tokens']
@@ -412,18 +419,22 @@ class TableLabelEncode(object):
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
- structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
+ structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
+ )
structure = np.array(structure)
data['structure'] = structure
elem_char_idx1 = self.dict_elem['
']
elem_char_idx2 = self.dict_elem[' | 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
@@ -450,9 +461,11 @@ class TableLabelEncode(object):
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
- data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
- elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
- self.max_elem_length, self.max_cell_num, elem_num])
+ data['sp_tokens'] = np.array([
+ char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
+ elem_char_idx1, elem_char_idx2, self.max_text_length,
+ self.max_elem_length, self.max_cell_num, elem_num
+ ])
return data
def encode(self, text, char_or_elem):
@@ -504,9 +517,8 @@ class TableLabelEncode(object):
idx = np.array(self.dict_elem[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
- % beg_or_end
+ % beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
- % char_or_elem
+ % char_or_elem
return idx
-
\ No newline at end of file
diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py
index ce9e1b38675ae8df4a2e83b88c1adae4476a10b5..b519f4fdea29f50895fbe7c597195699cdd2469b 100644
--- a/ppocr/data/simple_dataset.py
+++ b/ppocr/data/simple_dataset.py
@@ -22,6 +22,7 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
+ print("===== simpledataset ========")
super(SimpleDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 025ae7ca5cc604eea59423ca7f523c37c1492e35..2a6737745457802b7cdcbb66ebf6e199404b0491 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -41,10 +41,13 @@ from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
+from .rec_aster_loss import AsterLoss
+
+
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
- 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
+ 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss', 'AsterLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..858fadc021a7e6edbdea5e79f25e6c45880688f8
--- /dev/null
+++ b/ppocr/losses/rec_aster_loss.py
@@ -0,0 +1,79 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import fasttext
+
+
+class AsterLoss(nn.Layer):
+ def __init__(self,
+ weight=None,
+ size_average=True,
+ ignore_index=-100,
+ sequence_normalize=False,
+ sample_normalize=True,
+ **kwargs):
+ super(AsterLoss, self).__init__()
+ self.weight = weight
+ self.size_average = size_average
+ self.ignore_index = ignore_index
+ self.sequence_normalize = sequence_normalize
+ self.sample_normalize = sample_normalize
+ self.loss_func = paddle.nn.CosineSimilarity()
+
+ def forward(self, predicts, batch):
+ targets = batch[1].astype("int64")
+ label_lengths = batch[2].astype('int64')
+ # sem_target = batch[3].astype('float32')
+ embedding_vectors = predicts['embedding_vectors']
+ rec_pred = predicts['rec_pred']
+
+ # semantic loss
+ # print(embedding_vectors)
+ # print(embedding_vectors.shape)
+ # targets = fasttext[targets]
+ # sem_loss = 1 - self.loss_func(embedding_vectors, targets)
+
+ # rec loss
+ batch_size, num_steps, num_classes = rec_pred.shape[0], rec_pred.shape[
+ 1], rec_pred.shape[2]
+ assert len(targets.shape) == len(list(rec_pred.shape)) - 1, \
+ "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+ mask = paddle.zeros([batch_size, num_steps])
+ for i in range(batch_size):
+ mask[i, :label_lengths[i]] = 1
+ mask = paddle.cast(mask, "float32")
+ max_length = max(label_lengths)
+ assert max_length == rec_pred.shape[1]
+ targets = targets[:, :max_length]
+ mask = mask[:, :max_length]
+ rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[-1]])
+ input = nn.functional.log_softmax(rec_pred, axis=1)
+ targets = paddle.reshape(targets, [-1, 1])
+ mask = paddle.reshape(mask, [-1, 1])
+ # print("input:", input)
+ output = -paddle.gather(input, index=targets, axis=1) * mask
+ output = paddle.sum(output)
+ if self.sequence_normalize:
+ output = output / paddle.sum(mask)
+ if self.sample_normalize:
+ output = output / batch_size
+ loss = output
+ return {'loss': loss} # , 'sem_loss':sem_loss}
diff --git a/ppocr/losses/rec_att_loss.py b/ppocr/losses/rec_att_loss.py
index 6e2f67483c86a45f3aa1feb1e1fac1a5013bfb46..2d8d64b9d22ba1ef33a4a884cf619703691a8c70 100644
--- a/ppocr/losses/rec_att_loss.py
+++ b/ppocr/losses/rec_att_loss.py
@@ -35,5 +35,7 @@ class AttentionLoss(nn.Layer):
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
targets = paddle.reshape(targets, [-1])
+ print("input:", paddle.argmax(inputs, axis=1))
+ print("targets:", targets)
return {'loss': paddle.sum(self.loss_func(inputs, targets))}
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index f4fe8c76be0835f55f402f35ad6a91a5ca116d88..e0bc45b4762df1f9dfeebec579eb3bca13184a5c 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -26,8 +26,10 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
+ from .rec_resnet_aster import ResNet_ASTER
support_dict = [
- "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
+ "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN",
+ "ResNet_ASTER"
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/levit.py b/ppocr/modeling/backbones/levit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b04e9def992c5c72f0fc5cae4004576e325ea5c
--- /dev/null
+++ b/ppocr/modeling/backbones/levit.py
@@ -0,0 +1,707 @@
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+
+# Modified from
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+# Copyright 2020 Ross Wightman, Apache-2.0 License
+
+import paddle
+import itertools
+#import utils
+import math
+import warnings
+import paddle.nn.functional as F
+from paddle.nn.initializer import TruncatedNormal, Constant
+
+#from timm.models.vision_transformer import trunc_normal_
+#from timm.models.registry import register_model
+
+specification = {
+ 'LeViT_128S': {
+ 'C': '128_256_384',
+ 'D': 16,
+ 'N': '4_6_8',
+ 'X': '2_3_4',
+ 'drop_path': 0,
+ 'weights':
+ 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
+ },
+ 'LeViT_128': {
+ 'C': '128_256_384',
+ 'D': 16,
+ 'N': '4_8_12',
+ 'X': '4_4_4',
+ 'drop_path': 0,
+ 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
+ },
+ 'LeViT_192': {
+ 'C': '192_288_384',
+ 'D': 32,
+ 'N': '3_5_6',
+ 'X': '4_4_4',
+ 'drop_path': 0,
+ 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
+ },
+ 'LeViT_256': {
+ 'C': '256_384_512',
+ 'D': 32,
+ 'N': '4_6_8',
+ 'X': '4_4_4',
+ 'drop_path': 0,
+ 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
+ },
+ 'LeViT_384': {
+ 'C': '384_512_768',
+ 'D': 32,
+ 'N': '6_9_12',
+ 'X': '4_4_4',
+ 'drop_path': 0.1,
+ 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
+ },
+}
+
+__all__ = [specification.keys()]
+
+trunc_normal_ = TruncatedNormal(std=.02)
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+#@register_model
+def LeViT_128S(class_dim=1000, distillation=True, pretrained=False, fuse=False):
+ return model_factory(
+ **specification['LeViT_128S'],
+ class_dim=class_dim,
+ distillation=distillation,
+ pretrained=pretrained,
+ fuse=fuse)
+
+
+#@register_model
+def LeViT_128(class_dim=1000, distillation=True, pretrained=False, fuse=False):
+ return model_factory(
+ **specification['LeViT_128'],
+ class_dim=class_dim,
+ distillation=distillation,
+ pretrained=pretrained,
+ fuse=fuse)
+
+
+#@register_model
+def LeViT_192(class_dim=1000, distillation=True, pretrained=False, fuse=False):
+ return model_factory(
+ **specification['LeViT_192'],
+ class_dim=class_dim,
+ distillation=distillation,
+ pretrained=pretrained,
+ fuse=fuse)
+
+
+#@register_model
+def LeViT_256(class_dim=1000, distillation=False, pretrained=False, fuse=False):
+ return model_factory(
+ **specification['LeViT_256'],
+ class_dim=class_dim,
+ distillation=distillation,
+ pretrained=pretrained,
+ fuse=fuse)
+
+
+#@register_model
+def LeViT_384(class_dim=1000, distillation=True, pretrained=False, fuse=False):
+ return model_factory(
+ **specification['LeViT_384'],
+ class_dim=class_dim,
+ distillation=distillation,
+ pretrained=pretrained,
+ fuse=fuse)
+
+
+FLOPS_COUNTER = 0
+
+
+class Conv2d_BN(paddle.nn.Sequential):
+ def __init__(self,
+ a,
+ b,
+ ks=1,
+ stride=1,
+ pad=0,
+ dilation=1,
+ groups=1,
+ bn_weight_init=1,
+ resolution=-10000):
+ super().__init__()
+ self.add_sublayer(
+ 'c',
+ paddle.nn.Conv2D(
+ a, b, ks, stride, pad, dilation, groups, bias_attr=False))
+ bn = paddle.nn.BatchNorm2D(b)
+ ones_(bn.weight)
+ zeros_(bn.bias)
+ self.add_sublayer('bn', bn)
+
+ global FLOPS_COUNTER
+ output_points = (
+ (resolution + 2 * pad - dilation * (ks - 1) - 1) // stride + 1)**2
+ FLOPS_COUNTER += a * b * output_points * (ks**2)
+
+ @paddle.no_grad()
+ def fuse(self):
+ c, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn.running_mean * bn.weight / \
+ (bn.running_var + bn.eps)**0.5
+ m = paddle.nn.Conv2D(
+ w.size(1),
+ w.size(0),
+ w.shape[2:],
+ stride=self.c.stride,
+ padding=self.c.padding,
+ dilation=self.c.dilation,
+ groups=self.c.groups)
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+class Linear_BN(paddle.nn.Sequential):
+ def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
+ super().__init__()
+ self.add_sublayer('c', paddle.nn.Linear(a, b, bias_attr=False))
+ bn = paddle.nn.BatchNorm1D(b)
+ ones_(bn.weight)
+ zeros_(bn.bias)
+ self.add_sublayer('bn', bn)
+
+ global FLOPS_COUNTER
+ output_points = resolution**2
+ FLOPS_COUNTER += a * b * output_points
+
+ @paddle.no_grad()
+ def fuse(self):
+ l, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
+ w = l.weight * w[:, None]
+ b = bn.bias - bn.running_mean * bn.weight / \
+ (bn.running_var + bn.eps)**0.5
+ m = paddle.nn.Linear(w.size(1), w.size(0))
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+ def forward(self, x):
+ l, bn = self._sub_layers.values()
+ x = l(x)
+ return paddle.reshape(bn(x.flatten(0, 1)), x.shape)
+
+
+class BN_Linear(paddle.nn.Sequential):
+ def __init__(self, a, b, bias=True, std=0.02):
+ super().__init__()
+ self.add_sublayer('bn', paddle.nn.BatchNorm1D(a))
+ l = paddle.nn.Linear(a, b, bias_attr=bias)
+ trunc_normal_(l.weight)
+ if bias:
+ zeros_(l.bias)
+ self.add_sublayer('l', l)
+ global FLOPS_COUNTER
+ FLOPS_COUNTER += a * b
+
+ @paddle.no_grad()
+ def fuse(self):
+ bn, l = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
+ b = bn.bias - self.bn.running_mean * \
+ self.bn.weight / (bn.running_var + bn.eps)**0.5
+ w = l.weight * w[None, :]
+ if l.bias is None:
+ b = b @self.l.weight.T
+ else:
+ b = (l.weight @b[:, None]).view(-1) + self.l.bias
+ m = paddle.nn.Linear(w.size(1), w.size(0))
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+def b16(n, activation, resolution=224):
+ return paddle.nn.Sequential(
+ Conv2d_BN(
+ 3, n // 8, 3, 2, 1, resolution=resolution),
+ activation(),
+ Conv2d_BN(
+ n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
+ activation(),
+ Conv2d_BN(
+ n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
+ activation(),
+ Conv2d_BN(
+ n // 2, n, 3, 2, 1, resolution=resolution // 8))
+
+
+class Residual(paddle.nn.Layer):
+ def __init__(self, m, drop):
+ super().__init__()
+ self.m = m
+ self.drop = drop
+
+ def forward(self, x):
+ if self.training and self.drop > 0:
+ return x + self.m(x) * paddle.rand(
+ x.size(0), 1, 1,
+ device=x.device).ge_(self.drop).div(1 - self.drop).detach()
+ else:
+ return x + self.m(x)
+
+
+class Attention(paddle.nn.Layer):
+ def __init__(self,
+ dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=4,
+ activation=None,
+ resolution=14):
+ super().__init__()
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * num_heads
+ self.attn_ratio = attn_ratio
+ self.h = self.dh + nh_kd * 2
+ self.qkv = Linear_BN(dim, self.h, resolution=resolution)
+ self.proj = paddle.nn.Sequential(
+ activation(),
+ Linear_BN(
+ self.dh, dim, bn_weight_init=0, resolution=resolution))
+ points = list(itertools.product(range(resolution), range(resolution)))
+ N = len(points)
+ attention_offsets = {}
+ idxs = []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = self.create_parameter(
+ shape=(num_heads, len(attention_offsets)),
+ default_initializer=zeros_)
+ tensor_idxs = paddle.to_tensor(idxs, dtype='int64')
+ self.register_buffer('attention_bias_idxs',
+ paddle.reshape(tensor_idxs, [N, N]))
+
+ global FLOPS_COUNTER
+ #queries * keys
+ FLOPS_COUNTER += num_heads * (resolution**4) * key_dim
+ # softmax
+ FLOPS_COUNTER += num_heads * (resolution**4)
+ #attention * v
+ FLOPS_COUNTER += num_heads * self.d * (resolution**4)
+
+ @paddle.no_grad()
+ def train(self, mode=True):
+ if mode:
+ super().train()
+ else:
+ super().eval()
+ if mode and hasattr(self, 'ab'):
+ del self.ab
+ else:
+ gather_list = []
+ attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
+ for idx in self.attention_bias_idxs:
+ gather = paddle.gather(attention_bias_t, idx)
+ gather_list.append(gather)
+ attention_biases = paddle.transpose(
+ paddle.concat(gather_list), (1, 0)).reshape(
+ (0, self.attention_bias_idxs.shape[0],
+ self.attention_bias_idxs.shape[1]))
+ self.ab = attention_biases
+ #self.ab = self.attention_biases[:, self.attention_bias_idxs]
+
+ def forward(self, x): # x (B,N,C)
+ self.training = True
+ B, N, C = x.shape
+ qkv = self.qkv(x)
+ qkv = paddle.reshape(qkv,
+ [B, N, self.num_heads, self.h // self.num_heads])
+ q, k, v = paddle.split(
+ qkv, [self.key_dim, self.key_dim, self.d], axis=3)
+ q = paddle.transpose(q, perm=[0, 2, 1, 3])
+ k = paddle.transpose(k, perm=[0, 2, 1, 3])
+ v = paddle.transpose(v, perm=[0, 2, 1, 3])
+ k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2])
+
+ if self.training:
+ gather_list = []
+ attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
+ for idx in self.attention_bias_idxs:
+ gather = paddle.gather(attention_bias_t, idx)
+ gather_list.append(gather)
+ attention_biases = paddle.transpose(
+ paddle.concat(gather_list), (1, 0)).reshape(
+ (0, self.attention_bias_idxs.shape[0],
+ self.attention_bias_idxs.shape[1]))
+ else:
+ attention_biases = self.ab
+ #np_ = paddle.to_tensor(self.attention_biases.numpy()[:, self.attention_bias_idxs.numpy()])
+ #print(self.attention_bias_idxs.shape)
+ #print(attention_biases.shape)
+ #print(np_.shape)
+ #print(np_.equal(attention_biases))
+ #exit()
+
+ attn = ((q @k_transpose) * self.scale + attention_biases)
+ attn = F.softmax(attn)
+ x = paddle.transpose(attn @v, perm=[0, 2, 1, 3])
+ x = paddle.reshape(x, [B, N, self.dh])
+ x = self.proj(x)
+ return x
+
+
+class Subsample(paddle.nn.Layer):
+ def __init__(self, stride, resolution):
+ super().__init__()
+ self.stride = stride
+ self.resolution = resolution
+
+ def forward(self, x):
+ B, N, C = x.shape
+ x = paddle.reshape(x, [B, self.resolution, self.resolution,
+ C])[:, ::self.stride, ::self.stride]
+ x = paddle.reshape(x, [B, -1, C])
+ return x
+
+
+class AttentionSubsample(paddle.nn.Layer):
+ def __init__(self,
+ in_dim,
+ out_dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=2,
+ activation=None,
+ stride=2,
+ resolution=14,
+ resolution_=7):
+ super().__init__()
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * self.num_heads
+ self.attn_ratio = attn_ratio
+ self.resolution_ = resolution_
+ self.resolution_2 = resolution_**2
+ self.training = True
+ h = self.dh + nh_kd
+ self.kv = Linear_BN(in_dim, h, resolution=resolution)
+
+ self.q = paddle.nn.Sequential(
+ Subsample(stride, resolution),
+ Linear_BN(
+ in_dim, nh_kd, resolution=resolution_))
+ self.proj = paddle.nn.Sequential(
+ activation(), Linear_BN(
+ self.dh, out_dim, resolution=resolution_))
+
+ self.stride = stride
+ self.resolution = resolution
+ points = list(itertools.product(range(resolution), range(resolution)))
+ points_ = list(
+ itertools.product(range(resolution_), range(resolution_)))
+
+ N = len(points)
+ N_ = len(points_)
+ attention_offsets = {}
+ idxs = []
+ i = 0
+ j = 0
+ for p1 in points_:
+ i += 1
+ for p2 in points:
+ j += 1
+ size = 1
+ offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),
+ abs(p1[1] * stride - p2[1] + (size - 1) / 2))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = self.create_parameter(
+ shape=(num_heads, len(attention_offsets)),
+ default_initializer=zeros_)
+
+ tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64')
+ self.register_buffer('attention_bias_idxs',
+ paddle.reshape(tensor_idxs_, [N_, N]))
+
+ global FLOPS_COUNTER
+ #queries * keys
+ FLOPS_COUNTER += num_heads * \
+ (resolution**2) * (resolution_**2) * key_dim
+ # softmax
+ FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2)
+ #attention * v
+ FLOPS_COUNTER += num_heads * \
+ (resolution**2) * (resolution_**2) * self.d
+
+ @paddle.no_grad()
+ def train(self, mode=True):
+ if mode:
+ super().train()
+ else:
+ super().eval()
+ if mode and hasattr(self, 'ab'):
+ del self.ab
+ else:
+ gather_list = []
+ attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
+ for idx in self.attention_bias_idxs:
+ gather = paddle.gather(attention_bias_t, idx)
+ gather_list.append(gather)
+ attention_biases = paddle.transpose(
+ paddle.concat(gather_list), (1, 0)).reshape(
+ (0, self.attention_bias_idxs.shape[0],
+ self.attention_bias_idxs.shape[1]))
+ self.ab = attention_biases
+ #self.ab = self.attention_biases[:, self.attention_bias_idxs]
+
+ def forward(self, x):
+ self.training = True
+ B, N, C = x.shape
+ kv = self.kv(x)
+ kv = paddle.reshape(kv, [B, N, self.num_heads, -1])
+ k, v = paddle.split(kv, [self.key_dim, self.d], axis=3)
+ k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC
+ v = paddle.transpose(v, perm=[0, 2, 1, 3])
+ q = paddle.reshape(
+ self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim])
+ q = paddle.transpose(q, perm=[0, 2, 1, 3])
+
+ if self.training:
+ gather_list = []
+ attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
+ for idx in self.attention_bias_idxs:
+ gather = paddle.gather(attention_bias_t, idx)
+ gather_list.append(gather)
+ attention_biases = paddle.transpose(
+ paddle.concat(gather_list), (1, 0)).reshape(
+ (0, self.attention_bias_idxs.shape[0],
+ self.attention_bias_idxs.shape[1]))
+ else:
+ attention_biases = self.ab
+
+ attn = (q @paddle.transpose(
+ k, perm=[0, 1, 3, 2])) * self.scale + attention_biases
+ attn = F.softmax(attn)
+
+ x = paddle.reshape(
+ paddle.transpose(
+ (attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh])
+ x = self.proj(x)
+ return x
+
+
+class LeViT(paddle.nn.Layer):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ class_dim=1000,
+ embed_dim=[192],
+ key_dim=[64],
+ depth=[12],
+ num_heads=[3],
+ attn_ratio=[2],
+ mlp_ratio=[2],
+ hybrid_backbone=None,
+ down_ops=[],
+ attention_activation=paddle.nn.Hardswish,
+ mlp_activation=paddle.nn.Hardswish,
+ distillation=True,
+ drop_path=0):
+ super().__init__()
+ global FLOPS_COUNTER
+
+ self.class_dim = class_dim
+ self.num_features = embed_dim[-1]
+ self.embed_dim = embed_dim
+ self.distillation = distillation
+
+ self.patch_embed = hybrid_backbone
+
+ self.blocks = []
+ down_ops.append([''])
+ resolution = img_size // patch_size
+ for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
+ zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio,
+ down_ops)):
+ for _ in range(dpth):
+ self.blocks.append(
+ Residual(
+ Attention(
+ ed,
+ kd,
+ nh,
+ attn_ratio=ar,
+ activation=attention_activation,
+ resolution=resolution, ),
+ drop_path))
+ if mr > 0:
+ h = int(ed * mr)
+ self.blocks.append(
+ Residual(
+ paddle.nn.Sequential(
+ Linear_BN(
+ ed, h, resolution=resolution),
+ mlp_activation(),
+ Linear_BN(
+ h,
+ ed,
+ bn_weight_init=0,
+ resolution=resolution), ),
+ drop_path))
+ if do[0] == 'Subsample':
+ #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
+ resolution_ = (resolution - 1) // do[5] + 1
+ self.blocks.append(
+ AttentionSubsample(
+ *embed_dim[i:i + 2],
+ key_dim=do[1],
+ num_heads=do[2],
+ attn_ratio=do[3],
+ activation=attention_activation,
+ stride=do[5],
+ resolution=resolution,
+ resolution_=resolution_))
+ resolution = resolution_
+ if do[4] > 0: # mlp_ratio
+ h = int(embed_dim[i + 1] * do[4])
+ self.blocks.append(
+ Residual(
+ paddle.nn.Sequential(
+ Linear_BN(
+ embed_dim[i + 1], h, resolution=resolution),
+ mlp_activation(),
+ Linear_BN(
+ h,
+ embed_dim[i + 1],
+ bn_weight_init=0,
+ resolution=resolution), ),
+ drop_path))
+ self.blocks = paddle.nn.Sequential(*self.blocks)
+
+ # Classifier head
+ self.head = BN_Linear(
+ embed_dim[-1], class_dim) if class_dim > 0 else paddle.nn.Identity()
+ if distillation:
+ self.head_dist = BN_Linear(
+ embed_dim[-1],
+ class_dim) if class_dim > 0 else paddle.nn.Identity()
+
+ self.FLOPS = FLOPS_COUNTER
+ FLOPS_COUNTER = 0
+
+ def no_weight_decay(self):
+ return {x for x in self.state_dict().keys() if 'attention_biases' in x}
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ x = x.flatten(2)
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ x = self.blocks(x)
+ x = x.mean(1)
+ if self.distillation:
+ x = self.head(x), self.head_dist(x)
+ if not self.training:
+ x = (x[0] + x[1]) / 2
+ else:
+ x = self.head(x)
+ return x
+
+
+def model_factory(C, D, X, N, drop_path, weights, class_dim, distillation,
+ pretrained, fuse):
+ embed_dim = [int(x) for x in C.split('_')]
+ num_heads = [int(x) for x in N.split('_')]
+ depth = [int(x) for x in X.split('_')]
+ act = paddle.nn.Hardswish
+ model = LeViT(
+ patch_size=16,
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ key_dim=[D] * 3,
+ depth=depth,
+ attn_ratio=[2, 2, 2],
+ mlp_ratio=[2, 2, 2],
+ down_ops=[
+ #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
+ ['Subsample', D, embed_dim[0] // D, 4, 2, 2],
+ ['Subsample', D, embed_dim[1] // D, 4, 2, 2],
+ ],
+ attention_activation=act,
+ mlp_activation=act,
+ hybrid_backbone=b16(embed_dim[0], activation=act),
+ class_dim=class_dim,
+ drop_path=drop_path,
+ distillation=distillation)
+ # if pretrained:
+ # checkpoint = torch.hub.load_state_dict_from_url(
+ # weights, map_location='cpu')
+ # model.load_state_dict(checkpoint['model'])
+ if fuse:
+ utils.replace_batchnorm(model)
+
+ return model
+
+
+if __name__ == '__main__':
+ '''
+ import torch
+ checkpoint = torch.load('../LeViT/pretrained256.pth')
+ torch_dict = checkpoint['net']
+ paddle_dict = {}
+ fc_names = ["c.weight", "l.weight", "qkv.weight", "fc1.weight", "fc2.weight", "downsample.reduction.weight", "head.weight", "attn.proj.weight"]
+ rename_dict = {"running_mean": "_mean", "running_var": "_variance"}
+ range_tuple = (0, 502)
+ idx = 0
+ for key in torch_dict:
+ idx += 1
+ weight = torch_dict[key].cpu().numpy()
+ flag = [i in key for i in fc_names]
+ if any(flag):
+ if "emb" not in key:
+ print("weight {} need to be trans".format(key))
+ weight = weight.transpose()
+ key = key.replace("running_mean", "_mean")
+ key = key.replace("running_var", "_variance")
+ paddle_dict[key]=weight
+ '''
+ import numpy as np
+ net = globals()['LeViT_256'](fuse=False,
+ pretrained=False,
+ distillation=False)
+ load_layer_state_dict = paddle.load(
+ "./LeViT_256_official_nodistillation_paddle.pdparams")
+ #net.set_state_dict(paddle_dict)
+ net.set_state_dict(load_layer_state_dict)
+ net.eval()
+ #paddle.save(net.state_dict(), "./LeViT_256_official_paddle.pdparams")
+ #model = paddle.jit.to_static(net,input_spec=[paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype='float32')])
+ #paddle.jit.save(model, "./LeViT_256_official_inference/inference")
+ #exit()
+ np.random.seed(123)
+ img = np.random.rand(1, 3, 224, 224).astype('float32')
+ img = paddle.to_tensor(img)
+ outputs = net(img).numpy()
+ print(outputs[0][:10])
+ #print(outputs.shape)
diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb58035753e48f1bcf09a9028049e24775df210
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_aster.py
@@ -0,0 +1,147 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+
+import sys
+import math
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2D(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias_attr=False)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2D(
+ in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
+
+
+def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
+ # [n_position]
+ positions = paddle.arange(0, n_position)
+ # [feat_dim]
+ dim_range = paddle.arange(0, feat_dim)
+ dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
+ # [n_position, feat_dim]
+ angles = paddle.unsqueeze(
+ positions, axis=1) / paddle.unsqueeze(
+ dim_range, axis=0)
+ angles = paddle.cast(angles, "float32")
+ angles[:, 0::2] = paddle.sin(angles[:, 0::2])
+ angles[:, 1::2] = paddle.cos(angles[:, 1::2])
+ return angles
+
+
+class AsterBlock(nn.Layer):
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(AsterBlock, self).__init__()
+ self.conv1 = conv1x1(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2D(planes)
+ self.relu = nn.ReLU()
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2D(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class ResNet_ASTER(nn.Layer):
+ """For aster or crnn"""
+
+ def __init__(self, with_lstm=True, n_group=1, in_channels=3):
+ super(ResNet_ASTER, self).__init__()
+ self.with_lstm = with_lstm
+ self.n_group = n_group
+
+ self.layer0 = nn.Sequential(
+ nn.Conv2D(
+ in_channels,
+ 32,
+ kernel_size=(3, 3),
+ stride=1,
+ padding=1,
+ bias_attr=False),
+ nn.BatchNorm2D(32),
+ nn.ReLU())
+
+ self.inplanes = 32
+ self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
+ self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
+ self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
+ self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
+ self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
+
+ if with_lstm:
+ self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
+ self.out_channels = 2 * 256
+ else:
+ self.out_channels = 512
+
+ def _make_layer(self, planes, blocks, stride):
+ downsample = None
+ if stride != [1, 1] or self.inplanes != planes:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
+
+ layers = []
+ layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes
+ for _ in range(1, blocks):
+ layers.append(AsterBlock(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x0 = self.layer0(x)
+ x1 = self.layer1(x0)
+ x2 = self.layer2(x1)
+ x3 = self.layer3(x2)
+ x4 = self.layer4(x3)
+ x5 = self.layer5(x4)
+
+ cnn_feat = x5.squeeze(2) # [N, c, w]
+ cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
+ if self.with_lstm:
+ rnn_feat, _ = self.rnn(cnn_feat)
+ return rnn_feat
+ else:
+ return cnn_feat
+
+
+if __name__ == "__main__":
+ x = paddle.randn([3, 3, 32, 100])
+ net = ResNet_ASTER()
+ encoder_feat = net(x)
+ print(encoder_feat.shape)
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 5096479415f504aa9f074d55bd9b2e4a31c730b4..cd923d78be9c6c6eb1db849e923fcbe5e5b432b4 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -26,12 +26,15 @@ def build_head(config):
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
+ from .rec_aster_head import AttentionRecognitionHead, AsterHead
# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
- 'SRNHead', 'PGHead', 'TableAttentionHead']
+ 'SRNHead', 'PGHead', 'TableAttentionHead', 'AttentionRecognitionHead',
+ 'AsterHead'
+ ]
#table head
from .table_att_head import TableAttentionHead
@@ -39,5 +42,6 @@ def build_head(config):
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
support_dict))
+ print(config)
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..055b109730fc1c8ff9b886ce700e3783ee569b17
--- /dev/null
+++ b/ppocr/modeling/heads/rec_aster_head.py
@@ -0,0 +1,258 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+
+
+class AsterHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ sDim,
+ attDim,
+ max_len_labels,
+ time_step=25,
+ beam_width=5,
+ **kwargs):
+ super(AsterHead, self).__init__()
+ self.num_classes = out_channels
+ self.in_planes = in_channels
+ self.sDim = sDim
+ self.attDim = attDim
+ self.max_len_labels = max_len_labels
+ self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
+ attDim, max_len_labels)
+ self.time_step = time_step
+ self.embeder = Embedding(self.time_step, in_channels)
+ self.beam_width = beam_width
+
+ def forward(self, x, targets=None, embed=None):
+ return_dict = {}
+ embedding_vectors = self.embeder(x)
+ rec_targets, rec_lengths = targets
+
+ if self.training:
+ rec_pred = self.decoder([x, rec_targets, rec_lengths],
+ embedding_vectors)
+ return_dict['rec_pred'] = rec_pred
+ return_dict['embedding_vectors'] = embedding_vectors
+ else:
+ rec_pred, rec_pred_scores = self.decoder.beam_search(
+ x, self.beam_width, self.eos, embedding_vectors)
+ return_dict['rec_pred'] = rec_pred
+ return_dict['rec_pred_scores'] = rec_pred_scores
+ return_dict['embedding_vectors'] = embedding_vectors
+
+ return return_dict
+
+
+class Embedding(nn.Layer):
+ def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
+ super(Embedding, self).__init__()
+ self.in_timestep = in_timestep
+ self.in_planes = in_planes
+ self.embed_dim = embed_dim
+ self.mid_dim = mid_dim
+ self.eEmbed = nn.Linear(
+ in_timestep * in_planes,
+ self.embed_dim) # Embed encoder output to a word-embedding like
+
+ def forward(self, x):
+ x = paddle.reshape(x, [paddle.shape(x)[0], -1])
+ x = self.eEmbed(x)
+ return x
+
+
+class AttentionRecognitionHead(nn.Layer):
+ """
+ input: [b x 16 x 64 x in_planes]
+ output: probability sequence: [b x T x num_classes]
+ """
+
+ def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
+ super(AttentionRecognitionHead, self).__init__()
+ self.num_classes = out_channels # this is the output classes. So it includes the .
+ self.in_planes = in_channels
+ self.sDim = sDim
+ self.attDim = attDim
+ self.max_len_labels = max_len_labels
+
+ self.decoder = DecoderUnit(
+ sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
+
+ def forward(self, x, embed):
+ x, targets, lengths = x
+ batch_size = paddle.shape(x)[0]
+ # Decoder
+ state = self.decoder.get_initial_state(embed)
+ outputs = []
+
+ for i in range(max(lengths)):
+ if i == 0:
+ y_prev = paddle.full(
+ shape=[batch_size], fill_value=self.num_classes)
+ else:
+ y_prev = targets[:, i - 1]
+
+ output, state = self.decoder(x, state, y_prev)
+ outputs.append(output)
+ outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
+ return outputs
+
+ # inference stage.
+ def sample(self, x):
+ x, _, _ = x
+ batch_size = x.size(0)
+ # Decoder
+ state = paddle.zeros([1, batch_size, self.sDim])
+
+ predicted_ids, predicted_scores = [], []
+ for i in range(self.max_len_labels):
+ if i == 0:
+ y_prev = paddle.full(
+ shape=[batch_size], fill_value=self.num_classes)
+ else:
+ y_prev = predicted
+
+ output, state = self.decoder(x, state, y_prev)
+ output = F.softmax(output, axis=1)
+ score, predicted = output.max(1)
+ predicted_ids.append(predicted.unsqueeze(1))
+ predicted_scores.append(score.unsqueeze(1))
+ predicted_ids = paddle.concat([predicted_ids, 1])
+ predicted_scores = paddle.concat([predicted_scores, 1])
+ # return predicted_ids.squeeze(), predicted_scores.squeeze()
+ return predicted_ids, predicted_scores
+
+
+class AttentionUnit(nn.Layer):
+ def __init__(self, sDim, xDim, attDim):
+ super(AttentionUnit, self).__init__()
+
+ self.sDim = sDim
+ self.xDim = xDim
+ self.attDim = attDim
+
+ self.sEmbed = nn.Linear(
+ sDim,
+ attDim,
+ weight_attr=paddle.nn.initializer.Normal(std=0.01),
+ bias_attr=paddle.nn.initializer.Constant(0.0))
+ self.xEmbed = nn.Linear(
+ xDim,
+ attDim,
+ weight_attr=paddle.nn.initializer.Normal(std=0.01),
+ bias_attr=paddle.nn.initializer.Constant(0.0))
+ self.wEmbed = nn.Linear(
+ attDim,
+ 1,
+ weight_attr=paddle.nn.initializer.Normal(std=0.01),
+ bias_attr=paddle.nn.initializer.Constant(0.0))
+
+ def forward(self, x, sPrev):
+ batch_size, T, _ = x.shape # [b x T x xDim]
+ x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
+ xProj = self.xEmbed(x) # [(b x T) x attDim]
+ xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
+
+ sPrev = sPrev.squeeze(0)
+ sProj = self.sEmbed(sPrev) # [b x attDim]
+ sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
+ sProj = paddle.expand(sProj,
+ [batch_size, T, self.attDim]) # [b x T x attDim]
+
+ sumTanh = paddle.tanh(sProj + xProj)
+ sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
+
+ vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
+ vProj = paddle.reshape(vProj, [batch_size, T])
+
+ alpha = F.softmax(
+ vProj, axis=1) # attention weights for each sample in the minibatch
+
+ return alpha
+
+
+class DecoderUnit(nn.Layer):
+ def __init__(self, sDim, xDim, yDim, attDim):
+ super(DecoderUnit, self).__init__()
+ self.sDim = sDim
+ self.xDim = xDim
+ self.yDim = yDim
+ self.attDim = attDim
+ self.emdDim = attDim
+
+ self.attention_unit = AttentionUnit(sDim, xDim, attDim)
+ self.tgt_embedding = nn.Embedding(
+ yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
+ std=0.01)) # the last is used for
+ self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
+ self.fc = nn.Linear(
+ sDim,
+ yDim,
+ weight_attr=nn.initializer.Normal(std=0.01),
+ bias_attr=nn.initializer.Constant(value=0))
+ self.embed_fc = nn.Linear(300, self.sDim)
+
+ def get_initial_state(self, embed, tile_times=1):
+ assert embed.shape[1] == 300
+ state = self.embed_fc(embed) # N * sDim
+ if tile_times != 1:
+ state = state.unsqueeze(1)
+ trans_state = paddle.transpose(state, perm=[1, 0, 2])
+ state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
+ trans_state = paddle.transpose(state, perm=[1, 0, 2])
+ state = paddle.reshape(trans_state, shape=[-1, self.sDim])
+ state = state.unsqueeze(0) # 1 * N * sDim
+ return state
+
+ def forward(self, x, sPrev, yPrev):
+ # x: feature sequence from the image decoder.
+ batch_size, T, _ = x.shape
+ alpha = self.attention_unit(x, sPrev)
+ context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
+ yPrev = paddle.cast(yPrev, dtype="int64")
+ yProj = self.tgt_embedding(yPrev)
+
+ concat_context = paddle.concat([yProj, context], 1)
+ concat_context = paddle.squeeze(concat_context, 1)
+ sPrev = paddle.squeeze(sPrev, 0)
+ output, state = self.gru(concat_context, sPrev)
+ output = paddle.squeeze(output, axis=1)
+ output = self.fc(output)
+ return output, state
+
+
+if __name__ == "__main__":
+ model = AttentionRecognitionHead(
+ num_classes=20,
+ in_channels=30,
+ sDim=512,
+ attDim=512,
+ max_len_labels=25,
+ out_channels=38)
+
+ data = paddle.ones([16, 64, 3])
+ targets = paddle.ones([16, 25])
+ length = paddle.to_tensor(20)
+ x = [data, targets, length]
+ output = model(x)
+ print(output.shape)
diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py
index 4286d7691d1abcf80c283d1c1ab76f8cd1f4a634..79f112f723d8ded05dc114b0ce0cb5fd098c1f45 100644
--- a/ppocr/modeling/heads/rec_att_head.py
+++ b/ppocr/modeling/heads/rec_att_head.py
@@ -44,10 +44,13 @@ class AttentionHead(nn.Layer):
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
+ targets = targets[0]
+ print(targets)
if targets is not None:
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
+ # print("char_onehots:", char_onehots)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
@@ -104,6 +107,8 @@ class AttentionGRUCell(nn.Layer):
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
+ # print("concat_context:", concat_context.shape)
+ # print("prev_hidden:", prev_hidden.shape)
cur_hidden = self.rnn(concat_context, prev_hidden)
diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py
index 78eaecccc55f77d6624aa0c5bdb839acc3462129..0e02a1c0cf88eefa3f7dde4f606b7db3016e4e1c 100755
--- a/ppocr/modeling/transforms/__init__.py
+++ b/ppocr/modeling/transforms/__init__.py
@@ -17,8 +17,9 @@ __all__ = ['build_transform']
def build_transform(config):
from .tps import TPS
+ from .tps import STN_ON
- support_dict = ['TPS']
+ support_dict = ['TPS', 'STN_ON']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b26e27aea37c82c9c717a2450efac4cdd9c4648
--- /dev/null
+++ b/ppocr/modeling/transforms/stn.py
@@ -0,0 +1,121 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+
+
+def conv3x3_block(in_channels, out_channels, stride=1):
+ n = 3 * 3 * out_channels
+ w = math.sqrt(2. / n)
+ conv_layer = nn.Conv2D(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ weight_attr=nn.initializer.Normal(
+ mean=0.0, std=w),
+ bias_attr=nn.initializer.Constant(0))
+ block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
+ return block
+
+
+class STN(nn.Layer):
+ def __init__(self, in_channels, num_ctrlpoints, activation='none'):
+ super(STN, self).__init__()
+ self.in_channels = in_channels
+ self.num_ctrlpoints = num_ctrlpoints
+ self.activation = activation
+ self.stn_convnet = nn.Sequential(
+ conv3x3_block(in_channels, 32), #32x64
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(32, 64), #16x32
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(64, 128), # 8*16
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(128, 256), # 4*8
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(256, 256), # 2*4,
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(256, 256)) # 1*2
+ self.stn_fc1 = nn.Sequential(
+ nn.Linear(
+ 2 * 256,
+ 512,
+ weight_attr=nn.initializer.Normal(0, 0.001),
+ bias_attr=nn.initializer.Constant(0)),
+ nn.BatchNorm1D(512),
+ nn.ReLU())
+ fc2_bias = self.init_stn()
+ self.stn_fc2 = nn.Linear(
+ 512,
+ num_ctrlpoints * 2,
+ weight_attr=nn.initializer.Constant(0.0),
+ bias_attr=nn.initializer.Assign(fc2_bias))
+
+ def init_stn(self):
+ margin = 0.01
+ sampling_num_per_side = int(self.num_ctrlpoints / 2)
+ ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
+ ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
+ ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ ctrl_points = np.concatenate(
+ [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
+ if self.activation == 'none':
+ pass
+ elif self.activation == 'sigmoid':
+ ctrl_points = -np.log(1. / ctrl_points - 1.)
+ ctrl_points = paddle.to_tensor(ctrl_points)
+ fc2_bias = paddle.reshape(
+ ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
+ return fc2_bias
+
+ def forward(self, x):
+ x = self.stn_convnet(x)
+ batch_size, _, h, w = x.shape
+ x = paddle.reshape(x, shape=(batch_size, -1))
+ img_feat = self.stn_fc1(x)
+ x = self.stn_fc2(0.1 * img_feat)
+ if self.activation == 'sigmoid':
+ x = F.sigmoid(x)
+ x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
+ return img_feat, x
+
+
+if __name__ == "__main__":
+ in_planes = 3
+ num_ctrlpoints = 20
+ np.random.seed(100)
+ activation = 'none' # 'sigmoid'
+ stn_head = STN(in_planes, num_ctrlpoints, activation)
+ data = np.random.randn(10, 3, 32, 64).astype("float32")
+ print("data:", np.sum(data))
+ input = paddle.to_tensor(data)
+ #input = paddle.randn([10, 3, 32, 64])
+ control_points = stn_head(input)
diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py
index dcce6246ac64b4b84229cbd69a4dc53c658b4c7b..fc462100716f4ac360eea5354b9efaf7359c8976 100644
--- a/ppocr/modeling/transforms/tps.py
+++ b/ppocr/modeling/transforms/tps.py
@@ -22,6 +22,9 @@ from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
+from .tps_spatial_transformer import TPSSpatialTransformer
+from .stn import STN
+
class ConvBNLayer(nn.Layer):
def __init__(self,
@@ -231,7 +234,8 @@ class GridGenerator(nn.Layer):
""" Return inv_delta_C which is needed to calculate T """
F = self.F
hat_eye = paddle.eye(F, dtype='float64') # F x F
- hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
+ hat_C = paddle.norm(
+ C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3
[
@@ -301,3 +305,26 @@ class TPS(nn.Layer):
[-1, image.shape[2], image.shape[3], 2])
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
return batch_I_r
+
+
+class STN_ON(nn.Layer):
+ def __init__(self, in_channels, tps_inputsize, tps_outputsize,
+ num_control_points, tps_margins, stn_activation):
+ super(STN_ON, self).__init__()
+ self.tps = TPSSpatialTransformer(
+ output_image_size=tuple(tps_outputsize),
+ num_control_points=num_control_points,
+ margins=tuple(tps_margins))
+ self.stn_head = STN(in_channels=in_channels,
+ num_ctrlpoints=num_control_points,
+ activation=stn_activation)
+ self.tps_inputsize = tps_inputsize
+ self.out_channels = in_channels
+
+ def forward(self, image):
+ stn_input = paddle.nn.functional.interpolate(
+ image, self.tps_inputsize, mode="bilinear", align_corners=True)
+ stn_img_feat, ctrl_points = self.stn_head(stn_input)
+ x, _ = self.tps(image, ctrl_points)
+ # print(x.shape)
+ return x
diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..da54ffb7868229d9674d918b05b26dd8ee35935b
--- /dev/null
+++ b/ppocr/modeling/transforms/tps_spatial_transformer.py
@@ -0,0 +1,178 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+import itertools
+
+
+def grid_sample(input, grid, canvas=None):
+ input.stop_gradient = False
+ output = F.grid_sample(input, grid)
+ if canvas is None:
+ return output
+ else:
+ input_mask = paddle.ones(shape=input.shape)
+ output_mask = F.grid_sample(input_mask, grid)
+ padded_output = output * output_mask + canvas * (1 - output_mask)
+ return padded_output
+
+
+# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
+def compute_partial_repr(input_points, control_points):
+ N = input_points.shape[0]
+ M = control_points.shape[0]
+ pairwise_diff = paddle.reshape(
+ input_points, shape=[N, 1, 2]) - paddle.reshape(
+ control_points, shape=[1, M, 2])
+ # original implementation, very slow
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
+ pairwise_diff_square = pairwise_diff * pairwise_diff
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
+ 1]
+ repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
+ # fix numerical error for 0 * log(0), substitute all nan with 0
+ mask = repr_matrix != repr_matrix
+ repr_matrix[mask] = 0
+ return repr_matrix
+
+
+# output_ctrl_pts are specified, according to our task.
+def build_output_control_points(num_control_points, margins):
+ margin_x, margin_y = margins
+ num_ctrl_pts_per_side = num_control_points // 2
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ # ctrl_pts_top = ctrl_pts_top[1:-1,:]
+ # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
+ output_ctrl_pts_arr = np.concatenate(
+ [ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
+ return output_ctrl_pts
+
+
+class TPSSpatialTransformer(nn.Layer):
+ def __init__(self,
+ output_image_size=None,
+ num_control_points=None,
+ margins=None):
+ super(TPSSpatialTransformer, self).__init__()
+ self.output_image_size = output_image_size
+ self.num_control_points = num_control_points
+ self.margins = margins
+
+ self.target_height, self.target_width = output_image_size
+ target_control_points = build_output_control_points(num_control_points,
+ margins)
+ N = num_control_points
+ # N = N - 4
+
+ # create padded kernel matrix
+ forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
+ target_control_partial_repr = compute_partial_repr(
+ target_control_points, target_control_points)
+ target_control_partial_repr = paddle.cast(target_control_partial_repr,
+ forward_kernel.dtype)
+ forward_kernel[:N, :N] = target_control_partial_repr
+ forward_kernel[:N, -3] = 1
+ forward_kernel[-3, :N] = 1
+ target_control_points = paddle.cast(target_control_points,
+ forward_kernel.dtype)
+ forward_kernel[:N, -2:] = target_control_points
+ forward_kernel[-2:, :N] = paddle.transpose(
+ target_control_points, perm=[1, 0])
+ # compute inverse matrix
+ inverse_kernel = paddle.inverse(forward_kernel)
+
+ # create target cordinate matrix
+ HW = self.target_height * self.target_width
+ target_coordinate = list(
+ itertools.product(
+ range(self.target_height), range(self.target_width)))
+ target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
+ Y, X = paddle.split(
+ target_coordinate, target_coordinate.shape[1], axis=1)
+ #Y, X = target_coordinate.split(1, dim = 1)
+ Y = Y / (self.target_height - 1)
+ X = X / (self.target_width - 1)
+ target_coordinate = paddle.concat(
+ [X, Y], axis=1) # convert from (y, x) to (x, y)
+ target_coordinate_partial_repr = compute_partial_repr(
+ target_coordinate, target_control_points)
+ target_coordinate_repr = paddle.concat(
+ [
+ target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
+ target_coordinate
+ ],
+ axis=1)
+
+ # register precomputed matrices
+ self.inverse_kernel = inverse_kernel
+ self.padding_matrix = paddle.zeros(shape=[3, 2])
+ self.target_coordinate_repr = target_coordinate_repr
+ self.target_control_points = target_control_points
+
+ def forward(self, input, source_control_points):
+ assert source_control_points.ndimension() == 3
+ assert source_control_points.shape[1] == self.num_control_points
+ assert source_control_points.shape[2] == 2
+ batch_size = source_control_points.shape[0]
+
+ self.padding_matrix = paddle.expand(
+ self.padding_matrix, shape=[batch_size, 3, 2])
+ Y = paddle.concat([source_control_points, self.padding_matrix], 1)
+ mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
+ source_coordinate = paddle.matmul(self.target_coordinate_repr,
+ mapping_matrix)
+
+ grid = paddle.reshape(
+ source_coordinate,
+ shape=[-1, self.target_height, self.target_width, 2])
+ grid = paddle.clip(grid, 0,
+ 1) # the source_control_points may be out of [0, 1].
+ # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
+ # grid = 2.0 * grid - 1.0
+ output_maps = grid_sample(input, grid, canvas=None)
+ return output_maps, source_coordinate
+
+
+if __name__ == "__main__":
+ from stn import STN
+ in_planes = 3
+ num_ctrlpoints = 20
+ np.random.seed(100)
+ activation = 'none' # 'sigmoid'
+ stn_head = STN(in_planes, num_ctrlpoints, activation)
+ data = np.random.randn(10, 3, 32, 64).astype("float32")
+ input = paddle.to_tensor(data)
+ #input = paddle.randn([10, 3, 32, 64])
+ control_points = stn_head(input)
+ #print("control points:", control_points)
+ #input = paddle.randn(shape=[10,3,32,100])
+ tps = TPSSpatialTransformer(
+ output_image_size=[32, 320],
+ num_control_points=20,
+ margins=[0.05, 0.05])
+ out = tps(input, control_points[1])
+ print("out 0 :", out[0].shape)
+ print("out 1:", out[1].shape)
diff --git a/ppocr/modeling/transforms/tps_torch.py b/ppocr/modeling/transforms/tps_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aee133ae3b7fc64316b8cef6eca682655d09867
--- /dev/null
+++ b/ppocr/modeling/transforms/tps_torch.py
@@ -0,0 +1,149 @@
+from __future__ import absolute_import
+
+import numpy as np
+import itertools
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def grid_sample(input, grid, canvas=None):
+ output = F.grid_sample(input, grid)
+ if canvas is None:
+ return output
+ else:
+ input_mask = input.data.new(input.size()).fill_(1)
+ output_mask = F.grid_sample(input_mask, grid)
+ padded_output = output * output_mask + canvas * (1 - output_mask)
+ return padded_output
+
+
+# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
+def compute_partial_repr(input_points, control_points):
+ N = input_points.size(0)
+ M = control_points.size(0)
+ pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
+ # original implementation, very slow
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
+ pairwise_diff_square = pairwise_diff * pairwise_diff
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
+ 1]
+ repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
+ # fix numerical error for 0 * log(0), substitute all nan with 0
+ mask = repr_matrix != repr_matrix
+ repr_matrix.masked_fill_(mask, 0)
+ return repr_matrix
+
+
+# output_ctrl_pts are specified, according to our task.
+def build_output_control_points(num_control_points, margins):
+ margin_x, margin_y = margins
+ num_ctrl_pts_per_side = num_control_points // 2
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ # ctrl_pts_top = ctrl_pts_top[1:-1,:]
+ # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
+ output_ctrl_pts_arr = np.concatenate(
+ [ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
+ return output_ctrl_pts
+
+
+# demo: ~/test/models/test_tps_transformation.py
+class TPSSpatialTransformer(nn.Module):
+ def __init__(self,
+ output_image_size=None,
+ num_control_points=None,
+ margins=None):
+ super(TPSSpatialTransformer, self).__init__()
+ self.output_image_size = output_image_size
+ self.num_control_points = num_control_points
+ self.margins = margins
+
+ self.target_height, self.target_width = output_image_size
+ target_control_points = build_output_control_points(num_control_points,
+ margins)
+ N = num_control_points
+ # N = N - 4
+
+ # create padded kernel matrix
+ forward_kernel = torch.zeros(N + 3, N + 3)
+ target_control_partial_repr = compute_partial_repr(
+ target_control_points, target_control_points)
+ forward_kernel[:N, :N].copy_(target_control_partial_repr)
+ forward_kernel[:N, -3].fill_(1)
+ forward_kernel[-3, :N].fill_(1)
+ forward_kernel[:N, -2:].copy_(target_control_points)
+ forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
+ # compute inverse matrix
+ inverse_kernel = torch.inverse(forward_kernel)
+
+ # create target cordinate matrix
+ HW = self.target_height * self.target_width
+ target_coordinate = list(
+ itertools.product(
+ range(self.target_height), range(self.target_width)))
+ target_coordinate = torch.Tensor(target_coordinate) # HW x 2
+ Y, X = target_coordinate.split(1, dim=1)
+ Y = Y / (self.target_height - 1)
+ X = X / (self.target_width - 1)
+ target_coordinate = torch.cat([X, Y],
+ dim=1) # convert from (y, x) to (x, y)
+ target_coordinate_partial_repr = compute_partial_repr(
+ target_coordinate, target_control_points)
+ target_coordinate_repr = torch.cat([
+ target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
+ ],
+ dim=1)
+
+ # register precomputed matrices
+ self.register_buffer('inverse_kernel', inverse_kernel)
+ self.register_buffer('padding_matrix', torch.zeros(3, 2))
+ self.register_buffer('target_coordinate_repr', target_coordinate_repr)
+ self.register_buffer('target_control_points', target_control_points)
+
+ def forward(self, input, source_control_points):
+ assert source_control_points.ndimension() == 3
+ assert source_control_points.size(1) == self.num_control_points
+ assert source_control_points.size(2) == 2
+ batch_size = source_control_points.size(0)
+
+ Y = torch.cat([
+ source_control_points, self.padding_matrix.expand(batch_size, 3, 2)
+ ], 1)
+ mapping_matrix = torch.matmul(self.inverse_kernel, Y)
+ source_coordinate = torch.matmul(self.target_coordinate_repr,
+ mapping_matrix)
+
+ grid = source_coordinate.view(-1, self.target_height, self.target_width,
+ 2)
+ grid = torch.clamp(grid, 0,
+ 1) # the source_control_points may be out of [0, 1].
+ # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
+ grid = 2.0 * grid - 1.0
+ output_maps = grid_sample(input, grid, canvas=None)
+ return output_maps, source_coordinate
+
+
+if __name__ == "__main__":
+ from stn_torch import STNHead
+ in_planes = 3
+ num_ctrlpoints = 20
+ torch.manual_seed(10)
+ activation = 'none' # 'sigmoid'
+ stn_head = STNHead(in_planes, num_ctrlpoints, activation)
+ np.random.seed(100)
+ data = np.random.randn(10, 3, 32, 64).astype("float32")
+ input = torch.tensor(data)
+ control_points = stn_head(input)
+ tps = TPSSpatialTransformer(
+ output_image_size=[32, 320],
+ num_control_points=20,
+ margins=[0.05, 0.05])
+ out = tps(input, control_points[1])
+ print("out 0 :", out[0].shape)
+ print("out 1:", out[1].shape)
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 8426bcf2b9a71e0293d912e25f1b617fd18c59fc..17fc7e461c10b0320115a15c305d880604229725 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -170,8 +170,10 @@ class AttnLabelDecode(BaseRecLabelDecode):
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
+ self.unkonwn = "UNKNOWN"
dict_character = dict_character
- dict_character = [self.beg_str] + dict_character + [self.end_str]
+ dict_character = [self.beg_str] + dict_character + [self.end_str
+ ] + [self.unkonwn]
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
@@ -212,6 +214,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
+ preds = preds["rec_pred"]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
@@ -324,10 +327,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
class TableLabelDecode(object):
""" """
- def __init__(self,
- character_dict_path,
- **kwargs):
- list_character, list_elem = self.load_char_elem_dict(character_dict_path)
+ def __init__(self, character_dict_path, **kwargs):
+ list_character, list_elem = self.load_char_elem_dict(
+ character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
@@ -366,14 +368,14 @@ class TableLabelDecode(object):
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
- if isinstance(structure_probs,paddle.Tensor):
+ if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
- if isinstance(loc_preds,paddle.Tensor):
+ if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
- structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
- structure_probs, 'elem')
+ structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
+ structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
@@ -388,8 +390,13 @@ class TableLabelDecode(object):
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
- return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
- 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
+ return {
+ 'res_html_code': res_html_code_list,
+ 'res_loc': res_loc_list,
+ 'res_score_list': result_score_list,
+ 'res_elem_idx_list': result_elem_idx_list,
+ 'structure_str_list': structure_str
+ }
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index 1d760e983a635dcc6b48b839ee99434c67b4378d..0453509c7bc2fb9a641d74820cf7449079b879a9 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -105,13 +105,16 @@ def load_dygraph_params(config, model, logger, optimizer):
params = paddle.load(pm)
state_dict = model.state_dict()
new_state_dict = {}
- for k1, k2 in zip(state_dict.keys(), params.keys()):
- if list(state_dict[k1].shape) == list(params[k2].shape):
- new_state_dict[k1] = params[k2]
- else:
- logger.info(
- f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
- )
+ # for k1, k2 in zip(state_dict.keys(), params.keys()):
+ for k1 in state_dict.keys():
+ if k1 not in params:
+ continue
+ if list(state_dict[k1].shape) == list(params[k1].shape):
+ new_state_dict[k1] = params[k1]
+ else:
+ logger.info(
+ f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !"
+ )
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
diff --git a/tools/program.py b/tools/program.py
index 2d99f2968a3f0c8acc359ed0fbb199650bd7010c..920cf417ce1768de6061744ac34bdf68a15d18c9 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -187,6 +187,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type']
+ algorithm = config['Architecture']['algorithm']
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
@@ -210,10 +211,14 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
- if use_srn or model_type == 'table':
- preds = model(images, data=batch[1:])
- else:
- preds = model(images)
+ # if use_srn or model_type == 'table' or algorithm == "ASTER":
+ # preds = model(images, data=batch[1:])
+ # else:
+ # preds = model(images)
+ preds = model(images, data=batch[1:])
+ state_dict = model.state_dict()
+ # for key in state_dict:
+ # print(key)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
avg_loss.backward()
@@ -395,7 +400,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
- 'CLS', 'PGNet', 'Distillation', 'TableAttn'
+ 'CLS', 'PGNet', 'Distillation', 'TableAttn', 'ASTER'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
diff --git a/tools/train.py b/tools/train.py
index 20f5a670d5c8e666678259e0042b3b790e528590..e1515f57c9402adf1f7d00c03aee78ea20c2a371 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -72,6 +72,8 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
+ character = getattr(post_process_class, 'character')
+ print("getattr character:", character)
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
|