未验证 提交 41e40edb 编写于 作者: B Bai Yifan 提交者: GitHub

save inference model in distillation demo (#181)

上级 6a359a7a
......@@ -25,6 +25,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('save_inference', bool, False, "Whether to save inference model.")
add_arg('total_images', int, 1281167, "Training image number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
......@@ -214,6 +215,10 @@ def compress(args):
"valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
format(epoch_id, step_id, val_loss[0], val_acc1[0],
val_acc5[0]))
if args.save_inference:
fluid.io.save_inference_model(
os.path.join("./saved_models", str(epoch_id)), ["image"],
[out], exe, student_program)
_logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
epoch_id, np.mean(val_acc1s), np.mean(val_acc5s)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册