diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml index 7e98280b32558b8d3d203084e6e327bc7cd782bf..aa974d16586d2a0974835ee5801227239d37320c 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml index afa012b607eec2c223c94d2d0b09fbf4b2704f1b..7843f02a23d253fbcbbe65b3e86d3a22c25958de 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml @@ -45,7 +45,7 @@ Architecture: freeze_params: false return_all_feats: true model_type: *model_type - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance @@ -72,7 +72,7 @@ Architecture: freeze_params: false return_all_feats: true model_type: *model_type - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml index c728e0ac823b0bf835322dcbd0c385c3ac7b2489..51e29e9b419f198e0a6f1edaf5b3aad302455ea3 100644 --- a/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml index 8c650bd826d127f25c907f97d20d1a52f67f9203..6c78f0af2c1c82b8c78583e8e808010b611111f1 100644 --- a/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/chinese_cht_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/chinese_cht_PP-OCRv3_rec.yml index 28e0c10aa0f83fdf8e621aae04bf2b7374255adc..2c9b61828b644f7238ca8b0d8ad76d31b3a0a5c7 100644 --- a/configs/rec/PP-OCRv3/multi_language/chinese_cht_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/chinese_cht_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml index fbdbe6c44c689ea267c9995f832305d800046edb..6faffb729bacd71dd2362e61d7f1b9646c30030b 100644 --- a/configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml index 48eb38df36f931b76b8e9fb8369daf06ad037d25..488994d2dc0a173bd131d08f396f788c6ea4f320 100644 --- a/configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/japan_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/japan_PP-OCRv3_rec.yml index 6cab0d447247e28bb58b30384d4f9d032d6ce9d0..b3ddde396f212fc0f058c4b6b2ebc9044a8c90fb 100644 --- a/configs/rec/PP-OCRv3/multi_language/japan_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/japan_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml index 7a9c8241d1564e5f1295655ba64694a117064bd8..3e845032c2f22f021093e9575c7f36904d42388f 100644 --- a/configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/korean_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/korean_PP-OCRv3_rec.yml index 29ff570772a621ba747e0388bcc0c042db0dba43..17ea15fa8d08216160ef0b2540bb6a0cfc2ca3fa 100644 --- a/configs/rec/PP-OCRv3/multi_language/korean_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/korean_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml index 1784bfe611366c45230fd2abf69ab16e3a1c3ae9..34ade4af900b07ba81776e839fcdb4b15c99e911 100644 --- a/configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml index 70b26aa84a2178111edab9f094c369c5d22e31a9..17cc870ec8cf9f878770bef40ebf3b39cb0b8e0e 100644 --- a/configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/configs/rec/PP-OCRv3/multi_language/te_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/multi_language/te_PP-OCRv3_rec.yml index 3617af79e3b9c5a55ef22d549465ba2109618e32..5f3dad7cb938dfc462f86e504044e8a6940593c7 100644 --- a/configs/rec/PP-OCRv3/multi_language/te_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/multi_language/te_PP-OCRv3_rec.yml @@ -36,7 +36,7 @@ Optimizer: Architecture: model_type: rec - algorithm: SVTR + algorithm: SVTR_LCNet Transform: Backbone: name: MobileNetV1Enhance diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index 1c955ef3abe9c38e816616cc9b5399c6832aa5f1..384ae4cc2bcb8b840b7c35bc30a387e2bb89bb84 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -40,7 +40,7 @@ def apply_to_static(model, config, logger): return model assert "image_shape" in config[ "Global"], "image_shape must be assigned for static training mode..." - supported_list = ["DB", "SVTR"] + supported_list = ["DB", "SVTR_LCNet"] if config["Architecture"]["algorithm"] in ["Distillation"]: algo = list(config["Architecture"]["Models"].values())[0]["algorithm"] else: @@ -52,7 +52,7 @@ def apply_to_static(model, config, logger): [None] + config["Global"]["image_shape"], dtype='float32') ] - if algo == "SVTR": + if algo == "SVTR_LCNet": specs.append([ InputSpec( [None, config["Global"]["max_text_length"]], diff --git a/tools/eval.py b/tools/eval.py index 21f4d94d5e4ed560b8775c8827ffdbbd00355218..8b58b660934a1085b36e31d2a1a7cfd26df89685 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -75,7 +75,8 @@ def main(): model = build_model(config['Architecture']) extra_input_models = [ - "SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner" + "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "VisionLAN", + "RobustScanner" ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': diff --git a/tools/export_model.py b/tools/export_model.py index 9d3bf5629e639bc0a7112090cd18e4cb57bd55d0..cf2a2d8b2f77f7bef947a3190d4ec57d01295542 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -62,17 +62,17 @@ def export_single_model(model, shape=[None], dtype="float32")] ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SVTR_LCNet": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, -1], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "SVTR": - if arch_config["Head"]["name"] == 'MultiHead': - other_shape = [ - paddle.static.InputSpec( - shape=[None, 3, 48, -1], dtype="float32"), - ] - else: - other_shape = [ - paddle.static.InputSpec( - shape=[None] + input_shape, dtype="float32"), - ] + other_shape = [ + paddle.static.InputSpec( + shape=[None] + input_shape, dtype="float32"), + ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "PREN": other_shape = [ diff --git a/tools/program.py b/tools/program.py index 9134472b83f785bbabcf44dc0675c0e0d9a08fc9..f36620a9b311c8dc2ac4842013e46797fa504e1c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -219,7 +219,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = [ - "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", + "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN", "RobustScanner", "RFL", 'DRRG', 'SATRN' ] extra_input = False @@ -641,9 +641,9 @@ def preprocess(is_train=False): 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', - 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN', - 'Telescope', 'SATRN' + 'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', + 'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', + 'CAN', 'Telescope', 'SATRN' ] if use_xpu: