未验证 提交 bde50863 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #6061 from Topdu/dygraph

add_rec_svtr
Global:
use_gpu: True
epoch_num: 20
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/svtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations after the 0th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_svtr_tiny.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 0.00000008
weight_decay: 0.05
no_weight_decay_name: norm pos_embed
one_dim_param_no_weight_decay: true
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 2
Architecture:
model_type: rec
algorithm: SVTR
Transform:
name: STN_ON
tps_inputsize: [32, 64]
tps_outputsize: [32, 100]
num_control_points: 20
tps_margins: [0.05,0.05]
stn_activation: none
Backbone:
name: SVTRNet
img_size: [32, 100]
out_char_num: 25
out_channels: 192
patch_merging: 'Conv'
embed_dim: [64, 128, 256]
depth: [3, 6, 3]
num_heads: [2, 4, 8]
mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
local_mixer: [[7, 11], [7, 11], [7, 11]]
last_stage: True
prenorm: false
Neck:
name: SequenceEncoder
encoder_type: reshape
Head:
name: CTCHead
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 512
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 2
...@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer): ...@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
config['Optimizer'], config['Optimizer'],
epochs=config['Global']['epoch_num'], epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader), step_each_epoch=len(train_dataloader),
parameters=model.parameters()) model=model)
# resume PACT training process # resume PACT training process
if config["Global"]["checkpoints"] is not None: if config["Global"]["checkpoints"] is not None:
......
...@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask ...@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
from .ColorJitter import ColorJitter from .ColorJitter import ColorJitter
......
...@@ -16,6 +16,7 @@ import math ...@@ -16,6 +16,7 @@ import math
import cv2 import cv2
import numpy as np import numpy as np
import random import random
import copy
from PIL import Image from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort from .text_image_aug import tia_perspective, tia_stretch, tia_distort
...@@ -206,6 +207,25 @@ class PRENResizeImg(object): ...@@ -206,6 +207,25 @@ class PRENResizeImg(object):
return data return data
class SVTRRecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0] h = img.shape[0]
...@@ -324,6 +344,58 @@ def resize_norm_img_srn(img, image_shape): ...@@ -324,6 +344,58 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32) return np.reshape(img_black, (c, row, col)).astype(np.float32)
def resize_norm_img_svtr(img, image_shape, padding=False):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
if h > 2.0 * w:
image = Image.fromarray(img)
image1 = image.rotate(90, expand=True)
image2 = image.rotate(-90, expand=True)
img1 = np.array(image1)
img2 = np.array(image2)
else:
img1 = copy.deepcopy(img)
img2 = copy.deepcopy(img)
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image1 = cv2.resize(
img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image2 = cv2.resize(
img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image1 = resized_image1.astype('float32')
resized_image2 = resized_image2.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image1 = resized_image1.transpose((2, 0, 1)) / 255
resized_image2 = resized_image2.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resized_image1 -= 0.5
resized_image1 /= 0.5
resized_image2 -= 0.5
resized_image2 /= 0.5
padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32)
padding_im[0, :, :, 0:resized_w] = resized_image
padding_im[1, :, :, 0:resized_w] = resized_image1
padding_im[2, :, :, 0:resized_w] = resized_image2
return padding_im
def srn_other_inputs(image_shape, num_heads, max_text_length): def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
......
...@@ -296,47 +296,49 @@ class PatchEmbed(nn.Layer): ...@@ -296,47 +296,49 @@ class PatchEmbed(nn.Layer):
if sub_num == 2: if sub_num == 2:
self.proj = nn.Sequential( self.proj = nn.Sequential(
ConvBNLayer( ConvBNLayer(
in_channels, in_channels=in_channels,
embed_dim // 2, out_channels=embed_dim // 2,
3, kernel_size=3,
2, stride=2,
1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None), bias_attr=None),
ConvBNLayer( ConvBNLayer(
embed_dim // 2, in_channels=embed_dim // 2,
embed_dim, out_channels=embed_dim,
3, kernel_size=3,
2, stride=2,
1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None)) bias_attr=None))
if sub_num == 3: if sub_num == 3:
self.proj = nn.Sequential( self.proj = nn.Sequential(
ConvBNLayer( ConvBNLayer(
in_channels, in_channels=in_channels,
embed_dim // 4, out_channels=embed_dim // 4,
3, kernel_size=3,
2, stride=2,
1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None), bias_attr=None),
ConvBNLayer( ConvBNLayer(
embed_dim // 4, in_channels=embed_dim // 4,
embed_dim // 2, out_channels=embed_dim // 2,
3, kernel_size=3,
2, stride=2,
1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None), bias_attr=None),
ConvBNLayer( ConvBNLayer(
embed_dim // 2, embed_dim // 2,
embed_dim, embed_dim,
3, in_channels=embed_dim // 2,
2, out_channels=embed_dim,
1, kernel_size=3,
stride=2,
padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None), ) bias_attr=None))
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
...@@ -455,7 +457,7 @@ class SVTRNet(nn.Layer): ...@@ -455,7 +457,7 @@ class SVTRNet(nn.Layer):
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_scale=qk_scale, qk_scale=qk_scale,
drop=drop_rate, drop=drop_rate,
act_layer=nn.Swish, act_layer=eval(act),
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[0:depth[0]][i], drop_path=dpr[0:depth[0]][i],
norm_layer=norm_layer, norm_layer=norm_layer,
......
...@@ -128,6 +128,8 @@ class STN_ON(nn.Layer): ...@@ -128,6 +128,8 @@ class STN_ON(nn.Layer):
self.out_channels = in_channels self.out_channels = in_channels
def forward(self, image): def forward(self, image):
if len(image.shape)==5:
image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
stn_input = paddle.nn.functional.interpolate( stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True) image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input) stn_img_feat, ctrl_points = self.stn_head(stn_input)
......
...@@ -138,9 +138,9 @@ class TPSSpatialTransformer(nn.Layer): ...@@ -138,9 +138,9 @@ class TPSSpatialTransformer(nn.Layer):
assert source_control_points.shape[2] == 2 assert source_control_points.shape[2] == 2
batch_size = paddle.shape(source_control_points)[0] batch_size = paddle.shape(source_control_points)[0]
self.padding_matrix = paddle.expand( padding_matrix = paddle.expand(
self.padding_matrix, shape=[batch_size, 3, 2]) self.padding_matrix, shape=[batch_size, 3, 2])
Y = paddle.concat([source_control_points, self.padding_matrix], 1) Y = paddle.concat([source_control_points, padding_matrix], 1)
mapping_matrix = paddle.matmul(self.inverse_kernel, Y) mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
source_coordinate = paddle.matmul(self.target_coordinate_repr, source_coordinate = paddle.matmul(self.target_coordinate_repr,
mapping_matrix) mapping_matrix)
......
...@@ -30,7 +30,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -30,7 +30,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return lr return lr
def build_optimizer(config, epochs, step_each_epoch, parameters): def build_optimizer(config, epochs, step_each_epoch, model):
from . import regularizer, optimizer from . import regularizer, optimizer
config = copy.deepcopy(config) config = copy.deepcopy(config)
# step1 build lr # step1 build lr
...@@ -43,6 +43,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -43,6 +43,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
if not hasattr(regularizer, reg_name): if not hasattr(regularizer, reg_name):
reg_name += 'Decay' reg_name += 'Decay'
reg = getattr(regularizer, reg_name)(**reg_config)() reg = getattr(regularizer, reg_name)(**reg_config)()
elif 'weight_decay' in config:
reg = config.pop('weight_decay')
else: else:
reg = None reg = None
...@@ -57,4 +59,4 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -57,4 +59,4 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
weight_decay=reg, weight_decay=reg,
grad_clip=grad_clip, grad_clip=grad_clip,
**config) **config)
return optim(parameters), lr return optim(model), lr
...@@ -42,13 +42,13 @@ class Momentum(object): ...@@ -42,13 +42,13 @@ class Momentum(object):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, parameters): def __call__(self, model):
opt = optim.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=parameters) parameters=model.parameters())
return opt return opt
...@@ -75,7 +75,7 @@ class Adam(object): ...@@ -75,7 +75,7 @@ class Adam(object):
self.name = name self.name = name
self.lazy_mode = lazy_mode self.lazy_mode = lazy_mode
def __call__(self, parameters): def __call__(self, model):
opt = optim.Adam( opt = optim.Adam(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
...@@ -85,7 +85,7 @@ class Adam(object): ...@@ -85,7 +85,7 @@ class Adam(object):
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
lazy_mode=self.lazy_mode, lazy_mode=self.lazy_mode,
parameters=parameters) parameters=model.parameters())
return opt return opt
...@@ -117,7 +117,7 @@ class RMSProp(object): ...@@ -117,7 +117,7 @@ class RMSProp(object):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, parameters): def __call__(self, model):
opt = optim.RMSProp( opt = optim.RMSProp(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
...@@ -125,7 +125,7 @@ class RMSProp(object): ...@@ -125,7 +125,7 @@ class RMSProp(object):
epsilon=self.epsilon, epsilon=self.epsilon,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=parameters) parameters=model.parameters())
return opt return opt
...@@ -148,7 +148,7 @@ class Adadelta(object): ...@@ -148,7 +148,7 @@ class Adadelta(object):
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.name = name self.name = name
def __call__(self, parameters): def __call__(self, model):
opt = optim.Adadelta( opt = optim.Adadelta(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
epsilon=self.epsilon, epsilon=self.epsilon,
...@@ -156,7 +156,7 @@ class Adadelta(object): ...@@ -156,7 +156,7 @@ class Adadelta(object):
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
parameters=parameters) parameters=model.parameters())
return opt return opt
...@@ -165,31 +165,55 @@ class AdamW(object): ...@@ -165,31 +165,55 @@ class AdamW(object):
learning_rate=0.001, learning_rate=0.001,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-8,
weight_decay=0.01, weight_decay=0.01,
multi_precision=False,
grad_clip=None, grad_clip=None,
no_weight_decay_name=None,
one_dim_param_no_weight_decay=False,
name=None, name=None,
lazy_mode=False, lazy_mode=False,
**kwargs): **args):
super().__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.beta1 = beta1 self.beta1 = beta1
self.beta2 = beta2 self.beta2 = beta2
self.epsilon = epsilon self.epsilon = epsilon
self.learning_rate = learning_rate self.grad_clip = grad_clip
self.weight_decay = 0.01 if weight_decay is None else weight_decay self.weight_decay = 0.01 if weight_decay is None else weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.name = name self.name = name
self.lazy_mode = lazy_mode self.lazy_mode = lazy_mode
self.multi_precision = multi_precision
def __call__(self, parameters): self.no_weight_decay_name_list = no_weight_decay_name.split(
) if no_weight_decay_name else []
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model):
parameters = model.parameters()
self.no_weight_decay_param_name_list = [
p.name for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list)
]
if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [
p.name for n, p in model.named_parameters() if len(p.shape) == 1
]
opt = optim.AdamW( opt = optim.AdamW(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
beta2=self.beta2, beta2=self.beta2,
epsilon=self.epsilon, epsilon=self.epsilon,
parameters=parameters,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
multi_precision=self.multi_precision,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
lazy_mode=self.lazy_mode, lazy_mode=self.lazy_mode,
parameters=parameters) apply_decay_param_fun=self._apply_decay_param_fun)
return opt return opt
def _apply_decay_param_fun(self, name):
return name not in self.no_weight_decay_param_name_list
\ No newline at end of file
...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess ...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
...@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None): ...@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode' 'DistillationSARLabelDecode', 'SVTRLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -752,3 +752,40 @@ class PRENLabelDecode(BaseRecLabelDecode): ...@@ -752,3 +752,40 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text return text
label = self.decode(label) label = self.decode(label)
return text, label return text, label
class SVTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SVTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=-1)
preds_prob = preds.max(axis=-1)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
return_text = []
for i in range(0, len(text), 3):
text0 = text[i]
text1 = text[i + 1]
text2 = text[i + 2]
text_pred = [text0[0], text1[0], text2[0]]
text_prob = [text0[1], text1[1], text2[1]]
id_max = text_prob.index(max(text_prob))
return_text.append((text_pred[id_max], text_prob[id_max]))
if label is None:
return return_text
label = self.decode(label)
return return_text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
\ No newline at end of file
...@@ -61,6 +61,11 @@ def export_single_model(model, arch_config, save_path, logger): ...@@ -61,6 +61,11 @@ def export_single_model(model, arch_config, save_path, logger):
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 48, -1], dtype="float32"), shape=[None, 3, 48, -1], dtype="float32"),
] ]
else:
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN": elif arch_config["algorithm"] == "PREN":
other_shape = [ other_shape = [
......
...@@ -131,6 +131,17 @@ class TextRecognizer(object): ...@@ -131,6 +131,17 @@ class TextRecognizer(object):
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_srn(self, img, image_shape): def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
...@@ -263,12 +274,8 @@ class TextRecognizer(object): ...@@ -263,12 +274,8 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
norm_img = self.resize_norm_img(img_list[indices[ino]], if self.rec_algorithm == "SAR":
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar( norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape) img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
...@@ -276,7 +283,7 @@ class TextRecognizer(object): ...@@ -276,7 +283,7 @@ class TextRecognizer(object):
valid_ratios = [] valid_ratios = []
valid_ratios.append(valid_ratio) valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
else: elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn( norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25) img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = [] encoder_word_pos_list = []
...@@ -288,6 +295,16 @@ class TextRecognizer(object): ...@@ -288,6 +295,16 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list.append(norm_img[3]) gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4]) gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0]) norm_img_batch.append(norm_img[0])
elif self.rec_algorithm == "SVTR":
norm_img = self.resize_norm_img_svtr(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
if self.benchmark: if self.benchmark:
......
...@@ -129,7 +129,7 @@ def main(config, device, logger, vdl_writer): ...@@ -129,7 +129,7 @@ def main(config, device, logger, vdl_writer):
config['Optimizer'], config['Optimizer'],
epochs=config['Global']['epoch_num'], epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader), step_each_epoch=len(train_dataloader),
parameters=model.parameters()) model=model)
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册