diff --git a/fluid/object_detection/load_model.py b/fluid/object_detection/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2839bc7df1b19f65e4a218cbf0d0d7b8954a5e --- /dev/null +++ b/fluid/object_detection/load_model.py @@ -0,0 +1,39 @@ +import paddle.v2 as paddle +import paddle.fluid as fluid +import numpy as np + + +def load_vars(): + vars = {} + name_map = {} + with open('./ssd_mobilenet_v1_coco/names.map', 'r') as map_file: + for param in map_file: + fd_name, tf_name = param.strip().split('\t') + name_map[fd_name] = tf_name + + tf_vars = np.load( + './ssd_mobilenet_v1_coco/ssd_mobilenet_v1_coco_2017_11_17.npy').item() + for fd_name in name_map: + tf_name = name_map[fd_name] + tf_var = tf_vars[tf_name] + if len(tf_var.shape) == 4 and 'depthwise' in tf_name: + vars[fd_name] = np.transpose(tf_var, (2, 3, 0, 1)) + elif len(tf_var.shape) == 4: + vars[fd_name] = np.transpose(tf_var, (3, 2, 0, 1)) + else: + vars[fd_name] = tf_var + + return vars + + +def load_and_set_vars(place): + vars = load_vars() + for k, v in vars.items(): + t = fluid.global_scope().find_var(k).get_tensor() + #print(np.array(t).shape, v.shape, k) + assert np.array(t).shape == v.shape + t.set(v, place) + + +if __name__ == "__main__": + load_vars() diff --git a/fluid/object_detection/mobilenet_ssd_fluid.py b/fluid/object_detection/mobilenet_ssd.py similarity index 56% rename from fluid/object_detection/mobilenet_ssd_fluid.py rename to fluid/object_detection/mobilenet_ssd.py index 633dd8e165f387a2ffcc5819da6fbc359e8a680a..cb3cada6492ece9c9f4344aacfe32f78ea942375 100644 --- a/fluid/object_detection/mobilenet_ssd_fluid.py +++ b/fluid/object_detection/mobilenet_ssd.py @@ -1,24 +1,24 @@ import os - import paddle.v2 as paddle import paddle.fluid as fluid from paddle.fluid.initializer import MSRA from paddle.fluid.param_attr import ParamAttr import reader import numpy as np +import load_model as load_model parameter_attr = ParamAttr(initializer=MSRA()) -def conv_bn_layer(input, - filter_size, - num_filters, - stride, - padding, - channels=None, - num_groups=1, - act='relu', - use_cudnn=True): +def conv_bn(input, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='relu', + use_cudnn=True): conv = fluid.layers.conv2d( input=input, num_filters=num_filters, @@ -37,7 +37,7 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride, scale): """ """ - depthwise_conv = conv_bn_layer( + depthwise_conv = conv_bn( input=input, filter_size=3, num_filters=int(num_filters1 * scale), @@ -46,7 +46,7 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride, num_groups=int(num_groups * scale), use_cudnn=False) - pointwise_conv = conv_bn_layer( + pointwise_conv = conv_bn( input=depthwise_conv, filter_size=1, num_filters=int(num_filters2 * scale), @@ -56,9 +56,8 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride, def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale): - """ - """ - pointwise_conv = conv_bn_layer( + # 1x1 conv + pointwise_conv = conv_bn( input=input, filter_size=1, num_filters=int(num_filters1 * scale), @@ -66,7 +65,8 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale): num_groups=int(num_groups * scale), padding=0) - normal_conv = conv_bn_layer( + # 3x3 conv + normal_conv = conv_bn( input=pointwise_conv, filter_size=3, num_filters=int(num_filters2 * scale), @@ -77,130 +77,33 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale): def mobile_net(img, img_shape, scale=1.0): - # 300x300 - tmp = conv_bn_layer( - img, - filter_size=3, - channels=3, - num_filters=int(32 * scale), - stride=2, - padding=1) - + tmp = conv_bn(img, 3, int(32 * scale), 2, 1, 3) # 150x150 - tmp = depthwise_separable( - tmp, - num_filters1=32, - num_filters2=64, - num_groups=32, - stride=1, - scale=scale) - - tmp = depthwise_separable( - tmp, - num_filters1=64, - num_filters2=128, - num_groups=64, - stride=2, - scale=scale) - + tmp = depthwise_separable(tmp, 32, 64, 32, 1, scale) + tmp = depthwise_separable(tmp, 64, 128, 64, 2, scale) # 75x75 - tmp = depthwise_separable( - tmp, - num_filters1=128, - num_filters2=128, - num_groups=128, - stride=1, - scale=scale) - - tmp = depthwise_separable( - tmp, - num_filters1=128, - num_filters2=256, - num_groups=128, - stride=2, - scale=scale) - + tmp = depthwise_separable(tmp, 128, 128, 128, 1, scale) + tmp = depthwise_separable(tmp, 128, 256, 128, 2, scale) # 38x38 - tmp = depthwise_separable( - tmp, - num_filters1=256, - num_filters2=256, - num_groups=256, - stride=1, - scale=scale) - - tmp = depthwise_separable( - tmp, - num_filters1=256, - num_filters2=512, - num_groups=256, - stride=2, - scale=scale) + tmp = depthwise_separable(tmp, 256, 256, 256, 1, scale) + tmp = depthwise_separable(tmp, 256, 512, 256, 2, scale) # 19x19 for i in range(5): - tmp = depthwise_separable( - tmp, - num_filters1=512, - num_filters2=512, - num_groups=512, - stride=1, - scale=scale) + tmp = depthwise_separable(tmp, 512, 512, 512, 1, scale) module11 = tmp - - tmp = depthwise_separable( - tmp, - num_filters1=512, - num_filters2=1024, - num_groups=512, - stride=2, - scale=scale) + tmp = depthwise_separable(tmp, 512, 1024, 512, 2, scale) # 10x10 - module13 = depthwise_separable( - tmp, - num_filters1=1024, - num_filters2=1024, - num_groups=1024, - stride=1, - scale=scale) - - module14 = extra_block( - module13, - num_filters1=256, - num_filters2=512, - num_groups=1, - stride=2, - scale=scale) - + module13 = depthwise_separable(tmp, 1024, 1024, 1024, 1, scale) + module14 = extra_block(module13, 256, 512, 1, 2, scale) # 5x5 - module15 = extra_block( - module14, - num_filters1=128, - num_filters2=256, - num_groups=1, - stride=2, - scale=scale) - + module15 = extra_block(module14, 128, 256, 1, 2, scale) # 3x3 - module16 = extra_block( - module15, - num_filters1=128, - num_filters2=256, - num_groups=1, - stride=2, - scale=scale) - + module16 = extra_block(module15, 128, 256, 1, 2, scale) # 2x2 - module17 = extra_block( - module16, - num_filters1=64, - num_filters2=128, - num_groups=1, - stride=2, - scale=scale) - + module17 = extra_block(module16, 64, 128, 1, 2, scale) mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head( inputs=[module11, module13, module14, module15, module16, module17], image=img, @@ -230,7 +133,9 @@ def train(train_file_list, gt_box = fluid.layers.data( name='gt_box', shape=[4], dtype='float32', lod_level=1) gt_label = fluid.layers.data( - name='gt_label', shape=[1], dtype='float32', lod_level=1) + name='gt_label', shape=[1], dtype='int32', lod_level=1) + difficult = fluid.layers.data( + name='gt_difficult', shape=[1], dtype='int32', lod_level=1) mbox_locs, mbox_confs, box, box_var = mobile_net(image, image_shape) nmsed_out = fluid.layers.detection_output(mbox_locs, mbox_confs, box, @@ -239,35 +144,62 @@ def train(train_file_list, box, box_var) loss = fluid.layers.nn.reduce_sum(loss_vec) + map_eval = fluid.evaluator.DetectionMAP( + nmsed_out, + gt_label, + gt_box, + difficult, + 21, + overlap_threshold=0.5, + evaluate_difficult=False, + ap_version='11point') + + test_program = fluid.default_main_program().clone(for_test=True) optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.learning_rate_decay.exponential_decay( + learning_rate=fluid.layers.exponential_decay( learning_rate=learning_rate, decay_steps=40000, decay_rate=0.1, staircase=True), momentum=0.9, - regularization=fluid.regularizer.L2Decay(5 * 1e-5), ) + regularization=fluid.regularizer.L2Decay(0.0005), ) opts = optimizer.minimize(loss) place = fluid.CUDAPlace(0) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + load_model.load_and_set_vars(place) + train_reader = paddle.batch( reader.train(data_args, train_file_list), batch_size=batch_size) test_reader = paddle.batch( reader.test(data_args, train_file_list), batch_size=batch_size) - feeder = fluid.DataFeeder(place=place, feed_list=[image, gt_box, gt_label]) + feeder = fluid.DataFeeder( + place=place, feed_list=[image, gt_box, gt_label, difficult]) + #print fluid.default_main_program() + map, accum_map = map_eval.get_map_var() for pass_id in range(num_passes): + map_eval.reset(exe) for batch_id, data in enumerate(train_reader()): - loss_v = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[loss]) - if batch_id % 50 == 0: - print("Pass {0}, batch {1}, loss {2}".format(pass_id, batch_id, - np.sum(loss_v))) - if pass_id % 1 == 0: + loss_v, map_v, accum_map_v = exe.run( + fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[loss, map, accum_map]) + print( + "Pass {0}, batch {1}, loss {2}, cur_map {3}, map {4}" + .format(pass_id, batch_id, loss_v[0], map_v[0], accum_map_v[0])) + + map_eval.reset(exe) + test_map = None + for _, data in enumerate(test_reader()): + test_map = exe.run(test_program, + feed=feeder.feed(data), + fetch_list=[accum_map]) + print("Test {0}, map {1}".format(pass_id, test_map[0])) + + if pass_id % 10 == 0: model_path = os.path.join(model_save_dir, str(pass_id)) print 'save models to %s' % (model_path) fluid.io.save_inference_model(model_path, ['image'], [nmsed_out], @@ -285,6 +217,6 @@ if __name__ == '__main__': train_file_list='./data/trainval.txt', val_file_list='./data/test.txt', data_args=data_args, - learning_rate=0.001, + learning_rate=0.004, batch_size=32, num_passes=300) diff --git a/fluid/object_detection/reader.py b/fluid/object_detection/reader.py index aa721d89823df01360c21bfbecbd5e3947ffe928..6564384118c55db13d88d36c85ec1212f1be2ce5 100644 --- a/fluid/object_detection/reader.py +++ b/fluid/object_detection/reader.py @@ -159,7 +159,8 @@ def _reader_creator(settings, file_list, mode, shuffle): if mode == 'train' and len(sample_labels) == 0: continue yield img.astype( 'float32' - ), sample_labels[:, 1:5], sample_labels[:, 0].astype('int') + ), sample_labels[:, 1:5], sample_labels[:, 0].astype( + 'int32'), sample_labels[:, 5].astype('int32') elif mode == 'infer': yield img.astype('float32')