未验证 提交 9909d790 编写于 作者: Y yukavio 提交者: GitHub

fix lr schedule in prune demo (#595)

* fix prune demo batchsize

* fix lr shcedule in prune demo
;
Co-authored-by: Nwanghaoshuang <wanghaoshuang@baidu.com>
上级 143087b8
...@@ -29,7 +29,6 @@ add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay ...@@ -29,7 +29,6 @@ add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.") add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.") add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
...@@ -65,9 +64,8 @@ def get_pruned_params(args, program): ...@@ -65,9 +64,8 @@ def get_pruned_params(args, program):
return params return params
def piecewise_decay(args): def piecewise_decay(args, step_per_epoch):
step = int(math.ceil(float(args.total_images) / args.batch_size)) bd = [step_per_epoch * e for e in args.step_epochs]
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr) learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
...@@ -75,25 +73,24 @@ def piecewise_decay(args): ...@@ -75,25 +73,24 @@ def piecewise_decay(args):
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=args.momentum_rate, momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay)) weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer return optimizer, learning_rate
def cosine_decay(args): def cosine_decay(args, step_per_epoch):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay( learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=args.num_epochs * step) learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch)
optimizer = paddle.optimizer.Momentum( optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=args.momentum_rate, momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay)) weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer return optimizer, learning_rate
def create_optimizer(args): def create_optimizer(args, step_per_epoch):
if args.lr_strategy == "piecewise_decay": if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args) return piecewise_decay(args, step_per_epoch)
elif args.lr_strategy == "cosine_decay": elif args.lr_strategy == "cosine_decay":
return cosine_decay(args) return cosine_decay(args, step_per_epoch)
def compress(args): def compress(args):
...@@ -118,34 +115,13 @@ def compress(args): ...@@ -118,34 +115,13 @@ def compress(args):
image_shape = [int(m) for m in image_shape.split(",")] image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model, assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list) model_list)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
avg_cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = paddle.static.default_main_program().clone(for_test=True)
opt = create_optimizer(args)
opt.minimize(avg_cost)
places = paddle.static.cuda_places( places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places() ) if args.use_gpu else paddle.static.cpu_places()
place = places[0] place = places[0]
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
if args.pretrained_model: label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
paddle.static.load(paddle.static.default_main_program(),
args.pretrained_model, exe)
batch_size_per_card = int(args.batch_size / len(places)) batch_size_per_card = int(args.batch_size / len(places))
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
train_dataset, train_dataset,
...@@ -166,6 +142,29 @@ def compress(args): ...@@ -166,6 +142,29 @@ def compress(args):
use_shared_memory=True, use_shared_memory=True,
batch_size=batch_size_per_card, batch_size=batch_size_per_card,
shuffle=False) shuffle=False)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
avg_cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = paddle.static.default_main_program().clone(for_test=True)
opt, learning_rate = create_optimizer(args, step_per_epoch)
opt.minimize(avg_cost)
exe.run(paddle.static.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
paddle.static.load(paddle.static.default_main_program(),
args.pretrained_model, exe)
def test(epoch, program): def test(epoch, program):
acc_top1_ns = [] acc_top1_ns = []
...@@ -189,15 +188,6 @@ def compress(args): ...@@ -189,15 +188,6 @@ def compress(args):
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
def train(epoch, program): def train(epoch, program):
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
train_program = paddle.static.CompiledProgram(
program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
for batch_id, data in enumerate(train_loader): for batch_id, data in enumerate(train_loader):
start_time = time.time() start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run( loss_n, acc_top1_n, acc_top5_n = exe.run(
...@@ -210,9 +200,11 @@ def compress(args): ...@@ -210,9 +200,11 @@ def compress(args):
acc_top5_n = np.mean(acc_top5_n) acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0: if batch_id % args.log_period == 0:
_logger.info( _logger.info(
"epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}". "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n, format(epoch, batch_id,
end_time - start_time)) learning_rate.get_lr(), loss_n, acc_top1_n,
acc_top5_n, end_time - start_time))
learning_rate.step()
batch_id += 1 batch_id += 1
test(0, val_program) test(0, val_program)
...@@ -236,8 +228,16 @@ def compress(args): ...@@ -236,8 +228,16 @@ def compress(args):
place=place) place=place)
_logger.info("FLOPs after pruning: {}".format(flops(pruned_program))) _logger.info("FLOPs after pruning: {}".format(flops(pruned_program)))
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
train_program = paddle.static.CompiledProgram(
pruned_program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
for i in range(args.num_epochs): for i in range(args.num_epochs):
train(i, pruned_program) train(i, train_program)
if i % args.test_period == 0: if i % args.test_period == 0:
test(i, pruned_val_program) test(i, pruned_val_program)
save_model(exe, pruned_val_program, save_model(exe, pruned_val_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册