diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 960a11be16a9090d80b5c5a27069246d1bcaa3e7..fa3bc5ad5c7c98a48b282433a23613d276ba3b1c 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -81,7 +81,7 @@ class ClsResizeImg(object): def __call__(self, data): img = data['image'] - norm_img = resize_norm_img(img, self.image_shape) + norm_img, _ = resize_norm_img(img, self.image_shape) data['image'] = norm_img return data diff --git a/tools/eval.py b/tools/eval.py index 1038090ab4e8da139ff180feec2495a5f401fc54..7fd4fa7ada7b1550bcca8766f5acb9b4d4ed2049 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -74,9 +74,11 @@ def main(): model = build_model(config['Architecture']) extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input = False if config['Architecture']['algorithm'] == 'Distillation': - extra_input = config['Architecture']['Models']['Teacher'][ - 'algorithm'] in extra_input_models + for key in config['Architecture']["Models"]: + extra_input = extra_input or config['Architecture']['Models'][key][ + 'algorithm'] in extra_input_models else: extra_input = config['Architecture']['algorithm'] in extra_input_models if "model_type" in config['Architecture'].keys(): diff --git a/tools/program.py b/tools/program.py index 1742f6c9557929accd52a1748add68f0e569a6b9..90fd309ae9e1ae23723d8e67c62a905e79a073d3 100755 --- a/tools/program.py +++ b/tools/program.py @@ -202,9 +202,11 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input = False if config['Architecture']['algorithm'] == 'Distillation': - extra_input = config['Architecture']['Models']['Teacher'][ - 'algorithm'] in extra_input_models + for key in config['Architecture']["Models"]: + extra_input = extra_input or config['Architecture']['Models'][key][ + 'algorithm'] in extra_input_models else: extra_input = config['Architecture']['algorithm'] in extra_input_models try: