diff --git a/configs/efficientdet_d0.yml b/configs/efficientdet_d0.yml index 30f8ca117edc8dbf02a750592ee534d1c60eb73c..a0fdbd70936a7ed51fa8b523fda8422d2a3778c5 100644 --- a/configs/efficientdet_d0.yml +++ b/configs/efficientdet_d0.yml @@ -19,9 +19,7 @@ EfficientDet: box_loss_weight: 50. EfficientNet: - # norm_type: sync_bn - # TODO - norm_type: bn + norm_type: sync_bn scale: b0 use_se: true @@ -39,9 +37,9 @@ EfficientHead: alpha: 0.25 delta: 0.1 output_decoder: - score_thresh: 0.05 # originally 0. + score_thresh: 0.0 nms_thresh: 0.5 - pre_nms_top_n: 1000 # originally 5000 + pre_nms_top_n: 5000 detections_per_im: 100 nms_eta: 1.0 diff --git a/ppdet/modeling/architectures/efficientdet.py b/ppdet/modeling/architectures/efficientdet.py index 17561b687ef9efcae67eb09652fe9fccccc68d16..2e374505d94d6f1d18c1d6b08c291d19fbfe594e 100644 --- a/ppdet/modeling/architectures/efficientdet.py +++ b/ppdet/modeling/architectures/efficientdet.py @@ -64,7 +64,7 @@ class EfficientDet(object): mixed_precision_enabled = mixed_precision_global_state() is not None if mixed_precision_enabled: im = fluid.layers.cast(im, 'float16') - body_feats = self.backbone(im) + body_feats = self.backbone(im, mode) if mixed_precision_enabled: body_feats = [fluid.layers.cast(f, 'float32') for f in body_feats] body_feats = self.fpn(body_feats) diff --git a/ppdet/modeling/backbones/bifpn.py b/ppdet/modeling/backbones/bifpn.py index d65517ceaba06b61b60cb8bd44ff369640739f51..d4f41b07686273a9015a5634a66362374aaf8c7b 100644 --- a/ppdet/modeling/backbones/bifpn.py +++ b/ppdet/modeling/backbones/bifpn.py @@ -41,7 +41,8 @@ class FusionConv(object): groups=self.num_chan, param_attr=ParamAttr( initializer=Xavier(), name=name + '_dw_w'), - bias_attr=False) + bias_attr=False, + use_cudnn=False) # pointwise x = fluid.layers.conv2d( x, @@ -66,58 +67,87 @@ class FusionConv(object): class BiFPNCell(object): - def __init__(self, num_chan, levels=5): + def __init__(self, num_chan, levels=5, inputs_layer_num=3): + """ + # Node id starts from the input features and monotonically increase whenever + + # [Node NO.] Here is an example for level P3 - P7: + # {3: [0, 8], + # 4: [1, 7, 9], + # 5: [2, 6, 10], + # 6: [3, 5, 11], + # 7: [4, 12]} + + # [Related Edge] + # {'feat_level': 6, 'inputs_offsets': [3, 4]}, # for P6' + # {'feat_level': 5, 'inputs_offsets': [2, 5]}, # for P5' + # {'feat_level': 4, 'inputs_offsets': [1, 6]}, # for P4' + # {'feat_level': 3, 'inputs_offsets': [0, 7]}, # for P3" + # {'feat_level': 4, 'inputs_offsets': [1, 7, 8]}, # for P4" + # {'feat_level': 5, 'inputs_offsets': [2, 6, 9]}, # for P5" + # {'feat_level': 6, 'inputs_offsets': [3, 5, 10]}, # for P6" + # {'feat_level': 7, 'inputs_offsets': [4, 11]}, # for P7" + + P7 (4) --------------> P7" (12) + |----------| ↑ + ↓ | + P6 (3) --> P6' (5) --> P6" (11) + |----------|----------↑↑ + ↓ | + P5 (2) --> P5' (6) --> P5" (10) + |----------|----------↑↑ + ↓ | + P4 (1) --> P4' (7) --> P4" (9) + |----------|----------↑↑ + |----------↓| + P3 (0) --------------> P3" (8) + """ + super(BiFPNCell, self).__init__() self.levels = levels self.num_chan = num_chan - num_trigates = levels - 2 - num_bigates = levels + self.inputs_layer_num = inputs_layer_num + # Learnable weights of [P4", P5", P6"] self.trigates = fluid.layers.create_parameter( - shape=[num_trigates, 3], + shape=[levels - 2, 3], dtype='float32', default_initializer=fluid.initializer.Constant(1.)) + # Learnable weights of [P6', P5', P4', P3", P7"] self.bigates = fluid.layers.create_parameter( - shape=[num_bigates, 2], + shape=[levels, 2], dtype='float32', default_initializer=fluid.initializer.Constant(1.)) self.eps = 1e-4 - def __call__(self, inputs, cell_name=''): + def __call__(self, inputs, cell_name='', is_first_time=False, p4_2_p5_2=[]): assert len(inputs) == self.levels + assert ((is_first_time) and (len(p4_2_p5_2) != 0)) or ((not is_first_time) and (len(p4_2_p5_2) == 0)) + # upsample operator def upsample(feat): return fluid.layers.resize_nearest(feat, scale=2.) + # downsample operator def downsample(feat): - return fluid.layers.pool2d( - feat, - pool_type='max', - pool_size=3, - pool_stride=2, - pool_padding='SAME') + return fluid.layers.pool2d(feat, pool_type='max', pool_size=3, pool_stride=2, pool_padding='SAME') + # 3x3 fuse conv after OP combine fuse_conv = FusionConv(self.num_chan) - # normalize weight + # Normalize weight trigates = fluid.layers.relu(self.trigates) bigates = fluid.layers.relu(self.bigates) - trigates /= fluid.layers.reduce_sum( - trigates, dim=1, keep_dim=True) + self.eps - bigates /= fluid.layers.reduce_sum( - bigates, dim=1, keep_dim=True) + self.eps + trigates /= fluid.layers.reduce_sum(trigates, dim=1, keep_dim=True) + self.eps + bigates /= fluid.layers.reduce_sum(bigates, dim=1, keep_dim=True) + self.eps - feature_maps = list(inputs) # make a copy + feature_maps = list(inputs) # make a copy, 依次是 [P3, P4, P5, P6, P7] # top down path for l in range(self.levels - 1): p = self.levels - l - 2 - w1 = fluid.layers.slice( - bigates, axes=[0, 1], starts=[l, 0], ends=[l + 1, 1]) - w2 = fluid.layers.slice( - bigates, axes=[0, 1], starts=[l, 1], ends=[l + 1, 2]) - above = upsample(feature_maps[p + 1]) - feature_maps[p] = fuse_conv( - w1 * above + w2 * inputs[p], - name='{}_tb_{}'.format(cell_name, l)) + w1 = fluid.layers.slice(bigates, axes=[0, 1], starts=[l, 0], ends=[l + 1, 1]) + w2 = fluid.layers.slice(bigates, axes=[0, 1], starts=[l, 1], ends=[l + 1, 2]) + above_layer = upsample(feature_maps[p + 1]) + feature_maps[p] = fuse_conv(w1 * above_layer + w2 * inputs[p], name='{}_tb_{}'.format(cell_name, l)) # bottom up path for l in range(1, self.levels): p = l @@ -125,22 +155,26 @@ class BiFPNCell(object): below = downsample(feature_maps[p - 1]) if p == self.levels - 1: # handle P7 - w1 = fluid.layers.slice( - bigates, axes=[0, 1], starts=[p, 0], ends=[p + 1, 1]) - w2 = fluid.layers.slice( - bigates, axes=[0, 1], starts=[p, 1], ends=[p + 1, 2]) - feature_maps[p] = fuse_conv( - w1 * below + w2 * inputs[p], name=name) + w1 = fluid.layers.slice(bigates, axes=[0, 1], starts=[p, 0], ends=[p + 1, 1]) + w2 = fluid.layers.slice(bigates, axes=[0, 1], starts=[p, 1], ends=[p + 1, 2]) + feature_maps[p] = fuse_conv(w1 * below + w2 * inputs[p], name=name) else: - w1 = fluid.layers.slice( - trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) - w2 = fluid.layers.slice( - trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) - w3 = fluid.layers.slice( - trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) - feature_maps[p] = fuse_conv( - w1 * feature_maps[p] + w2 * below + w3 * inputs[p], - name=name) + if is_first_time: + if p < self.inputs_layer_num: + w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) + w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) + w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) + feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * p4_2_p5_2[p - 1], name=name) + else: # For P6" + w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) + w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) + w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) + feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * inputs[p], name=name) + else: + w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) + w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) + w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) + feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * inputs[p], name=name) return feature_maps @@ -163,40 +197,81 @@ class BiFPN(object): def __call__(self, inputs): feats = [] - # NOTE add two extra levels - for idx in range(self.levels): - if idx <= len(inputs): - if idx == len(inputs): - feat = inputs[-1] - else: - feat = inputs[idx] - - if feat.shape[1] != self.num_chan: - feat = fluid.layers.conv2d( - feat, - self.num_chan, - filter_size=1, - padding='SAME', - param_attr=ParamAttr(initializer=Xavier()), - bias_attr=ParamAttr(regularizer=L2Decay(0.))) - feat = fluid.layers.batch_norm( - feat, - momentum=0.997, - epsilon=1e-04, - param_attr=ParamAttr( - initializer=Constant(1.0), regularizer=L2Decay(0.)), - bias_attr=ParamAttr(regularizer=L2Decay(0.))) - - if idx >= len(inputs): - feat = fluid.layers.pool2d( + # Squeeze the channel with 1x1 conv + for idx in range(len(inputs)): + if inputs[idx].shape[1] != self.num_chan: + feat = fluid.layers.conv2d( + inputs[idx], + self.num_chan, + filter_size=1, + padding='SAME', + param_attr=ParamAttr(initializer=Xavier()), + bias_attr=ParamAttr(regularizer=L2Decay(0.)), + name='resample_conv_{}'.format(idx)) + feat = fluid.layers.batch_norm( feat, - pool_type='max', - pool_size=3, - pool_stride=2, - pool_padding='SAME') + momentum=0.997, + epsilon=1e-04, + param_attr=ParamAttr(initializer=Constant(1.0), regularizer=L2Decay(0.)), + bias_attr=ParamAttr(regularizer=L2Decay(0.)), + name='resample_bn_{}'.format(idx)) + else: + feat = inputs[idx] + feats.append(feat) + # Build additional input features that are not from backbone. + # P_7 layer we just use pool2d without conv layer & bn, for the same channel with P_6. + # https://github.com/google/automl/blob/master/efficientdet/keras/efficientdet_keras.py#L820 + for idx in range(len(inputs), self.levels): + if feats[-1].shape[1] != self.num_chan: + feat = fluid.layers.conv2d( + feats[-1], + self.num_chan, + filter_size=1, + padding='SAME', + param_attr=ParamAttr(initializer=Xavier()), + bias_attr=ParamAttr(regularizer=L2Decay(0.)), + name='resample_conv_{}'.format(idx)) + feat = fluid.layers.batch_norm( + feat, + momentum=0.997, + epsilon=1e-04, + param_attr=ParamAttr(initializer=Constant(1.0), regularizer=L2Decay(0.)), + bias_attr=ParamAttr(regularizer=L2Decay(0.)), + name='resample_bn_{}'.format(idx)) + feat = fluid.layers.pool2d( + feat, + pool_type='max', + pool_size=3, + pool_stride=2, + pool_padding='SAME', + name='resample_downsample_{}'.format(idx)) feats.append(feat) + # Handle the p4_2 and p5_2 with another 1x1 conv & bn layer + p4_2_p5_2 = [] + for idx in range(1, len(inputs)): + feat = fluid.layers.conv2d( + inputs[idx], + self.num_chan, + filter_size=1, + padding='SAME', + param_attr=ParamAttr(initializer=Xavier()), + bias_attr=ParamAttr(regularizer=L2Decay(0.)), + name='resample2_conv_{}'.format(idx)) + feat = fluid.layers.batch_norm( + feat, + momentum=0.997, + epsilon=1e-04, + param_attr=ParamAttr(initializer=Constant(1.0), regularizer=L2Decay(0.)), + bias_attr=ParamAttr(regularizer=L2Decay(0.)), + name='resample2_bn_{}'.format(idx)) + p4_2_p5_2.append(feat) - biFPN = BiFPNCell(self.num_chan, self.levels) + # BiFPN, repeated + biFPN = BiFPNCell(self.num_chan, self.levels, len(inputs)) for r in range(self.repeat): - feats = biFPN(feats, 'bifpn_{}'.format(r)) + if r == 0: + feats = biFPN(feats, cell_name='bifpn_{}'.format(r), is_first_time=True, p4_2_p5_2=p4_2_p5_2) + else: + feats = biFPN(feats, cell_name='bifpn_{}'.format(r)) + return feats diff --git a/ppdet/modeling/backbones/efficientnet.py b/ppdet/modeling/backbones/efficientnet.py index c70db3649b9855fb95a08bfc3d4d265358a61541..be6e5b230c5b07ea26c7c9cc688fcb7df788ba85 100644 --- a/ppdet/modeling/backbones/efficientnet.py +++ b/ppdet/modeling/backbones/efficientnet.py @@ -28,12 +28,15 @@ __all__ = ['EfficientNet'] GlobalParams = collections.namedtuple('GlobalParams', [ 'batch_norm_momentum', 'batch_norm_epsilon', 'width_coefficient', - 'depth_coefficient', 'depth_divisor' + 'depth_coefficient', 'depth_divisor', 'min_depth', 'drop_connect_rate', + 'relu_fn', 'batch_norm', 'use_se', 'local_pooling', 'condconv_num_experts', + 'clip_projection_output', 'blocks_args', 'fix_head_stem' ]) BlockArgs = collections.namedtuple('BlockArgs', [ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', - 'expand_ratio', 'stride', 'se_ratio' + 'expand_ratio', 'id_skip', 'stride', 'se_ratio', 'conv_type', 'fused_conv', + 'super_pixel', 'condconv' ]) GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) @@ -51,8 +54,8 @@ def _decode_block_string(block_string): key, value = splits[:2] options[key] = value - assert (('s' in options and len(options['s']) == 1) or - (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + if 's' not in options or len(options['s']) != 2: + raise ValueError('Strides options should be a pair of integers.') return BlockArgs( kernel_size=int(options['k']), @@ -60,8 +63,13 @@ def _decode_block_string(block_string): input_filters=int(options['i']), output_filters=int(options['o']), expand_ratio=int(options['e']), + id_skip=('noskip' not in block_string), se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s'][0])) + stride=int(options['s'][0]), + conv_type=int(options['c']) if 'c' in options else 0, + fused_conv=int(options['f']) if 'f' in options else 0, + super_pixel=int(options['p']) if 'p' in options else 0, + condconv=('cc' in block_string)) def get_model_params(scale): @@ -88,37 +96,47 @@ def get_model_params(scale): 'b5': (1.6, 2.2), 'b6': (1.8, 2.6), 'b7': (2.0, 3.1), + 'l2': (4.3, 5.3), } w, d = params_dict[scale] global_params = GlobalParams( + blocks_args=block_strings, batch_norm_momentum=0.99, batch_norm_epsilon=1e-3, + drop_connect_rate=0 if scale == 'b0' else 0.2, width_coefficient=w, depth_coefficient=d, - depth_divisor=8) + depth_divisor=8, + min_depth=None, + fix_head_stem=False, + use_se=True, + clip_projection_output=False) return block_args, global_params -def round_filters(filters, global_params): +def round_filters(filters, global_params, skip=False): + """Round number of filters based on depth multiplier.""" multiplier = global_params.width_coefficient - if not multiplier: - return filters divisor = global_params.depth_divisor + min_depth = global_params.min_depth + if skip or not multiplier: + return filters + filters *= multiplier - min_depth = divisor - new_filters = max(min_depth, - int(filters + divisor / 2) // divisor * divisor) + min_depth = min_depth or divisor + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) if new_filters < 0.9 * filters: # prevent rounding by more than 10% new_filters += divisor return int(new_filters) -def round_repeats(repeats, global_params): +def round_repeats(repeats, global_params, skip=False): + """Round number of filters based on depth multiplier.""" multiplier = global_params.depth_coefficient - if not multiplier: + if skip or not multiplier: return repeats return int(math.ceil(multiplier * repeats)) @@ -130,7 +148,8 @@ def conv2d(inputs, padding='SAME', groups=1, use_bias=False, - name='conv2d'): + name='conv2d', + use_cudnn=True): param_attr = fluid.ParamAttr(name=name + '_weights') bias_attr = False if use_bias: @@ -145,7 +164,8 @@ def conv2d(inputs, stride=stride, padding=padding, param_attr=param_attr, - bias_attr=bias_attr) + bias_attr=bias_attr, + use_cudnn=use_cudnn) return feats @@ -163,6 +183,16 @@ def batch_norm(inputs, momentum, eps, name=None): bias_attr=bias_attr) +def _drop_connect(inputs, prob, mode): + if mode != 'train': + return inputs + keep_prob = 1.0 - prob + inputs_shape = fluid.layers.shape(inputs) + random_tensor = keep_prob + fluid.layers.uniform_random(shape=[inputs_shape[0], 1, 1, 1], min=0., max=1.) + binary_tensor = fluid.layers.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + def mb_conv_block(inputs, input_filters, output_filters, @@ -171,30 +201,37 @@ def mb_conv_block(inputs, stride, momentum, eps, + block_arg, + drop_connect_rate, + mode, se_ratio=None, name=None): feats = inputs num_filters = input_filters * expand_ratio + # Expansion if expand_ratio != 1: feats = conv2d(feats, num_filters, 1, name=name + '_expand_conv') feats = batch_norm(feats, momentum, eps, name=name + '_bn0') feats = fluid.layers.swish(feats) + # Depthwise Convolution feats = conv2d( feats, num_filters, kernel_size, stride, groups=num_filters, - name=name + '_depthwise_conv') + name=name + '_depthwise_conv', + use_cudnn=False) feats = batch_norm(feats, momentum, eps, name=name + '_bn1') feats = fluid.layers.swish(feats) + # Squeeze and Excitation if se_ratio is not None: filter_squeezed = max(1, int(input_filters * se_ratio)) squeezed = fluid.layers.pool2d( - feats, pool_type='avg', global_pooling=True) + feats, pool_type='avg', global_pooling=True, use_cudnn=True) squeezed = conv2d( squeezed, filter_squeezed, @@ -206,10 +243,14 @@ def mb_conv_block(inputs, squeezed, num_filters, 1, use_bias=True, name=name + '_se_expand') feats = feats * fluid.layers.sigmoid(squeezed) + # Project_conv_norm feats = conv2d(feats, output_filters, 1, name=name + '_project_conv') feats = batch_norm(feats, momentum, eps, name=name + '_bn2') - if stride == 1 and input_filters == output_filters: + # Skip connection and drop connect + if block_arg.id_skip and block_arg.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + feats = _drop_connect(feats, drop_connect_rate, mode) feats = fluid.layers.elementwise_add(feats, inputs) return feats @@ -227,7 +268,10 @@ class EfficientNet(object): """ __shared__ = ['norm_type'] - def __init__(self, scale='b0', use_se=True, norm_type='bn'): + def __init__(self, + scale='b0', + use_se=True, + norm_type='bn'): assert scale in ['b' + str(i) for i in range(8)], \ "valid scales are b0 - b7" assert norm_type in ['bn', 'sync_bn'], \ @@ -238,54 +282,80 @@ class EfficientNet(object): self.scale = scale self.use_se = use_se - def __call__(self, inputs): + def __call__(self, inputs, mode): + assert mode in ['train', 'test'], \ + "only 'train' and 'test' mode are supported" + blocks_args, global_params = get_model_params(self.scale) momentum = global_params.batch_norm_momentum eps = global_params.batch_norm_epsilon - num_filters = round_filters(32, global_params) - feats = conv2d( - inputs, - num_filters=num_filters, - filter_size=3, - stride=2, - name='_conv_stem') + # Stem part. + num_filters = round_filters(blocks_args[0].input_filters, global_params, global_params.fix_head_stem) + feats = conv2d(inputs, num_filters=num_filters, filter_size=3, stride=2, name='_conv_stem') feats = batch_norm(feats, momentum=momentum, eps=eps, name='_bn0') feats = fluid.layers.swish(feats) - layer_count = 0 + # Builds blocks. feature_maps = [] - - for b, block_arg in enumerate(blocks_args): - for r in range(block_arg.num_repeat): - input_filters = round_filters(block_arg.input_filters, - global_params) - output_filters = round_filters(block_arg.output_filters, - global_params) - kernel_size = block_arg.kernel_size - stride = block_arg.stride - se_ratio = None - if self.use_se: - se_ratio = block_arg.se_ratio - - if r > 0: - input_filters = output_filters - stride = 1 - + layer_count = 0 + num_blocks = sum([block_arg.num_repeat for block_arg in blocks_args]) + + for block_arg in blocks_args: + # Update block input and output filters based on depth multiplier. + block_arg = block_arg._replace( + input_filters=round_filters(block_arg.input_filters, + global_params), + output_filters=round_filters(block_arg.output_filters, + global_params), + num_repeat=round_repeats(block_arg.num_repeat, + global_params)) + + # The first block needs to take care of stride, + # and filter size increase. + drop_connect_rate = global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(layer_count) / num_blocks + feats = mb_conv_block( + feats, + block_arg.input_filters, + block_arg.output_filters, + block_arg.expand_ratio, + block_arg.kernel_size, + block_arg.stride, + momentum, + eps, + block_arg, + drop_connect_rate, + mode, + se_ratio=block_arg.se_ratio, + name='_blocks.{}.'.format(layer_count)) + layer_count += 1 + + # Other block + if block_arg.num_repeat > 1: + block_arg = block_arg._replace(input_filters=block_arg.output_filters, stride=1) + + for _ in range(block_arg.num_repeat - 1): + drop_connect_rate = global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(layer_count) / num_blocks feats = mb_conv_block( feats, - input_filters, - output_filters, + block_arg.input_filters, + block_arg.output_filters, block_arg.expand_ratio, - kernel_size, - stride, + block_arg.kernel_size, + block_arg.stride, momentum, eps, - se_ratio=se_ratio, + block_arg, + drop_connect_rate, + mode, + se_ratio=block_arg.se_ratio, name='_blocks.{}.'.format(layer_count)) - layer_count += 1 feature_maps.append(feats) - return list(feature_maps[i] for i in [2, 4, 6]) + return list(feature_maps[i] for i in [2, 4, 6]) # 1/8, 1/16, 1/32