未验证 提交 c8d6ba6b 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph]fix problems in export_model.py (#1932)

* fix problems in export_model.py

* modify code
上级 f3caf39c
...@@ -57,8 +57,10 @@ def parse_args(): ...@@ -57,8 +57,10 @@ def parse_args():
def dygraph_to_static(model, save_dir, cfg): def dygraph_to_static(model, save_dir, cfg):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
image_shape = None
if 'inputs_def' in cfg['TestReader']:
inputs_def = cfg['TestReader']['inputs_def'] inputs_def = cfg['TestReader']['inputs_def']
image_shape = inputs_def.get('image_shape') image_shape = inputs_def.get('image_shape', None)
if image_shape is None: if image_shape is None:
image_shape = [3, None, None] image_shape = [3, None, None]
# Save infer cfg # Save infer cfg
...@@ -102,7 +104,7 @@ def main(): ...@@ -102,7 +104,7 @@ def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
# TODO: to be refined in the future # TODO: to be refined in the future
if cfg.norm_type == 'sync_bn': if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn':
FLAGS.opt['norm_type'] = 'bn' FLAGS.opt['norm_type'] = 'bn'
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
check_config(cfg) check_config(cfg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册