From 2d385d258381392aea906f35c1dc00ea02e654b6 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Tue, 17 Mar 2020 10:18:39 +0800 Subject: [PATCH] save inference model in slim/distillation (#339) --- slim/distillation/distill.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/slim/distillation/distill.py b/slim/distillation/distill.py index 134e1f37e..a2b901223 100644 --- a/slim/distillation/distill.py +++ b/slim/distillation/distill.py @@ -335,6 +335,12 @@ def main(): checkpoint.save(exe, fluid.default_main_program(), os.path.join(save_dir, save_name)) + if FLAGS.save_inference: + feeded_var_names = ['image', 'im_size'] + targets = list(fetches.values()) + fluid.io.save_inference_model(save_dir + '/infer', + feeded_var_names, targets, exe, + eval_prog) # eval results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys, eval_values, eval_cls) @@ -349,7 +355,13 @@ def main(): best_box_ap_list[1] = step_id checkpoint.save(exe, fluid.default_main_program(), - os.path.join("./", "best_model")) + os.path.join(save_dir, "best_model")) + if FLAGS.save_inference: + feeded_var_names = ['image', 'im_size'] + targets = list(fetches.values()) + fluid.io.save_inference_model(save_dir + '/infer', + feeded_var_names, targets, + exe, eval_prog) logger.info("Best test box ap: {}, in step: {}".format( best_box_ap_list[0], best_box_ap_list[1])) train_loader.reset() @@ -379,5 +391,10 @@ if __name__ == '__main__': default=None, type=str, help="Evaluation directory, default is current directory.") + parser.add_argument( + "--save_inference", + default=False, + type=bool, + help="Whether to save inference model.") FLAGS = parser.parse_args() main() -- GitLab