未验证 提交 6439a4df 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #4218 from tink2123/fix_seed

Fix some typo for SEED
...@@ -37,7 +37,7 @@ Optimizer: ...@@ -37,7 +37,7 @@ Optimizer:
Architecture: Architecture:
model_type: rec model_type: rec
algorithm: seed algorithm: SEED
Transform: Transform:
name: STN_ON name: STN_ON
tps_inputsize: [32, 64] tps_inputsize: [32, 64]
......
...@@ -28,9 +28,10 @@ def build_backbone(config, model_type): ...@@ -28,9 +28,10 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31 from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31" "ResNet31", "ResNet_ASTER"
] ]
elif model_type == "e2e": elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
...@@ -39,9 +40,6 @@ def build_backbone(config, model_type): ...@@ -39,9 +40,6 @@ def build_backbone(config, model_type):
from .table_resnet_vd import ResNet from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3 from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"] support_dict = ["ResNet", "MobileNetV3"]
elif model_type == "seed":
from .rec_resnet_aster import ResNet_ASTER
support_dict = ["ResNet_ASTER"]
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -414,8 +414,7 @@ def preprocess(is_train=False): ...@@ -414,8 +414,7 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'ASTER' 'SEED']
]
windows_not_support_list = ['PSE'] windows_not_support_list = ['PSE']
if platform.system() == "Windows" and alg in windows_not_support_list: if platform.system() == "Windows" and alg in windows_not_support_list:
logger.warning('{} is not support in Windows now'.format( logger.warning('{} is not support in Windows now'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册