提交 0c280c70 编写于 作者: A andyjpaddle

fix key for dis and cls resize

上级 7756e66b
...@@ -81,7 +81,7 @@ class ClsResizeImg(object): ...@@ -81,7 +81,7 @@ class ClsResizeImg(object):
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
norm_img = resize_norm_img(img, self.image_shape) norm_img, _ = resize_norm_img(img, self.image_shape)
data['image'] = norm_img data['image'] = norm_img
return data return data
......
...@@ -74,9 +74,11 @@ def main(): ...@@ -74,9 +74,11 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation': if config['Architecture']['algorithm'] == 'Distillation':
extra_input = config['Architecture']['Models']['Teacher'][ for key in config['Architecture']["Models"]:
'algorithm'] in extra_input_models extra_input = extra_input or config['Architecture']['Models'][key][
'algorithm'] in extra_input_models
else: else:
extra_input = config['Architecture']['algorithm'] in extra_input_models extra_input = config['Architecture']['algorithm'] in extra_input_models
if "model_type" in config['Architecture'].keys(): if "model_type" in config['Architecture'].keys():
......
...@@ -202,9 +202,11 @@ def train(config, ...@@ -202,9 +202,11 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation': if config['Architecture']['algorithm'] == 'Distillation':
extra_input = config['Architecture']['Models']['Teacher'][ for key in config['Architecture']["Models"]:
'algorithm'] in extra_input_models extra_input = extra_input or config['Architecture']['Models'][key][
'algorithm'] in extra_input_models
else: else:
extra_input = config['Architecture']['algorithm'] in extra_input_models extra_input = config['Architecture']['algorithm'] in extra_input_models
try: try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册