提交 a3a09515 编写于 作者: A andyjpaddle

add vl

上级 0401e520
Global:
use_gpu: true
epoch_num: 8
log_smooth_window: 200
print_batch_step: 200
save_model_dir: /paddle/backup/visionlan/LA_v2
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model: ./pretrained_model/LF_2_ocr
checkpoints:
save_inference_dir:
use_visualdl: True
infer_img: doc/imgs_words/en/word_2.png
# for data or label process
character_dict_path: ppocr/utils/dict36.txt
max_text_length: &max_text_length 25
training_step: &training_step LA
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_visionlan.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 20.0
group_lr: true
training_step: *training_step
lr:
name: Piecewise
decay_epochs: [6]
values: [0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: VisionLAN
Transform:
Backbone:
name: ResNet45
strides: [2, 2, 2, 1, 1]
Head:
name: VLHead
n_layers: 3
n_position: 256
n_dim: 512
max_text_length: *max_text_length
training_step: *training_step
Loss:
name: VLLoss
mode: *training_step
weight_res: 0.5
weight_mas: 0.5
PostProcess:
name: VLLabelDecode
Metric:
name: RecMetric
is_filter: true
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetRecAug:
- VLLabelEncode: # Class handling label
- VLRecResizeImg:
image_shape: [3, 64, 256]
- KeepKeys:
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 220
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VLLabelEncode: # Class handling label
- VLRecResizeImg:
image_shape: [3, 64, 256]
- KeepKeys:
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 64
num_workers: 4
......@@ -99,12 +99,13 @@ class BaseRecLabelEncode(object):
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False):
use_space_char=False,
lower=False):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
self.lower = False
self.lower = lower
if character_dict_path is None:
logger = get_logger()
......@@ -1227,9 +1228,10 @@ class VLLabelEncode(BaseRecLabelEncode):
max_text_length,
character_dict_path=None,
use_space_char=False,
lower=True,
**kwargs):
super(VLLabelEncode, self).__init__(max_text_length,
character_dict_path, use_space_char)
super(VLLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char, lower)
def __call__(self, data):
text = data['label'] # original string
......
......@@ -13,6 +13,5 @@
# limitations under the License.
from .augment import tia_perspective, tia_distort, tia_stretch
from .vl_aug import VLAug
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective', 'VLAug']
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numbers
import random
import cv2
import numpy as np
from PIL import Image
from paddle.vision import transforms
from paddle.vision.transforms import Compose
def sample_asym(magnitude, size=None):
return np.random.beta(1, 4, size) * magnitude
def sample_sym(magnitude, size=None):
return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
def sample_uniform(low, high, size=None):
return np.random.uniform(low, high, size=size)
def get_interpolation(type='random'):
if type == 'random':
choice = [
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
]
interpolation = choice[random.randint(0, len(choice) - 1)]
elif type == 'nearest':
interpolation = cv2.INTER_NEAREST
elif type == 'linear':
interpolation = cv2.INTER_LINEAR
elif type == 'cubic':
interpolation = cv2.INTER_CUBIC
elif type == 'area':
interpolation = cv2.INTER_AREA
else:
raise TypeError(
'Interpolation types only nearest, linear, cubic, area are supported!'
)
return interpolation
class CVRandomRotation(object):
def __init__(self, degrees=15):
assert isinstance(degrees,
numbers.Number), "degree should be a single number."
assert degrees >= 0, "degree must be positive."
self.degrees = degrees
@staticmethod
def get_params(degrees):
return sample_sym(degrees)
def __call__(self, img):
angle = self.get_params(self.degrees)
src_h, src_w = img.shape[:2]
M = cv2.getRotationMatrix2D(
center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
dst_w = int(src_h * abs_sin + src_w * abs_cos)
dst_h = int(src_h * abs_cos + src_w * abs_sin)
M[0, 2] += (dst_w - src_w) / 2
M[1, 2] += (dst_h - src_h) / 2
flags = get_interpolation()
return cv2.warpAffine(
img,
M, (dst_w, dst_h),
flags=flags,
borderMode=cv2.BORDER_REPLICATE)
class CVRandomAffine(object):
def __init__(self, degrees, translate=None, scale=None, shear=None):
assert isinstance(degrees,
numbers.Number), "degree should be a single number."
assert degrees >= 0, "degree must be positive."
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError(
"translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError(
"If shear is a single number, it must be positive.")
self.shear = [shear]
else:
assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
"shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear
def _get_inverse_affine_matrix(self, center, angle, translate, scale,
shear):
from numpy import sin, cos, tan
if isinstance(shear, numbers.Number):
shear = [shear, 0]
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
raise ValueError(
"Shear should be a single value or a tuple/list containing " +
"two values. Got {}".format(shear))
rot = math.radians(angle)
sx, sy = [math.radians(s) for s in shear]
cx, cy = center
tx, ty = translate
# RSS without scaling
a = cos(rot - sy) / cos(sy)
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
c = sin(rot - sy) / cos(sy)
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
M = [d, -b, 0, -c, a, 0]
M = [x / scale for x in M]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
M[2] += cx
M[5] += cy
return M
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, height):
angle = sample_sym(degrees)
if translate is not None:
max_dx = translate[0] * height
max_dy = translate[1] * height
translations = (np.round(sample_sym(max_dx)),
np.round(sample_sym(max_dy)))
else:
translations = (0, 0)
if scale_ranges is not None:
scale = sample_uniform(scale_ranges[0], scale_ranges[1])
else:
scale = 1.0
if shears is not None:
if len(shears) == 1:
shear = [sample_sym(shears[0]), 0.]
elif len(shears) == 2:
shear = [sample_sym(shears[0]), sample_sym(shears[1])]
else:
shear = 0.0
return angle, translations, scale, shear
def __call__(self, img):
src_h, src_w = img.shape[:2]
angle, translate, scale, shear = self.get_params(
self.degrees, self.translate, self.scale, self.shear, src_h)
M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
(0, 0), scale, shear)
M = np.array(M).reshape(2, 3)
startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
(0, src_h - 1)]
project = lambda x, y, a, b, c: int(a * x + b * y + c)
endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
for x, y in startpoints]
rect = cv2.minAreaRect(np.array(endpoints))
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
dst_w = int(max_x - min_x)
dst_h = int(max_y - min_y)
M[0, 2] += (dst_w - src_w) / 2
M[1, 2] += (dst_h - src_h) / 2
# add translate
dst_w += int(abs(translate[0]))
dst_h += int(abs(translate[1]))
if translate[0] < 0: M[0, 2] += abs(translate[0])
if translate[1] < 0: M[1, 2] += abs(translate[1])
flags = get_interpolation()
return cv2.warpAffine(
img,
M, (dst_w, dst_h),
flags=flags,
borderMode=cv2.BORDER_REPLICATE)
class CVRandomPerspective(object):
def __init__(self, distortion=0.5):
self.distortion = distortion
def get_params(self, width, height, distortion):
offset_h = sample_asym(
distortion * height / 2, size=4).astype(dtype=np.int)
offset_w = sample_asym(
distortion * width / 2, size=4).astype(dtype=np.int)
topleft = (offset_w[0], offset_h[0])
topright = (width - 1 - offset_w[1], offset_h[1])
botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
botleft = (offset_w[3], height - 1 - offset_h[3])
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
(0, height - 1)]
endpoints = [topleft, topright, botright, botleft]
return np.array(
startpoints, dtype=np.float32), np.array(
endpoints, dtype=np.float32)
def __call__(self, img):
height, width = img.shape[:2]
startpoints, endpoints = self.get_params(width, height, self.distortion)
M = cv2.getPerspectiveTransform(startpoints, endpoints)
# TODO: more robust way to crop image
rect = cv2.minAreaRect(endpoints)
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
min_x, min_y = max(min_x, 0), max(min_y, 0)
flags = get_interpolation()
img = cv2.warpPerspective(
img,
M, (max_x, max_y),
flags=flags,
borderMode=cv2.BORDER_REPLICATE)
img = img[min_y:, min_x:]
return img
class CVRescale(object):
def __init__(self, factor=4, base_size=(128, 512)):
""" Define image scales using gaussian pyramid and rescale image to target scale.
Args:
factor: the decayed factor from base size, factor=4 keeps target scale by default.
base_size: base size the build the bottom layer of pyramid
"""
if isinstance(factor, numbers.Number):
self.factor = round(sample_uniform(0, factor))
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
self.factor = round(sample_uniform(factor[0], factor[1]))
else:
raise Exception('factor must be number or list with length 2')
# assert factor is valid
self.base_h, self.base_w = base_size[:2]
def __call__(self, img):
if self.factor == 0:
return img
src_h, src_w = img.shape[:2]
cur_w, cur_h = self.base_w, self.base_h
scale_img = cv2.resize(
img, (cur_w, cur_h), interpolation=get_interpolation())
for _ in range(np.int(self.factor)):
scale_img = cv2.pyrDown(scale_img)
scale_img = cv2.resize(
scale_img, (src_w, src_h), interpolation=get_interpolation())
return scale_img
class CVGaussianNoise(object):
def __init__(self, mean=0, var=20):
self.mean = mean
if isinstance(var, numbers.Number):
self.var = max(int(sample_asym(var)), 1)
elif isinstance(var, (tuple, list)) and len(var) == 2:
self.var = int(sample_uniform(var[0], var[1]))
else:
raise Exception('degree must be number or list with length 2')
def __call__(self, img):
noise = np.random.normal(self.mean, self.var**0.5, img.shape)
img = np.clip(img + noise, 0, 255).astype(np.uint8)
return img
class CVMotionBlur(object):
def __init__(self, degrees=12, angle=90):
if isinstance(degrees, numbers.Number):
self.degree = max(int(sample_asym(degrees)), 1)
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
self.degree = int(sample_uniform(degrees[0], degrees[1]))
else:
raise Exception('degree must be number or list with length 2')
self.angle = sample_uniform(-angle, angle)
def __call__(self, img):
M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
self.angle, 1)
motion_blur_kernel = np.zeros((self.degree, self.degree))
motion_blur_kernel[self.degree // 2, :] = 1
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
(self.degree, self.degree))
motion_blur_kernel = motion_blur_kernel / self.degree
img = cv2.filter2D(img, -1, motion_blur_kernel)
img = np.clip(img, 0, 255).astype(np.uint8)
return img
class CVGeometry(object):
def __init__(self,
degrees=15,
translate=(0.3, 0.3),
scale=(0.5, 2.),
shear=(45, 15),
distortion=0.5,
p=0.5):
self.p = p
type_p = random.random()
if type_p < 0.33:
self.transforms = CVRandomRotation(degrees=degrees)
elif type_p < 0.66:
self.transforms = CVRandomAffine(
degrees=degrees, translate=translate, scale=scale, shear=shear)
else:
self.transforms = CVRandomPerspective(distortion=distortion)
def __call__(self, img):
if random.random() < self.p:
return self.transforms(img)
else:
return img
class CVDeterioration(object):
def __init__(self, var, degrees, factor, p=0.5):
self.p = p
transforms = []
if var is not None:
transforms.append(CVGaussianNoise(var=var))
if degrees is not None:
transforms.append(CVMotionBlur(degrees=degrees))
if factor is not None:
transforms.append(CVRescale(factor=factor))
random.shuffle(transforms)
transforms = Compose(transforms)
self.transforms = transforms
def __call__(self, img):
if random.random() < self.p:
return self.transforms(img)
else:
return img
class CVColorJitter(object):
def __init__(self,
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.1,
p=0.5):
self.p = p
self.transforms = transforms.ColorJitter(
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue)
def __call__(self, img):
if random.random() < self.p:
return self.transforms(img)
else:
return img
class VLAug(object):
def __init__(self,
geometry_p=0.5,
Deterioration_p=0.25,
ColorJitter_p=0.25,
**kwargs):
self.Geometry = CVGeometry(
degrees=45,
translate=(0.0, 0.0),
scale=(0.5, 2.),
shear=(45, 15),
distortion=0.5,
p=geometry_p)
self.Deterioration = CVDeterioration(
var=20, degrees=6, factor=4, p=Deterioration_p)
self.ColorJitter = CVColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.1,
p=ColorJitter_p)
def __call__(self, data):
img = data['image']
img = self.Geometry(img)
img = self.Deterioration(img)
img = self.ColorJitter(img)
data['image'] = img
return data
if __name__ == '__main__':
geo = CVGeometry(
degrees=45,
translate=(0.0, 0.0),
scale=(0.5, 2.),
shear=(45, 15),
distortion=0.5,
p=1)
det = CVDeterioration(var=20, degrees=6, factor=4, p=1)
color = CVColorJitter(
brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=1)
img = np.ones((64, 256, 3))
img = geo(img)
img = det(img)
img = color(img)
# import pdb
# pdb.set_trace()
# print()
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class VLLoss(nn.Layer):
def __init__(self, mode='LF_1', weight_res=0.5, weight_mas=0.5, **kwargs):
super(VLLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean")
assert mode in ['LF_1', 'LF_2', 'LA']
self.mode = mode
self.weight_res = weight_res
self.weight_mas = weight_mas
def flatten_label(self, target):
label_flatten = []
label_length = []
for i in range(0, target.shape[0]):
cur_label = target[i].tolist()
label_flatten += cur_label[:cur_label.index(0) + 1]
label_length.append(cur_label.index(0) + 1)
label_flatten = paddle.to_tensor(label_flatten, dtype='int64')
label_length = paddle.to_tensor(label_length, dtype='int32')
return (label_flatten, label_length)
def _flatten(self, sources, lengths):
return paddle.concat([t[:l] for t, l in zip(sources, lengths)])
def forward(self, predicts, batch):
text_pre = predicts[0]
target = batch[1].astype('int64')
label_flatten, length = self.flatten_label(target)
text_pre = self._flatten(text_pre, length)
if self.mode == 'LF_1':
loss = self.loss_func(text_pre, label_flatten)
else:
text_rem = predicts[1]
text_mas = predicts[2]
target_res = batch[2].astype('int64')
target_sub = batch[3].astype('int64')
label_flatten_res, length_res = self.flatten_label(target_res)
label_flatten_sub, length_sub = self.flatten_label(target_sub)
text_rem = self._flatten(text_rem, length_res)
text_mas = self._flatten(text_mas, length_sub)
loss_ori = self.loss_func(text_pre, label_flatten)
loss_res = self.loss_func(text_rem, label_flatten_res)
loss_mas = self.loss_func(text_mas, label_flatten_sub)
loss = loss_ori + loss_res * self.weight_res + loss_mas * self.weight_mas
return {'loss': loss}
......@@ -84,11 +84,15 @@ class BasicBlock(nn.Layer):
class ResNet45(nn.Layer):
def __init__(self, block=BasicBlock, layers=[3, 4, 6, 6, 3], in_channels=3):
def __init__(self,
in_channels=3,
block=BasicBlock,
layers=[3, 4, 6, 6, 3],
strides=[2, 1, 2, 1, 1]):
self.inplanes = 32
super(ResNet45, self).__init__()
self.conv1 = nn.Conv2D(
3,
in_channels,
32,
kernel_size=3,
stride=1,
......@@ -98,18 +102,13 @@ class ResNet45(nn.Layer):
self.bn1 = nn.BatchNorm2D(32)
self.relu = nn.ReLU()
self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
self.layer3 = self._make_layer(block, 128, layers[2], stride=strides[2])
self.layer4 = self._make_layer(block, 256, layers[3], stride=strides[3])
self.layer5 = self._make_layer(block, 512, layers[4], stride=strides[4])
self.out_channels = 512
# for m in self.modules():
# if isinstance(m, nn.Conv2D):
# n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
......@@ -137,11 +136,9 @@ class ResNet45(nn.Layer):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# print(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
# print(x)
x = self.layer4(x)
x = self.layer5(x)
return x
......@@ -20,10 +20,6 @@ import paddle.nn as nn
import sys
import math
from paddle.nn.initializer import KaimingNormal, Constant
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
def conv3x3(in_planes, out_planes, stride=1):
......@@ -144,111 +140,4 @@ class ResNet_ASTER(nn.Layer):
rnn_feat, _ = self.rnn(cnn_feat)
return rnn_feat
else:
return cnn_feat
class Block(nn.Layer):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Block, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2D(planes)
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = nn.BatchNorm2D(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet45(nn.Layer):
def __init__(self, in_channels=3, compress_layer=False):
super(ResNet45, self).__init__()
self.compress_layer = compress_layer
self.conv1_new = nn.Conv2D(
in_channels,
32,
kernel_size=(3, 3),
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(32)
self.relu = nn.ReLU()
self.inplanes = 32
self.layer1 = self._make_layer(32, 3, [2, 2]) # [32, 128]
self.layer2 = self._make_layer(64, 4, [2, 2]) # [16, 64]
self.layer3 = self._make_layer(128, 6, [2, 2]) # [8, 32]
self.layer4 = self._make_layer(256, 6, [1, 1]) # [8, 32]
self.layer5 = self._make_layer(512, 3, [1, 1]) # [8, 32]
if self.compress_layer:
self.layer6 = nn.Sequential(
nn.Conv2D(
512, 256, kernel_size=(3, 1), padding=(0, 0), stride=(1,
1)),
nn.BatchNorm(256),
nn.ReLU())
self.out_channels = 256
else:
self.out_channels = 512
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2D):
KaimingNormal(m.weight)
elif isinstance(m, nn.BatchNorm):
ones_(m.weight)
zeros_(m.bias)
def _make_layer(self, planes, blocks, stride):
downsample = None
if stride != [1, 1] or self.inplanes != planes:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
layers = []
layers.append(Block(self.inplanes, planes, stride, downsample))
self.inplanes = planes
for _ in range(1, blocks):
layers.append(Block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1_new(x)
x = self.bn1(x)
x = self.relu(x)
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
if not self.compress_layer:
return x5
else:
x6 = self.layer6(x5)
return x6
if __name__ == '__main__':
model = ResNet45()
x = paddle.rand([1, 3, 64, 256])
x = paddle.to_tensor(x)
print(x.shape)
out = model(x)
print(out.shape)
return cnn_feat
\ No newline at end of file
......@@ -22,7 +22,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, XavierNormal
import numpy as np
from ppocr.modeling.backbones.rec_resnet_aster import ResNet45
from ppocr.modeling.backbones.rec_resnet_45 import ResNet45
class PositionalEncoding(nn.Layer):
......@@ -442,14 +442,6 @@ class MLM_VRM(nn.Layer):
if out_length[j] == 0 and tmp_result[j] == 0:
out_length[j] = now_step + 1
now_step += 1
# while 0 in out_length and now_step < nsteps:
# tmp_result = text_pre[now_step, :, :]
# out_res[now_step] = tmp_result
# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
# for j in range(b):
# if out_length[j] == 0 and tmp_result[j] == 0:
# out_length[j] = now_step + 1
# now_step += 1
for j in range(0, b):
if int(out_length[j]) == 0:
out_length[j] = nsteps
......
......@@ -77,11 +77,62 @@ class Adam(object):
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
self.group_lr = kwargs.get('group_lr', False)
self.training_step = kwargs.get('training_step', None)
def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
if self.group_lr:
if self.training_step == 'LF_2':
import paddle
if isinstance(model, paddle.fluid.dygraph.parallel.
DataParallel): # multi gpu
mlm = model._layers.head.MLM_VRM.MLM.parameters()
pre_mlm_pp = model._layers.head.MLM_VRM.Prediction.pp_share.parameters(
)
pre_mlm_w = model._layers.head.MLM_VRM.Prediction.w_share.parameters(
)
else: # single gpu
mlm = model.head.MLM_VRM.MLM.parameters()
pre_mlm_pp = model.head.MLM_VRM.Prediction.pp_share.parameters(
)
pre_mlm_w = model.head.MLM_VRM.Prediction.w_share.parameters(
)
total = []
for param in mlm:
total.append(id(param))
for param in pre_mlm_pp:
total.append(id(param))
for param in pre_mlm_w:
total.append(id(param))
group_base_params = [
param for param in model.parameters() if id(param) in total
]
group_small_params = [
param for param in model.parameters()
if id(param) not in total
]
train_params = [{
'params': group_base_params
}, {
'params': group_small_params,
'learning_rate': self.learning_rate.values[0] * 0.1
}]
else:
print(
'group lr currently only support VisionLAN in LF_2 training step'
)
train_params = [
param for param in model.parameters()
if param.trainable is True
]
else:
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
......
......@@ -27,8 +27,7 @@ class BaseRecLabelDecode(object):
self.character_str = []
if character_dict_path is None:
# self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
self.character_str = "abcdefghijklmnopqrstuvwxyz1234567890"
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
......
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
1
2
3
4
5
6
7
8
9
0
\ No newline at end of file
......@@ -60,7 +60,7 @@ def export_single_model(model,
shape=[None, 3, 48, 160], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["SVTR", "VisionLAN"]:
elif arch_config["algorithm"] == "SVTR":
if arch_config["Head"]["name"] == 'MultiHead':
other_shape = [
paddle.static.InputSpec(
......@@ -97,6 +97,12 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "VisionLAN":
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
......@@ -217,4 +223,4 @@ def main():
if __name__ == "__main__":
main()
main()
\ No newline at end of file
......@@ -366,6 +366,8 @@ class TextRecognizer(object):
elif self.rec_algorithm == "VisionLAN":
norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "ABINet":
norm_img = self.resize_norm_img_abinet(
img_list[indices[ino]], self.rec_image_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册