未验证 提交 39f7ed30 编写于 作者: Y yukavio 提交者: GitHub

fix prune demo (#597)

上级 82f6ef8a
......@@ -130,7 +130,6 @@ class MobileNet():
with fluid.name_scope('last_fc'):
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(),
name="fc7_weights"),
......
......@@ -110,7 +110,6 @@ class MobileNetV2():
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
return output
......
......@@ -119,7 +119,6 @@ class MobileNetV3():
conv = self.hard_swish(conv)
out = fluid.layers.fc(input=conv,
size=class_dim,
act='softmax',
param_attr=ParamAttr(name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
return out
......@@ -244,8 +243,7 @@ class MobileNetV3():
if num_in_filter != num_out_filter or stride != 1:
return conv2
else:
return fluid.layers.elementwise_add(
x=input_data, y=conv2, act=None)
return fluid.layers.elementwise_add(x=input_data, y=conv2, act=None)
def MobileNetV3_small_x0_25():
......
......@@ -59,10 +59,8 @@ class PVANet():
block_configs=[
BlockConfig(2, '64 48-96 24-48-48 96 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True, BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True, BLOCK_TYPE_INCEP),
BlockConfig(1, '64 64-96 24-48-48 128', True, BLOCK_TYPE_INCEP)
],
name='conv4',
......@@ -76,9 +74,8 @@ class PVANet():
BlockConfig(1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP), BlockConfig(
1, '64 96-128 32-64-64 196', True,
BLOCK_TYPE_INCEP)
BLOCK_TYPE_INCEP),
BlockConfig(1, '64 96-128 32-64-64 196', True, BLOCK_TYPE_INCEP)
],
name='conv5',
end_points=end_points)
......@@ -89,7 +86,6 @@ class PVANet():
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name="fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
......@@ -182,9 +178,8 @@ class PVANet():
conv_stride = stride
else:
conv_stride = 1
path_net = self._conv_bn_relu(path_net, num_output,
kernel_size, name + scope,
conv_stride)
path_net = self._conv_bn_relu(path_net, num_output, kernel_size,
name + scope, conv_stride)
paths.append(path_net)
if stride > 1:
......@@ -359,8 +354,8 @@ class PVANet():
name,
stride=1,
groups=1):
return self._conv_bn_relu(input, num_filters, filter_size, name,
stride, groups)
return self._conv_bn_relu(input, num_filters, filter_size, name, stride,
groups)
def Fpn_Fusion(blocks, net):
......@@ -433,8 +428,7 @@ def east(input, class_num=31):
out[i], k, 1, name='fusion_' + str(len(blocks)))
elif j <= 4:
conv = net.deconv_bn_layer(
out[i], k, 2 * j, j, j // 2,
name='fusion_' + str(len(blocks)))
out[i], k, 2 * j, j, j // 2, name='fusion_' + str(len(blocks)))
else:
conv = net.deconv_bn_layer(
out[i], 32, 8, 4, 2, name='fusion_' + str(len(blocks)) + '_1')
......
......@@ -105,7 +105,6 @@ class ResNet():
out = fluid.layers.fc(
input=pool,
size=class_dim,
act='softmax',
name=fc_name,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
......@@ -138,8 +137,7 @@ class ResNet():
bn_name = "bn" + name[3:]
else:
if name.split("_")[1] == "conv1":
bn_name = name.split("_", 1)[0] + "_bn_" + name.split("_",
1)[1]
bn_name = name.split("_", 1)[0] + "_bn_" + name.split("_", 1)[1]
else:
bn_name = name.split("_", 1)[0] + "_bn" + name.split("_",
1)[1][3:]
......
......@@ -29,7 +29,6 @@ add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
......@@ -65,9 +64,8 @@ def get_pruned_params(args, program):
return params
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
def piecewise_decay(args, step_per_epoch):
bd = [step_per_epoch * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
......@@ -75,25 +73,24 @@ def piecewise_decay(args):
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer
return optimizer, learning_rate
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
def cosine_decay(args, step_per_epoch):
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=args.num_epochs)
learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer
return optimizer, learning_rate
def create_optimizer(args):
def create_optimizer(args, step_per_epoch):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
return piecewise_decay(args, step_per_epoch)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
return cosine_decay(args, step_per_epoch)
def compress(args):
......@@ -118,34 +115,13 @@ def compress(args):
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
avg_cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = paddle.static.default_main_program().clone(for_test=True)
opt = create_optimizer(args)
opt.minimize(avg_cost)
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
place = places[0]
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
paddle.static.load(paddle.static.default_main_program(),
args.pretrained_model, exe)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
batch_size_per_card = int(args.batch_size / len(places))
train_loader = paddle.io.DataLoader(
train_dataset,
......@@ -166,6 +142,30 @@ def compress(args):
use_shared_memory=True,
batch_size=batch_size_per_card,
shuffle=False)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = paddle.static.default_main_program().clone(for_test=True)
opt, learning_rate = create_optimizer(args, step_per_epoch)
opt.minimize(avg_cost)
exe.run(paddle.static.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
paddle.static.load(paddle.static.default_main_program(),
args.pretrained_model, exe)
def test(epoch, program):
acc_top1_ns = []
......@@ -189,15 +189,6 @@ def compress(args):
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
def train(epoch, program):
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
train_program = paddle.static.CompiledProgram(
program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
for batch_id, data in enumerate(train_loader):
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
......@@ -210,9 +201,11 @@ def compress(args):
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
learning_rate.get_lr(), loss_n, acc_top1_n,
acc_top5_n, end_time - start_time))
learning_rate.step()
batch_id += 1
test(0, val_program)
......@@ -236,8 +229,16 @@ def compress(args):
place=place)
_logger.info("FLOPs after pruning: {}".format(flops(pruned_program)))
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
train_program = paddle.static.CompiledProgram(
pruned_program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
for i in range(args.num_epochs):
train(i, pruned_program)
train(i, train_program)
if i % args.test_period == 0:
test(i, pruned_val_program)
save_model(exe, pruned_val_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册