From 946094e3124bed2b098894d1f9433c3fcb7d1cc4 Mon Sep 17 00:00:00 2001 From: wubinghong Date: Fri, 18 Sep 2020 14:30:11 +0800 Subject: [PATCH] Add drop-connected in efficientnet & refine the bifpn --- configs/efficientdet_d0.yml | 2 +- ppdet/modeling/backbones/bifpn.py | 127 ++++++++++------------- ppdet/modeling/backbones/efficientnet.py | 66 +++++------- 3 files changed, 85 insertions(+), 110 deletions(-) diff --git a/configs/efficientdet_d0.yml b/configs/efficientdet_d0.yml index a0fdbd709..b7c969e76 100644 --- a/configs/efficientdet_d0.yml +++ b/configs/efficientdet_d0.yml @@ -39,7 +39,7 @@ EfficientHead: output_decoder: score_thresh: 0.0 nms_thresh: 0.5 - pre_nms_top_n: 5000 + pre_nms_top_n: 1000 # originally 5000 detections_per_im: 100 nms_eta: 1.0 diff --git a/ppdet/modeling/backbones/bifpn.py b/ppdet/modeling/backbones/bifpn.py index d4f41b076..8d25fa304 100644 --- a/ppdet/modeling/backbones/bifpn.py +++ b/ppdet/modeling/backbones/bifpn.py @@ -41,8 +41,7 @@ class FusionConv(object): groups=self.num_chan, param_attr=ParamAttr( initializer=Xavier(), name=name + '_dw_w'), - bias_attr=False, - use_cudnn=False) + bias_attr=False) # pointwise x = fluid.layers.conv2d( x, @@ -68,53 +67,18 @@ class FusionConv(object): class BiFPNCell(object): def __init__(self, num_chan, levels=5, inputs_layer_num=3): - """ - # Node id starts from the input features and monotonically increase whenever - - # [Node NO.] Here is an example for level P3 - P7: - # {3: [0, 8], - # 4: [1, 7, 9], - # 5: [2, 6, 10], - # 6: [3, 5, 11], - # 7: [4, 12]} - - # [Related Edge] - # {'feat_level': 6, 'inputs_offsets': [3, 4]}, # for P6' - # {'feat_level': 5, 'inputs_offsets': [2, 5]}, # for P5' - # {'feat_level': 4, 'inputs_offsets': [1, 6]}, # for P4' - # {'feat_level': 3, 'inputs_offsets': [0, 7]}, # for P3" - # {'feat_level': 4, 'inputs_offsets': [1, 7, 8]}, # for P4" - # {'feat_level': 5, 'inputs_offsets': [2, 6, 9]}, # for P5" - # {'feat_level': 6, 'inputs_offsets': [3, 5, 10]}, # for P6" - # {'feat_level': 7, 'inputs_offsets': [4, 11]}, # for P7" - - P7 (4) --------------> P7" (12) - |----------| ↑ - ↓ | - P6 (3) --> P6' (5) --> P6" (11) - |----------|----------↑↑ - ↓ | - P5 (2) --> P5' (6) --> P5" (10) - |----------|----------↑↑ - ↓ | - P4 (1) --> P4' (7) --> P4" (9) - |----------|----------↑↑ - |----------↓| - P3 (0) --------------> P3" (8) - """ - super(BiFPNCell, self).__init__() self.levels = levels self.num_chan = num_chan + num_trigates = levels - 2 + num_bigates = levels self.inputs_layer_num = inputs_layer_num - # Learnable weights of [P4", P5", P6"] self.trigates = fluid.layers.create_parameter( - shape=[levels - 2, 3], + shape=[num_trigates, 3], dtype='float32', default_initializer=fluid.initializer.Constant(1.)) - # Learnable weights of [P6', P5', P4', P3", P7"] self.bigates = fluid.layers.create_parameter( - shape=[levels, 2], + shape=[num_bigates, 2], dtype='float32', default_initializer=fluid.initializer.Constant(1.)) self.eps = 1e-4 @@ -123,31 +87,38 @@ class BiFPNCell(object): assert len(inputs) == self.levels assert ((is_first_time) and (len(p4_2_p5_2) != 0)) or ((not is_first_time) and (len(p4_2_p5_2) == 0)) - # upsample operator def upsample(feat): return fluid.layers.resize_nearest(feat, scale=2.) - # downsample operator def downsample(feat): - return fluid.layers.pool2d(feat, pool_type='max', pool_size=3, pool_stride=2, pool_padding='SAME') + return fluid.layers.pool2d( + feat, + pool_type='max', + pool_size=3, + pool_stride=2, + pool_padding='SAME') - # 3x3 fuse conv after OP combine fuse_conv = FusionConv(self.num_chan) - # Normalize weight + # normalize weight trigates = fluid.layers.relu(self.trigates) bigates = fluid.layers.relu(self.bigates) - trigates /= fluid.layers.reduce_sum(trigates, dim=1, keep_dim=True) + self.eps - bigates /= fluid.layers.reduce_sum(bigates, dim=1, keep_dim=True) + self.eps + trigates /= fluid.layers.reduce_sum( + trigates, dim=1, keep_dim=True) + self.eps + bigates /= fluid.layers.reduce_sum( + bigates, dim=1, keep_dim=True) + self.eps - feature_maps = list(inputs) # make a copy, 依次是 [P3, P4, P5, P6, P7] - # top down path + feature_maps = list(inputs) # make a copy # top down path for l in range(self.levels - 1): p = self.levels - l - 2 - w1 = fluid.layers.slice(bigates, axes=[0, 1], starts=[l, 0], ends=[l + 1, 1]) - w2 = fluid.layers.slice(bigates, axes=[0, 1], starts=[l, 1], ends=[l + 1, 2]) - above_layer = upsample(feature_maps[p + 1]) - feature_maps[p] = fuse_conv(w1 * above_layer + w2 * inputs[p], name='{}_tb_{}'.format(cell_name, l)) + w1 = fluid.layers.slice( + bigates, axes=[0, 1], starts=[l, 0], ends=[l + 1, 1]) + w2 = fluid.layers.slice( + bigates, axes=[0, 1], starts=[l, 1], ends=[l + 1, 2]) + above = upsample(feature_maps[p + 1]) + feature_maps[p] = fuse_conv( + w1 * above + w2 * inputs[p], + name='{}_tb_{}'.format(cell_name, l)) # bottom up path for l in range(1, self.levels): p = l @@ -155,26 +126,40 @@ class BiFPNCell(object): below = downsample(feature_maps[p - 1]) if p == self.levels - 1: # handle P7 - w1 = fluid.layers.slice(bigates, axes=[0, 1], starts=[p, 0], ends=[p + 1, 1]) - w2 = fluid.layers.slice(bigates, axes=[0, 1], starts=[p, 1], ends=[p + 1, 2]) - feature_maps[p] = fuse_conv(w1 * below + w2 * inputs[p], name=name) + w1 = fluid.layers.slice( + bigates, axes=[0, 1], starts=[p, 0], ends=[p + 1, 1]) + w2 = fluid.layers.slice( + bigates, axes=[0, 1], starts=[p, 1], ends=[p + 1, 2]) + feature_maps[p] = fuse_conv( + w1 * below + w2 * inputs[p], name=name) else: if is_first_time: if p < self.inputs_layer_num: - w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) - w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) + w1 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) + w2 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) - feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * p4_2_p5_2[p - 1], name=name) + feature_maps[p] = fuse_conv( + w1 * feature_maps[p] + w2 * below + w3 * p4_2_p5_2[p - 1], name=name) else: # For P6" - w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) - w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) - w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) - feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * inputs[p], name=name) + w1 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) + w2 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) + w3 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) + feature_maps[p] = fuse_conv( + w1 * feature_maps[p] + w2 * below + w3 * inputs[p], name=name) else: - w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) - w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) - w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) - feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * inputs[p], name=name) + w1 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) + w2 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) + w3 = fluid.layers.slice( + trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) + feature_maps[p] = fuse_conv( + w1 * feature_maps[p] + w2 * below + w3 * inputs[p], name=name) return feature_maps @@ -197,7 +182,7 @@ class BiFPN(object): def __call__(self, inputs): feats = [] - # Squeeze the channel with 1x1 conv + # NOTE add two extra levels for idx in range(len(inputs)): if inputs[idx].shape[1] != self.num_chan: feat = fluid.layers.conv2d( @@ -212,7 +197,8 @@ class BiFPN(object): feat, momentum=0.997, epsilon=1e-04, - param_attr=ParamAttr(initializer=Constant(1.0), regularizer=L2Decay(0.)), + param_attr=ParamAttr( + initializer=Constant(1.0), regularizer=L2Decay(0.)), bias_attr=ParamAttr(regularizer=L2Decay(0.)), name='resample_bn_{}'.format(idx)) else: @@ -266,7 +252,6 @@ class BiFPN(object): name='resample2_bn_{}'.format(idx)) p4_2_p5_2.append(feat) - # BiFPN, repeated biFPN = BiFPNCell(self.num_chan, self.levels, len(inputs)) for r in range(self.repeat): if r == 0: diff --git a/ppdet/modeling/backbones/efficientnet.py b/ppdet/modeling/backbones/efficientnet.py index be6e5b230..adb606c3e 100644 --- a/ppdet/modeling/backbones/efficientnet.py +++ b/ppdet/modeling/backbones/efficientnet.py @@ -54,8 +54,8 @@ def _decode_block_string(block_string): key, value = splits[:2] options[key] = value - if 's' not in options or len(options['s']) != 2: - raise ValueError('Strides options should be a pair of integers.') + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) return BlockArgs( kernel_size=int(options['k']), @@ -118,23 +118,20 @@ def get_model_params(scale): def round_filters(filters, global_params, skip=False): - """Round number of filters based on depth multiplier.""" multiplier = global_params.width_coefficient - divisor = global_params.depth_divisor - min_depth = global_params.min_depth if skip or not multiplier: return filters - + divisor = global_params.depth_divisor filters *= multiplier - min_depth = min_depth or divisor - new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + min_depth = global_params.min_depth or divisor + new_filters = max(min_depth, + int(filters + divisor / 2) // divisor * divisor) if new_filters < 0.9 * filters: # prevent rounding by more than 10% new_filters += divisor return int(new_filters) def round_repeats(repeats, global_params, skip=False): - """Round number of filters based on depth multiplier.""" multiplier = global_params.depth_coefficient if skip or not multiplier: return repeats @@ -148,8 +145,7 @@ def conv2d(inputs, padding='SAME', groups=1, use_bias=False, - name='conv2d', - use_cudnn=True): + name='conv2d'): param_attr = fluid.ParamAttr(name=name + '_weights') bias_attr = False if use_bias: @@ -164,8 +160,7 @@ def conv2d(inputs, stride=stride, padding=padding, param_attr=param_attr, - bias_attr=bias_attr, - use_cudnn=use_cudnn) + bias_attr=bias_attr) return feats @@ -193,45 +188,42 @@ def _drop_connect(inputs, prob, mode): output = inputs / keep_prob * binary_tensor return output + def mb_conv_block(inputs, input_filters, output_filters, expand_ratio, kernel_size, stride, + id_skip, + drop_connect_rate, momentum, eps, - block_arg, - drop_connect_rate, mode, se_ratio=None, name=None): feats = inputs num_filters = input_filters * expand_ratio - # Expansion if expand_ratio != 1: feats = conv2d(feats, num_filters, 1, name=name + '_expand_conv') feats = batch_norm(feats, momentum, eps, name=name + '_bn0') feats = fluid.layers.swish(feats) - # Depthwise Convolution feats = conv2d( feats, num_filters, kernel_size, stride, groups=num_filters, - name=name + '_depthwise_conv', - use_cudnn=False) + name=name + '_depthwise_conv') feats = batch_norm(feats, momentum, eps, name=name + '_bn1') feats = fluid.layers.swish(feats) - # Squeeze and Excitation if se_ratio is not None: filter_squeezed = max(1, int(input_filters * se_ratio)) squeezed = fluid.layers.pool2d( - feats, pool_type='avg', global_pooling=True, use_cudnn=True) + feats, pool_type='avg', global_pooling=True) squeezed = conv2d( squeezed, filter_squeezed, @@ -243,12 +235,10 @@ def mb_conv_block(inputs, squeezed, num_filters, 1, use_bias=True, name=name + '_se_expand') feats = feats * fluid.layers.sigmoid(squeezed) - # Project_conv_norm feats = conv2d(feats, output_filters, 1, name=name + '_project_conv') feats = batch_norm(feats, momentum, eps, name=name + '_bn2') - # Skip connection and drop connect - if block_arg.id_skip and block_arg.stride == 1 and input_filters == output_filters: + if id_skip and stride == 1 and input_filters == output_filters: if drop_connect_rate: feats = _drop_connect(feats, drop_connect_rate, mode) feats = fluid.layers.elementwise_add(feats, inputs) @@ -268,10 +258,7 @@ class EfficientNet(object): """ __shared__ = ['norm_type'] - def __init__(self, - scale='b0', - use_se=True, - norm_type='bn'): + def __init__(self, scale='b0', use_se=True, norm_type='bn'): assert scale in ['b' + str(i) for i in range(8)], \ "valid scales are b0 - b7" assert norm_type in ['bn', 'sync_bn'], \ @@ -285,21 +272,23 @@ class EfficientNet(object): def __call__(self, inputs, mode): assert mode in ['train', 'test'], \ "only 'train' and 'test' mode are supported" - blocks_args, global_params = get_model_params(self.scale) momentum = global_params.batch_norm_momentum eps = global_params.batch_norm_epsilon - # Stem part. num_filters = round_filters(blocks_args[0].input_filters, global_params, global_params.fix_head_stem) - feats = conv2d(inputs, num_filters=num_filters, filter_size=3, stride=2, name='_conv_stem') + feats = conv2d( + inputs, + num_filters=num_filters, + filter_size=3, + stride=2, + name='_conv_stem') feats = batch_norm(feats, momentum=momentum, eps=eps, name='_bn0') feats = fluid.layers.swish(feats) - # Builds blocks. - feature_maps = [] layer_count = 0 num_blocks = sum([block_arg.num_repeat for block_arg in blocks_args]) + feature_maps = [] for block_arg in blocks_args: # Update block input and output filters based on depth multiplier. @@ -323,10 +312,10 @@ class EfficientNet(object): block_arg.expand_ratio, block_arg.kernel_size, block_arg.stride, + block_arg.id_skip, + drop_connect_rate, momentum, eps, - block_arg, - drop_connect_rate, mode, se_ratio=block_arg.se_ratio, name='_blocks.{}.'.format(layer_count)) @@ -347,15 +336,16 @@ class EfficientNet(object): block_arg.expand_ratio, block_arg.kernel_size, block_arg.stride, + block_arg.id_skip, + drop_connect_rate, momentum, eps, - block_arg, - drop_connect_rate, mode, se_ratio=block_arg.se_ratio, name='_blocks.{}.'.format(layer_count)) + layer_count += 1 feature_maps.append(feats) - return list(feature_maps[i] for i in [2, 4, 6]) # 1/8, 1/16, 1/32 + return list(feature_maps[i] for i in [2, 4, 6]) -- GitLab