未验证 提交 bde8cad0 编写于 作者: T topduke 提交者: GitHub

add svtr ch model (#7134)

上级 05b6d296
......@@ -83,8 +83,7 @@ Train:
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
character_dict_path:
- SVTRRecResizeImg:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
......@@ -98,14 +97,13 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
character_dict_path:
- SVTRRecResizeImg:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
......
Global:
use_gpu: true
epoch_num: 100
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/svtr_ch_all/
save_epoch_step: 10
eval_batch_step:
- 0
- 2000
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: 25
infer_mode: false
use_space_char: true
save_res_path: ./output/rec/predicts_svtr_tiny_ch_all.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 8.0e-08
weight_decay: 0.05
no_weight_decay_name: norm pos_embed
one_dim_param_no_weight_decay: true
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 2
Architecture:
model_type: rec
algorithm: SVTR
Transform: null
Backbone:
name: SVTRNet
img_size:
- 32
- 320
out_char_num: 40
out_channels: 96
patch_merging: Conv
embed_dim:
- 64
- 128
- 256
depth:
- 3
- 6
- 3
num_heads:
- 2
- 4
- 8
mixer:
- Local
- Local
- Local
- Local
- Local
- Local
- Global
- Global
- Global
- Global
- Global
- Global
local_mixer:
- - 7
- 11
- - 7
- 11
- - 7
- 11
last_stage: true
prenorm: false
Neck:
name: SequenceEncoder
encoder_type: reshape
Head:
name: CTCHead
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/train_list.txt
ext_op_transform_idx: 1
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape:
- 32
- 320
- 3
- RecAug: null
- CTCLabelEncode: null
- SVTRRecResizeImg:
image_shape:
- 3
- 32
- 320
padding: true
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 256
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode: null
- SVTRRecResizeImg:
image_shape:
- 3
- 32
- 320
padding: true
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 256
num_workers: 2
profiler_options: null
......@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, \
SVTRRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -207,6 +207,21 @@ class PRENResizeImg(object):
return data
class SVTRRecResizeImg(object):
def __init__(self, image_shape, padding=True, **kwargs):
self.image_shape = image_shape
self.padding = padding
def __call__(self, data):
img = data['image']
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
self.padding)
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
......
......@@ -31,7 +31,12 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
def export_single_model(model, arch_config, save_path, logger, quanter=None):
def export_single_model(model,
arch_config,
save_path,
logger,
input_shape=None,
quanter=None):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
......@@ -64,7 +69,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
else:
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 64, 256], dtype="float32"),
shape=[None] + input_shape, dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN":
......@@ -157,6 +162,13 @@ def main():
arch_config = config["Architecture"]
if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
"name"] != 'MultiHead':
input_shape = config["Eval"]["dataset"]["transforms"][-2][
'SVTRRecResizeImg']['image_shape']
else:
input_shape = None
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
......@@ -165,7 +177,8 @@ def main():
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger)
export_single_model(
model, arch_config, save_path, logger, input_shape=input_shape)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册