diff --git a/demo/prune/train.py b/demo/prune/train.py index a91084aa91ed2058234bb74efdf224a64b99a180..231cda9cc183ede5f83bf4f51033d8d321fd1e1a 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -38,6 +38,7 @@ add_arg('test_period', int, 10, "Test period in epoches.") add_arg('model_path', str, "./models", "The path to save model.") add_arg('pruned_ratio', float, None, "The ratios to be pruned.") add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.") +add_arg('save_inference', bool, False, "Whether to save inference model.") # yapf: enable model_list = models.__all__ @@ -230,6 +231,13 @@ def compress(args): test(i, pruned_val_program) save_model(exe, pruned_val_program, os.path.join(args.model_path, str(i))) + if args.save_inference: + infer_model_path = os.path.join(args.model_path, "infer_models", + str(i)) + fluid.io.save_inference_model(infer_model_path, ["image"], [out], + exe, pruned_val_program) + _logger.info("Saved inference model into [{}]".format( + infer_model_path)) def main():