提交 32481e12 编写于 作者: D dangqingqing

Refine MobileNet SSD model.

上级 e9c9cfd2
import paddle.v2 as paddle
import paddle.fluid as fluid
import numpy as np
def load_vars():
vars = {}
name_map = {}
with open('./ssd_mobilenet_v1_coco/names.map', 'r') as map_file:
for param in map_file:
fd_name, tf_name = param.strip().split('\t')
name_map[fd_name] = tf_name
tf_vars = np.load(
'./ssd_mobilenet_v1_coco/ssd_mobilenet_v1_coco_2017_11_17.npy').item()
for fd_name in name_map:
tf_name = name_map[fd_name]
tf_var = tf_vars[tf_name]
if len(tf_var.shape) == 4 and 'depthwise' in tf_name:
vars[fd_name] = np.transpose(tf_var, (2, 3, 0, 1))
elif len(tf_var.shape) == 4:
vars[fd_name] = np.transpose(tf_var, (3, 2, 0, 1))
else:
vars[fd_name] = tf_var
return vars
def load_and_set_vars(place):
vars = load_vars()
for k, v in vars.items():
t = fluid.global_scope().find_var(k).get_tensor()
#print(np.array(t).shape, v.shape, k)
assert np.array(t).shape == v.shape
t.set(v, place)
if __name__ == "__main__":
load_vars()
import os
import paddle.v2 as paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
import reader
import numpy as np
import load_model as load_model
parameter_attr = ParamAttr(initializer=MSRA())
def conv_bn_layer(input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True):
def conv_bn(input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -37,7 +37,7 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride,
scale):
"""
"""
depthwise_conv = conv_bn_layer(
depthwise_conv = conv_bn(
input=input,
filter_size=3,
num_filters=int(num_filters1 * scale),
......@@ -46,7 +46,7 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride,
num_groups=int(num_groups * scale),
use_cudnn=False)
pointwise_conv = conv_bn_layer(
pointwise_conv = conv_bn(
input=depthwise_conv,
filter_size=1,
num_filters=int(num_filters2 * scale),
......@@ -56,9 +56,8 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride,
def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale):
"""
"""
pointwise_conv = conv_bn_layer(
# 1x1 conv
pointwise_conv = conv_bn(
input=input,
filter_size=1,
num_filters=int(num_filters1 * scale),
......@@ -66,7 +65,8 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale):
num_groups=int(num_groups * scale),
padding=0)
normal_conv = conv_bn_layer(
# 3x3 conv
normal_conv = conv_bn(
input=pointwise_conv,
filter_size=3,
num_filters=int(num_filters2 * scale),
......@@ -77,130 +77,33 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale):
def mobile_net(img, img_shape, scale=1.0):
# 300x300
tmp = conv_bn_layer(
img,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
tmp = conv_bn(img, 3, int(32 * scale), 2, 1, 3)
# 150x150
tmp = depthwise_separable(
tmp,
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale)
tmp = depthwise_separable(
tmp,
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=scale)
tmp = depthwise_separable(tmp, 32, 64, 32, 1, scale)
tmp = depthwise_separable(tmp, 64, 128, 64, 2, scale)
# 75x75
tmp = depthwise_separable(
tmp,
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale)
tmp = depthwise_separable(
tmp,
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale)
tmp = depthwise_separable(tmp, 128, 128, 128, 1, scale)
tmp = depthwise_separable(tmp, 128, 256, 128, 2, scale)
# 38x38
tmp = depthwise_separable(
tmp,
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale)
tmp = depthwise_separable(
tmp,
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale)
tmp = depthwise_separable(tmp, 256, 256, 256, 1, scale)
tmp = depthwise_separable(tmp, 256, 512, 256, 2, scale)
# 19x19
for i in range(5):
tmp = depthwise_separable(
tmp,
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale)
tmp = depthwise_separable(tmp, 512, 512, 512, 1, scale)
module11 = tmp
tmp = depthwise_separable(
tmp,
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale)
tmp = depthwise_separable(tmp, 512, 1024, 512, 2, scale)
# 10x10
module13 = depthwise_separable(
tmp,
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale)
module14 = extra_block(
module13,
num_filters1=256,
num_filters2=512,
num_groups=1,
stride=2,
scale=scale)
module13 = depthwise_separable(tmp, 1024, 1024, 1024, 1, scale)
module14 = extra_block(module13, 256, 512, 1, 2, scale)
# 5x5
module15 = extra_block(
module14,
num_filters1=128,
num_filters2=256,
num_groups=1,
stride=2,
scale=scale)
module15 = extra_block(module14, 128, 256, 1, 2, scale)
# 3x3
module16 = extra_block(
module15,
num_filters1=128,
num_filters2=256,
num_groups=1,
stride=2,
scale=scale)
module16 = extra_block(module15, 128, 256, 1, 2, scale)
# 2x2
module17 = extra_block(
module16,
num_filters1=64,
num_filters2=128,
num_groups=1,
stride=2,
scale=scale)
module17 = extra_block(module16, 64, 128, 1, 2, scale)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[module11, module13, module14, module15, module16, module17],
image=img,
......@@ -230,7 +133,9 @@ def train(train_file_list,
gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1)
gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='float32', lod_level=1)
name='gt_label', shape=[1], dtype='int32', lod_level=1)
difficult = fluid.layers.data(
name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
mbox_locs, mbox_confs, box, box_var = mobile_net(image, image_shape)
nmsed_out = fluid.layers.detection_output(mbox_locs, mbox_confs, box,
......@@ -239,35 +144,62 @@ def train(train_file_list,
box, box_var)
loss = fluid.layers.nn.reduce_sum(loss_vec)
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
difficult,
21,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version='11point')
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.learning_rate_decay.exponential_decay(
learning_rate=fluid.layers.exponential_decay(
learning_rate=learning_rate,
decay_steps=40000,
decay_rate=0.1,
staircase=True),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(5 * 1e-5), )
regularization=fluid.regularizer.L2Decay(0.0005), )
opts = optimizer.minimize(loss)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
load_model.load_and_set_vars(place)
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
test_reader = paddle.batch(
reader.test(data_args, train_file_list), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, gt_box, gt_label])
feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult])
#print fluid.default_main_program()
map, accum_map = map_eval.get_map_var()
for pass_id in range(num_passes):
map_eval.reset(exe)
for batch_id, data in enumerate(train_reader()):
loss_v = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss])
if batch_id % 50 == 0:
print("Pass {0}, batch {1}, loss {2}".format(pass_id, batch_id,
np.sum(loss_v)))
if pass_id % 1 == 0:
loss_v, map_v, accum_map_v = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss, map, accum_map])
print(
"Pass {0}, batch {1}, loss {2}, cur_map {3}, map {4}"
.format(pass_id, batch_id, loss_v[0], map_v[0], accum_map_v[0]))
map_eval.reset(exe)
test_map = None
for _, data in enumerate(test_reader()):
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
print("Test {0}, map {1}".format(pass_id, test_map[0]))
if pass_id % 10 == 0:
model_path = os.path.join(model_save_dir, str(pass_id))
print 'save models to %s' % (model_path)
fluid.io.save_inference_model(model_path, ['image'], [nmsed_out],
......@@ -285,6 +217,6 @@ if __name__ == '__main__':
train_file_list='./data/trainval.txt',
val_file_list='./data/test.txt',
data_args=data_args,
learning_rate=0.001,
learning_rate=0.004,
batch_size=32,
num_passes=300)
......@@ -159,7 +159,8 @@ def _reader_creator(settings, file_list, mode, shuffle):
if mode == 'train' and len(sample_labels) == 0: continue
yield img.astype(
'float32'
), sample_labels[:, 1:5], sample_labels[:, 0].astype('int')
), sample_labels[:, 1:5], sample_labels[:, 0].astype(
'int32'), sample_labels[:, 5].astype('int32')
elif mode == 'infer':
yield img.astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册