diff --git a/tools/export_model.py b/tools/export_model.py index e646b7cc0937a2f61458721045a1ce55404d6a56..4557542a446bd11e1c5ea2f505b38e3111989324 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -78,6 +78,15 @@ def parse_reader(reader_cfg, metric, arch): params['image_shape'] = image_shape[1:] if 'target_dim' in params: params.pop('target_dim') + if p['type'] == 'ResizeAndPad': + assert has_shape_def, "missing input shape" + p['type'] = 'Resize' + p['target_size'] = params['target_dim'] + p['max_size'] = params['target_dim'] + p['interp'] = params['interp'] + p['image_shape'] = image_shape[1:] + preprocess_list.append(p) + continue p.update(params) preprocess_list.append(p) batch_transforms = reader_cfg.get('batch_transforms', None) @@ -116,9 +125,9 @@ def dump_infer_config(FLAGS, config): 'Face': 3, 'TTFNet': 3, 'FCOS': 3, - 'EfficientDet': 40 } infer_arch = config['architecture'] + infer_arch = 'RetinaNet' if infer_arch == 'EfficientDet' else infer_arch for arch, min_subgraph_size in trt_min_subgraph.items(): if arch in infer_arch: