diff --git a/demo/prune/train.py b/demo/prune/train.py index 574f515803c7a4485bae0cf11bfc95babea9eec8..da2c7eb5ae279fd3096565b208a3c49fe95c32a0 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -239,7 +239,7 @@ def compress(args): for i in range(args.num_epochs): train(i, train_program) - if i % args.test_period == 0: + if (i + 1) % args.test_period == 0: test(i, pruned_val_program) save_model(exe, pruned_val_program, os.path.join(args.model_path, str(i)))