提交 0c280c70 编写于 作者: A andyjpaddle

fix key for dis and cls resize

上级 7756e66b
......@@ -81,7 +81,7 @@ class ClsResizeImg(object):
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img(img, self.image_shape)
norm_img, _ = resize_norm_img(img, self.image_shape)
data['image'] = norm_img
return data
......
......@@ -74,9 +74,11 @@ def main():
model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
extra_input = config['Architecture']['Models']['Teacher'][
'algorithm'] in extra_input_models
for key in config['Architecture']["Models"]:
extra_input = extra_input or config['Architecture']['Models'][key][
'algorithm'] in extra_input_models
else:
extra_input = config['Architecture']['algorithm'] in extra_input_models
if "model_type" in config['Architecture'].keys():
......
......@@ -202,9 +202,11 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
extra_input = config['Architecture']['Models']['Teacher'][
'algorithm'] in extra_input_models
for key in config['Architecture']["Models"]:
extra_input = extra_input or config['Architecture']['Models'][key][
'algorithm'] in extra_input_models
else:
extra_input = config['Architecture']['algorithm'] in extra_input_models
try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册