diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml index b18bb685739597ee2667008f3549c915b6ad3060..1f6e534a6878a7ae84fc7fa7e1d975077f164d80 100644 --- a/configs/rec/rec_resnet_stn_bilstm_att.yml +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -37,7 +37,7 @@ Optimizer: Architecture: model_type: rec - algorithm: seed + algorithm: SEED Transform: name: STN_ON tps_inputsize: [32, 64] diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index d9815021c9a9a1ee40106c5e323bf346c3b4376d..169eb821f110d4a212068ebab4d46d636e241307 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -28,9 +28,10 @@ def build_backbone(config, model_type): from .rec_mv1_enhance import MobileNetV1Enhance from .rec_nrtr_mtb import MTB from .rec_resnet_31 import ResNet31 + from .rec_resnet_aster import ResNet_ASTER support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', - "ResNet31" + "ResNet31", "ResNet_ASTER" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet @@ -39,9 +40,6 @@ def build_backbone(config, model_type): from .table_resnet_vd import ResNet from .table_mobilenet_v3 import MobileNetV3 support_dict = ["ResNet", "MobileNetV3"] - elif model_type == "seed": - from .rec_resnet_aster import ResNet_ASTER - support_dict = ["ResNet_ASTER"] else: raise NotImplementedError diff --git a/tools/program.py b/tools/program.py index 8750dd9adcd51889bc1737985cad9f6fc2f8f4b3..4df87c16868260f7e09979b4dcfa76bccef72a79 100755 --- a/tools/program.py +++ b/tools/program.py @@ -402,8 +402,7 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'ASTER' - ] + 'SEED'] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = paddle.set_device(device)