提交 5db482cd 编写于 作者: A andyjpaddle

fix det slim export

上级 2d14b487
...@@ -18,6 +18,7 @@ Global: ...@@ -18,6 +18,7 @@ Global:
Architecture: Architecture:
name: DistillationModel name: DistillationModel
algorithm: Distillation algorithm: Distillation
model_type: det
Models: Models:
Teacher: Teacher:
freeze_params: true freeze_params: true
......
...@@ -111,7 +111,7 @@ def main(): ...@@ -111,7 +111,7 @@ def main():
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = None model_type = config['Architecture'].get('model_type', None)
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn) eval_class, model_type, use_srn)
...@@ -120,7 +120,7 @@ def main(): ...@@ -120,7 +120,7 @@ def main():
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
infer_shape = [3, 32, 100] if model_type != "det" else [3, 640, 640] infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册