diff --git a/configs/rec/rec_svtrnet.yml b/configs/rec/rec_svtrnet.yml index 233d5e276577cad0144456ef7df1e20de99891f9..f3195948553233dd315fd6b7e923e608d99cbb65 100644 --- a/configs/rec/rec_svtrnet.yml +++ b/configs/rec/rec_svtrnet.yml @@ -83,8 +83,7 @@ Train: img_mode: BGR channel_first: False - CTCLabelEncode: # Class handling label - - RecResizeImg: - character_dict_path: + - SVTRRecResizeImg: image_shape: [3, 64, 256] padding: False - KeepKeys: @@ -98,14 +97,13 @@ Train: Eval: dataset: name: LMDBDataSet - data_dir: ./train_data/data_lmdb_release/validation/ + data_dir: ./train_data/data_lmdb_release/evaluation/ transforms: - DecodeImage: # load image img_mode: BGR channel_first: False - CTCLabelEncode: # Class handling label - - RecResizeImg: - character_dict_path: + - SVTRRecResizeImg: image_shape: [3, 64, 256] padding: False - KeepKeys: diff --git a/configs/rec/rec_svtrnet_ch.yml b/configs/rec/rec_svtrnet_ch.yml new file mode 100644 index 0000000000000000000000000000000000000000..0d3f63d125ea12fa097fe49454b04423710e2f68 --- /dev/null +++ b/configs/rec/rec_svtrnet_ch.yml @@ -0,0 +1,155 @@ +Global: + use_gpu: true + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/svtr_ch_all/ + save_epoch_step: 10 + eval_batch_step: + - 0 + - 2000 + cal_metric_during_train: true + pretrained_model: null + checkpoints: null + save_inference_dir: null + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: 25 + infer_mode: false + use_space_char: true + save_res_path: ./output/rec/predicts_svtr_tiny_ch_all.txt +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 8.0e-08 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 2 +Architecture: + model_type: rec + algorithm: SVTR + Transform: null + Backbone: + name: SVTRNet + img_size: + - 32 + - 320 + out_char_num: 40 + out_channels: 96 + patch_merging: Conv + embed_dim: + - 64 + - 128 + - 256 + depth: + - 3 + - 6 + - 3 + num_heads: + - 2 + - 4 + - 8 + mixer: + - Local + - Local + - Local + - Local + - Local + - Local + - Global + - Global + - Global + - Global + - Global + - Global + local_mixer: + - - 7 + - 11 + - - 7 + - 11 + - - 7 + - 11 + last_stage: true + prenorm: false + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead +Loss: + name: CTCLoss +PostProcess: + name: CTCLabelDecode +Metric: + name: RecMetric + main_indicator: acc +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/train_list.txt + ext_op_transform_idx: 1 + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecConAug: + prob: 0.5 + ext_data_num: 2 + image_shape: + - 32 + - 320 + - 3 + - RecAug: null + - CTCLabelEncode: null + - SVTRRecResizeImg: + image_shape: + - 3 + - 32 + - 320 + padding: true + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: true + batch_size_per_card: 256 + drop_last: true + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - CTCLabelEncode: null + - SVTRRecResizeImg: + image_shape: + - 3 + - 32 + - 320 + padding: true + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: false + drop_last: false + batch_size_per_card: 256 + num_workers: 2 +profiler_options: null diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 548832fb0d116ba2de622bd97562b591d74501d8..65497e63f9da03d8fc1fd1aa6baba673461ab8bc 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ - SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg + SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, \ + SVTRRecResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 7483dffe5b6d9a0a2204702757fcb49762a1cc7a..2c897dce07a7867b9dd2eed1e7b24fa046336f8c 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -207,6 +207,21 @@ class PRENResizeImg(object): return data +class SVTRRecResizeImg(object): + def __init__(self, image_shape, padding=True, **kwargs): + self.image_shape = image_shape + self.padding = padding + + def __call__(self, data): + img = data['image'] + + norm_img, valid_ratio = resize_norm_img(img, self.image_shape, + self.padding) + data['image'] = norm_img + data['valid_ratio'] = valid_ratio + return data + + def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape h = img.shape[0] diff --git a/tools/export_model.py b/tools/export_model.py index 3ea0228f857a2fadb36678ecd3b91bc865e56e46..8ccaea2908b6e6121e0d30f2769f6e93bef49392 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -31,7 +31,12 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser -def export_single_model(model, arch_config, save_path, logger, quanter=None): +def export_single_model(model, + arch_config, + save_path, + logger, + input_shape=None, + quanter=None): if arch_config["algorithm"] == "SRN": max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ @@ -64,7 +69,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): else: other_shape = [ paddle.static.InputSpec( - shape=[None, 3, 64, 256], dtype="float32"), + shape=[None] + input_shape, dtype="float32"), ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "PREN": @@ -157,6 +162,13 @@ def main(): arch_config = config["Architecture"] + if arch_config["algorithm"] == "SVTR" and arch_config["Head"][ + "name"] != 'MultiHead': + input_shape = config["Eval"]["dataset"]["transforms"][-2][ + 'SVTRRecResizeImg']['image_shape'] + else: + input_shape = None + if arch_config["algorithm"] in ["Distillation", ]: # distillation model archs = list(arch_config["Models"].values()) for idx, name in enumerate(model.model_name_list): @@ -165,7 +177,8 @@ def main(): sub_model_save_path, logger) else: save_path = os.path.join(save_path, "inference") - export_single_model(model, arch_config, save_path, logger) + export_single_model( + model, arch_config, save_path, logger, input_shape=input_shape) if __name__ == "__main__":