From cf533b65c5acb9c1482d23c8c5201c8431d00d90 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Tue, 19 Jul 2022 12:38:54 +0000 Subject: [PATCH] add vl --- ppocr/data/imaug/__init__.py | 3 +- ppocr/data/imaug/label_ops.py | 63 ++- ppocr/data/imaug/rec_img_aug.py | 35 ++ ppocr/data/imaug/text_image_aug/__init__.py | 3 +- ppocr/data/imaug/text_image_aug/vl_aug.py | 460 +++++++++++++++++ ppocr/losses/__init__.py | 4 +- ppocr/modeling/backbones/__init__.py | 4 +- ppocr/modeling/backbones/rec_resnet_aster.py | 111 +++++ ppocr/modeling/heads/__init__.py | 3 +- ppocr/modeling/heads/rec_visionlan_head.py | 498 +++++++++++++++++++ ppocr/postprocess/__init__.py | 4 +- ppocr/postprocess/rec_postprocess.py | 70 ++- tools/eval.py | 2 +- tools/export_model.py | 2 +- tools/infer/predict_rec.py | 20 + tools/program.py | 42 +- 16 files changed, 1297 insertions(+), 27 deletions(-) create mode 100644 ppocr/data/imaug/text_image_aug/vl_aug.py create mode 100644 ppocr/modeling/heads/rec_visionlan_head.py diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index f0fd578f..20719e02 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ - SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg + SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, VLRecResizeImg +from .text_image_aug import VLAug from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 02a5187d..304e190d 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -23,6 +23,7 @@ import string from shapely.geometry import LineString, Point, Polygon import json import copy +from random import sample from ppocr.utils.logging import get_logger @@ -443,7 +444,9 @@ class KieLabelEncode(object): elif 'key_cls' in anno.keys(): labels.append(anno['key_cls']) else: - raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.") + raise ValueError( + "Cannot found 'key_cls' in ann.keys(), please check your training annotation." + ) edges.append(ann.get('edge', 0)) ann_infos = dict( image=data['image'], @@ -1044,3 +1047,61 @@ class MultiLabelEncode(BaseRecLabelEncode): data_out['label_sar'] = sar['label'] data_out['length'] = ctc['length'] return data_out + + +class VLLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(VLLabelEncode, self).__init__(max_text_length, + character_dict_path, use_space_char) + + def __call__(self, data): + text = data['label'] # original string + # generate occluded text + len_str = len(text) + if len_str <= 0: + return None + change_num = 1 + order = list(range(len_str)) + change_id = sample(order, change_num)[0] + label_sub = text[change_id] + if change_id == (len_str - 1): + label_res = text[:change_id] + elif change_id == 0: + label_res = text[1:] + else: + label_res = text[:change_id] + text[change_id + 1:] + + data['label_res'] = label_res # remaining string + data['label_sub'] = label_sub # occluded character + data['label_id'] = change_id # character index + # encode label + text = self.encode(text) + if text is None: + return None + text = [i + 1 for i in text] + data['length'] = np.array(len(text)) + text = text + [0] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + label_res = self.encode(label_res) + label_sub = self.encode(label_sub) + if label_res is None: + label_res = [] + else: + label_res = [i + 1 for i in label_res] + if label_sub is None: + label_sub = [] + else: + label_sub = [i + 1 for i in label_sub] + data['length_res'] = np.array(len(label_res)) + data['length_sub'] = np.array(len(label_sub)) + label_res = label_res + [0] * (self.max_text_len - len(label_res)) + label_sub = label_sub + [0] * (self.max_text_len - len(label_sub)) + data['label_res'] = np.array(label_res) + data['label_sub'] = np.array(label_sub) + return data diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 32de2b3f..18d57963 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -213,6 +213,41 @@ class RecResizeImg(object): return data +class VLRecResizeImg(object): + def __init__(self, + image_shape, + infer_mode=False, + character_dict_path='./ppocr/utils/ppocr_keys_v1.txt', + padding=True, + **kwargs): + self.image_shape = image_shape + self.infer_mode = infer_mode + self.character_dict_path = character_dict_path + self.padding = padding + + def __call__(self, data): + img = data['image'] + if self.infer_mode and self.character_dict_path is not None: + norm_img, valid_ratio = resize_norm_img_chinese(img, + self.image_shape) + else: + imgC, imgH, imgW = self.image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_w = imgW + resized_image = resized_image.astype('float32') + if self.image_shape[0] == 1: + resized_image = resized_image / 255 + norm_img = resized_image[np.newaxis, :] + else: + norm_img = resized_image.transpose((2, 0, 1)) / 255 + valid_ratio = min(1.0, float(resized_w / imgW)) + + data['image'] = norm_img + data['valid_ratio'] = valid_ratio + return data + + class SRNRecResizeImg(object): def __init__(self, image_shape, num_heads, max_text_length, **kwargs): self.image_shape = image_shape diff --git a/ppocr/data/imaug/text_image_aug/__init__.py b/ppocr/data/imaug/text_image_aug/__init__.py index bca26263..ca108b28 100644 --- a/ppocr/data/imaug/text_image_aug/__init__.py +++ b/ppocr/data/imaug/text_image_aug/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from .augment import tia_perspective, tia_distort, tia_stretch +from .vl_aug import VLAug -__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective'] +__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective', 'VLAug'] diff --git a/ppocr/data/imaug/text_image_aug/vl_aug.py b/ppocr/data/imaug/text_image_aug/vl_aug.py new file mode 100644 index 00000000..50b066b1 --- /dev/null +++ b/ppocr/data/imaug/text_image_aug/vl_aug.py @@ -0,0 +1,460 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numbers +import random + +import cv2 +import numpy as np +from PIL import Image +from paddle.vision import transforms +from paddle.vision.transforms import Compose + + +def sample_asym(magnitude, size=None): + return np.random.beta(1, 4, size) * magnitude + + +def sample_sym(magnitude, size=None): + return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude + + +def sample_uniform(low, high, size=None): + return np.random.uniform(low, high, size=size) + + +def get_interpolation(type='random'): + if type == 'random': + choice = [ + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA + ] + interpolation = choice[random.randint(0, len(choice) - 1)] + elif type == 'nearest': + interpolation = cv2.INTER_NEAREST + elif type == 'linear': + interpolation = cv2.INTER_LINEAR + elif type == 'cubic': + interpolation = cv2.INTER_CUBIC + elif type == 'area': + interpolation = cv2.INTER_AREA + else: + raise TypeError( + 'Interpolation types only nearest, linear, cubic, area are supported!' + ) + return interpolation + + +class CVRandomRotation(object): + def __init__(self, degrees=15): + assert isinstance(degrees, + numbers.Number), "degree should be a single number." + assert degrees >= 0, "degree must be positive." + self.degrees = degrees + + @staticmethod + def get_params(degrees): + return sample_sym(degrees) + + def __call__(self, img): + angle = self.get_params(self.degrees) + src_h, src_w = img.shape[:2] + M = cv2.getRotationMatrix2D( + center=(src_w / 2, src_h / 2), angle=angle, scale=1.0) + abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1]) + dst_w = int(src_h * abs_sin + src_w * abs_cos) + dst_h = int(src_h * abs_cos + src_w * abs_sin) + M[0, 2] += (dst_w - src_w) / 2 + M[1, 2] += (dst_h - src_h) / 2 + + flags = get_interpolation() + return cv2.warpAffine( + img, + M, (dst_w, dst_h), + flags=flags, + borderMode=cv2.BORDER_REPLICATE) + + +class CVRandomAffine(object): + def __init__(self, degrees, translate=None, scale=None, shear=None): + assert isinstance(degrees, + numbers.Number), "degree should be a single number." + assert degrees >= 0, "degree must be positive." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError( + "translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError( + "If shear is a single number, it must be positive.") + self.shear = [shear] + else: + assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \ + "shear should be a list or tuple and it must be of length 2." + self.shear = shear + else: + self.shear = shear + + def _get_inverse_affine_matrix(self, center, angle, translate, scale, + shear): + from numpy import sin, cos, tan + + if isinstance(shear, numbers.Number): + shear = [shear, 0] + + if not isinstance(shear, (tuple, list)) and len(shear) == 2: + raise ValueError( + "Shear should be a single value or a tuple/list containing " + + "two values. Got {}".format(shear)) + + rot = math.radians(angle) + sx, sy = [math.radians(s) for s in shear] + + cx, cy = center + tx, ty = translate + + # RSS without scaling + a = cos(rot - sy) / cos(sy) + b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) + c = sin(rot - sy) / cos(sy) + d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) + + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + M = [d, -b, 0, -c, a, 0] + M = [x / scale for x in M] + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) + M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) + + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + M[2] += cx + M[5] += cy + return M + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, height): + angle = sample_sym(degrees) + if translate is not None: + max_dx = translate[0] * height + max_dy = translate[1] * height + translations = (np.round(sample_sym(max_dx)), + np.round(sample_sym(max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = sample_uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + if shears is not None: + if len(shears) == 1: + shear = [sample_sym(shears[0]), 0.] + elif len(shears) == 2: + shear = [sample_sym(shears[0]), sample_sym(shears[1])] + else: + shear = 0.0 + + return angle, translations, scale, shear + + def __call__(self, img): + src_h, src_w = img.shape[:2] + angle, translate, scale, shear = self.get_params( + self.degrees, self.translate, self.scale, self.shear, src_h) + + M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle, + (0, 0), scale, shear) + M = np.array(M).reshape(2, 3) + + startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), + (0, src_h - 1)] + project = lambda x, y, a, b, c: int(a * x + b * y + c) + endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) + for x, y in startpoints] + + rect = cv2.minAreaRect(np.array(endpoints)) + bbox = cv2.boxPoints(rect).astype(dtype=np.int) + max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() + min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() + + dst_w = int(max_x - min_x) + dst_h = int(max_y - min_y) + M[0, 2] += (dst_w - src_w) / 2 + M[1, 2] += (dst_h - src_h) / 2 + + # add translate + dst_w += int(abs(translate[0])) + dst_h += int(abs(translate[1])) + if translate[0] < 0: M[0, 2] += abs(translate[0]) + if translate[1] < 0: M[1, 2] += abs(translate[1]) + + flags = get_interpolation() + return cv2.warpAffine( + img, + M, (dst_w, dst_h), + flags=flags, + borderMode=cv2.BORDER_REPLICATE) + + +class CVRandomPerspective(object): + def __init__(self, distortion=0.5): + self.distortion = distortion + + def get_params(self, width, height, distortion): + offset_h = sample_asym( + distortion * height / 2, size=4).astype(dtype=np.int) + offset_w = sample_asym( + distortion * width / 2, size=4).astype(dtype=np.int) + topleft = (offset_w[0], offset_h[0]) + topright = (width - 1 - offset_w[1], offset_h[1]) + botright = (width - 1 - offset_w[2], height - 1 - offset_h[2]) + botleft = (offset_w[3], height - 1 - offset_h[3]) + + startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), + (0, height - 1)] + endpoints = [topleft, topright, botright, botleft] + return np.array( + startpoints, dtype=np.float32), np.array( + endpoints, dtype=np.float32) + + def __call__(self, img): + height, width = img.shape[:2] + startpoints, endpoints = self.get_params(width, height, self.distortion) + M = cv2.getPerspectiveTransform(startpoints, endpoints) + + # TODO: more robust way to crop image + rect = cv2.minAreaRect(endpoints) + bbox = cv2.boxPoints(rect).astype(dtype=np.int) + max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() + min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() + min_x, min_y = max(min_x, 0), max(min_y, 0) + + flags = get_interpolation() + img = cv2.warpPerspective( + img, + M, (max_x, max_y), + flags=flags, + borderMode=cv2.BORDER_REPLICATE) + img = img[min_y:, min_x:] + return img + + +class CVRescale(object): + def __init__(self, factor=4, base_size=(128, 512)): + """ Define image scales using gaussian pyramid and rescale image to target scale. + + Args: + factor: the decayed factor from base size, factor=4 keeps target scale by default. + base_size: base size the build the bottom layer of pyramid + """ + if isinstance(factor, numbers.Number): + self.factor = round(sample_uniform(0, factor)) + elif isinstance(factor, (tuple, list)) and len(factor) == 2: + self.factor = round(sample_uniform(factor[0], factor[1])) + else: + raise Exception('factor must be number or list with length 2') + # assert factor is valid + self.base_h, self.base_w = base_size[:2] + + def __call__(self, img): + if self.factor == 0: + return img + src_h, src_w = img.shape[:2] + cur_w, cur_h = self.base_w, self.base_h + scale_img = cv2.resize( + img, (cur_w, cur_h), interpolation=get_interpolation()) + for _ in range(np.int(self.factor)): + scale_img = cv2.pyrDown(scale_img) + scale_img = cv2.resize( + scale_img, (src_w, src_h), interpolation=get_interpolation()) + return scale_img + + +class CVGaussianNoise(object): + def __init__(self, mean=0, var=20): + self.mean = mean + if isinstance(var, numbers.Number): + self.var = max(int(sample_asym(var)), 1) + elif isinstance(var, (tuple, list)) and len(var) == 2: + self.var = int(sample_uniform(var[0], var[1])) + else: + raise Exception('degree must be number or list with length 2') + + def __call__(self, img): + noise = np.random.normal(self.mean, self.var**0.5, img.shape) + img = np.clip(img + noise, 0, 255).astype(np.uint8) + return img + + +class CVMotionBlur(object): + def __init__(self, degrees=12, angle=90): + if isinstance(degrees, numbers.Number): + self.degree = max(int(sample_asym(degrees)), 1) + elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: + self.degree = int(sample_uniform(degrees[0], degrees[1])) + else: + raise Exception('degree must be number or list with length 2') + self.angle = sample_uniform(-angle, angle) + + def __call__(self, img): + M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), + self.angle, 1) + motion_blur_kernel = np.zeros((self.degree, self.degree)) + motion_blur_kernel[self.degree // 2, :] = 1 + motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, + (self.degree, self.degree)) + motion_blur_kernel = motion_blur_kernel / self.degree + img = cv2.filter2D(img, -1, motion_blur_kernel) + img = np.clip(img, 0, 255).astype(np.uint8) + return img + + +class CVGeometry(object): + def __init__(self, + degrees=15, + translate=(0.3, 0.3), + scale=(0.5, 2.), + shear=(45, 15), + distortion=0.5, + p=0.5): + self.p = p + type_p = random.random() + if type_p < 0.33: + self.transforms = CVRandomRotation(degrees=degrees) + elif type_p < 0.66: + self.transforms = CVRandomAffine( + degrees=degrees, translate=translate, scale=scale, shear=shear) + else: + self.transforms = CVRandomPerspective(distortion=distortion) + + def __call__(self, img): + if random.random() < self.p: + return self.transforms(img) + else: + return img + + +class CVDeterioration(object): + def __init__(self, var, degrees, factor, p=0.5): + self.p = p + transforms = [] + if var is not None: + transforms.append(CVGaussianNoise(var=var)) + if degrees is not None: + transforms.append(CVMotionBlur(degrees=degrees)) + if factor is not None: + transforms.append(CVRescale(factor=factor)) + + random.shuffle(transforms) + transforms = Compose(transforms) + self.transforms = transforms + + def __call__(self, img): + if random.random() < self.p: + return self.transforms(img) + else: + return img + + +class CVColorJitter(object): + def __init__(self, + brightness=0.5, + contrast=0.5, + saturation=0.5, + hue=0.1, + p=0.5): + self.p = p + self.transforms = transforms.ColorJitter( + brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue) + + def __call__(self, img): + if random.random() < self.p: + return self.transforms(img) + else: + return img + + +class VLAug(object): + def __init__(self, + geometry_p=0.5, + Deterioration_p=0.25, + ColorJitter_p=0.25, + **kwargs): + self.Geometry = CVGeometry( + degrees=45, + translate=(0.0, 0.0), + scale=(0.5, 2.), + shear=(45, 15), + distortion=0.5, + p=geometry_p) + self.Deterioration = CVDeterioration( + var=20, degrees=6, factor=4, p=Deterioration_p) + self.ColorJitter = CVColorJitter( + brightness=0.5, + contrast=0.5, + saturation=0.5, + hue=0.1, + p=ColorJitter_p) + + def __call__(self, data): + img = data['image'] + img = self.Geometry(img) + img = self.Deterioration(img) + img = self.ColorJitter(img) + data['image'] = img + return data + + +if __name__ == '__main__': + + geo = CVGeometry( + degrees=45, + translate=(0.0, 0.0), + scale=(0.5, 2.), + shear=(45, 15), + distortion=0.5, + p=1) + det = CVDeterioration(var=20, degrees=6, factor=4, p=1) + color = CVColorJitter( + brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=1) + + img = np.ones((64, 256, 3)) + img = geo(img) + img = det(img) + img = color(img) + # import pdb + # pdb.set_trace() + # print() diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index de8419b7..e1e0635a 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss from .rec_aster_loss import AsterLoss from .rec_pren_loss import PRENLoss from .rec_multi_loss import MultiLoss +from .rec_vl_loss import VLLoss # cls loss from .cls_loss import ClsLoss @@ -61,7 +62,8 @@ def build_loss(config): 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', - 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' + 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', + 'VLLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 072d6e0f..bd6dc5ce 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -28,14 +28,14 @@ def build_backbone(config, model_type): from .rec_mv1_enhance import MobileNetV1Enhance from .rec_nrtr_mtb import MTB from .rec_resnet_31 import ResNet31 - from .rec_resnet_aster import ResNet_ASTER + from .rec_resnet_aster import ResNet_ASTER, ResNet45 from .rec_micronet import MicroNet from .rec_efficientb3_pren import EfficientNetb3_PREN from .rec_svtrnet import SVTRNet support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN', - 'SVTRNet' + 'SVTRNet', 'ResNet45' ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py index 6a2710df..a59c2da2 100644 --- a/ppocr/modeling/backbones/rec_resnet_aster.py +++ b/ppocr/modeling/backbones/rec_resnet_aster.py @@ -20,6 +20,10 @@ import paddle.nn as nn import sys import math +from paddle.nn.initializer import KaimingNormal, Constant + +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) def conv3x3(in_planes, out_planes, stride=1): @@ -141,3 +145,110 @@ class ResNet_ASTER(nn.Layer): return rnn_feat else: return cnn_feat + + +class Block(nn.Layer): + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Block, self).__init__() + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = nn.BatchNorm2D(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes, stride) + self.bn2 = nn.BatchNorm2D(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet45(nn.Layer): + def __init__(self, in_channels=3, compress_layer=False): + super(ResNet45, self).__init__() + self.compress_layer = compress_layer + + self.conv1_new = nn.Conv2D( + in_channels, + 32, + kernel_size=(3, 3), + stride=1, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm2D(32) + self.relu = nn.ReLU() + + self.inplanes = 32 + self.layer1 = self._make_layer(32, 3, [2, 2]) # [32, 128] + self.layer2 = self._make_layer(64, 4, [2, 2]) # [16, 64] + self.layer3 = self._make_layer(128, 6, [2, 2]) # [8, 32] + self.layer4 = self._make_layer(256, 6, [1, 1]) # [8, 32] + self.layer5 = self._make_layer(512, 3, [1, 1]) # [8, 32] + + if self.compress_layer: + self.layer6 = nn.Sequential( + nn.Conv2D( + 512, 256, kernel_size=(3, 1), padding=(0, 0), stride=(1, + 1)), + nn.BatchNorm(256), + nn.ReLU()) + self.out_channels = 256 + else: + self.out_channels = 512 + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2D): + KaimingNormal(m.weight) + elif isinstance(m, nn.BatchNorm): + ones_(m.weight) + zeros_(m.bias) + + def _make_layer(self, planes, blocks, stride): + downsample = None + if stride != [1, 1] or self.inplanes != planes: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes)) + + layers = [] + layers.append(Block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes + for _ in range(1, blocks): + layers.append(Block(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1_new(x) + x = self.bn1(x) + x = self.relu(x) + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + x5 = self.layer5(x4) + + if not self.compress_layer: + return x5 + else: + x6 = self.layer6(x5) + return x6 + + +if __name__ == '__main__': + model = ResNet45() + x = paddle.rand([1, 3, 64, 256]) + x = paddle.to_tensor(x) + print(x.shape) + out = model(x) + print(out.shape) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 1670ea38..37ad6bd6 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -33,6 +33,7 @@ def build_head(config): from .rec_aster_head import AsterHead from .rec_pren_head import PRENHead from .rec_multi_head import MultiHead + from .rec_visionlan_head import VLHead # cls head from .cls_head import ClsHead @@ -46,7 +47,7 @@ def build_head(config): 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', - 'MultiHead' + 'MultiHead', 'VLHead' ] #table head diff --git a/ppocr/modeling/heads/rec_visionlan_head.py b/ppocr/modeling/heads/rec_visionlan_head.py new file mode 100644 index 00000000..a5d60598 --- /dev/null +++ b/ppocr/modeling/heads/rec_visionlan_head.py @@ -0,0 +1,498 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import Normal, XavierNormal +import numpy as np +from ppocr.modeling.backbones.rec_resnet_aster import ResNet45 + + +class PositionalEncoding(nn.Layer): + def __init__(self, d_hid, n_position=200): + super(PositionalEncoding, self).__init__() + self.register_buffer( + 'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + ''' Sinusoid position encoding table ''' + + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32') + sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0) + return sinusoid_table + + def forward(self, x): + return x + self.pos_table[:, :x.shape[1]].clone().detach() + + +class ScaledDotProductAttention(nn.Layer): + "Scaled Dot-Product Attention" + + def __init__(self, temperature, attn_dropout=0.1): + super(ScaledDotProductAttention, self).__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + self.softmax = nn.Softmax(axis=2) + + def forward(self, q, k, v, mask=None): + k = paddle.transpose(k, perm=[0, 2, 1]) + attn = paddle.bmm(q, k) + attn = attn / self.temperature + if mask is not None: + attn = attn.masked_fill(mask, -1e9) + if mask.dim() == 3: + mask = paddle.unsqueeze(mask, axis=1) + elif mask.dim() == 2: + mask = paddle.unsqueeze(mask, axis=1) + mask = paddle.unsqueeze(mask, axis=1) + repeat_times = [ + attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2] + ] + mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1]) + attn[mask == 0] = -1e9 + attn = self.softmax(attn) + attn = self.dropout(attn) + output = paddle.bmm(attn, v) + return output + + +class MultiHeadAttention(nn.Layer): + " Multi-Head Attention module" + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super(MultiHeadAttention, self).__init__() + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + self.w_qs = nn.Linear( + d_model, + n_head * d_k, + weight_attr=ParamAttr(initializer=Normal( + mean=0, std=np.sqrt(2.0 / (d_model + d_k))))) + self.w_ks = nn.Linear( + d_model, + n_head * d_k, + weight_attr=ParamAttr(initializer=Normal( + mean=0, std=np.sqrt(2.0 / (d_model + d_k))))) + self.w_vs = nn.Linear( + d_model, + n_head * d_v, + weight_attr=ParamAttr(initializer=Normal( + mean=0, std=np.sqrt(2.0 / (d_model + d_v))))) + + self.attention = ScaledDotProductAttention(temperature=np.power(d_k, + 0.5)) + self.layer_norm = nn.LayerNorm(d_model) + self.fc = nn.Linear( + n_head * d_v, + d_model, + weight_attr=ParamAttr(initializer=XavierNormal())) + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, _ = q.shape + sz_b, len_k, _ = k.shape + sz_b, len_v, _ = v.shape + residual = q + + q = self.w_qs(q) + q = paddle.reshape( + q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64 + k = self.w_ks(k) + k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k]) + v = self.w_vs(v) + v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v]) + + q = paddle.transpose(q, perm=[2, 0, 1, 3]) + q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk + k = paddle.transpose(k, perm=[2, 0, 1, 3]) + k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk + v = paddle.transpose(v, perm=[2, 0, 1, 3]) + v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv + + mask = paddle.tile( + mask, + [n_head, 1, 1]) if mask is not None else None # (n*b) x .. x .. + output = self.attention(q, k, v, mask=mask) + output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v]) + output = paddle.transpose(output, perm=[1, 2, 0, 3]) + output = paddle.reshape( + output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv) + output = self.dropout(self.fc(output)) + output = self.layer_norm(output + residual) + return output + + +class PositionwiseFeedForward(nn.Layer): + def __init__(self, d_in, d_hid, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise + self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise + self.layer_norm = nn.LayerNorm(d_in) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = paddle.transpose(x, perm=[0, 2, 1]) + x = self.w_2(F.relu(self.w_1(x))) + x = paddle.transpose(x, perm=[0, 2, 1]) + x = self.dropout(x) + x = self.layer_norm(x + residual) + return x + + +class EncoderLayer(nn.Layer): + ''' Compose with two layers ''' + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(EncoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout) + + def forward(self, enc_input, slf_attn_mask=None): + enc_output = self.slf_attn( + enc_input, enc_input, enc_input, mask=slf_attn_mask) + enc_output = self.pos_ffn(enc_output) + return enc_output + + +class Transformer_Encoder(nn.Layer): + def __init__(self, + n_layers=2, + n_head=8, + d_word_vec=512, + d_k=64, + d_v=64, + d_model=512, + d_inner=2048, + dropout=0.1, + n_position=256): + super(Transformer_Encoder, self).__init__() + self.position_enc = PositionalEncoding( + d_word_vec, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + self.layer_stack = nn.LayerList([ + EncoderLayer( + d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6) + + def forward(self, enc_output, src_mask, return_attns=False): + enc_output = self.dropout( + self.position_enc(enc_output)) # position embeding + for enc_layer in self.layer_stack: + enc_output = enc_layer(enc_output, slf_attn_mask=src_mask) + enc_output = self.layer_norm(enc_output) + return enc_output + + +class PP_layer(nn.Layer): + def __init__(self, n_dim=512, N_max_character=25, n_position=256): + + super(PP_layer, self).__init__() + self.character_len = N_max_character + self.f0_embedding = nn.Embedding(N_max_character, n_dim) + self.w0 = nn.Linear(N_max_character, n_position) + self.wv = nn.Linear(n_dim, n_dim) + self.we = nn.Linear(n_dim, N_max_character) + self.active = nn.Tanh() + self.softmax = nn.Softmax(axis=2) + + def forward(self, enc_output): + # enc_output: b,256,512 + reading_order = paddle.arange(self.character_len, dtype='int64') + reading_order = reading_order.unsqueeze(0).expand( + [enc_output.shape[0], -1]) # (S,) -> (B, S) + reading_order = self.f0_embedding(reading_order) # b,25,512 + + # calculate attention + reading_order = paddle.transpose(reading_order, perm=[0, 2, 1]) + t = self.w0(reading_order) # b,512,256 + t = self.active( + paddle.transpose( + t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512 + t = self.we(t) # b,256,25 + t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256 + g_output = paddle.bmm(t, enc_output) # b,25,512 + return g_output + + +class Prediction(nn.Layer): + def __init__(self, + n_dim=512, + n_position=256, + N_max_character=25, + n_class=37): + super(Prediction, self).__init__() + self.pp = PP_layer( + n_dim=n_dim, N_max_character=N_max_character, n_position=n_position) + self.pp_share = PP_layer( + n_dim=n_dim, N_max_character=N_max_character, n_position=n_position) + self.w_vrm = nn.Linear(n_dim, n_class) # output layer + self.w_share = nn.Linear(n_dim, n_class) # output layer + self.nclass = n_class + + def forward(self, cnn_feature, f_res, f_sub, train_mode=False, + use_mlm=True): + if train_mode: + if not use_mlm: + g_output = self.pp(cnn_feature) # b,25,512 + g_output = self.w_vrm(g_output) + f_res = 0 + f_sub = 0 + return g_output, f_res, f_sub + g_output = self.pp(cnn_feature) # b,25,512 + f_res = self.pp_share(f_res) + f_sub = self.pp_share(f_sub) + g_output = self.w_vrm(g_output) + f_res = self.w_share(f_res) + f_sub = self.w_share(f_sub) + return g_output, f_res, f_sub + else: + g_output = self.pp(cnn_feature) # b,25,512 + g_output = self.w_vrm(g_output) + return g_output + + +class MLM(nn.Layer): + "Architecture of MLM" + + def __init__(self, n_dim=512, n_position=256, max_text_length=25): + super(MLM, self).__init__() + self.MLM_SequenceModeling_mask = Transformer_Encoder( + n_layers=2, n_position=n_position) + self.MLM_SequenceModeling_WCL = Transformer_Encoder( + n_layers=1, n_position=n_position) + self.pos_embedding = nn.Embedding(max_text_length, n_dim) + self.w0_linear = nn.Linear(1, n_position) + self.wv = nn.Linear(n_dim, n_dim) + self.active = nn.Tanh() + self.we = nn.Linear(n_dim, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x, label_pos): + # transformer unit for generating mask_c + feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None) + # position embedding layer + label_pos = paddle.to_tensor(label_pos, dtype='int64') + pos_emb = self.pos_embedding(label_pos) + pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2)) + pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1]) + # fusion position embedding with features V & generate mask_c + att_map_sub = self.active(pos_emb + self.wv(feature_v_seq)) + att_map_sub = self.we(att_map_sub) # b,256,1 + att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1]) + att_map_sub = self.sigmoid(att_map_sub) # b,1,256 + # WCL + ## generate inputs for WCL + att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1]) + f_res = x * (1 - att_map_sub) # second path with remaining string + f_sub = x * att_map_sub # first path with occluded character + ## transformer units in WCL + f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None) + f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None) + return f_res, f_sub, att_map_sub + + +def trans_1d_2d(x): + b, w_h, c = x.shape # b, 256, 512 + x = paddle.transpose(x, perm=[0, 2, 1]) + x = paddle.reshape(x, [-1, c, 32, 8]) + x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32] + return x + + +class MLM_VRM(nn.Layer): + """ + MLM+VRM, MLM is only used in training. + ratio controls the occluded number in a batch. + The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer). + x: input image + label_pos: character index + training_step: LF or LA process + output + text_pre: prediction of VRM + test_rem: prediction of remaining string in MLM + text_mas: prediction of occluded character in MLM + mask_c_show: visualization of Mask_c + """ + + def __init__(self, + n_layers=3, + n_position=256, + n_dim=512, + max_text_length=25, + nclass=37): + super(MLM_VRM, self).__init__() + self.MLM = MLM(n_dim=n_dim, + n_position=n_position, + max_text_length=max_text_length) + self.SequenceModeling = Transformer_Encoder( + n_layers=n_layers, n_position=n_position) + self.Prediction = Prediction( + n_dim=n_dim, + n_position=n_position, + N_max_character=max_text_length + + 1, # N_max_character = 1 eos + 25 characters + n_class=nclass) + self.nclass = nclass + self.max_text_length = max_text_length + + def forward(self, x, label_pos, training_step, train_mode=False): + b, c, h, w = x.shape + nT = self.max_text_length + x = paddle.transpose(x, perm=[0, 1, 3, 2]) + x = paddle.reshape(x, [-1, c, h * w]) + x = paddle.transpose(x, perm=[0, 2, 1]) + if train_mode: + if training_step == 'LF_1': + f_res = 0 + f_sub = 0 + x = self.SequenceModeling(x, src_mask=None) + text_pre, test_rem, text_mas = self.Prediction( + x, f_res, f_sub, train_mode=True, use_mlm=False) + return text_pre, text_pre, text_pre, text_pre + elif training_step == 'LF_2': + # MLM + f_res, f_sub, mask_c = self.MLM(x, label_pos) + x = self.SequenceModeling(x, src_mask=None) + text_pre, test_rem, text_mas = self.Prediction( + x, f_res, f_sub, train_mode=True) + mask_c_show = trans_1d_2d(mask_c) + return text_pre, test_rem, text_mas, mask_c_show + elif training_step == 'LA': + # MLM + f_res, f_sub, mask_c = self.MLM(x, label_pos) + ## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input + ## ratio controls the occluded number in a batch + character_mask = paddle.zeros_like(mask_c) + + ratio = b // 2 + if ratio >= 1: + with paddle.no_grad(): + character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :] + else: + character_mask = mask_c + x = x * (1 - character_mask) + # VRM + ## transformer unit for VRM + x = self.SequenceModeling(x, src_mask=None) + ## prediction layer for MLM and VSR + text_pre, test_rem, text_mas = self.Prediction( + x, f_res, f_sub, train_mode=True) + mask_c_show = trans_1d_2d(mask_c) + return text_pre, test_rem, text_mas, mask_c_show + else: + raise NotImplementedError + else: # VRM is only used in the testing stage + f_res = 0 + f_sub = 0 + contextual_feature = self.SequenceModeling(x, src_mask=None) + text_pre = self.Prediction( + contextual_feature, + f_res, + f_sub, + train_mode=False, + use_mlm=False) + text_pre = paddle.transpose( + text_pre, perm=[1, 0, 2]) # (26, b, 37)) + lenText = nT + nsteps = nT + out_res = paddle.zeros( + shape=[lenText, b, self.nclass], dtype=x.dtype) # (25, b, 37) + out_length = paddle.zeros(shape=[b], dtype=x.dtype) + now_step = 0 + for _ in range(nsteps): + if 0 in out_length and now_step < nsteps: + tmp_result = text_pre[now_step, :, :] + out_res[now_step] = tmp_result + tmp_result = tmp_result.topk(1)[1].squeeze(axis=1) + for j in range(b): + if out_length[j] == 0 and tmp_result[j] == 0: + out_length[j] = now_step + 1 + now_step += 1 + # while 0 in out_length and now_step < nsteps: + # tmp_result = text_pre[now_step, :, :] + # out_res[now_step] = tmp_result + # tmp_result = tmp_result.topk(1)[1].squeeze(axis=1) + # for j in range(b): + # if out_length[j] == 0 and tmp_result[j] == 0: + # out_length[j] = now_step + 1 + # now_step += 1 + for j in range(0, b): + if int(out_length[j]) == 0: + out_length[j] = nsteps + start = 0 + output = paddle.zeros( + shape=[int(out_length.sum()), self.nclass], dtype=x.dtype) + for i in range(0, b): + cur_length = int(out_length[i]) + output[start:start + cur_length] = out_res[0:cur_length, i, :] + start += cur_length + return output, out_length + + +class VLHead(nn.Layer): + """ + Architecture of VisionLAN + """ + + def __init__(self, + in_channels, + out_channels=36, + n_layers=3, + n_position=256, + n_dim=512, + max_text_length=25, + training_step='LA'): + super(VLHead, self).__init__() + self.MLM_VRM = MLM_VRM( + n_layers=n_layers, + n_position=n_position, + n_dim=n_dim, + max_text_length=max_text_length, + nclass=out_channels + 1) + self.training_step = training_step + + def forward(self, feat, targets=None): + + if self.training: + label_pos = targets[-2] + text_pre, test_rem, text_mas, mask_map = self.MLM_VRM( + feat, label_pos, self.training_step, train_mode=True) + return text_pre, test_rem, text_mas, mask_map + else: + output, out_length = self.MLM_VRM( + feat, targets, self.training_step, train_mode=False) + return output, out_length diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index f50b5f1c..a22b7996 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ - SEEDLabelDecode, PRENLabelDecode + SEEDLabelDecode, PRENLabelDecode, VLLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess @@ -42,7 +42,7 @@ def build_post_process(config, global_config=None): 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', - 'DistillationSARLabelDecode' + 'DistillationSARLabelDecode', 'VLLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index bf0fd890..e2434cd7 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -27,7 +27,8 @@ class BaseRecLabelDecode(object): self.character_str = [] if character_dict_path is None: - self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + # self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + self.character_str = "abcdefghijklmnopqrstuvwxyz1234567890" dict_character = list(self.character_str) else: with open(character_dict_path, "rb") as fin: @@ -752,3 +753,70 @@ class PRENLabelDecode(BaseRecLabelDecode): return text label = self.decode(label) return text, label + + +class VLLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(VLLabelDecode, self).__init__(character_dict_path, use_space_char) + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[ + batch_idx][:-1] + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + + char_list = [ + self.character[text_id - 1] + for text_id in text_index[batch_idx][selection] + ] + if text_prob is not None: + conf_list = text_prob[batch_idx][selection] + else: + conf_list = [1] * len(selection) + if len(conf_list) == 0: + conf_list = [0] + + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, length=None, *args, **kwargs): + if len(preds) == 2: # eval mode + net_out, length = preds + else: # train mode + net_out = preds[0] + length = length + net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)]) + text = [] + if not isinstance(net_out, paddle.Tensor): + net_out = paddle.to_tensor(net_out, dtype='float32') + # import pdb + # pdb.set_trace() + net_out = F.softmax(net_out, axis=1) + for i in range(0, length.shape[0]): + preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum( + ) + length[i])].topk(1)[1][:, 0].tolist() + preds_text = ''.join([ + self.character[idx - 1] + if idx > 0 and idx <= len(self.character) else '' + for idx in preds_idx + ]) + preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum( + ) + length[i])].topk(1)[0][:, 0] + preds_prob = paddle.exp( + paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6)) + text.append((preds_text, preds_prob)) + if label is None: + return text + label = self.decode(label) + return text, label diff --git a/tools/eval.py b/tools/eval.py index cab28334..2fc53488 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -73,7 +73,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: diff --git a/tools/export_model.py b/tools/export_model.py index c0cbcd36..5d17410a 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -55,7 +55,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): shape=[None, 3, 48, 160], dtype="float32"), ] model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "SVTR": + elif arch_config["algorithm"] in ["SVTR", "VisionLAN"]: if arch_config["Head"]["name"] == 'MultiHead': other_shape = [ paddle.static.InputSpec( diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 3664ef2c..cdfc984c 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -69,6 +69,12 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "VisionLAN": + postprocess_params = { + 'name': 'VLLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -143,6 +149,15 @@ class TextRecognizer(object): resized_image /= 0.5 return resized_image + def resize_norm_img_vl(self, img, image_shape): + + imgC, imgH, imgW = image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + return resized_image + def resize_norm_img_srn(self, img, image_shape): imgC, imgH, imgW = image_shape @@ -300,6 +315,11 @@ class TextRecognizer(object): self.rec_image_shape) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) + elif self.rec_algorithm == "VisionLAN": + norm_img = self.resize_norm_img_vl(img_list[indices[ino]], + self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) else: norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) diff --git a/tools/program.py b/tools/program.py index aa0d2698..bf774fd4 100755 --- a/tools/program.py +++ b/tools/program.py @@ -207,7 +207,7 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: @@ -249,7 +249,6 @@ def train(config, images = batch[0] if use_srn: model_average = True - # use amp if scaler: with paddle.amp.auto_cast(): @@ -264,7 +263,6 @@ def train(config, preds = model(batch) else: preds = model(images) - loss = loss_class(preds, batch) avg_loss = loss['loss'] @@ -286,6 +284,9 @@ def train(config, ]: # for multi head loss post_result = post_process_class( preds['ctc'], batch[1]) # for CTC head out + elif config['Loss']['name'] in ['VLLoss']: + post_result = post_process_class(preds, batch[1], + batch[-1]) else: post_result = post_process_class(preds, batch[1]) eval_class(post_result, batch) @@ -307,7 +308,8 @@ def train(config, train_stats.update(stats) if log_writer is not None and dist.get_rank() == 0: - log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) + log_writer.log_metrics( + metrics=train_stats.get(), prefix="TRAIN", step=global_step) if dist.get_rank() == 0 and ( (global_step > 0 and global_step % print_batch_step == 0) or @@ -354,7 +356,8 @@ def train(config, # logger metric if log_writer is not None: - log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) + log_writer.log_metrics( + metrics=cur_metric, prefix="EVAL", step=global_step) if cur_metric[main_indicator] >= best_model_dict[ main_indicator]: @@ -377,11 +380,18 @@ def train(config, logger.info(best_str) # logger best metric if log_writer is not None: - log_writer.log_metrics(metrics={ - "best_{}".format(main_indicator): best_model_dict[main_indicator] - }, prefix="EVAL", step=global_step) - - log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) + log_writer.log_metrics( + metrics={ + "best_{}".format(main_indicator): + best_model_dict[main_indicator] + }, + prefix="EVAL", + step=global_step) + + log_writer.log_model( + is_best=True, + prefix="best_accuracy", + metadata=best_model_dict) reader_start = time.time() if dist.get_rank() == 0: @@ -413,7 +423,8 @@ def train(config, epoch=epoch, global_step=global_step) if log_writer is not None: - log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) + log_writer.log_model( + is_best=False, prefix='iter_epoch_{}'.format(epoch)) best_str = 'best metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) @@ -451,7 +462,6 @@ def eval(model, preds = model(batch) else: preds = model(images) - batch_numpy = [] for item in batch: if isinstance(item, paddle.Tensor): @@ -564,7 +574,8 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', + 'VisionLAN' ] if use_xpu: @@ -583,9 +594,10 @@ def preprocess(is_train=False): if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']: save_model_dir = config['Global']['save_model_dir'] vdl_writer_path = '{}/vdl/'.format(save_model_dir) - log_writer = VDLLogger(save_model_dir) + log_writer = VDLLogger(vdl_writer_path) loggers.append(log_writer) - if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: + if ('use_wandb' in config['Global'] and + config['Global']['use_wandb']) or 'wandb' in config: save_dir = config['Global']['save_model_dir'] wandb_writer_path = "{}/wandb".format(save_dir) if "wandb" in config: -- GitLab