diff --git a/demo/prune/train.py b/demo/prune/train.py index 44082c9563f146a523a08184a51e8008cb08d8d8..b85b7c7e3beb1528134e92630fb4b06be7e2956a 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -81,7 +81,7 @@ def piecewise_decay(args): def cosine_decay(args): step = int(math.ceil(float(args.total_images) / args.batch_size)) learning_rate = paddle.optimizer.lr.CosineAnnealingDecay( - learning_rate=args.lr, T_max=args.num_epochs) + learning_rate=args.lr, T_max=args.num_epochs * step) optimizer = paddle.optimizer.Momentum( learning_rate=learning_rate, momentum=args.momentum_rate,