From 490886aee7926b064e186c3609c1a83b73cef9e9 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Sat, 19 Dec 2020 00:10:45 +0800 Subject: [PATCH] fix problems in export_model, test=dygraph (#1934) --- dygraph/tools/export_model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dygraph/tools/export_model.py b/dygraph/tools/export_model.py index 26c0871db..06ac2006b 100644 --- a/dygraph/tools/export_model.py +++ b/dygraph/tools/export_model.py @@ -57,8 +57,10 @@ def parse_args(): def dygraph_to_static(model, save_dir, cfg): if not os.path.exists(save_dir): os.makedirs(save_dir) - inputs_def = cfg['TestReader']['inputs_def'] - image_shape = inputs_def.get('image_shape') + image_shape = None + if 'inputs_def' in cfg['TestReader']: + inputs_def = cfg['TestReader']['inputs_def'] + image_shape = inputs_def.get('image_shape', None) if image_shape is None: image_shape = [3, None, None] # Save infer cfg @@ -102,7 +104,7 @@ def main(): cfg = load_config(FLAGS.config) # TODO: to be refined in the future - if cfg.norm_type == 'sync_bn': + if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn': FLAGS.opt['norm_type'] = 'bn' merge_config(FLAGS.opt) check_config(cfg) -- GitLab