diff --git a/ppdet/modeling/architectures/efficientdet.py b/ppdet/modeling/architectures/efficientdet.py index 2e374505d94d6f1d18c1d6b08c291d19fbfe594e..17561b687ef9efcae67eb09652fe9fccccc68d16 100644 --- a/ppdet/modeling/architectures/efficientdet.py +++ b/ppdet/modeling/architectures/efficientdet.py @@ -64,7 +64,7 @@ class EfficientDet(object): mixed_precision_enabled = mixed_precision_global_state() is not None if mixed_precision_enabled: im = fluid.layers.cast(im, 'float16') - body_feats = self.backbone(im, mode) + body_feats = self.backbone(im) if mixed_precision_enabled: body_feats = [fluid.layers.cast(f, 'float32') for f in body_feats] body_feats = self.fpn(body_feats) diff --git a/ppdet/modeling/backbones/bifpn.py b/ppdet/modeling/backbones/bifpn.py index 8d25fa3046387f876c317512408edc51978cd3f9..6f9f52415aa6d45fd5c541fecd699131eb0ded4c 100644 --- a/ppdet/modeling/backbones/bifpn.py +++ b/ppdet/modeling/backbones/bifpn.py @@ -83,9 +83,7 @@ class BiFPNCell(object): default_initializer=fluid.initializer.Constant(1.)) self.eps = 1e-4 - def __call__(self, inputs, cell_name='', is_first_time=False, p4_2_p5_2=[]): - assert len(inputs) == self.levels - assert ((is_first_time) and (len(p4_2_p5_2) != 0)) or ((not is_first_time) and (len(p4_2_p5_2) == 0)) + def __call__(self, inputs, cell_name=''): def upsample(feat): return fluid.layers.resize_nearest(feat, scale=2.) @@ -108,7 +106,8 @@ class BiFPNCell(object): bigates /= fluid.layers.reduce_sum( bigates, dim=1, keep_dim=True) + self.eps - feature_maps = list(inputs) # make a copy # top down path + # top down path + feature_maps = list(inputs[:self.levels]) # make a copy for l in range(self.levels - 1): p = self.levels - l - 2 w1 = fluid.layers.slice( @@ -133,7 +132,8 @@ class BiFPNCell(object): feature_maps[p] = fuse_conv( w1 * below + w2 * inputs[p], name=name) else: - if is_first_time: + # For the first loop in BiFPN + if len(inputs) != self.levels: if p < self.inputs_layer_num: w1 = fluid.layers.slice( trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) @@ -141,7 +141,7 @@ class BiFPNCell(object): trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) feature_maps[p] = fuse_conv( - w1 * feature_maps[p] + w2 * below + w3 * p4_2_p5_2[p - 1], name=name) + w1 * feature_maps[p] + w2 * below + w3 * inputs[p - 1 + self.levels], name=name) else: # For P6" w1 = fluid.layers.slice( trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) @@ -233,7 +233,6 @@ class BiFPN(object): name='resample_downsample_{}'.format(idx)) feats.append(feat) # Handle the p4_2 and p5_2 with another 1x1 conv & bn layer - p4_2_p5_2 = [] for idx in range(1, len(inputs)): feat = fluid.layers.conv2d( inputs[idx], @@ -250,13 +249,10 @@ class BiFPN(object): param_attr=ParamAttr(initializer=Constant(1.0), regularizer=L2Decay(0.)), bias_attr=ParamAttr(regularizer=L2Decay(0.)), name='resample2_bn_{}'.format(idx)) - p4_2_p5_2.append(feat) + feats.append(feat) biFPN = BiFPNCell(self.num_chan, self.levels, len(inputs)) for r in range(self.repeat): - if r == 0: - feats = biFPN(feats, cell_name='bifpn_{}'.format(r), is_first_time=True, p4_2_p5_2=p4_2_p5_2) - else: feats = biFPN(feats, cell_name='bifpn_{}'.format(r)) return feats diff --git a/ppdet/modeling/backbones/efficientnet.py b/ppdet/modeling/backbones/efficientnet.py index adb606c3e3850b2ff69a522f6a4c4b0abb078ad6..c70db3649b9855fb95a08bfc3d4d265358a61541 100644 --- a/ppdet/modeling/backbones/efficientnet.py +++ b/ppdet/modeling/backbones/efficientnet.py @@ -28,15 +28,12 @@ __all__ = ['EfficientNet'] GlobalParams = collections.namedtuple('GlobalParams', [ 'batch_norm_momentum', 'batch_norm_epsilon', 'width_coefficient', - 'depth_coefficient', 'depth_divisor', 'min_depth', 'drop_connect_rate', - 'relu_fn', 'batch_norm', 'use_se', 'local_pooling', 'condconv_num_experts', - 'clip_projection_output', 'blocks_args', 'fix_head_stem' + 'depth_coefficient', 'depth_divisor' ]) BlockArgs = collections.namedtuple('BlockArgs', [ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', - 'expand_ratio', 'id_skip', 'stride', 'se_ratio', 'conv_type', 'fused_conv', - 'super_pixel', 'condconv' + 'expand_ratio', 'stride', 'se_ratio' ]) GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) @@ -63,13 +60,8 @@ def _decode_block_string(block_string): input_filters=int(options['i']), output_filters=int(options['o']), expand_ratio=int(options['e']), - id_skip=('noskip' not in block_string), se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s'][0]), - conv_type=int(options['c']) if 'c' in options else 0, - fused_conv=int(options['f']) if 'f' in options else 0, - super_pixel=int(options['p']) if 'p' in options else 0, - condconv=('cc' in block_string)) + stride=int(options['s'][0])) def get_model_params(scale): @@ -96,34 +88,27 @@ def get_model_params(scale): 'b5': (1.6, 2.2), 'b6': (1.8, 2.6), 'b7': (2.0, 3.1), - 'l2': (4.3, 5.3), } w, d = params_dict[scale] global_params = GlobalParams( - blocks_args=block_strings, batch_norm_momentum=0.99, batch_norm_epsilon=1e-3, - drop_connect_rate=0 if scale == 'b0' else 0.2, width_coefficient=w, depth_coefficient=d, - depth_divisor=8, - min_depth=None, - fix_head_stem=False, - use_se=True, - clip_projection_output=False) + depth_divisor=8) return block_args, global_params -def round_filters(filters, global_params, skip=False): +def round_filters(filters, global_params): multiplier = global_params.width_coefficient - if skip or not multiplier: + if not multiplier: return filters divisor = global_params.depth_divisor filters *= multiplier - min_depth = global_params.min_depth or divisor + min_depth = divisor new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) if new_filters < 0.9 * filters: # prevent rounding by more than 10% @@ -131,9 +116,9 @@ def round_filters(filters, global_params, skip=False): return int(new_filters) -def round_repeats(repeats, global_params, skip=False): +def round_repeats(repeats, global_params): multiplier = global_params.depth_coefficient - if skip or not multiplier: + if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) @@ -178,28 +163,14 @@ def batch_norm(inputs, momentum, eps, name=None): bias_attr=bias_attr) -def _drop_connect(inputs, prob, mode): - if mode != 'train': - return inputs - keep_prob = 1.0 - prob - inputs_shape = fluid.layers.shape(inputs) - random_tensor = keep_prob + fluid.layers.uniform_random(shape=[inputs_shape[0], 1, 1, 1], min=0., max=1.) - binary_tensor = fluid.layers.floor(random_tensor) - output = inputs / keep_prob * binary_tensor - return output - - def mb_conv_block(inputs, input_filters, output_filters, expand_ratio, kernel_size, stride, - id_skip, - drop_connect_rate, momentum, eps, - mode, se_ratio=None, name=None): feats = inputs @@ -238,9 +209,7 @@ def mb_conv_block(inputs, feats = conv2d(feats, output_filters, 1, name=name + '_project_conv') feats = batch_norm(feats, momentum, eps, name=name + '_bn2') - if id_skip and stride == 1 and input_filters == output_filters: - if drop_connect_rate: - feats = _drop_connect(feats, drop_connect_rate, mode) + if stride == 1 and input_filters == output_filters: feats = fluid.layers.elementwise_add(feats, inputs) return feats @@ -269,14 +238,12 @@ class EfficientNet(object): self.scale = scale self.use_se = use_se - def __call__(self, inputs, mode): - assert mode in ['train', 'test'], \ - "only 'train' and 'test' mode are supported" + def __call__(self, inputs): blocks_args, global_params = get_model_params(self.scale) momentum = global_params.batch_norm_momentum eps = global_params.batch_norm_epsilon - num_filters = round_filters(blocks_args[0].input_filters, global_params, global_params.fix_head_stem) + num_filters = round_filters(32, global_params) feats = conv2d( inputs, num_filters=num_filters, @@ -287,61 +254,34 @@ class EfficientNet(object): feats = fluid.layers.swish(feats) layer_count = 0 - num_blocks = sum([block_arg.num_repeat for block_arg in blocks_args]) feature_maps = [] - for block_arg in blocks_args: - # Update block input and output filters based on depth multiplier. - block_arg = block_arg._replace( - input_filters=round_filters(block_arg.input_filters, - global_params), - output_filters=round_filters(block_arg.output_filters, - global_params), - num_repeat=round_repeats(block_arg.num_repeat, - global_params)) - - # The first block needs to take care of stride, - # and filter size increase. - drop_connect_rate = global_params.drop_connect_rate - if drop_connect_rate: - drop_connect_rate *= float(layer_count) / num_blocks - feats = mb_conv_block( - feats, - block_arg.input_filters, - block_arg.output_filters, - block_arg.expand_ratio, - block_arg.kernel_size, - block_arg.stride, - block_arg.id_skip, - drop_connect_rate, - momentum, - eps, - mode, - se_ratio=block_arg.se_ratio, - name='_blocks.{}.'.format(layer_count)) - layer_count += 1 - - # Other block - if block_arg.num_repeat > 1: - block_arg = block_arg._replace(input_filters=block_arg.output_filters, stride=1) - - for _ in range(block_arg.num_repeat - 1): - drop_connect_rate = global_params.drop_connect_rate - if drop_connect_rate: - drop_connect_rate *= float(layer_count) / num_blocks + for b, block_arg in enumerate(blocks_args): + for r in range(block_arg.num_repeat): + input_filters = round_filters(block_arg.input_filters, + global_params) + output_filters = round_filters(block_arg.output_filters, + global_params) + kernel_size = block_arg.kernel_size + stride = block_arg.stride + se_ratio = None + if self.use_se: + se_ratio = block_arg.se_ratio + + if r > 0: + input_filters = output_filters + stride = 1 + feats = mb_conv_block( feats, - block_arg.input_filters, - block_arg.output_filters, + input_filters, + output_filters, block_arg.expand_ratio, - block_arg.kernel_size, - block_arg.stride, - block_arg.id_skip, - drop_connect_rate, + kernel_size, + stride, momentum, eps, - mode, - se_ratio=block_arg.se_ratio, + se_ratio=se_ratio, name='_blocks.{}.'.format(layer_count)) layer_count += 1