From dd02918bb4488b3113e8666bbaef322a61ab246c Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Mon, 13 Feb 2023 11:17:58 +0800 Subject: [PATCH] rename PP-OCRv3 algorithm to SVTR_LCNet (#9025) * rename PP-OCRv3 algorithm to SVTR_LCNet * rename PP-OCRv3 algorithm to SVTR_LCNet * update multi-lang config for v3 --- configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml | 2 +- .../PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml | 4 ++-- configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml | 2 +- .../multi_language/arabic_PP-OCRv3_rec.yml | 2 +- .../chinese_cht_PP-OCRv3_rec.yml | 2 +- .../multi_language/cyrillic_PP-OCRv3_rec.yml | 2 +- .../devanagari_PP-OCRv3_rec.yml | 2 +- .../multi_language/japan_PP-OCRv3_rec.yml | 2 +- .../multi_language/ka_PP-OCRv3_rec.yml | 2 +- .../multi_language/korean_PP-OCRv3_rec.yml | 2 +- .../multi_language/latin_PP-OCRv3_rec.yml | 2 +- .../multi_language/ta_PP-OCRv3_rec.yml | 2 +- .../multi_language/te_PP-OCRv3_rec.yml | 2 +- ppocr/modeling/architectures/__init__.py | 4 ++-- tools/eval.py | 3 ++- tools/export_model.py | 20 +++++++++---------- tools/program.py | 8 ++++---- 17 files changed, 32 insertions(+), 31 deletions(-) diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml index 7e98280b..aa974d16 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml index afa012b6..7843f02a 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml @@ -45,7 +45,7 @@ Architecture: freeze_params: false return_all_feats: true model_type: *model_type - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance @@ -72,7 +72,7 @@ Architecture: freeze_params: false return_all_feats: true model_type: *model_type - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml index c728e0ac..51e29e9b 100644 --- a/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml index 8c650bd8..6c78f0af 100644 --- a/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/chinese_cht_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/chinese_cht_PP-OCRv3_rec.yml index 28e0c10a..2c9b6182 100644 --- a/configs/rec/PP-OCRv3/multi_language/chinese_cht_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/chinese_cht_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml index fbdbe6c4..6faffb72 100644 --- a/configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml index 48eb38df..488994d2 100644 --- a/configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/japan_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/japan_PP-OCRv3_rec.yml index 6cab0d44..b3ddde39 100644 --- a/configs/rec/PP-OCRv3/multi_language/japan_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/japan_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml index 7a9c8241..3e845032 100644 --- a/configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/korean_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/korean_PP-OCRv3_rec.yml index 29ff5707..17ea15fa 100644 --- a/configs/rec/PP-OCRv3/multi_language/korean_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/korean_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml index 1784bfe6..34ade4af 100644 --- a/configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml index 70b26aa8..17cc870e 100644 --- a/configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/te_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/te_PP-OCRv3_rec.yml index 3617af79..5f3dad7c 100644 --- a/configs/rec/PP-OCRv3/multi_language/te_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/te_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index 1c955ef3..384ae4cc 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -40,7 +40,7 @@ def apply_to_static(model, config, logger): return model assert "image_shape" in config[ "Global"], "image_shape must be assigned for static training mode..." - supported_list = ["DB", "SVTR"] + supported_list = ["DB", "SVTR_LCNet"] if config["Architecture"]["algorithm"] in ["Distillation"]: algo = list(config["Architecture"]["Models"].values())[0]["algorithm"] else: @@ -52,7 +52,7 @@ def apply_to_static(model, config, logger): [None] + config["Global"]["image_shape"], dtype='float32') ] - if algo == "SVTR": + if algo == "SVTR_LCNet": specs.append([ InputSpec( [None, config["Global"]["max_text_length"]], diff --git a/tools/eval.py b/tools/eval.py index 21f4d94d..8b58b660 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -75,7 +75,8 @@ def main(): model = build_model(config['Architecture']) extra_input_models = [ - "SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner" + "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "VisionLAN", + "RobustScanner" ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': diff --git a/tools/export_model.py b/tools/export_model.py index 9d3bf562..cf2a2d8b 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -62,17 +62,17 @@ def export_single_model(model, shape=[None], dtype="float32")] ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SVTR_LCNet": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, -1], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "SVTR": - if arch_config["Head"]["name"] == 'MultiHead': - other_shape = [ - paddle.static.InputSpec( - shape=[None, 3, 48, -1], dtype="float32"), - ] - else: - other_shape = [ - paddle.static.InputSpec( - shape=[None] + input_shape, dtype="float32"), - ] + other_shape = [ + paddle.static.InputSpec( + shape=[None] + input_shape, dtype="float32"), + ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "PREN": other_shape = [ diff --git a/tools/program.py b/tools/program.py index 9134472b..f36620a9 100755 --- a/tools/program.py +++ b/tools/program.py @@ -219,7 +219,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = [ - "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", + "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN", "RobustScanner", "RFL", 'DRRG', 'SATRN' ] extra_input = False @@ -641,9 +641,9 @@ def preprocess(is_train=False): 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', - 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN', - 'Telescope', 'SATRN' + 'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', + 'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', + 'CAN', 'Telescope', 'SATRN' ] if use_xpu: -- GitLab