未验证 提交 58a408ab 编写于 作者: A andyjpaddle 提交者: GitHub

Merge pull request #6069 from andyjpaddle/dygraph

fix key for distillation and cls resize
...@@ -81,7 +81,7 @@ class ClsResizeImg(object): ...@@ -81,7 +81,7 @@ class ClsResizeImg(object):
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
norm_img = resize_norm_img(img, self.image_shape) norm_img, _ = resize_norm_img(img, self.image_shape)
data['image'] = norm_img data['image'] = norm_img
return data return data
......
...@@ -74,8 +74,10 @@ def main(): ...@@ -74,8 +74,10 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation': if config['Architecture']['algorithm'] == 'Distillation':
extra_input = config['Architecture']['Models']['Teacher'][ for key in config['Architecture']["Models"]:
extra_input = extra_input or config['Architecture']['Models'][key][
'algorithm'] in extra_input_models 'algorithm'] in extra_input_models
else: else:
extra_input = config['Architecture']['algorithm'] in extra_input_models extra_input = config['Architecture']['algorithm'] in extra_input_models
......
...@@ -202,8 +202,10 @@ def train(config, ...@@ -202,8 +202,10 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation': if config['Architecture']['algorithm'] == 'Distillation':
extra_input = config['Architecture']['Models']['Teacher'][ for key in config['Architecture']["Models"]:
extra_input = extra_input or config['Architecture']['Models'][key][
'algorithm'] in extra_input_models 'algorithm'] in extra_input_models
else: else:
extra_input = config['Architecture']['algorithm'] in extra_input_models extra_input = config['Architecture']['algorithm'] in extra_input_models
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册