未验证 提交 96f645a5 编写于 作者: W whs 提交者: GitHub

Add option for saving inference model in pruning demo. (#196)

上级 e4c6ae55
...@@ -38,6 +38,7 @@ add_arg('test_period', int, 10, "Test period in epoches.") ...@@ -38,6 +38,7 @@ add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.") add_arg('model_path', str, "./models", "The path to save model.")
add_arg('pruned_ratio', float, None, "The ratios to be pruned.") add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.") add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
add_arg('save_inference', bool, False, "Whether to save inference model.")
# yapf: enable # yapf: enable
model_list = models.__all__ model_list = models.__all__
...@@ -230,6 +231,13 @@ def compress(args): ...@@ -230,6 +231,13 @@ def compress(args):
test(i, pruned_val_program) test(i, pruned_val_program)
save_model(exe, pruned_val_program, save_model(exe, pruned_val_program,
os.path.join(args.model_path, str(i))) os.path.join(args.model_path, str(i)))
if args.save_inference:
infer_model_path = os.path.join(args.model_path, "infer_models",
str(i))
fluid.io.save_inference_model(infer_model_path, ["image"], [out],
exe, pruned_val_program)
_logger.info("Saved inference model into [{}]".format(
infer_model_path))
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册