From 6aa35c18ae1e7696fbb94a255d8133ace8234ea0 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 10 Jul 2022 12:31:27 +0800 Subject: [PATCH] modified head --- ppocr/modeling/heads/__init__.py | 2 +- ppocr/modeling/heads/rec_robustscanner_head.py | 2 +- tools/export_model.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 99cb59e6..9a1b7657 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -33,8 +33,8 @@ def build_head(config): from .rec_aster_head import AsterHead from .rec_pren_head import PRENHead from .rec_multi_head import MultiHead - from .rec_robustscanner_head import RobustScannerHead from .rec_abinet_head import ABINetHead + from .rec_robustscanner_head import RobustScannerHead # cls head from .cls_head import ClsHead diff --git a/ppocr/modeling/heads/rec_robustscanner_head.py b/ppocr/modeling/heads/rec_robustscanner_head.py index fc889d59..b9f8962d 100644 --- a/ppocr/modeling/heads/rec_robustscanner_head.py +++ b/ppocr/modeling/heads/rec_robustscanner_head.py @@ -1,4 +1,4 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tools/export_model.py b/tools/export_model.py index 11794f74..a9f4a62e 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -79,7 +79,7 @@ def export_single_model(model, ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "RobustScanner": - max_seq_len = arch_config["Head"]["max_seq_len"] + max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ paddle.static.InputSpec( shape=[None, 3, 48, 160], dtype="float32"), @@ -89,7 +89,7 @@ def export_single_model(model, shape=[None, ], dtype="float32"), paddle.static.InputSpec( - shape=[None, max_seq_len], + shape=[None, max_text_length], dtype="int64") ] ] -- GitLab