diff --git a/dygraph/tools/export_model.py b/dygraph/tools/export_model.py index 26c0871dbbf4565ef1f18d857ce58e2edb5e0eff..06ac2006b22e25cf4032b3fe8678fd5812314c46 100644 --- a/dygraph/tools/export_model.py +++ b/dygraph/tools/export_model.py @@ -57,8 +57,10 @@ def parse_args(): def dygraph_to_static(model, save_dir, cfg): if not os.path.exists(save_dir): os.makedirs(save_dir) - inputs_def = cfg['TestReader']['inputs_def'] - image_shape = inputs_def.get('image_shape') + image_shape = None + if 'inputs_def' in cfg['TestReader']: + inputs_def = cfg['TestReader']['inputs_def'] + image_shape = inputs_def.get('image_shape', None) if image_shape is None: image_shape = [3, None, None] # Save infer cfg @@ -102,7 +104,7 @@ def main(): cfg = load_config(FLAGS.config) # TODO: to be refined in the future - if cfg.norm_type == 'sync_bn': + if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn': FLAGS.opt['norm_type'] = 'bn' merge_config(FLAGS.opt) check_config(cfg)