From 0a343fd3aae14b866281d897474d0669b1b26282 Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Thu, 23 Jun 2022 08:17:00 +0000 Subject: [PATCH] svtr ch large model --- configs/rec/rec_svtrnet.yml | 4 +- configs/rec/rec_svtrnet_ch.yml | 159 ++++++++++++++++++++++++++++ doc/doc_ch/algorithm_rec_svtr.md | 1 - doc/doc_en/algorithm_rec_svtr_en.md | 1 - ppocr/data/imaug/__init__.py | 3 +- ppocr/data/imaug/rec_img_aug.py | 124 +++++++++++++--------- tools/export_model.py | 19 +++- 7 files changed, 252 insertions(+), 59 deletions(-) create mode 100644 configs/rec/rec_svtrnet_ch.yml diff --git a/configs/rec/rec_svtrnet.yml b/configs/rec/rec_svtrnet.yml index a3d292b6..5896b105 100644 --- a/configs/rec/rec_svtrnet.yml +++ b/configs/rec/rec_svtrnet.yml @@ -83,7 +83,7 @@ Train: img_mode: BGR channel_first: False - CTCLabelEncode: # Class handling label - - RecResizeImg: + - SVTRRecResizeImg: character_dict_path: image_shape: [3, 64, 256] padding: False @@ -104,7 +104,7 @@ Eval: img_mode: BGR channel_first: False - CTCLabelEncode: # Class handling label - - RecResizeImg: + - SVTRRecResizeImg: character_dict_path: image_shape: [3, 64, 256] padding: False diff --git a/configs/rec/rec_svtrnet_ch.yml b/configs/rec/rec_svtrnet_ch.yml new file mode 100644 index 00000000..b327edf4 --- /dev/null +++ b/configs/rec/rec_svtrnet_ch.yml @@ -0,0 +1,159 @@ +Global: + use_gpu: true + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/svtr_ch_all/ + save_epoch_step: 10 + eval_batch_step: + - 0 + - 2000 + cal_metric_during_train: true + pretrained_model: null + checkpoints: null + save_inference_dir: null + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: 25 + infer_mode: false + use_space_char: true + save_res_path: ./output/rec/predicts_svtr_tiny_ch_all.txt +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 8.0e-08 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 2 +Architecture: + model_type: rec + algorithm: SVTR + Transform: null + Backbone: + name: SVTRNet + img_size: + - 32 + - 320 + out_char_num: 40 + out_channels: 96 + patch_merging: Conv + embed_dim: + - 64 + - 128 + - 256 + depth: + - 3 + - 6 + - 3 + num_heads: + - 2 + - 4 + - 8 + mixer: + - Local + - Local + - Local + - Local + - Local + - Local + - Global + - Global + - Global + - Global + - Global + - Global + local_mixer: + - - 7 + - 11 + - - 7 + - 11 + - - 7 + - 11 + last_stage: true + prenorm: false + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead +Loss: + name: CTCLoss +PostProcess: + name: CTCLabelDecode +Metric: + name: RecMetric + main_indicator: acc +Train: + dataset: + name: SimpleDataSet + label_file_list: + - /paddle/data/ocr_all/train_all_list.txt + data_dir: /paddle/data/ocr_all + ext_op_transform_idx: 1 + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecConAug: + prob: 0.5 + ext_data_num: 2 + image_shape: + - 32 + - 320 + - 3 + - RecAug: null + - CTCLabelEncode: null + - SVTRRecResizeImg: + character_dict_path: null + infer_mode: False + image_shape: + - 3 + - 32 + - 320 + padding: true + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: true + batch_size_per_card: 256 + drop_last: true + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: /paddle/data/ocr_all + label_file_list: + - /paddle/data/ocr_all/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - CTCLabelEncode: null + - SVTRRecResizeImg: + character_dict_path: null + infer_mode: False + image_shape: + - 3 + - 32 + - 320 + padding: true + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: false + drop_last: false + batch_size_per_card: 256 + num_workers: 2 +profiler_options: null diff --git a/doc/doc_ch/algorithm_rec_svtr.md b/doc/doc_ch/algorithm_rec_svtr.md index 41a22ca6..c0e26433 100644 --- a/doc/doc_ch/algorithm_rec_svtr.md +++ b/doc/doc_ch/algorithm_rec_svtr.md @@ -111,7 +111,6 @@ python3 tools/export_model.py -c ./rec_svtr_tiny_none_ctc_en_train/rec_svtr_tiny **注意:** - 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。 -- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应SVTR的`infer_shape`。 转换成功后,在目录下有三个文件: ``` diff --git a/doc/doc_en/algorithm_rec_svtr_en.md b/doc/doc_en/algorithm_rec_svtr_en.md index d402a6b4..37cd35f3 100644 --- a/doc/doc_en/algorithm_rec_svtr_en.md +++ b/doc/doc_en/algorithm_rec_svtr_en.md @@ -88,7 +88,6 @@ python3 tools/export_model.py -c configs/rec/rec_svtrnet.yml -o Global.pretraine **Note:** - If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file. -- If you modified the input size during training, please modify the `infer_shape` corresponding to SVTR in the `tools/export_model.py` file. After the conversion is successful, there are three files in the directory: ``` diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 58e8a5c7..437e0152 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -24,7 +24,8 @@ from .make_pse_gt import MakePseGt from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ - SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, ABINetRecResizeImg + SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \ + ABINetRecResizeImg, SVTRRecResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index a8b3b813..874d9aa0 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -148,57 +148,6 @@ class ClsResizeImg(object): return data -class GrayRecResizeImg(object): - def __init__(self, - image_shape, - resize_type, - inter_type='Image.ANTIALIAS', - scale=True, - padding=False, - **kwargs): - self.image_shape = image_shape - self.resize_type = resize_type - self.padding = padding - self.inter_type = eval(inter_type) - self.scale = scale - - def __call__(self, data): - img = data['image'] - img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - image_shape = self.image_shape - if self.padding: - imgC, imgH, imgW = image_shape - # todo: change to 0 and modified image shape - h = img.shape[0] - w = img.shape[1] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = cv2.resize(img, (resized_w, imgH)) - norm_img = np.expand_dims(resized_image, -1) - norm_img = norm_img.transpose((2, 0, 1)) - resized_image = norm_img.astype(np.float32) / 128. - 1. - padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding_im[:, :, 0:resized_w] = resized_image - data['image'] = padding_im - return data - if self.resize_type == 'PIL': - image_pil = Image.fromarray(np.uint8(img)) - img = image_pil.resize(self.image_shape, self.inter_type) - img = np.array(img) - if self.resize_type == 'OpenCV': - img = cv2.resize(img, self.image_shape) - norm_img = np.expand_dims(img, -1) - norm_img = norm_img.transpose((2, 0, 1)) - if self.scale: - data['image'] = norm_img.astype(np.float32) / 128. - 1. - else: - data['image'] = norm_img.astype(np.float32) / 255. - return data - - class RecResizeImg(object): def __init__(self, image_shape, @@ -279,6 +228,57 @@ class PRENResizeImg(object): return data +class GrayRecResizeImg(object): + def __init__(self, + image_shape, + resize_type, + inter_type='Image.ANTIALIAS', + scale=True, + padding=False, + **kwargs): + self.image_shape = image_shape + self.resize_type = resize_type + self.padding = padding + self.inter_type = eval(inter_type) + self.scale = scale + + def __call__(self, data): + img = data['image'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + image_shape = self.image_shape + if self.padding: + imgC, imgH, imgW = image_shape + # todo: change to 0 and modified image shape + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + norm_img = np.expand_dims(resized_image, -1) + norm_img = norm_img.transpose((2, 0, 1)) + resized_image = norm_img.astype(np.float32) / 128. - 1. + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + data['image'] = padding_im + return data + if self.resize_type == 'PIL': + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize(self.image_shape, self.inter_type) + img = np.array(img) + if self.resize_type == 'OpenCV': + img = cv2.resize(img, self.image_shape) + norm_img = np.expand_dims(img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + if self.scale: + data['image'] = norm_img.astype(np.float32) / 128. - 1. + else: + data['image'] = norm_img.astype(np.float32) / 255. + return data + + class ABINetRecResizeImg(object): def __init__(self, image_shape, @@ -297,6 +297,28 @@ class ABINetRecResizeImg(object): return data +class SVTRRecResizeImg(object): + def __init__(self, + image_shape, + infer_mode=False, + character_dict_path='./ppocr/utils/ppocr_keys_v1.txt', + padding=True, + **kwargs): + self.image_shape = image_shape + self.infer_mode = infer_mode + self.character_dict_path = character_dict_path + self.padding = padding + + def __call__(self, data): + img = data['image'] + + norm_img, valid_ratio = resize_norm_img(img, self.image_shape, + self.padding) + data['image'] = norm_img + data['valid_ratio'] = valid_ratio + return data + + def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape h = img.shape[0] diff --git a/tools/export_model.py b/tools/export_model.py index e2673239..b10d41d5 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -31,7 +31,12 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser -def export_single_model(model, arch_config, save_path, logger, quanter=None): +def export_single_model(model, + arch_config, + save_path, + logger, + input_shape=None, + quanter=None): if arch_config["algorithm"] == "SRN": max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ @@ -64,7 +69,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): else: other_shape = [ paddle.static.InputSpec( - shape=[None, 3, 64, 256], dtype="float32"), + shape=[None] + input_shape, dtype="float32"), ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "PREN": @@ -174,6 +179,13 @@ def main(): arch_config = config["Architecture"] + if arch_config["algorithm"] == "SVTR" and arch_config["Head"][ + "name"] != 'MultiHead': + input_shape = config["Eval"]["dataset"]["transforms"][-2][ + 'SVTRRecResizeImg']['image_shape'] + else: + input_shape = None + if arch_config["algorithm"] in ["Distillation", ]: # distillation model archs = list(arch_config["Models"].values()) for idx, name in enumerate(model.model_name_list): @@ -182,7 +194,8 @@ def main(): sub_model_save_path, logger) else: save_path = os.path.join(save_path, "inference") - export_single_model(model, arch_config, save_path, logger) + export_single_model( + model, arch_config, save_path, logger, input_shape=input_shape) if __name__ == "__main__": -- GitLab