提交 cf533b65 编写于 作者: A andyjpaddle

add vl

上级 05a98305
......@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, VLRecResizeImg
from .text_image_aug import VLAug
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -23,6 +23,7 @@ import string
from shapely.geometry import LineString, Point, Polygon
import json
import copy
from random import sample
from ppocr.utils.logging import get_logger
......@@ -443,7 +444,9 @@ class KieLabelEncode(object):
elif 'key_cls' in anno.keys():
labels.append(anno['key_cls'])
else:
raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.")
raise ValueError(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
edges.append(ann.get('edge', 0))
ann_infos = dict(
image=data['image'],
......@@ -1044,3 +1047,61 @@ class MultiLabelEncode(BaseRecLabelEncode):
data_out['label_sar'] = sar['label']
data_out['length'] = ctc['length']
return data_out
class VLLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(VLLabelEncode, self).__init__(max_text_length,
character_dict_path, use_space_char)
def __call__(self, data):
text = data['label'] # original string
# generate occluded text
len_str = len(text)
if len_str <= 0:
return None
change_num = 1
order = list(range(len_str))
change_id = sample(order, change_num)[0]
label_sub = text[change_id]
if change_id == (len_str - 1):
label_res = text[:change_id]
elif change_id == 0:
label_res = text[1:]
else:
label_res = text[:change_id] + text[change_id + 1:]
data['label_res'] = label_res # remaining string
data['label_sub'] = label_sub # occluded character
data['label_id'] = change_id # character index
# encode label
text = self.encode(text)
if text is None:
return None
text = [i + 1 for i in text]
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
label_res = self.encode(label_res)
label_sub = self.encode(label_sub)
if label_res is None:
label_res = []
else:
label_res = [i + 1 for i in label_res]
if label_sub is None:
label_sub = []
else:
label_sub = [i + 1 for i in label_sub]
data['length_res'] = np.array(len(label_res))
data['length_sub'] = np.array(len(label_sub))
label_res = label_res + [0] * (self.max_text_len - len(label_res))
label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
data['label_res'] = np.array(label_res)
data['label_sub'] = np.array(label_sub)
return data
......@@ -213,6 +213,41 @@ class RecResizeImg(object):
return data
class VLRecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_dict_path is not None:
norm_img, valid_ratio = resize_norm_img_chinese(img,
self.image_shape)
else:
imgC, imgH, imgW = self.image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
resized_image = resized_image.astype('float32')
if self.image_shape[0] == 1:
resized_image = resized_image / 255
norm_img = resized_image[np.newaxis, :]
else:
norm_img = resized_image.transpose((2, 0, 1)) / 255
valid_ratio = min(1.0, float(resized_w / imgW))
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
......
......@@ -13,5 +13,6 @@
# limitations under the License.
from .augment import tia_perspective, tia_distort, tia_stretch
from .vl_aug import VLAug
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective', 'VLAug']
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numbers
import random
import cv2
import numpy as np
from PIL import Image
from paddle.vision import transforms
from paddle.vision.transforms import Compose
def sample_asym(magnitude, size=None):
return np.random.beta(1, 4, size) * magnitude
def sample_sym(magnitude, size=None):
return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
def sample_uniform(low, high, size=None):
return np.random.uniform(low, high, size=size)
def get_interpolation(type='random'):
if type == 'random':
choice = [
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
]
interpolation = choice[random.randint(0, len(choice) - 1)]
elif type == 'nearest':
interpolation = cv2.INTER_NEAREST
elif type == 'linear':
interpolation = cv2.INTER_LINEAR
elif type == 'cubic':
interpolation = cv2.INTER_CUBIC
elif type == 'area':
interpolation = cv2.INTER_AREA
else:
raise TypeError(
'Interpolation types only nearest, linear, cubic, area are supported!'
)
return interpolation
class CVRandomRotation(object):
def __init__(self, degrees=15):
assert isinstance(degrees,
numbers.Number), "degree should be a single number."
assert degrees >= 0, "degree must be positive."
self.degrees = degrees
@staticmethod
def get_params(degrees):
return sample_sym(degrees)
def __call__(self, img):
angle = self.get_params(self.degrees)
src_h, src_w = img.shape[:2]
M = cv2.getRotationMatrix2D(
center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
dst_w = int(src_h * abs_sin + src_w * abs_cos)
dst_h = int(src_h * abs_cos + src_w * abs_sin)
M[0, 2] += (dst_w - src_w) / 2
M[1, 2] += (dst_h - src_h) / 2
flags = get_interpolation()
return cv2.warpAffine(
img,
M, (dst_w, dst_h),
flags=flags,
borderMode=cv2.BORDER_REPLICATE)
class CVRandomAffine(object):
def __init__(self, degrees, translate=None, scale=None, shear=None):
assert isinstance(degrees,
numbers.Number), "degree should be a single number."
assert degrees >= 0, "degree must be positive."
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError(
"translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError(
"If shear is a single number, it must be positive.")
self.shear = [shear]
else:
assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
"shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear
def _get_inverse_affine_matrix(self, center, angle, translate, scale,
shear):
from numpy import sin, cos, tan
if isinstance(shear, numbers.Number):
shear = [shear, 0]
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
raise ValueError(
"Shear should be a single value or a tuple/list containing " +
"two values. Got {}".format(shear))
rot = math.radians(angle)
sx, sy = [math.radians(s) for s in shear]
cx, cy = center
tx, ty = translate
# RSS without scaling
a = cos(rot - sy) / cos(sy)
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
c = sin(rot - sy) / cos(sy)
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
M = [d, -b, 0, -c, a, 0]
M = [x / scale for x in M]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
M[2] += cx
M[5] += cy
return M
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, height):
angle = sample_sym(degrees)
if translate is not None:
max_dx = translate[0] * height
max_dy = translate[1] * height
translations = (np.round(sample_sym(max_dx)),
np.round(sample_sym(max_dy)))
else:
translations = (0, 0)
if scale_ranges is not None:
scale = sample_uniform(scale_ranges[0], scale_ranges[1])
else:
scale = 1.0
if shears is not None:
if len(shears) == 1:
shear = [sample_sym(shears[0]), 0.]
elif len(shears) == 2:
shear = [sample_sym(shears[0]), sample_sym(shears[1])]
else:
shear = 0.0
return angle, translations, scale, shear
def __call__(self, img):
src_h, src_w = img.shape[:2]
angle, translate, scale, shear = self.get_params(
self.degrees, self.translate, self.scale, self.shear, src_h)
M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
(0, 0), scale, shear)
M = np.array(M).reshape(2, 3)
startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
(0, src_h - 1)]
project = lambda x, y, a, b, c: int(a * x + b * y + c)
endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
for x, y in startpoints]
rect = cv2.minAreaRect(np.array(endpoints))
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
dst_w = int(max_x - min_x)
dst_h = int(max_y - min_y)
M[0, 2] += (dst_w - src_w) / 2
M[1, 2] += (dst_h - src_h) / 2
# add translate
dst_w += int(abs(translate[0]))
dst_h += int(abs(translate[1]))
if translate[0] < 0: M[0, 2] += abs(translate[0])
if translate[1] < 0: M[1, 2] += abs(translate[1])
flags = get_interpolation()
return cv2.warpAffine(
img,
M, (dst_w, dst_h),
flags=flags,
borderMode=cv2.BORDER_REPLICATE)
class CVRandomPerspective(object):
def __init__(self, distortion=0.5):
self.distortion = distortion
def get_params(self, width, height, distortion):
offset_h = sample_asym(
distortion * height / 2, size=4).astype(dtype=np.int)
offset_w = sample_asym(
distortion * width / 2, size=4).astype(dtype=np.int)
topleft = (offset_w[0], offset_h[0])
topright = (width - 1 - offset_w[1], offset_h[1])
botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
botleft = (offset_w[3], height - 1 - offset_h[3])
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
(0, height - 1)]
endpoints = [topleft, topright, botright, botleft]
return np.array(
startpoints, dtype=np.float32), np.array(
endpoints, dtype=np.float32)
def __call__(self, img):
height, width = img.shape[:2]
startpoints, endpoints = self.get_params(width, height, self.distortion)
M = cv2.getPerspectiveTransform(startpoints, endpoints)
# TODO: more robust way to crop image
rect = cv2.minAreaRect(endpoints)
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
min_x, min_y = max(min_x, 0), max(min_y, 0)
flags = get_interpolation()
img = cv2.warpPerspective(
img,
M, (max_x, max_y),
flags=flags,
borderMode=cv2.BORDER_REPLICATE)
img = img[min_y:, min_x:]
return img
class CVRescale(object):
def __init__(self, factor=4, base_size=(128, 512)):
""" Define image scales using gaussian pyramid and rescale image to target scale.
Args:
factor: the decayed factor from base size, factor=4 keeps target scale by default.
base_size: base size the build the bottom layer of pyramid
"""
if isinstance(factor, numbers.Number):
self.factor = round(sample_uniform(0, factor))
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
self.factor = round(sample_uniform(factor[0], factor[1]))
else:
raise Exception('factor must be number or list with length 2')
# assert factor is valid
self.base_h, self.base_w = base_size[:2]
def __call__(self, img):
if self.factor == 0:
return img
src_h, src_w = img.shape[:2]
cur_w, cur_h = self.base_w, self.base_h
scale_img = cv2.resize(
img, (cur_w, cur_h), interpolation=get_interpolation())
for _ in range(np.int(self.factor)):
scale_img = cv2.pyrDown(scale_img)
scale_img = cv2.resize(
scale_img, (src_w, src_h), interpolation=get_interpolation())
return scale_img
class CVGaussianNoise(object):
def __init__(self, mean=0, var=20):
self.mean = mean
if isinstance(var, numbers.Number):
self.var = max(int(sample_asym(var)), 1)
elif isinstance(var, (tuple, list)) and len(var) == 2:
self.var = int(sample_uniform(var[0], var[1]))
else:
raise Exception('degree must be number or list with length 2')
def __call__(self, img):
noise = np.random.normal(self.mean, self.var**0.5, img.shape)
img = np.clip(img + noise, 0, 255).astype(np.uint8)
return img
class CVMotionBlur(object):
def __init__(self, degrees=12, angle=90):
if isinstance(degrees, numbers.Number):
self.degree = max(int(sample_asym(degrees)), 1)
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
self.degree = int(sample_uniform(degrees[0], degrees[1]))
else:
raise Exception('degree must be number or list with length 2')
self.angle = sample_uniform(-angle, angle)
def __call__(self, img):
M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
self.angle, 1)
motion_blur_kernel = np.zeros((self.degree, self.degree))
motion_blur_kernel[self.degree // 2, :] = 1
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
(self.degree, self.degree))
motion_blur_kernel = motion_blur_kernel / self.degree
img = cv2.filter2D(img, -1, motion_blur_kernel)
img = np.clip(img, 0, 255).astype(np.uint8)
return img
class CVGeometry(object):
def __init__(self,
degrees=15,
translate=(0.3, 0.3),
scale=(0.5, 2.),
shear=(45, 15),
distortion=0.5,
p=0.5):
self.p = p
type_p = random.random()
if type_p < 0.33:
self.transforms = CVRandomRotation(degrees=degrees)
elif type_p < 0.66:
self.transforms = CVRandomAffine(
degrees=degrees, translate=translate, scale=scale, shear=shear)
else:
self.transforms = CVRandomPerspective(distortion=distortion)
def __call__(self, img):
if random.random() < self.p:
return self.transforms(img)
else:
return img
class CVDeterioration(object):
def __init__(self, var, degrees, factor, p=0.5):
self.p = p
transforms = []
if var is not None:
transforms.append(CVGaussianNoise(var=var))
if degrees is not None:
transforms.append(CVMotionBlur(degrees=degrees))
if factor is not None:
transforms.append(CVRescale(factor=factor))
random.shuffle(transforms)
transforms = Compose(transforms)
self.transforms = transforms
def __call__(self, img):
if random.random() < self.p:
return self.transforms(img)
else:
return img
class CVColorJitter(object):
def __init__(self,
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.1,
p=0.5):
self.p = p
self.transforms = transforms.ColorJitter(
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue)
def __call__(self, img):
if random.random() < self.p:
return self.transforms(img)
else:
return img
class VLAug(object):
def __init__(self,
geometry_p=0.5,
Deterioration_p=0.25,
ColorJitter_p=0.25,
**kwargs):
self.Geometry = CVGeometry(
degrees=45,
translate=(0.0, 0.0),
scale=(0.5, 2.),
shear=(45, 15),
distortion=0.5,
p=geometry_p)
self.Deterioration = CVDeterioration(
var=20, degrees=6, factor=4, p=Deterioration_p)
self.ColorJitter = CVColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.1,
p=ColorJitter_p)
def __call__(self, data):
img = data['image']
img = self.Geometry(img)
img = self.Deterioration(img)
img = self.ColorJitter(img)
data['image'] = img
return data
if __name__ == '__main__':
geo = CVGeometry(
degrees=45,
translate=(0.0, 0.0),
scale=(0.5, 2.),
shear=(45, 15),
distortion=0.5,
p=1)
det = CVDeterioration(var=20, degrees=6, factor=4, p=1)
color = CVColorJitter(
brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=1)
img = np.ones((64, 256, 3))
img = geo(img)
img = det(img)
img = color(img)
# import pdb
# pdb.set_trace()
# print()
......@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss
# cls loss
from .cls_loss import ClsLoss
......@@ -61,7 +62,8 @@ def build_loss(config):
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'VLLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
......@@ -28,14 +28,14 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER
from .rec_resnet_aster import ResNet_ASTER, ResNet45
from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
'SVTRNet'
'SVTRNet', 'ResNet45'
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
......
......@@ -20,6 +20,10 @@ import paddle.nn as nn
import sys
import math
from paddle.nn.initializer import KaimingNormal, Constant
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
def conv3x3(in_planes, out_planes, stride=1):
......@@ -141,3 +145,110 @@ class ResNet_ASTER(nn.Layer):
return rnn_feat
else:
return cnn_feat
class Block(nn.Layer):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Block, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2D(planes)
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = nn.BatchNorm2D(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet45(nn.Layer):
def __init__(self, in_channels=3, compress_layer=False):
super(ResNet45, self).__init__()
self.compress_layer = compress_layer
self.conv1_new = nn.Conv2D(
in_channels,
32,
kernel_size=(3, 3),
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(32)
self.relu = nn.ReLU()
self.inplanes = 32
self.layer1 = self._make_layer(32, 3, [2, 2]) # [32, 128]
self.layer2 = self._make_layer(64, 4, [2, 2]) # [16, 64]
self.layer3 = self._make_layer(128, 6, [2, 2]) # [8, 32]
self.layer4 = self._make_layer(256, 6, [1, 1]) # [8, 32]
self.layer5 = self._make_layer(512, 3, [1, 1]) # [8, 32]
if self.compress_layer:
self.layer6 = nn.Sequential(
nn.Conv2D(
512, 256, kernel_size=(3, 1), padding=(0, 0), stride=(1,
1)),
nn.BatchNorm(256),
nn.ReLU())
self.out_channels = 256
else:
self.out_channels = 512
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2D):
KaimingNormal(m.weight)
elif isinstance(m, nn.BatchNorm):
ones_(m.weight)
zeros_(m.bias)
def _make_layer(self, planes, blocks, stride):
downsample = None
if stride != [1, 1] or self.inplanes != planes:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
layers = []
layers.append(Block(self.inplanes, planes, stride, downsample))
self.inplanes = planes
for _ in range(1, blocks):
layers.append(Block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1_new(x)
x = self.bn1(x)
x = self.relu(x)
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
if not self.compress_layer:
return x5
else:
x6 = self.layer6(x5)
return x6
if __name__ == '__main__':
model = ResNet45()
x = paddle.rand([1, 3, 64, 256])
x = paddle.to_tensor(x)
print(x.shape)
out = model(x)
print(out.shape)
......@@ -33,6 +33,7 @@ def build_head(config):
from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead
from .rec_visionlan_head import VLHead
# cls head
from .cls_head import ClsHead
......@@ -46,7 +47,7 @@ def build_head(config):
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead'
'MultiHead', 'VLHead'
]
#table head
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, XavierNormal
import numpy as np
from ppocr.modeling.backbones.rec_resnet_aster import ResNet45
class PositionalEncoding(nn.Layer):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
self.register_buffer(
'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
return sinusoid_table
def forward(self, x):
return x + self.pos_table[:, :x.shape[1]].clone().detach()
class ScaledDotProductAttention(nn.Layer):
"Scaled Dot-Product Attention"
def __init__(self, temperature, attn_dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(axis=2)
def forward(self, q, k, v, mask=None):
k = paddle.transpose(k, perm=[0, 2, 1])
attn = paddle.bmm(q, k)
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -1e9)
if mask.dim() == 3:
mask = paddle.unsqueeze(mask, axis=1)
elif mask.dim() == 2:
mask = paddle.unsqueeze(mask, axis=1)
mask = paddle.unsqueeze(mask, axis=1)
repeat_times = [
attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
]
mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
attn[mask == 0] = -1e9
attn = self.softmax(attn)
attn = self.dropout(attn)
output = paddle.bmm(attn, v)
return output
class MultiHeadAttention(nn.Layer):
" Multi-Head Attention module"
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(
d_model,
n_head * d_k,
weight_attr=ParamAttr(initializer=Normal(
mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
self.w_ks = nn.Linear(
d_model,
n_head * d_k,
weight_attr=ParamAttr(initializer=Normal(
mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
self.w_vs = nn.Linear(
d_model,
n_head * d_v,
weight_attr=ParamAttr(initializer=Normal(
mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
0.5))
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(
n_head * d_v,
d_model,
weight_attr=ParamAttr(initializer=XavierNormal()))
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.shape
sz_b, len_k, _ = k.shape
sz_b, len_v, _ = v.shape
residual = q
q = self.w_qs(q)
q = paddle.reshape(
q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
k = self.w_ks(k)
k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
v = self.w_vs(v)
v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
q = paddle.transpose(q, perm=[2, 0, 1, 3])
q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk
k = paddle.transpose(k, perm=[2, 0, 1, 3])
k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk
v = paddle.transpose(v, perm=[2, 0, 1, 3])
v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
mask = paddle.tile(
mask,
[n_head, 1, 1]) if mask is not None else None # (n*b) x .. x ..
output = self.attention(q, k, v, mask=mask)
output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
output = paddle.transpose(output, perm=[1, 2, 0, 3])
output = paddle.reshape(
output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output
class PositionwiseFeedForward(nn.Layer):
def __init__(self, d_in, d_hid, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise
self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise
self.layer_norm = nn.LayerNorm(d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = paddle.transpose(x, perm=[0, 2, 1])
x = self.w_2(F.relu(self.w_1(x)))
x = paddle.transpose(x, perm=[0, 2, 1])
x = self.dropout(x)
x = self.layer_norm(x + residual)
return x
class EncoderLayer(nn.Layer):
''' Compose with two layers '''
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout)
def forward(self, enc_input, slf_attn_mask=None):
enc_output = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output
class Transformer_Encoder(nn.Layer):
def __init__(self,
n_layers=2,
n_head=8,
d_word_vec=512,
d_k=64,
d_v=64,
d_model=512,
d_inner=2048,
dropout=0.1,
n_position=256):
super(Transformer_Encoder, self).__init__()
self.position_enc = PositionalEncoding(
d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.LayerList([
EncoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
def forward(self, enc_output, src_mask, return_attns=False):
enc_output = self.dropout(
self.position_enc(enc_output)) # position embeding
for enc_layer in self.layer_stack:
enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
enc_output = self.layer_norm(enc_output)
return enc_output
class PP_layer(nn.Layer):
def __init__(self, n_dim=512, N_max_character=25, n_position=256):
super(PP_layer, self).__init__()
self.character_len = N_max_character
self.f0_embedding = nn.Embedding(N_max_character, n_dim)
self.w0 = nn.Linear(N_max_character, n_position)
self.wv = nn.Linear(n_dim, n_dim)
self.we = nn.Linear(n_dim, N_max_character)
self.active = nn.Tanh()
self.softmax = nn.Softmax(axis=2)
def forward(self, enc_output):
# enc_output: b,256,512
reading_order = paddle.arange(self.character_len, dtype='int64')
reading_order = reading_order.unsqueeze(0).expand(
[enc_output.shape[0], -1]) # (S,) -> (B, S)
reading_order = self.f0_embedding(reading_order) # b,25,512
# calculate attention
reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
t = self.w0(reading_order) # b,512,256
t = self.active(
paddle.transpose(
t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512
t = self.we(t) # b,256,25
t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
g_output = paddle.bmm(t, enc_output) # b,25,512
return g_output
class Prediction(nn.Layer):
def __init__(self,
n_dim=512,
n_position=256,
N_max_character=25,
n_class=37):
super(Prediction, self).__init__()
self.pp = PP_layer(
n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
self.pp_share = PP_layer(
n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
self.w_vrm = nn.Linear(n_dim, n_class) # output layer
self.w_share = nn.Linear(n_dim, n_class) # output layer
self.nclass = n_class
def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
use_mlm=True):
if train_mode:
if not use_mlm:
g_output = self.pp(cnn_feature) # b,25,512
g_output = self.w_vrm(g_output)
f_res = 0
f_sub = 0
return g_output, f_res, f_sub
g_output = self.pp(cnn_feature) # b,25,512
f_res = self.pp_share(f_res)
f_sub = self.pp_share(f_sub)
g_output = self.w_vrm(g_output)
f_res = self.w_share(f_res)
f_sub = self.w_share(f_sub)
return g_output, f_res, f_sub
else:
g_output = self.pp(cnn_feature) # b,25,512
g_output = self.w_vrm(g_output)
return g_output
class MLM(nn.Layer):
"Architecture of MLM"
def __init__(self, n_dim=512, n_position=256, max_text_length=25):
super(MLM, self).__init__()
self.MLM_SequenceModeling_mask = Transformer_Encoder(
n_layers=2, n_position=n_position)
self.MLM_SequenceModeling_WCL = Transformer_Encoder(
n_layers=1, n_position=n_position)
self.pos_embedding = nn.Embedding(max_text_length, n_dim)
self.w0_linear = nn.Linear(1, n_position)
self.wv = nn.Linear(n_dim, n_dim)
self.active = nn.Tanh()
self.we = nn.Linear(n_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x, label_pos):
# transformer unit for generating mask_c
feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
# position embedding layer
label_pos = paddle.to_tensor(label_pos, dtype='int64')
pos_emb = self.pos_embedding(label_pos)
pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
# fusion position embedding with features V & generate mask_c
att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
att_map_sub = self.we(att_map_sub) # b,256,1
att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
att_map_sub = self.sigmoid(att_map_sub) # b,1,256
# WCL
## generate inputs for WCL
att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
f_res = x * (1 - att_map_sub) # second path with remaining string
f_sub = x * att_map_sub # first path with occluded character
## transformer units in WCL
f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
return f_res, f_sub, att_map_sub
def trans_1d_2d(x):
b, w_h, c = x.shape # b, 256, 512
x = paddle.transpose(x, perm=[0, 2, 1])
x = paddle.reshape(x, [-1, c, 32, 8])
x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32]
return x
class MLM_VRM(nn.Layer):
"""
MLM+VRM, MLM is only used in training.
ratio controls the occluded number in a batch.
The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
x: input image
label_pos: character index
training_step: LF or LA process
output
text_pre: prediction of VRM
test_rem: prediction of remaining string in MLM
text_mas: prediction of occluded character in MLM
mask_c_show: visualization of Mask_c
"""
def __init__(self,
n_layers=3,
n_position=256,
n_dim=512,
max_text_length=25,
nclass=37):
super(MLM_VRM, self).__init__()
self.MLM = MLM(n_dim=n_dim,
n_position=n_position,
max_text_length=max_text_length)
self.SequenceModeling = Transformer_Encoder(
n_layers=n_layers, n_position=n_position)
self.Prediction = Prediction(
n_dim=n_dim,
n_position=n_position,
N_max_character=max_text_length +
1, # N_max_character = 1 eos + 25 characters
n_class=nclass)
self.nclass = nclass
self.max_text_length = max_text_length
def forward(self, x, label_pos, training_step, train_mode=False):
b, c, h, w = x.shape
nT = self.max_text_length
x = paddle.transpose(x, perm=[0, 1, 3, 2])
x = paddle.reshape(x, [-1, c, h * w])
x = paddle.transpose(x, perm=[0, 2, 1])
if train_mode:
if training_step == 'LF_1':
f_res = 0
f_sub = 0
x = self.SequenceModeling(x, src_mask=None)
text_pre, test_rem, text_mas = self.Prediction(
x, f_res, f_sub, train_mode=True, use_mlm=False)
return text_pre, text_pre, text_pre, text_pre
elif training_step == 'LF_2':
# MLM
f_res, f_sub, mask_c = self.MLM(x, label_pos)
x = self.SequenceModeling(x, src_mask=None)
text_pre, test_rem, text_mas = self.Prediction(
x, f_res, f_sub, train_mode=True)
mask_c_show = trans_1d_2d(mask_c)
return text_pre, test_rem, text_mas, mask_c_show
elif training_step == 'LA':
# MLM
f_res, f_sub, mask_c = self.MLM(x, label_pos)
## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
## ratio controls the occluded number in a batch
character_mask = paddle.zeros_like(mask_c)
ratio = b // 2
if ratio >= 1:
with paddle.no_grad():
character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
else:
character_mask = mask_c
x = x * (1 - character_mask)
# VRM
## transformer unit for VRM
x = self.SequenceModeling(x, src_mask=None)
## prediction layer for MLM and VSR
text_pre, test_rem, text_mas = self.Prediction(
x, f_res, f_sub, train_mode=True)
mask_c_show = trans_1d_2d(mask_c)
return text_pre, test_rem, text_mas, mask_c_show
else:
raise NotImplementedError
else: # VRM is only used in the testing stage
f_res = 0
f_sub = 0
contextual_feature = self.SequenceModeling(x, src_mask=None)
text_pre = self.Prediction(
contextual_feature,
f_res,
f_sub,
train_mode=False,
use_mlm=False)
text_pre = paddle.transpose(
text_pre, perm=[1, 0, 2]) # (26, b, 37))
lenText = nT
nsteps = nT
out_res = paddle.zeros(
shape=[lenText, b, self.nclass], dtype=x.dtype) # (25, b, 37)
out_length = paddle.zeros(shape=[b], dtype=x.dtype)
now_step = 0
for _ in range(nsteps):
if 0 in out_length and now_step < nsteps:
tmp_result = text_pre[now_step, :, :]
out_res[now_step] = tmp_result
tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
for j in range(b):
if out_length[j] == 0 and tmp_result[j] == 0:
out_length[j] = now_step + 1
now_step += 1
# while 0 in out_length and now_step < nsteps:
# tmp_result = text_pre[now_step, :, :]
# out_res[now_step] = tmp_result
# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
# for j in range(b):
# if out_length[j] == 0 and tmp_result[j] == 0:
# out_length[j] = now_step + 1
# now_step += 1
for j in range(0, b):
if int(out_length[j]) == 0:
out_length[j] = nsteps
start = 0
output = paddle.zeros(
shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
for i in range(0, b):
cur_length = int(out_length[i])
output[start:start + cur_length] = out_res[0:cur_length, i, :]
start += cur_length
return output, out_length
class VLHead(nn.Layer):
"""
Architecture of VisionLAN
"""
def __init__(self,
in_channels,
out_channels=36,
n_layers=3,
n_position=256,
n_dim=512,
max_text_length=25,
training_step='LA'):
super(VLHead, self).__init__()
self.MLM_VRM = MLM_VRM(
n_layers=n_layers,
n_position=n_position,
n_dim=n_dim,
max_text_length=max_text_length,
nclass=out_channels + 1)
self.training_step = training_step
def forward(self, feat, targets=None):
if self.training:
label_pos = targets[-2]
text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
feat, label_pos, self.training_step, train_mode=True)
return text_pre, test_rem, text_mas, mask_map
else:
output, out_length = self.MLM_VRM(
feat, targets, self.training_step, train_mode=False)
return output, out_length
......@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode
SEEDLabelDecode, PRENLabelDecode, VLLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
......@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode'
'DistillationSARLabelDecode', 'VLLabelDecode'
]
if config['name'] == 'PSEPostProcess':
......
......@@ -27,7 +27,8 @@ class BaseRecLabelDecode(object):
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
# self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
self.character_str = "abcdefghijklmnopqrstuvwxyz1234567890"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
......@@ -752,3 +753,70 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text
label = self.decode(label)
return text, label
class VLLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id - 1]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, length=None, *args, **kwargs):
if len(preds) == 2: # eval mode
net_out, length = preds
else: # train mode
net_out = preds[0]
length = length
net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
text = []
if not isinstance(net_out, paddle.Tensor):
net_out = paddle.to_tensor(net_out, dtype='float32')
# import pdb
# pdb.set_trace()
net_out = F.softmax(net_out, axis=1)
for i in range(0, length.shape[0]):
preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
) + length[i])].topk(1)[1][:, 0].tolist()
preds_text = ''.join([
self.character[idx - 1]
if idx > 0 and idx <= len(self.character) else ''
for idx in preds_idx
])
preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
) + length[i])].topk(1)[0][:, 0]
preds_prob = paddle.exp(
paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
text.append((preds_text, preds_prob))
if label is None:
return text
label = self.decode(label)
return text, label
......@@ -73,7 +73,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
......
......@@ -55,7 +55,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
shape=[None, 3, 48, 160], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SVTR":
elif arch_config["algorithm"] in ["SVTR", "VisionLAN"]:
if arch_config["Head"]["name"] == 'MultiHead':
other_shape = [
paddle.static.InputSpec(
......
......@@ -69,6 +69,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "VisionLAN":
postprocess_params = {
'name': 'VLLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
......@@ -143,6 +149,15 @@ class TextRecognizer(object):
resized_image /= 0.5
return resized_image
def resize_norm_img_vl(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
......@@ -300,6 +315,11 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "VisionLAN":
norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
......
......@@ -207,7 +207,7 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
......@@ -249,7 +249,6 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
# use amp
if scaler:
with paddle.amp.auto_cast():
......@@ -264,7 +263,6 @@ def train(config,
preds = model(batch)
else:
preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
......@@ -286,6 +284,9 @@ def train(config,
]: # for multi head loss
post_result = post_process_class(
preds['ctc'], batch[1]) # for CTC head out
elif config['Loss']['name'] in ['VLLoss']:
post_result = post_process_class(preds, batch[1],
batch[-1])
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
......@@ -307,7 +308,8 @@ def train(config,
train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0:
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
log_writer.log_metrics(
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or
......@@ -354,7 +356,8 @@ def train(config,
# logger metric
if log_writer is not None:
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
log_writer.log_metrics(
metrics=cur_metric, prefix="EVAL", step=global_step)
if cur_metric[main_indicator] >= best_model_dict[
main_indicator]:
......@@ -377,11 +380,18 @@ def train(config,
logger.info(best_str)
# logger best metric
if log_writer is not None:
log_writer.log_metrics(metrics={
"best_{}".format(main_indicator): best_model_dict[main_indicator]
}, prefix="EVAL", step=global_step)
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
log_writer.log_metrics(
metrics={
"best_{}".format(main_indicator):
best_model_dict[main_indicator]
},
prefix="EVAL",
step=global_step)
log_writer.log_model(
is_best=True,
prefix="best_accuracy",
metadata=best_model_dict)
reader_start = time.time()
if dist.get_rank() == 0:
......@@ -413,7 +423,8 @@ def train(config,
epoch=epoch,
global_step=global_step)
if log_writer is not None:
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
log_writer.log_model(
is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
......@@ -451,7 +462,6 @@ def eval(model,
preds = model(batch)
else:
preds = model(images)
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
......@@ -564,7 +574,8 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
'VisionLAN'
]
if use_xpu:
......@@ -583,9 +594,10 @@ def preprocess(is_train=False):
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir)
log_writer = VDLLogger(vdl_writer_path)
loggers.append(log_writer)
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册