From c9e1077daac3efb2e5c42ebf879aa363d4c59db4 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Mon, 30 Aug 2021 06:32:54 +0000 Subject: [PATCH] polish code --- configs/rec/rec_resnet_stn_bilstm_att.yml | 65 +- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 38 +- ppocr/data/imaug/operators.py | 16 +- ppocr/data/imaug/rec_img_aug.py | 23 + ppocr/data/simple_dataset.py | 1 - ppocr/losses/rec_aster_loss.py | 55 +- ppocr/losses/rec_att_loss.py | 2 - ppocr/metrics/rec_metric.py | 12 +- ppocr/modeling/backbones/__init__.py | 7 +- ppocr/modeling/backbones/levit.py | 707 ------------------ ppocr/modeling/heads/__init__.py | 1 - ppocr/modeling/heads/rec_aster_head.py | 208 +++++- ppocr/modeling/heads/rec_att_head.py | 5 - ppocr/modeling/transforms/stn.py | 13 - ppocr/modeling/transforms/tps.py | 1 + .../transforms/tps_spatial_transformer.py | 27 +- ppocr/modeling/transforms/tps_torch.py | 149 ---- ppocr/optimizer/optimizer.py | 31 + ppocr/postprocess/__init__.py | 4 +- ppocr/postprocess/rec_postprocess.py | 87 ++- ppocr/utils/save_load.py | 17 +- tools/program.py | 10 +- 23 files changed, 461 insertions(+), 1020 deletions(-) delete mode 100644 ppocr/modeling/backbones/levit.py delete mode 100644 ppocr/modeling/transforms/tps_torch.py diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml index f705f1e2..7b5a9c71 100644 --- a/configs/rec/rec_resnet_stn_bilstm_att.yml +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -1,9 +1,9 @@ Global: - use_gpu: False + use_gpu: True epoch_num: 400 log_smooth_window: 20 print_batch_step: 10 - save_model_dir: ./output/rec/b3_rare_r34_none_gru/ + save_model_dir: ./output/rec/seed save_epoch_step: 3 # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: [0, 2000] @@ -12,28 +12,32 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: character_type: EN_symbol - max_text_length: 25 + max_text_length: 100 infer_mode: False use_space_char: False - save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt + eval_filter: True + save_res_path: ./output/rec/predicts_seed.txt Optimizer: - name: Adam - beta1: 0.9 - beta2: 0.999 + name: Adadelta + weight_deacy: 0.0 + momentum: 0.9 lr: - learning_rate: 0.0005 + name: Piecewise + decay_epochs: [4,5,8] + values: [1.0, 0.1, 0.01] regularizer: name: 'L2' - factor: 0.00000 + factor: 2.0e-05 + Architecture: - model_type: rec + model_type: seed algorithm: ASTER Transform: name: STN_ON @@ -54,48 +58,49 @@ Loss: name: AsterLoss PostProcess: - name: AttnLabelDecode + name: SEEDLabelDecode Metric: name: RecMetric main_indicator: acc + is_filter: True Train: dataset: - name: SimpleDataSet - data_dir: ./train_data/ic15_data/ - label_file_list: ["./train_data/ic15_data/1.txt"] + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ transforms: + - Fasttext: + path: "./cc.en.300.bin" - DecodeImage: # load image img_mode: BGR channel_first: False - - AttnLabelEncode: # Class handling label - - RecResizeImg: - image_shape: [3, 32, 100] + - SEEDLabelEncode: # Class handling label + - SEEDResize: + image_shape: [3, 64, 256] - KeepKeys: - keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order loader: shuffle: True - batch_size_per_card: 2 + batch_size_per_card: 256 drop_last: True - num_workers: 8 + num_workers: 6 Eval: dataset: - name: SimpleDataSet - data_dir: ./train_data/ic15_data/ - label_file_list: ["./train_data/ic15_data/1.txt"] + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ transforms: - DecodeImage: # load image img_mode: BGR channel_first: False - - AttnLabelEncode: # Class handling label - - RecResizeImg: - image_shape: [3, 32, 100] + - SEEDLabelEncode: # Class handling label + - SEEDResize: + image_shape: [3, 64, 256] - KeepKeys: keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order loader: shuffle: False - drop_last: False - batch_size_per_card: 2 - num_workers: 8 + drop_last: True + batch_size_per_card: 256 + num_workers: 4 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 52194eb9..7a792c2f 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, SEEDResize from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 0e1d4939..21d91030 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -276,9 +276,7 @@ class AttnLabelEncode(BaseRecLabelEncode): def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" - self.unknown = "UNKNOWN" - dict_character = [self.beg_str] + dict_character + [self.end_str - ] + [self.unknown] + dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character def __call__(self, data): @@ -291,7 +289,6 @@ class AttnLabelEncode(BaseRecLabelEncode): data['length'] = np.array(len(text)) text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len - len(text) - 2) - data['label'] = np.array(text) return data @@ -311,6 +308,39 @@ class AttnLabelEncode(BaseRecLabelEncode): return idx +class SEEDLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SEEDLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + [self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data['length'] = np.array(len(text)) + 1 # conclue eos + text = text + [len(self.character) - 1] * (self.max_text_len - len(text) + ) + data['label'] = np.array(text) + return data + + class SRNLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 2535b442..ba5f01b4 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -23,6 +23,7 @@ import sys import six import cv2 import numpy as np +import fasttext class DecodeImage(object): @@ -81,7 +82,7 @@ class NormalizeImage(object): assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" data['image'] = ( - img.astype('float32') * self.scale - self.mean) / self.std + img.astype('float32') * self.scale - self.mean) / self.std return data @@ -101,6 +102,17 @@ class ToCHWImage(object): return data +class Fasttext(object): + def __init__(self, path="None", **kwargs): + self.fast_model = fasttext.load_model(path) + + def __call__(self, data): + label = data['label'] + fast_label = self.fast_model[label] + data['fast_label'] = fast_label + return data + + class KeepKeys(object): def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys @@ -183,7 +195,7 @@ class DetResizeForTest(object): else: ratio = 1. elif self.limit_type == 'resize_long': - ratio = float(limit_side_len) / max(h,w) + ratio = float(limit_side_len) / max(h, w) else: raise Exception('not support limit type, image ') resize_h = int(h * ratio) diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 28e6bd0b..ed5b7a52 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -63,6 +63,18 @@ class RecResizeImg(object): return data +class SEEDResize(object): + def __init__(self, image_shape, infer_mode=False, **kwargs): + self.image_shape = image_shape + self.infer_mode = infer_mode + + def __call__(self, data): + img = data['image'] + norm_img = resize_no_padding_img(img, self.image_shape) + data['image'] = norm_img + return data + + class SRNRecResizeImg(object): def __init__(self, image_shape, num_heads, max_text_length, **kwargs): self.image_shape = image_shape @@ -106,6 +118,17 @@ def resize_norm_img(img, image_shape): return padding_im +def resize_no_padding_img(img, image_shape): + imgC, imgH, imgW = image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image + + def resize_norm_img_chinese(img, image_shape): imgC, imgH, imgW = image_shape # todo: change to 0 and modified image shape diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index b519f4fd..ce9e1b38 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -22,7 +22,6 @@ from .imaug import transform, create_operators class SimpleDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): - print("===== simpledataset ========") super(SimpleDataSet, self).__init__() self.logger = logger self.mode = mode.lower() diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py index 858fadc0..d900617f 100644 --- a/ppocr/losses/rec_aster_loss.py +++ b/ppocr/losses/rec_aster_loss.py @@ -18,7 +18,26 @@ from __future__ import print_function import paddle from paddle import nn -import fasttext + + +class CosineEmbeddingLoss(nn.Layer): + def __init__(self, margin=0.): + super(CosineEmbeddingLoss, self).__init__() + self.margin = margin + + def forward(self, x1, x2, target): + similarity = paddle.fluid.layers.reduce_sum( + x1 * x2, dim=-1) / (paddle.norm( + x1, axis=-1) * paddle.norm( + x2, axis=-1)) + one_list = paddle.full_like(target, fill_value=1) + out = paddle.fluid.layers.reduce_mean( + paddle.where( + paddle.equal(target, one_list), 1. - similarity, + paddle.maximum( + paddle.zeros_like(similarity), similarity - self.margin))) + + return out class AsterLoss(nn.Layer): @@ -35,28 +54,28 @@ class AsterLoss(nn.Layer): self.ignore_index = ignore_index self.sequence_normalize = sequence_normalize self.sample_normalize = sample_normalize - self.loss_func = paddle.nn.CosineSimilarity() + self.loss_sem = CosineEmbeddingLoss() + self.is_cosin_loss = True + self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none') def forward(self, predicts, batch): targets = batch[1].astype("int64") label_lengths = batch[2].astype('int64') - # sem_target = batch[3].astype('float32') + sem_target = batch[3].astype('float32') embedding_vectors = predicts['embedding_vectors'] rec_pred = predicts['rec_pred'] - # semantic loss - # print(embedding_vectors) - # print(embedding_vectors.shape) - # targets = fasttext[targets] - # sem_loss = 1 - self.loss_func(embedding_vectors, targets) + if not self.is_cosin_loss: + sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target)) + else: + label_target = paddle.ones([embedding_vectors.shape[0]]) + sem_loss = paddle.sum( + self.loss_sem(embedding_vectors, sem_target, label_target)) # rec loss - batch_size, num_steps, num_classes = rec_pred.shape[0], rec_pred.shape[ - 1], rec_pred.shape[2] - assert len(targets.shape) == len(list(rec_pred.shape)) - 1, \ - "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + batch_size, def_max_length = targets.shape[0], targets.shape[1] - mask = paddle.zeros([batch_size, num_steps]) + mask = paddle.zeros([batch_size, def_max_length]) for i in range(batch_size): mask[i, :label_lengths[i]] = 1 mask = paddle.cast(mask, "float32") @@ -64,16 +83,16 @@ class AsterLoss(nn.Layer): assert max_length == rec_pred.shape[1] targets = targets[:, :max_length] mask = mask[:, :max_length] - rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[-1]]) + rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]]) input = nn.functional.log_softmax(rec_pred, axis=1) targets = paddle.reshape(targets, [-1, 1]) mask = paddle.reshape(mask, [-1, 1]) - # print("input:", input) - output = -paddle.gather(input, index=targets, axis=1) * mask + output = -paddle.index_sample(input, index=targets) * mask output = paddle.sum(output) if self.sequence_normalize: output = output / paddle.sum(mask) if self.sample_normalize: output = output / batch_size - loss = output - return {'loss': loss} # , 'sem_loss':sem_loss} + + loss = output + sem_loss * 0.1 + return {'loss': loss} diff --git a/ppocr/losses/rec_att_loss.py b/ppocr/losses/rec_att_loss.py index 2d8d64b9..6e2f6748 100644 --- a/ppocr/losses/rec_att_loss.py +++ b/ppocr/losses/rec_att_loss.py @@ -35,7 +35,5 @@ class AttentionLoss(nn.Layer): inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]]) targets = paddle.reshape(targets, [-1]) - print("input:", paddle.argmax(inputs, axis=1)) - print("targets:", targets) return {'loss': paddle.sum(self.loss_func(inputs, targets))} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 66c084d7..db2f41c3 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -13,13 +13,20 @@ # limitations under the License. import Levenshtein +import string class RecMetric(object): - def __init__(self, main_indicator='acc', **kwargs): + def __init__(self, main_indicator='acc', is_filter=False, **kwargs): self.main_indicator = main_indicator + self.is_filter = is_filter self.reset() + def _normalize_text(self, text): + text = ''.join( + filter(lambda x: x in (string.digits + string.ascii_letters), text)) + return text.lower() + def __call__(self, pred_label, *args, **kwargs): preds, labels = pred_label correct_num = 0 @@ -28,6 +35,9 @@ class RecMetric(object): for (pred, pred_conf), (target, _) in zip(preds, labels): pred = pred.replace(" ", "") target = target.replace(" ", "") + if self.is_filter: + pred = self._normalize_text(pred) + target = self._normalize_text(target) norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target), 1) if pred == target: diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index e0bc45b4..25cedb16 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -26,10 +26,8 @@ def build_backbone(config, model_type): from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN from .rec_mv1_enhance import MobileNetV1Enhance - from .rec_resnet_aster import ResNet_ASTER support_dict = [ - "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN", - "ResNet_ASTER" + "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet @@ -38,6 +36,9 @@ def build_backbone(config, model_type): from .table_resnet_vd import ResNet from .table_mobilenet_v3 import MobileNetV3 support_dict = ["ResNet", "MobileNetV3"] + elif model_type == "seed": + from .rec_resnet_aster import ResNet_ASTER + support_dict = ["ResNet_ASTER"] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/levit.py b/ppocr/modeling/backbones/levit.py deleted file mode 100644 index 8b04e9de..00000000 --- a/ppocr/modeling/backbones/levit.py +++ /dev/null @@ -1,707 +0,0 @@ -# Copyright (c) 2015-present, Facebook, Inc. -# All rights reserved. - -# Modified from -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -# Copyright 2020 Ross Wightman, Apache-2.0 License - -import paddle -import itertools -#import utils -import math -import warnings -import paddle.nn.functional as F -from paddle.nn.initializer import TruncatedNormal, Constant - -#from timm.models.vision_transformer import trunc_normal_ -#from timm.models.registry import register_model - -specification = { - 'LeViT_128S': { - 'C': '128_256_384', - 'D': 16, - 'N': '4_6_8', - 'X': '2_3_4', - 'drop_path': 0, - 'weights': - 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' - }, - 'LeViT_128': { - 'C': '128_256_384', - 'D': 16, - 'N': '4_8_12', - 'X': '4_4_4', - 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' - }, - 'LeViT_192': { - 'C': '192_288_384', - 'D': 32, - 'N': '3_5_6', - 'X': '4_4_4', - 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' - }, - 'LeViT_256': { - 'C': '256_384_512', - 'D': 32, - 'N': '4_6_8', - 'X': '4_4_4', - 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' - }, - 'LeViT_384': { - 'C': '384_512_768', - 'D': 32, - 'N': '6_9_12', - 'X': '4_4_4', - 'drop_path': 0.1, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' - }, -} - -__all__ = [specification.keys()] - -trunc_normal_ = TruncatedNormal(std=.02) -zeros_ = Constant(value=0.) -ones_ = Constant(value=1.) - - -#@register_model -def LeViT_128S(class_dim=1000, distillation=True, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_128S'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -#@register_model -def LeViT_128(class_dim=1000, distillation=True, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_128'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -#@register_model -def LeViT_192(class_dim=1000, distillation=True, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_192'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -#@register_model -def LeViT_256(class_dim=1000, distillation=False, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_256'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -#@register_model -def LeViT_384(class_dim=1000, distillation=True, pretrained=False, fuse=False): - return model_factory( - **specification['LeViT_384'], - class_dim=class_dim, - distillation=distillation, - pretrained=pretrained, - fuse=fuse) - - -FLOPS_COUNTER = 0 - - -class Conv2d_BN(paddle.nn.Sequential): - def __init__(self, - a, - b, - ks=1, - stride=1, - pad=0, - dilation=1, - groups=1, - bn_weight_init=1, - resolution=-10000): - super().__init__() - self.add_sublayer( - 'c', - paddle.nn.Conv2D( - a, b, ks, stride, pad, dilation, groups, bias_attr=False)) - bn = paddle.nn.BatchNorm2D(b) - ones_(bn.weight) - zeros_(bn.bias) - self.add_sublayer('bn', bn) - - global FLOPS_COUNTER - output_points = ( - (resolution + 2 * pad - dilation * (ks - 1) - 1) // stride + 1)**2 - FLOPS_COUNTER += a * b * output_points * (ks**2) - - @paddle.no_grad() - def fuse(self): - c, bn = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps)**0.5 - w = c.weight * w[:, None, None, None] - b = bn.bias - bn.running_mean * bn.weight / \ - (bn.running_var + bn.eps)**0.5 - m = paddle.nn.Conv2D( - w.size(1), - w.size(0), - w.shape[2:], - stride=self.c.stride, - padding=self.c.padding, - dilation=self.c.dilation, - groups=self.c.groups) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -class Linear_BN(paddle.nn.Sequential): - def __init__(self, a, b, bn_weight_init=1, resolution=-100000): - super().__init__() - self.add_sublayer('c', paddle.nn.Linear(a, b, bias_attr=False)) - bn = paddle.nn.BatchNorm1D(b) - ones_(bn.weight) - zeros_(bn.bias) - self.add_sublayer('bn', bn) - - global FLOPS_COUNTER - output_points = resolution**2 - FLOPS_COUNTER += a * b * output_points - - @paddle.no_grad() - def fuse(self): - l, bn = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps)**0.5 - w = l.weight * w[:, None] - b = bn.bias - bn.running_mean * bn.weight / \ - (bn.running_var + bn.eps)**0.5 - m = paddle.nn.Linear(w.size(1), w.size(0)) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - def forward(self, x): - l, bn = self._sub_layers.values() - x = l(x) - return paddle.reshape(bn(x.flatten(0, 1)), x.shape) - - -class BN_Linear(paddle.nn.Sequential): - def __init__(self, a, b, bias=True, std=0.02): - super().__init__() - self.add_sublayer('bn', paddle.nn.BatchNorm1D(a)) - l = paddle.nn.Linear(a, b, bias_attr=bias) - trunc_normal_(l.weight) - if bias: - zeros_(l.bias) - self.add_sublayer('l', l) - global FLOPS_COUNTER - FLOPS_COUNTER += a * b - - @paddle.no_grad() - def fuse(self): - bn, l = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps)**0.5 - b = bn.bias - self.bn.running_mean * \ - self.bn.weight / (bn.running_var + bn.eps)**0.5 - w = l.weight * w[None, :] - if l.bias is None: - b = b @self.l.weight.T - else: - b = (l.weight @b[:, None]).view(-1) + self.l.bias - m = paddle.nn.Linear(w.size(1), w.size(0)) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -def b16(n, activation, resolution=224): - return paddle.nn.Sequential( - Conv2d_BN( - 3, n // 8, 3, 2, 1, resolution=resolution), - activation(), - Conv2d_BN( - n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), - activation(), - Conv2d_BN( - n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), - activation(), - Conv2d_BN( - n // 2, n, 3, 2, 1, resolution=resolution // 8)) - - -class Residual(paddle.nn.Layer): - def __init__(self, m, drop): - super().__init__() - self.m = m - self.drop = drop - - def forward(self, x): - if self.training and self.drop > 0: - return x + self.m(x) * paddle.rand( - x.size(0), 1, 1, - device=x.device).ge_(self.drop).div(1 - self.drop).detach() - else: - return x + self.m(x) - - -class Attention(paddle.nn.Layer): - def __init__(self, - dim, - key_dim, - num_heads=8, - attn_ratio=4, - activation=None, - resolution=14): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim**-0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * num_heads - self.attn_ratio = attn_ratio - self.h = self.dh + nh_kd * 2 - self.qkv = Linear_BN(dim, self.h, resolution=resolution) - self.proj = paddle.nn.Sequential( - activation(), - Linear_BN( - self.dh, dim, bn_weight_init=0, resolution=resolution)) - points = list(itertools.product(range(resolution), range(resolution))) - N = len(points) - attention_offsets = {} - idxs = [] - for p1 in points: - for p2 in points: - offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = self.create_parameter( - shape=(num_heads, len(attention_offsets)), - default_initializer=zeros_) - tensor_idxs = paddle.to_tensor(idxs, dtype='int64') - self.register_buffer('attention_bias_idxs', - paddle.reshape(tensor_idxs, [N, N])) - - global FLOPS_COUNTER - #queries * keys - FLOPS_COUNTER += num_heads * (resolution**4) * key_dim - # softmax - FLOPS_COUNTER += num_heads * (resolution**4) - #attention * v - FLOPS_COUNTER += num_heads * self.d * (resolution**4) - - @paddle.no_grad() - def train(self, mode=True): - if mode: - super().train() - else: - super().eval() - if mode and hasattr(self, 'ab'): - del self.ab - else: - gather_list = [] - attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) - for idx in self.attention_bias_idxs: - gather = paddle.gather(attention_bias_t, idx) - gather_list.append(gather) - attention_biases = paddle.transpose( - paddle.concat(gather_list), (1, 0)).reshape( - (0, self.attention_bias_idxs.shape[0], - self.attention_bias_idxs.shape[1])) - self.ab = attention_biases - #self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): # x (B,N,C) - self.training = True - B, N, C = x.shape - qkv = self.qkv(x) - qkv = paddle.reshape(qkv, - [B, N, self.num_heads, self.h // self.num_heads]) - q, k, v = paddle.split( - qkv, [self.key_dim, self.key_dim, self.d], axis=3) - q = paddle.transpose(q, perm=[0, 2, 1, 3]) - k = paddle.transpose(k, perm=[0, 2, 1, 3]) - v = paddle.transpose(v, perm=[0, 2, 1, 3]) - k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2]) - - if self.training: - gather_list = [] - attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) - for idx in self.attention_bias_idxs: - gather = paddle.gather(attention_bias_t, idx) - gather_list.append(gather) - attention_biases = paddle.transpose( - paddle.concat(gather_list), (1, 0)).reshape( - (0, self.attention_bias_idxs.shape[0], - self.attention_bias_idxs.shape[1])) - else: - attention_biases = self.ab - #np_ = paddle.to_tensor(self.attention_biases.numpy()[:, self.attention_bias_idxs.numpy()]) - #print(self.attention_bias_idxs.shape) - #print(attention_biases.shape) - #print(np_.shape) - #print(np_.equal(attention_biases)) - #exit() - - attn = ((q @k_transpose) * self.scale + attention_biases) - attn = F.softmax(attn) - x = paddle.transpose(attn @v, perm=[0, 2, 1, 3]) - x = paddle.reshape(x, [B, N, self.dh]) - x = self.proj(x) - return x - - -class Subsample(paddle.nn.Layer): - def __init__(self, stride, resolution): - super().__init__() - self.stride = stride - self.resolution = resolution - - def forward(self, x): - B, N, C = x.shape - x = paddle.reshape(x, [B, self.resolution, self.resolution, - C])[:, ::self.stride, ::self.stride] - x = paddle.reshape(x, [B, -1, C]) - return x - - -class AttentionSubsample(paddle.nn.Layer): - def __init__(self, - in_dim, - out_dim, - key_dim, - num_heads=8, - attn_ratio=2, - activation=None, - stride=2, - resolution=14, - resolution_=7): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim**-0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * self.num_heads - self.attn_ratio = attn_ratio - self.resolution_ = resolution_ - self.resolution_2 = resolution_**2 - self.training = True - h = self.dh + nh_kd - self.kv = Linear_BN(in_dim, h, resolution=resolution) - - self.q = paddle.nn.Sequential( - Subsample(stride, resolution), - Linear_BN( - in_dim, nh_kd, resolution=resolution_)) - self.proj = paddle.nn.Sequential( - activation(), Linear_BN( - self.dh, out_dim, resolution=resolution_)) - - self.stride = stride - self.resolution = resolution - points = list(itertools.product(range(resolution), range(resolution))) - points_ = list( - itertools.product(range(resolution_), range(resolution_))) - - N = len(points) - N_ = len(points_) - attention_offsets = {} - idxs = [] - i = 0 - j = 0 - for p1 in points_: - i += 1 - for p2 in points: - j += 1 - size = 1 - offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), - abs(p1[1] * stride - p2[1] + (size - 1) / 2)) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = self.create_parameter( - shape=(num_heads, len(attention_offsets)), - default_initializer=zeros_) - - tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64') - self.register_buffer('attention_bias_idxs', - paddle.reshape(tensor_idxs_, [N_, N])) - - global FLOPS_COUNTER - #queries * keys - FLOPS_COUNTER += num_heads * \ - (resolution**2) * (resolution_**2) * key_dim - # softmax - FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2) - #attention * v - FLOPS_COUNTER += num_heads * \ - (resolution**2) * (resolution_**2) * self.d - - @paddle.no_grad() - def train(self, mode=True): - if mode: - super().train() - else: - super().eval() - if mode and hasattr(self, 'ab'): - del self.ab - else: - gather_list = [] - attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) - for idx in self.attention_bias_idxs: - gather = paddle.gather(attention_bias_t, idx) - gather_list.append(gather) - attention_biases = paddle.transpose( - paddle.concat(gather_list), (1, 0)).reshape( - (0, self.attention_bias_idxs.shape[0], - self.attention_bias_idxs.shape[1])) - self.ab = attention_biases - #self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): - self.training = True - B, N, C = x.shape - kv = self.kv(x) - kv = paddle.reshape(kv, [B, N, self.num_heads, -1]) - k, v = paddle.split(kv, [self.key_dim, self.d], axis=3) - k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC - v = paddle.transpose(v, perm=[0, 2, 1, 3]) - q = paddle.reshape( - self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim]) - q = paddle.transpose(q, perm=[0, 2, 1, 3]) - - if self.training: - gather_list = [] - attention_bias_t = paddle.transpose(self.attention_biases, (1, 0)) - for idx in self.attention_bias_idxs: - gather = paddle.gather(attention_bias_t, idx) - gather_list.append(gather) - attention_biases = paddle.transpose( - paddle.concat(gather_list), (1, 0)).reshape( - (0, self.attention_bias_idxs.shape[0], - self.attention_bias_idxs.shape[1])) - else: - attention_biases = self.ab - - attn = (q @paddle.transpose( - k, perm=[0, 1, 3, 2])) * self.scale + attention_biases - attn = F.softmax(attn) - - x = paddle.reshape( - paddle.transpose( - (attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh]) - x = self.proj(x) - return x - - -class LeViT(paddle.nn.Layer): - """ Vision Transformer with support for patch or hybrid CNN input stage - """ - - def __init__(self, - img_size=224, - patch_size=16, - in_chans=3, - class_dim=1000, - embed_dim=[192], - key_dim=[64], - depth=[12], - num_heads=[3], - attn_ratio=[2], - mlp_ratio=[2], - hybrid_backbone=None, - down_ops=[], - attention_activation=paddle.nn.Hardswish, - mlp_activation=paddle.nn.Hardswish, - distillation=True, - drop_path=0): - super().__init__() - global FLOPS_COUNTER - - self.class_dim = class_dim - self.num_features = embed_dim[-1] - self.embed_dim = embed_dim - self.distillation = distillation - - self.patch_embed = hybrid_backbone - - self.blocks = [] - down_ops.append(['']) - resolution = img_size // patch_size - for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( - zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, - down_ops)): - for _ in range(dpth): - self.blocks.append( - Residual( - Attention( - ed, - kd, - nh, - attn_ratio=ar, - activation=attention_activation, - resolution=resolution, ), - drop_path)) - if mr > 0: - h = int(ed * mr) - self.blocks.append( - Residual( - paddle.nn.Sequential( - Linear_BN( - ed, h, resolution=resolution), - mlp_activation(), - Linear_BN( - h, - ed, - bn_weight_init=0, - resolution=resolution), ), - drop_path)) - if do[0] == 'Subsample': - #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - resolution_ = (resolution - 1) // do[5] + 1 - self.blocks.append( - AttentionSubsample( - *embed_dim[i:i + 2], - key_dim=do[1], - num_heads=do[2], - attn_ratio=do[3], - activation=attention_activation, - stride=do[5], - resolution=resolution, - resolution_=resolution_)) - resolution = resolution_ - if do[4] > 0: # mlp_ratio - h = int(embed_dim[i + 1] * do[4]) - self.blocks.append( - Residual( - paddle.nn.Sequential( - Linear_BN( - embed_dim[i + 1], h, resolution=resolution), - mlp_activation(), - Linear_BN( - h, - embed_dim[i + 1], - bn_weight_init=0, - resolution=resolution), ), - drop_path)) - self.blocks = paddle.nn.Sequential(*self.blocks) - - # Classifier head - self.head = BN_Linear( - embed_dim[-1], class_dim) if class_dim > 0 else paddle.nn.Identity() - if distillation: - self.head_dist = BN_Linear( - embed_dim[-1], - class_dim) if class_dim > 0 else paddle.nn.Identity() - - self.FLOPS = FLOPS_COUNTER - FLOPS_COUNTER = 0 - - def no_weight_decay(self): - return {x for x in self.state_dict().keys() if 'attention_biases' in x} - - def forward(self, x): - x = self.patch_embed(x) - x = x.flatten(2) - x = paddle.transpose(x, perm=[0, 2, 1]) - x = self.blocks(x) - x = x.mean(1) - if self.distillation: - x = self.head(x), self.head_dist(x) - if not self.training: - x = (x[0] + x[1]) / 2 - else: - x = self.head(x) - return x - - -def model_factory(C, D, X, N, drop_path, weights, class_dim, distillation, - pretrained, fuse): - embed_dim = [int(x) for x in C.split('_')] - num_heads = [int(x) for x in N.split('_')] - depth = [int(x) for x in X.split('_')] - act = paddle.nn.Hardswish - model = LeViT( - patch_size=16, - embed_dim=embed_dim, - num_heads=num_heads, - key_dim=[D] * 3, - depth=depth, - attn_ratio=[2, 2, 2], - mlp_ratio=[2, 2, 2], - down_ops=[ - #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ['Subsample', D, embed_dim[0] // D, 4, 2, 2], - ['Subsample', D, embed_dim[1] // D, 4, 2, 2], - ], - attention_activation=act, - mlp_activation=act, - hybrid_backbone=b16(embed_dim[0], activation=act), - class_dim=class_dim, - drop_path=drop_path, - distillation=distillation) - # if pretrained: - # checkpoint = torch.hub.load_state_dict_from_url( - # weights, map_location='cpu') - # model.load_state_dict(checkpoint['model']) - if fuse: - utils.replace_batchnorm(model) - - return model - - -if __name__ == '__main__': - ''' - import torch - checkpoint = torch.load('../LeViT/pretrained256.pth') - torch_dict = checkpoint['net'] - paddle_dict = {} - fc_names = ["c.weight", "l.weight", "qkv.weight", "fc1.weight", "fc2.weight", "downsample.reduction.weight", "head.weight", "attn.proj.weight"] - rename_dict = {"running_mean": "_mean", "running_var": "_variance"} - range_tuple = (0, 502) - idx = 0 - for key in torch_dict: - idx += 1 - weight = torch_dict[key].cpu().numpy() - flag = [i in key for i in fc_names] - if any(flag): - if "emb" not in key: - print("weight {} need to be trans".format(key)) - weight = weight.transpose() - key = key.replace("running_mean", "_mean") - key = key.replace("running_var", "_variance") - paddle_dict[key]=weight - ''' - import numpy as np - net = globals()['LeViT_256'](fuse=False, - pretrained=False, - distillation=False) - load_layer_state_dict = paddle.load( - "./LeViT_256_official_nodistillation_paddle.pdparams") - #net.set_state_dict(paddle_dict) - net.set_state_dict(load_layer_state_dict) - net.eval() - #paddle.save(net.state_dict(), "./LeViT_256_official_paddle.pdparams") - #model = paddle.jit.to_static(net,input_spec=[paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype='float32')]) - #paddle.jit.save(model, "./LeViT_256_official_inference/inference") - #exit() - np.random.seed(123) - img = np.random.rand(1, 3, 224, 224).astype('float32') - img = paddle.to_tensor(img) - outputs = net(img).numpy() - print(outputs[0][:10]) - #print(outputs.shape) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index cd923d78..c04ff81a 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -42,6 +42,5 @@ def build_head(config): module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( support_dict)) - print(config) module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py index 055b1097..ed520669 100644 --- a/ppocr/modeling/heads/rec_aster_head.py +++ b/ppocr/modeling/heads/rec_aster_head.py @@ -43,13 +43,14 @@ class AsterHead(nn.Layer): self.time_step = time_step self.embeder = Embedding(self.time_step, in_channels) self.beam_width = beam_width + self.eos = self.num_classes - 1 def forward(self, x, targets=None, embed=None): return_dict = {} embedding_vectors = self.embeder(x) - rec_targets, rec_lengths = targets if self.training: + rec_targets, rec_lengths, _ = targets rec_pred = self.decoder([x, rec_targets, rec_lengths], embedding_vectors) return_dict['rec_pred'] = rec_pred @@ -104,14 +105,12 @@ class AttentionRecognitionHead(nn.Layer): # Decoder state = self.decoder.get_initial_state(embed) outputs = [] - for i in range(max(lengths)): if i == 0: y_prev = paddle.full( shape=[batch_size], fill_value=self.num_classes) else: y_prev = targets[:, i - 1] - output, state = self.decoder(x, state, y_prev) outputs.append(output) outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1) @@ -142,6 +141,170 @@ class AttentionRecognitionHead(nn.Layer): # return predicted_ids.squeeze(), predicted_scores.squeeze() return predicted_ids, predicted_scores + def beam_search(self, x, beam_width, eos, embed): + def _inflate(tensor, times, dim): + repeat_dims = [1] * tensor.dim() + repeat_dims[dim] = times + output = paddle.tile(tensor, repeat_dims) + return output + + # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py + batch_size, l, d = x.shape + # inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC + x = paddle.tile( + paddle.transpose( + x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1]) + inflated_encoder_feats = paddle.reshape( + paddle.transpose( + x, perm=[1, 0, 2, 3]), [-1, l, d]) + + # Initialize the decoder + state = self.decoder.get_initial_state(embed, tile_times=beam_width) + + pos_index = paddle.reshape( + paddle.arange(batch_size) * beam_width, shape=[-1, 1]) + + # Initialize the scores + sequence_scores = paddle.full( + shape=[batch_size * beam_width, 1], fill_value=-float('Inf')) + index = [i * beam_width for i in range(0, batch_size)] + sequence_scores[index] = 0.0 + + # Initialize the input vector + y_prev = paddle.full( + shape=[batch_size * beam_width], fill_value=self.num_classes) + + # Store decisions for backtracking + stored_scores = list() + stored_predecessors = list() + stored_emitted_symbols = list() + + for i in range(self.max_len_labels): + output, state = self.decoder(inflated_encoder_feats, state, y_prev) + state = paddle.unsqueeze(state, axis=0) + log_softmax_output = paddle.nn.functional.log_softmax( + output, axis=1) + + sequence_scores = _inflate(sequence_scores, self.num_classes, 1) + sequence_scores += log_softmax_output + scores, candidates = paddle.topk( + paddle.reshape(sequence_scores, [batch_size, -1]), + beam_width, + axis=1) + + # Reshape input = (bk, 1) and sequence_scores = (bk, 1) + y_prev = paddle.reshape( + candidates % self.num_classes, shape=[batch_size * beam_width]) + sequence_scores = paddle.reshape( + scores, shape=[batch_size * beam_width, 1]) + + # Update fields for next timestep + pos_index = paddle.expand_as(pos_index, candidates) + predecessors = paddle.cast( + candidates / self.num_classes + pos_index, dtype='int64') + predecessors = paddle.reshape( + predecessors, shape=[batch_size * beam_width, 1]) + state = paddle.index_select( + state, index=predecessors.squeeze(), axis=1) + + # Update sequence socres and erase scores for symbol so that they aren't expanded + stored_scores.append(sequence_scores.clone()) + y_prev = paddle.reshape(y_prev, shape=[-1, 1]) + eos_prev = paddle.full_like(y_prev, fill_value=eos) + mask = eos_prev == y_prev + mask = paddle.nonzero(mask) + if mask.dim() > 0: + sequence_scores = sequence_scores.numpy() + mask = mask.numpy() + sequence_scores[mask] = -float('inf') + sequence_scores = paddle.to_tensor(sequence_scores) + + # Cache results for backtracking + stored_predecessors.append(predecessors) + y_prev = paddle.squeeze(y_prev) + stored_emitted_symbols.append(y_prev) + + # Do backtracking to return the optimal values + #====== backtrak ======# + # Initialize return variables given different types + p = list() + l = [[self.max_len_labels] * beam_width for _ in range(batch_size) + ] # Placeholder for lengths of top-k sequences + + # the last step output of the beams are not sorted + # thus they are sorted here + sorted_score, sorted_idx = paddle.topk( + paddle.reshape( + stored_scores[-1], shape=[batch_size, beam_width]), + beam_width) + + # initialize the sequence scores with the sorted last step beam scores + s = sorted_score.clone() + + batch_eos_found = [0] * batch_size # the number of EOS found + # in the backward loop below for each batch + t = self.max_len_labels - 1 + # initialize the back pointer with the sorted order of the last step beams. + # add pos_index for indexing variable with b*k as the first dimension. + t_predecessors = paddle.reshape( + sorted_idx + pos_index.expand_as(sorted_idx), + shape=[batch_size * beam_width]) + while t >= 0: + # Re-order the variables with the back pointer + current_symbol = paddle.index_select( + stored_emitted_symbols[t], index=t_predecessors, axis=0) + t_predecessors = paddle.index_select( + stored_predecessors[t].squeeze(), index=t_predecessors, axis=0) + eos_indices = stored_emitted_symbols[t] == eos + eos_indices = paddle.nonzero(eos_indices) + + if eos_indices.dim() > 0: + for i in range(eos_indices.shape[0] - 1, -1, -1): + # Indices of the EOS symbol for both variables + # with b*k as the first dimension, and b, k for + # the first two dimensions + idx = eos_indices[i] + b_idx = int(idx[0] / beam_width) + # The indices of the replacing position + # according to the replacement strategy noted above + res_k_idx = beam_width - (batch_eos_found[b_idx] % + beam_width) - 1 + batch_eos_found[b_idx] += 1 + res_idx = b_idx * beam_width + res_k_idx + + # Replace the old information in return variables + # with the new ended sequence information + t_predecessors[res_idx] = stored_predecessors[t][idx[0]] + current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]] + s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0] + l[b_idx][res_k_idx] = t + 1 + + # record the back tracked results + p.append(current_symbol) + t -= 1 + + # Sort and re-order again as the added ended sequences may change + # the order (very unlikely) + s, re_sorted_idx = s.topk(beam_width) + for b_idx in range(batch_size): + l[b_idx] = [ + l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :] + ] + + re_sorted_idx = paddle.reshape( + re_sorted_idx + pos_index.expand_as(re_sorted_idx), + [batch_size * beam_width]) + + # Reverse the sequences and re-order at the same time + # It is reversed because the backtracking happens in reverse time order + p = [ + paddle.reshape( + paddle.index_select(step, re_sorted_idx, 0), + shape=[batch_size, beam_width, -1]) for step in reversed(p) + ] + p = paddle.concat(p, -1)[:, 0, :] + return p, paddle.ones_like(p) + class AttentionUnit(nn.Layer): def __init__(self, sDim, xDim, attDim): @@ -151,21 +314,9 @@ class AttentionUnit(nn.Layer): self.xDim = xDim self.attDim = attDim - self.sEmbed = nn.Linear( - sDim, - attDim, - weight_attr=paddle.nn.initializer.Normal(std=0.01), - bias_attr=paddle.nn.initializer.Constant(0.0)) - self.xEmbed = nn.Linear( - xDim, - attDim, - weight_attr=paddle.nn.initializer.Normal(std=0.01), - bias_attr=paddle.nn.initializer.Constant(0.0)) - self.wEmbed = nn.Linear( - attDim, - 1, - weight_attr=paddle.nn.initializer.Normal(std=0.01), - bias_attr=paddle.nn.initializer.Constant(0.0)) + self.sEmbed = nn.Linear(sDim, attDim) + self.xEmbed = nn.Linear(xDim, attDim) + self.wEmbed = nn.Linear(attDim, 1) def forward(self, x, sPrev): batch_size, T, _ = x.shape # [b x T x xDim] @@ -184,10 +335,8 @@ class AttentionUnit(nn.Layer): vProj = self.wEmbed(sumTanh) # [(b x T) x 1] vProj = paddle.reshape(vProj, [batch_size, T]) - alpha = F.softmax( vProj, axis=1) # attention weights for each sample in the minibatch - return alpha @@ -238,21 +387,4 @@ class DecoderUnit(nn.Layer): output, state = self.gru(concat_context, sPrev) output = paddle.squeeze(output, axis=1) output = self.fc(output) - return output, state - - -if __name__ == "__main__": - model = AttentionRecognitionHead( - num_classes=20, - in_channels=30, - sDim=512, - attDim=512, - max_len_labels=25, - out_channels=38) - - data = paddle.ones([16, 64, 3]) - targets = paddle.ones([16, 25]) - length = paddle.to_tensor(20) - x = [data, targets, length] - output = model(x) - print(output.shape) + return output, state \ No newline at end of file diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index 79f112f7..4286d769 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -44,13 +44,10 @@ class AttentionHead(nn.Layer): hidden = paddle.zeros((batch_size, self.hidden_size)) output_hiddens = [] - targets = targets[0] - print(targets) if targets is not None: for i in range(num_steps): char_onehots = self._char_to_onehot( targets[:, i], onehot_dim=self.num_classes) - # print("char_onehots:", char_onehots) (outputs, hidden), alpha = self.attention_cell(hidden, inputs, char_onehots) output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) @@ -107,8 +104,6 @@ class AttentionGRUCell(nn.Layer): alpha = paddle.transpose(alpha, [0, 2, 1]) context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) concat_context = paddle.concat([context, char_onehots], 1) - # print("concat_context:", concat_context.shape) - # print("prev_hidden:", prev_hidden.shape) cur_hidden = self.rnn(concat_context, prev_hidden) diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py index 0b26e27a..23bd2189 100644 --- a/ppocr/modeling/transforms/stn.py +++ b/ppocr/modeling/transforms/stn.py @@ -106,16 +106,3 @@ class STN(nn.Layer): x = F.sigmoid(x) x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) return img_feat, x - - -if __name__ == "__main__": - in_planes = 3 - num_ctrlpoints = 20 - np.random.seed(100) - activation = 'none' # 'sigmoid' - stn_head = STN(in_planes, num_ctrlpoints, activation) - data = np.random.randn(10, 3, 32, 64).astype("float32") - print("data:", np.sum(data)) - input = paddle.to_tensor(data) - #input = paddle.randn([10, 3, 32, 64]) - control_points = stn_head(input) diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index fc462100..de4bb7a6 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -326,5 +326,6 @@ class STN_ON(nn.Layer): image, self.tps_inputsize, mode="bilinear", align_corners=True) stn_img_feat, ctrl_points = self.stn_head(stn_input) x, _ = self.tps(image, ctrl_points) + #print("x:", np.sum(x.numpy())) # print(x.shape) return x diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py index da54ffb7..731e3ee9 100644 --- a/ppocr/modeling/transforms/tps_spatial_transformer.py +++ b/ppocr/modeling/transforms/tps_spatial_transformer.py @@ -136,7 +136,8 @@ class TPSSpatialTransformer(nn.Layer): assert source_control_points.ndimension() == 3 assert source_control_points.shape[1] == self.num_control_points assert source_control_points.shape[2] == 2 - batch_size = source_control_points.shape[0] + #batch_size = source_control_points.shape[0] + batch_size = paddle.shape(source_control_points)[0] self.padding_matrix = paddle.expand( self.padding_matrix, shape=[batch_size, 3, 2]) @@ -151,28 +152,6 @@ class TPSSpatialTransformer(nn.Layer): grid = paddle.clip(grid, 0, 1) # the source_control_points may be out of [0, 1]. # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] - # grid = 2.0 * grid - 1.0 + grid = 2.0 * grid - 1.0 output_maps = grid_sample(input, grid, canvas=None) return output_maps, source_coordinate - - -if __name__ == "__main__": - from stn import STN - in_planes = 3 - num_ctrlpoints = 20 - np.random.seed(100) - activation = 'none' # 'sigmoid' - stn_head = STN(in_planes, num_ctrlpoints, activation) - data = np.random.randn(10, 3, 32, 64).astype("float32") - input = paddle.to_tensor(data) - #input = paddle.randn([10, 3, 32, 64]) - control_points = stn_head(input) - #print("control points:", control_points) - #input = paddle.randn(shape=[10,3,32,100]) - tps = TPSSpatialTransformer( - output_image_size=[32, 320], - num_control_points=20, - margins=[0.05, 0.05]) - out = tps(input, control_points[1]) - print("out 0 :", out[0].shape) - print("out 1:", out[1].shape) diff --git a/ppocr/modeling/transforms/tps_torch.py b/ppocr/modeling/transforms/tps_torch.py deleted file mode 100644 index 7aee133a..00000000 --- a/ppocr/modeling/transforms/tps_torch.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import absolute_import - -import numpy as np -import itertools - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def grid_sample(input, grid, canvas=None): - output = F.grid_sample(input, grid) - if canvas is None: - return output - else: - input_mask = input.data.new(input.size()).fill_(1) - output_mask = F.grid_sample(input_mask, grid) - padded_output = output * output_mask + canvas * (1 - output_mask) - return padded_output - - -# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 -def compute_partial_repr(input_points, control_points): - N = input_points.size(0) - M = control_points.size(0) - pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) - # original implementation, very slow - # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance - pairwise_diff_square = pairwise_diff * pairwise_diff - pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, - 1] - repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) - # fix numerical error for 0 * log(0), substitute all nan with 0 - mask = repr_matrix != repr_matrix - repr_matrix.masked_fill_(mask, 0) - return repr_matrix - - -# output_ctrl_pts are specified, according to our task. -def build_output_control_points(num_control_points, margins): - margin_x, margin_y = margins - num_ctrl_pts_per_side = num_control_points // 2 - ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) - ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y - ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) - ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) - ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) - # ctrl_pts_top = ctrl_pts_top[1:-1,:] - # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] - output_ctrl_pts_arr = np.concatenate( - [ctrl_pts_top, ctrl_pts_bottom], axis=0) - output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) - return output_ctrl_pts - - -# demo: ~/test/models/test_tps_transformation.py -class TPSSpatialTransformer(nn.Module): - def __init__(self, - output_image_size=None, - num_control_points=None, - margins=None): - super(TPSSpatialTransformer, self).__init__() - self.output_image_size = output_image_size - self.num_control_points = num_control_points - self.margins = margins - - self.target_height, self.target_width = output_image_size - target_control_points = build_output_control_points(num_control_points, - margins) - N = num_control_points - # N = N - 4 - - # create padded kernel matrix - forward_kernel = torch.zeros(N + 3, N + 3) - target_control_partial_repr = compute_partial_repr( - target_control_points, target_control_points) - forward_kernel[:N, :N].copy_(target_control_partial_repr) - forward_kernel[:N, -3].fill_(1) - forward_kernel[-3, :N].fill_(1) - forward_kernel[:N, -2:].copy_(target_control_points) - forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) - # compute inverse matrix - inverse_kernel = torch.inverse(forward_kernel) - - # create target cordinate matrix - HW = self.target_height * self.target_width - target_coordinate = list( - itertools.product( - range(self.target_height), range(self.target_width))) - target_coordinate = torch.Tensor(target_coordinate) # HW x 2 - Y, X = target_coordinate.split(1, dim=1) - Y = Y / (self.target_height - 1) - X = X / (self.target_width - 1) - target_coordinate = torch.cat([X, Y], - dim=1) # convert from (y, x) to (x, y) - target_coordinate_partial_repr = compute_partial_repr( - target_coordinate, target_control_points) - target_coordinate_repr = torch.cat([ - target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate - ], - dim=1) - - # register precomputed matrices - self.register_buffer('inverse_kernel', inverse_kernel) - self.register_buffer('padding_matrix', torch.zeros(3, 2)) - self.register_buffer('target_coordinate_repr', target_coordinate_repr) - self.register_buffer('target_control_points', target_control_points) - - def forward(self, input, source_control_points): - assert source_control_points.ndimension() == 3 - assert source_control_points.size(1) == self.num_control_points - assert source_control_points.size(2) == 2 - batch_size = source_control_points.size(0) - - Y = torch.cat([ - source_control_points, self.padding_matrix.expand(batch_size, 3, 2) - ], 1) - mapping_matrix = torch.matmul(self.inverse_kernel, Y) - source_coordinate = torch.matmul(self.target_coordinate_repr, - mapping_matrix) - - grid = source_coordinate.view(-1, self.target_height, self.target_width, - 2) - grid = torch.clamp(grid, 0, - 1) # the source_control_points may be out of [0, 1]. - # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] - grid = 2.0 * grid - 1.0 - output_maps = grid_sample(input, grid, canvas=None) - return output_maps, source_coordinate - - -if __name__ == "__main__": - from stn_torch import STNHead - in_planes = 3 - num_ctrlpoints = 20 - torch.manual_seed(10) - activation = 'none' # 'sigmoid' - stn_head = STNHead(in_planes, num_ctrlpoints, activation) - np.random.seed(100) - data = np.random.randn(10, 3, 32, 64).astype("float32") - input = torch.tensor(data) - control_points = stn_head(input) - tps = TPSSpatialTransformer( - output_image_size=[32, 320], - num_control_points=20, - margins=[0.05, 0.05]) - out = tps(input, control_points[1]) - print("out 0 :", out[0].shape) - print("out 1:", out[1].shape) diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py index 8215b92d..34098c0f 100644 --- a/ppocr/optimizer/optimizer.py +++ b/ppocr/optimizer/optimizer.py @@ -127,3 +127,34 @@ class RMSProp(object): grad_clip=self.grad_clip, parameters=parameters) return opt + + +class Adadelta(object): + def __init__(self, + learning_rate=0.001, + epsilon=1e-08, + rho=0.95, + parameter_list=None, + weight_decay=None, + grad_clip=None, + name=None, + **kwargs): + self.learning_rate = learning_rate + self.epsilon = epsilon + self.rho = rho + self.parameter_list = parameter_list + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.grad_clip = grad_clip + self.name = name + + def __call__(self, parameters): + opt = optim.Adadelta( + learning_rate=self.learning_rate, + epsilon=self.epsilon, + rho=self.rho, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + name=self.name, + parameters=parameters) + return opt diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 2f5bdc3b..ba7e06db 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ - TableLabelDecode + TableLabelDecode, SEEDLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess @@ -34,7 +34,7 @@ def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode', 'TableLabelDecode' + 'DistillationCTCLabelDecode', 'TableLabelDecode', 'SEEDLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 17fc7e46..921d619a 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode): def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" - self.unkonwn = "UNKNOWN" dict_character = dict_character - dict_character = [self.beg_str] + dict_character + [self.end_str - ] + [self.unkonwn] + dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character def decode(self, text_index, text_prob=None, is_remove_duplicate=False): @@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode): label = self.decode(label, is_remove_duplicate=False) return text, label """ - preds = preds["rec_pred"] if isinstance(preds, paddle.Tensor): preds = preds.numpy() @@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode): return idx +class SEEDLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SEEDLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = dict_character + [self.end_str] + return dict_character + + def get_ignored_tokens(self): + end_idx = self.get_beg_end_flag_idx("eos") + return [end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "sos": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "eos": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + [end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + preds_idx = preds["rec_pred"] + if isinstance(preds_idx, paddle.Tensor): + preds_idx = preds_idx.numpy() + if "rec_pred_scores" in preds: + preds_idx = preds["rec_pred"] + preds_prob = preds["rec_pred_scores"] + else: + preds_idx = preds["rec_pred"].argmax(axis=2) + preds_prob = preds["rec_pred"].max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + class SRNLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 0453509c..1d760e98 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -105,16 +105,13 @@ def load_dygraph_params(config, model, logger, optimizer): params = paddle.load(pm) state_dict = model.state_dict() new_state_dict = {} - # for k1, k2 in zip(state_dict.keys(), params.keys()): - for k1 in state_dict.keys(): - if k1 not in params: - continue - if list(state_dict[k1].shape) == list(params[k1].shape): - new_state_dict[k1] = params[k1] - else: - logger.info( - f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !" - ) + for k1, k2 in zip(state_dict.keys(), params.keys()): + if list(state_dict[k1].shape) == list(params[k2].shape): + new_state_dict[k1] = params[k2] + else: + logger.info( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" + ) model.set_state_dict(new_state_dict) logger.info(f"loaded pretrained_model successful from {pm}") return {} diff --git a/tools/program.py b/tools/program.py index 920cf417..3479ff26 100755 --- a/tools/program.py +++ b/tools/program.py @@ -211,11 +211,10 @@ def train(config, images = batch[0] if use_srn: model_average = True - # if use_srn or model_type == 'table' or algorithm == "ASTER": - # preds = model(images, data=batch[1:]) - # else: - # preds = model(images) - preds = model(images, data=batch[1:]) + if use_srn or model_type == 'table' or model_type == "seed": + preds = model(images, data=batch[1:]) + else: + preds = model(images) state_dict = model.state_dict() # for key in state_dict: # print(key) @@ -415,6 +414,7 @@ def preprocess(is_train=False): yaml.dump( dict(config), f, default_flow_style=False, sort_keys=False) log_file = '{}/train.log'.format(save_model_dir) + print("log has save in {}/train.log".format(save_model_dir)) else: log_file = None logger = get_logger(name='root', log_file=log_file) -- GitLab