提交 8064bab9 编写于 作者: W wubinghong

Add drop-connected in efficientnet & refine the bifpn

上级 ae38108a
...@@ -19,9 +19,7 @@ EfficientDet: ...@@ -19,9 +19,7 @@ EfficientDet:
box_loss_weight: 50. box_loss_weight: 50.
EfficientNet: EfficientNet:
# norm_type: sync_bn norm_type: sync_bn
# TODO
norm_type: bn
scale: b0 scale: b0
use_se: true use_se: true
...@@ -39,9 +37,9 @@ EfficientHead: ...@@ -39,9 +37,9 @@ EfficientHead:
alpha: 0.25 alpha: 0.25
delta: 0.1 delta: 0.1
output_decoder: output_decoder:
score_thresh: 0.05 # originally 0. score_thresh: 0.0
nms_thresh: 0.5 nms_thresh: 0.5
pre_nms_top_n: 1000 # originally 5000 pre_nms_top_n: 5000
detections_per_im: 100 detections_per_im: 100
nms_eta: 1.0 nms_eta: 1.0
......
...@@ -64,7 +64,7 @@ class EfficientDet(object): ...@@ -64,7 +64,7 @@ class EfficientDet(object):
mixed_precision_enabled = mixed_precision_global_state() is not None mixed_precision_enabled = mixed_precision_global_state() is not None
if mixed_precision_enabled: if mixed_precision_enabled:
im = fluid.layers.cast(im, 'float16') im = fluid.layers.cast(im, 'float16')
body_feats = self.backbone(im) body_feats = self.backbone(im, mode)
if mixed_precision_enabled: if mixed_precision_enabled:
body_feats = [fluid.layers.cast(f, 'float32') for f in body_feats] body_feats = [fluid.layers.cast(f, 'float32') for f in body_feats]
body_feats = self.fpn(body_feats) body_feats = self.fpn(body_feats)
......
...@@ -41,7 +41,8 @@ class FusionConv(object): ...@@ -41,7 +41,8 @@ class FusionConv(object):
groups=self.num_chan, groups=self.num_chan,
param_attr=ParamAttr( param_attr=ParamAttr(
initializer=Xavier(), name=name + '_dw_w'), initializer=Xavier(), name=name + '_dw_w'),
bias_attr=False) bias_attr=False,
use_cudnn=False)
# pointwise # pointwise
x = fluid.layers.conv2d( x = fluid.layers.conv2d(
x, x,
...@@ -66,58 +67,87 @@ class FusionConv(object): ...@@ -66,58 +67,87 @@ class FusionConv(object):
class BiFPNCell(object): class BiFPNCell(object):
def __init__(self, num_chan, levels=5): def __init__(self, num_chan, levels=5, inputs_layer_num=3):
"""
# Node id starts from the input features and monotonically increase whenever
# [Node NO.] Here is an example for level P3 - P7:
# {3: [0, 8],
# 4: [1, 7, 9],
# 5: [2, 6, 10],
# 6: [3, 5, 11],
# 7: [4, 12]}
# [Related Edge]
# {'feat_level': 6, 'inputs_offsets': [3, 4]}, # for P6'
# {'feat_level': 5, 'inputs_offsets': [2, 5]}, # for P5'
# {'feat_level': 4, 'inputs_offsets': [1, 6]}, # for P4'
# {'feat_level': 3, 'inputs_offsets': [0, 7]}, # for P3"
# {'feat_level': 4, 'inputs_offsets': [1, 7, 8]}, # for P4"
# {'feat_level': 5, 'inputs_offsets': [2, 6, 9]}, # for P5"
# {'feat_level': 6, 'inputs_offsets': [3, 5, 10]}, # for P6"
# {'feat_level': 7, 'inputs_offsets': [4, 11]}, # for P7"
P7 (4) --------------> P7" (12)
|----------| ↑
↓ |
P6 (3) --> P6' (5) --> P6" (11)
|----------|----------↑↑
↓ |
P5 (2) --> P5' (6) --> P5" (10)
|----------|----------↑↑
↓ |
P4 (1) --> P4' (7) --> P4" (9)
|----------|----------↑↑
|----------↓|
P3 (0) --------------> P3" (8)
"""
super(BiFPNCell, self).__init__() super(BiFPNCell, self).__init__()
self.levels = levels self.levels = levels
self.num_chan = num_chan self.num_chan = num_chan
num_trigates = levels - 2 self.inputs_layer_num = inputs_layer_num
num_bigates = levels # Learnable weights of [P4", P5", P6"]
self.trigates = fluid.layers.create_parameter( self.trigates = fluid.layers.create_parameter(
shape=[num_trigates, 3], shape=[levels - 2, 3],
dtype='float32', dtype='float32',
default_initializer=fluid.initializer.Constant(1.)) default_initializer=fluid.initializer.Constant(1.))
# Learnable weights of [P6', P5', P4', P3", P7"]
self.bigates = fluid.layers.create_parameter( self.bigates = fluid.layers.create_parameter(
shape=[num_bigates, 2], shape=[levels, 2],
dtype='float32', dtype='float32',
default_initializer=fluid.initializer.Constant(1.)) default_initializer=fluid.initializer.Constant(1.))
self.eps = 1e-4 self.eps = 1e-4
def __call__(self, inputs, cell_name=''): def __call__(self, inputs, cell_name='', is_first_time=False, p4_2_p5_2=[]):
assert len(inputs) == self.levels assert len(inputs) == self.levels
assert ((is_first_time) and (len(p4_2_p5_2) != 0)) or ((not is_first_time) and (len(p4_2_p5_2) == 0))
# upsample operator
def upsample(feat): def upsample(feat):
return fluid.layers.resize_nearest(feat, scale=2.) return fluid.layers.resize_nearest(feat, scale=2.)
# downsample operator
def downsample(feat): def downsample(feat):
return fluid.layers.pool2d( return fluid.layers.pool2d(feat, pool_type='max', pool_size=3, pool_stride=2, pool_padding='SAME')
feat,
pool_type='max',
pool_size=3,
pool_stride=2,
pool_padding='SAME')
# 3x3 fuse conv after OP combine
fuse_conv = FusionConv(self.num_chan) fuse_conv = FusionConv(self.num_chan)
# normalize weight # Normalize weight
trigates = fluid.layers.relu(self.trigates) trigates = fluid.layers.relu(self.trigates)
bigates = fluid.layers.relu(self.bigates) bigates = fluid.layers.relu(self.bigates)
trigates /= fluid.layers.reduce_sum( trigates /= fluid.layers.reduce_sum(trigates, dim=1, keep_dim=True) + self.eps
trigates, dim=1, keep_dim=True) + self.eps bigates /= fluid.layers.reduce_sum(bigates, dim=1, keep_dim=True) + self.eps
bigates /= fluid.layers.reduce_sum(
bigates, dim=1, keep_dim=True) + self.eps
feature_maps = list(inputs) # make a copy feature_maps = list(inputs) # make a copy, 依次是 [P3, P4, P5, P6, P7]
# top down path # top down path
for l in range(self.levels - 1): for l in range(self.levels - 1):
p = self.levels - l - 2 p = self.levels - l - 2
w1 = fluid.layers.slice( w1 = fluid.layers.slice(bigates, axes=[0, 1], starts=[l, 0], ends=[l + 1, 1])
bigates, axes=[0, 1], starts=[l, 0], ends=[l + 1, 1]) w2 = fluid.layers.slice(bigates, axes=[0, 1], starts=[l, 1], ends=[l + 1, 2])
w2 = fluid.layers.slice( above_layer = upsample(feature_maps[p + 1])
bigates, axes=[0, 1], starts=[l, 1], ends=[l + 1, 2]) feature_maps[p] = fuse_conv(w1 * above_layer + w2 * inputs[p], name='{}_tb_{}'.format(cell_name, l))
above = upsample(feature_maps[p + 1])
feature_maps[p] = fuse_conv(
w1 * above + w2 * inputs[p],
name='{}_tb_{}'.format(cell_name, l))
# bottom up path # bottom up path
for l in range(1, self.levels): for l in range(1, self.levels):
p = l p = l
...@@ -125,22 +155,26 @@ class BiFPNCell(object): ...@@ -125,22 +155,26 @@ class BiFPNCell(object):
below = downsample(feature_maps[p - 1]) below = downsample(feature_maps[p - 1])
if p == self.levels - 1: if p == self.levels - 1:
# handle P7 # handle P7
w1 = fluid.layers.slice( w1 = fluid.layers.slice(bigates, axes=[0, 1], starts=[p, 0], ends=[p + 1, 1])
bigates, axes=[0, 1], starts=[p, 0], ends=[p + 1, 1]) w2 = fluid.layers.slice(bigates, axes=[0, 1], starts=[p, 1], ends=[p + 1, 2])
w2 = fluid.layers.slice( feature_maps[p] = fuse_conv(w1 * below + w2 * inputs[p], name=name)
bigates, axes=[0, 1], starts=[p, 1], ends=[p + 1, 2])
feature_maps[p] = fuse_conv(
w1 * below + w2 * inputs[p], name=name)
else: else:
w1 = fluid.layers.slice( if is_first_time:
trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1]) if p < self.inputs_layer_num:
w2 = fluid.layers.slice( w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1])
trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2]) w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2])
w3 = fluid.layers.slice( w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3])
trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3]) feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * p4_2_p5_2[p - 1], name=name)
feature_maps[p] = fuse_conv( else: # For P6"
w1 * feature_maps[p] + w2 * below + w3 * inputs[p], w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1])
name=name) w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2])
w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3])
feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * inputs[p], name=name)
else:
w1 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1])
w2 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2])
w3 = fluid.layers.slice(trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3])
feature_maps[p] = fuse_conv(w1 * feature_maps[p] + w2 * below + w3 * inputs[p], name=name)
return feature_maps return feature_maps
...@@ -163,40 +197,81 @@ class BiFPN(object): ...@@ -163,40 +197,81 @@ class BiFPN(object):
def __call__(self, inputs): def __call__(self, inputs):
feats = [] feats = []
# NOTE add two extra levels # Squeeze the channel with 1x1 conv
for idx in range(self.levels): for idx in range(len(inputs)):
if idx <= len(inputs): if inputs[idx].shape[1] != self.num_chan:
if idx == len(inputs): feat = fluid.layers.conv2d(
feat = inputs[-1] inputs[idx],
else: self.num_chan,
feat = inputs[idx] filter_size=1,
padding='SAME',
if feat.shape[1] != self.num_chan: param_attr=ParamAttr(initializer=Xavier()),
feat = fluid.layers.conv2d( bias_attr=ParamAttr(regularizer=L2Decay(0.)),
feat, name='resample_conv_{}'.format(idx))
self.num_chan, feat = fluid.layers.batch_norm(
filter_size=1,
padding='SAME',
param_attr=ParamAttr(initializer=Xavier()),
bias_attr=ParamAttr(regularizer=L2Decay(0.)))
feat = fluid.layers.batch_norm(
feat,
momentum=0.997,
epsilon=1e-04,
param_attr=ParamAttr(
initializer=Constant(1.0), regularizer=L2Decay(0.)),
bias_attr=ParamAttr(regularizer=L2Decay(0.)))
if idx >= len(inputs):
feat = fluid.layers.pool2d(
feat, feat,
pool_type='max', momentum=0.997,
pool_size=3, epsilon=1e-04,
pool_stride=2, param_attr=ParamAttr(initializer=Constant(1.0), regularizer=L2Decay(0.)),
pool_padding='SAME') bias_attr=ParamAttr(regularizer=L2Decay(0.)),
name='resample_bn_{}'.format(idx))
else:
feat = inputs[idx]
feats.append(feat)
# Build additional input features that are not from backbone.
# P_7 layer we just use pool2d without conv layer & bn, for the same channel with P_6.
# https://github.com/google/automl/blob/master/efficientdet/keras/efficientdet_keras.py#L820
for idx in range(len(inputs), self.levels):
if feats[-1].shape[1] != self.num_chan:
feat = fluid.layers.conv2d(
feats[-1],
self.num_chan,
filter_size=1,
padding='SAME',
param_attr=ParamAttr(initializer=Xavier()),
bias_attr=ParamAttr(regularizer=L2Decay(0.)),
name='resample_conv_{}'.format(idx))
feat = fluid.layers.batch_norm(
feat,
momentum=0.997,
epsilon=1e-04,
param_attr=ParamAttr(initializer=Constant(1.0), regularizer=L2Decay(0.)),
bias_attr=ParamAttr(regularizer=L2Decay(0.)),
name='resample_bn_{}'.format(idx))
feat = fluid.layers.pool2d(
feat,
pool_type='max',
pool_size=3,
pool_stride=2,
pool_padding='SAME',
name='resample_downsample_{}'.format(idx))
feats.append(feat) feats.append(feat)
# Handle the p4_2 and p5_2 with another 1x1 conv & bn layer
p4_2_p5_2 = []
for idx in range(1, len(inputs)):
feat = fluid.layers.conv2d(
inputs[idx],
self.num_chan,
filter_size=1,
padding='SAME',
param_attr=ParamAttr(initializer=Xavier()),
bias_attr=ParamAttr(regularizer=L2Decay(0.)),
name='resample2_conv_{}'.format(idx))
feat = fluid.layers.batch_norm(
feat,
momentum=0.997,
epsilon=1e-04,
param_attr=ParamAttr(initializer=Constant(1.0), regularizer=L2Decay(0.)),
bias_attr=ParamAttr(regularizer=L2Decay(0.)),
name='resample2_bn_{}'.format(idx))
p4_2_p5_2.append(feat)
biFPN = BiFPNCell(self.num_chan, self.levels) # BiFPN, repeated
biFPN = BiFPNCell(self.num_chan, self.levels, len(inputs))
for r in range(self.repeat): for r in range(self.repeat):
feats = biFPN(feats, 'bifpn_{}'.format(r)) if r == 0:
feats = biFPN(feats, cell_name='bifpn_{}'.format(r), is_first_time=True, p4_2_p5_2=p4_2_p5_2)
else:
feats = biFPN(feats, cell_name='bifpn_{}'.format(r))
return feats return feats
...@@ -28,12 +28,15 @@ __all__ = ['EfficientNet'] ...@@ -28,12 +28,15 @@ __all__ = ['EfficientNet']
GlobalParams = collections.namedtuple('GlobalParams', [ GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'width_coefficient', 'batch_norm_momentum', 'batch_norm_epsilon', 'width_coefficient',
'depth_coefficient', 'depth_divisor' 'depth_coefficient', 'depth_divisor', 'min_depth', 'drop_connect_rate',
'relu_fn', 'batch_norm', 'use_se', 'local_pooling', 'condconv_num_experts',
'clip_projection_output', 'blocks_args', 'fix_head_stem'
]) ])
BlockArgs = collections.namedtuple('BlockArgs', [ BlockArgs = collections.namedtuple('BlockArgs', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'stride', 'se_ratio' 'expand_ratio', 'id_skip', 'stride', 'se_ratio', 'conv_type', 'fused_conv',
'super_pixel', 'condconv'
]) ])
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
...@@ -51,8 +54,8 @@ def _decode_block_string(block_string): ...@@ -51,8 +54,8 @@ def _decode_block_string(block_string):
key, value = splits[:2] key, value = splits[:2]
options[key] = value options[key] = value
assert (('s' in options and len(options['s']) == 1) or if 's' not in options or len(options['s']) != 2:
(len(options['s']) == 2 and options['s'][0] == options['s'][1])) raise ValueError('Strides options should be a pair of integers.')
return BlockArgs( return BlockArgs(
kernel_size=int(options['k']), kernel_size=int(options['k']),
...@@ -60,8 +63,13 @@ def _decode_block_string(block_string): ...@@ -60,8 +63,13 @@ def _decode_block_string(block_string):
input_filters=int(options['i']), input_filters=int(options['i']),
output_filters=int(options['o']), output_filters=int(options['o']),
expand_ratio=int(options['e']), expand_ratio=int(options['e']),
id_skip=('noskip' not in block_string),
se_ratio=float(options['se']) if 'se' in options else None, se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s'][0])) stride=int(options['s'][0]),
conv_type=int(options['c']) if 'c' in options else 0,
fused_conv=int(options['f']) if 'f' in options else 0,
super_pixel=int(options['p']) if 'p' in options else 0,
condconv=('cc' in block_string))
def get_model_params(scale): def get_model_params(scale):
...@@ -88,37 +96,47 @@ def get_model_params(scale): ...@@ -88,37 +96,47 @@ def get_model_params(scale):
'b5': (1.6, 2.2), 'b5': (1.6, 2.2),
'b6': (1.8, 2.6), 'b6': (1.8, 2.6),
'b7': (2.0, 3.1), 'b7': (2.0, 3.1),
'l2': (4.3, 5.3),
} }
w, d = params_dict[scale] w, d = params_dict[scale]
global_params = GlobalParams( global_params = GlobalParams(
blocks_args=block_strings,
batch_norm_momentum=0.99, batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3, batch_norm_epsilon=1e-3,
drop_connect_rate=0 if scale == 'b0' else 0.2,
width_coefficient=w, width_coefficient=w,
depth_coefficient=d, depth_coefficient=d,
depth_divisor=8) depth_divisor=8,
min_depth=None,
fix_head_stem=False,
use_se=True,
clip_projection_output=False)
return block_args, global_params return block_args, global_params
def round_filters(filters, global_params): def round_filters(filters, global_params, skip=False):
"""Round number of filters based on depth multiplier."""
multiplier = global_params.width_coefficient multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor divisor = global_params.depth_divisor
min_depth = global_params.min_depth
if skip or not multiplier:
return filters
filters *= multiplier filters *= multiplier
min_depth = divisor min_depth = min_depth or divisor
new_filters = max(min_depth, new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10% if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor new_filters += divisor
return int(new_filters) return int(new_filters)
def round_repeats(repeats, global_params): def round_repeats(repeats, global_params, skip=False):
"""Round number of filters based on depth multiplier."""
multiplier = global_params.depth_coefficient multiplier = global_params.depth_coefficient
if not multiplier: if skip or not multiplier:
return repeats return repeats
return int(math.ceil(multiplier * repeats)) return int(math.ceil(multiplier * repeats))
...@@ -130,7 +148,8 @@ def conv2d(inputs, ...@@ -130,7 +148,8 @@ def conv2d(inputs,
padding='SAME', padding='SAME',
groups=1, groups=1,
use_bias=False, use_bias=False,
name='conv2d'): name='conv2d',
use_cudnn=True):
param_attr = fluid.ParamAttr(name=name + '_weights') param_attr = fluid.ParamAttr(name=name + '_weights')
bias_attr = False bias_attr = False
if use_bias: if use_bias:
...@@ -145,7 +164,8 @@ def conv2d(inputs, ...@@ -145,7 +164,8 @@ def conv2d(inputs,
stride=stride, stride=stride,
padding=padding, padding=padding,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr,
use_cudnn=use_cudnn)
return feats return feats
...@@ -163,6 +183,16 @@ def batch_norm(inputs, momentum, eps, name=None): ...@@ -163,6 +183,16 @@ def batch_norm(inputs, momentum, eps, name=None):
bias_attr=bias_attr) bias_attr=bias_attr)
def _drop_connect(inputs, prob, mode):
if mode != 'train':
return inputs
keep_prob = 1.0 - prob
inputs_shape = fluid.layers.shape(inputs)
random_tensor = keep_prob + fluid.layers.uniform_random(shape=[inputs_shape[0], 1, 1, 1], min=0., max=1.)
binary_tensor = fluid.layers.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def mb_conv_block(inputs, def mb_conv_block(inputs,
input_filters, input_filters,
output_filters, output_filters,
...@@ -171,30 +201,37 @@ def mb_conv_block(inputs, ...@@ -171,30 +201,37 @@ def mb_conv_block(inputs,
stride, stride,
momentum, momentum,
eps, eps,
block_arg,
drop_connect_rate,
mode,
se_ratio=None, se_ratio=None,
name=None): name=None):
feats = inputs feats = inputs
num_filters = input_filters * expand_ratio num_filters = input_filters * expand_ratio
# Expansion
if expand_ratio != 1: if expand_ratio != 1:
feats = conv2d(feats, num_filters, 1, name=name + '_expand_conv') feats = conv2d(feats, num_filters, 1, name=name + '_expand_conv')
feats = batch_norm(feats, momentum, eps, name=name + '_bn0') feats = batch_norm(feats, momentum, eps, name=name + '_bn0')
feats = fluid.layers.swish(feats) feats = fluid.layers.swish(feats)
# Depthwise Convolution
feats = conv2d( feats = conv2d(
feats, feats,
num_filters, num_filters,
kernel_size, kernel_size,
stride, stride,
groups=num_filters, groups=num_filters,
name=name + '_depthwise_conv') name=name + '_depthwise_conv',
use_cudnn=False)
feats = batch_norm(feats, momentum, eps, name=name + '_bn1') feats = batch_norm(feats, momentum, eps, name=name + '_bn1')
feats = fluid.layers.swish(feats) feats = fluid.layers.swish(feats)
# Squeeze and Excitation
if se_ratio is not None: if se_ratio is not None:
filter_squeezed = max(1, int(input_filters * se_ratio)) filter_squeezed = max(1, int(input_filters * se_ratio))
squeezed = fluid.layers.pool2d( squeezed = fluid.layers.pool2d(
feats, pool_type='avg', global_pooling=True) feats, pool_type='avg', global_pooling=True, use_cudnn=True)
squeezed = conv2d( squeezed = conv2d(
squeezed, squeezed,
filter_squeezed, filter_squeezed,
...@@ -206,10 +243,14 @@ def mb_conv_block(inputs, ...@@ -206,10 +243,14 @@ def mb_conv_block(inputs,
squeezed, num_filters, 1, use_bias=True, name=name + '_se_expand') squeezed, num_filters, 1, use_bias=True, name=name + '_se_expand')
feats = feats * fluid.layers.sigmoid(squeezed) feats = feats * fluid.layers.sigmoid(squeezed)
# Project_conv_norm
feats = conv2d(feats, output_filters, 1, name=name + '_project_conv') feats = conv2d(feats, output_filters, 1, name=name + '_project_conv')
feats = batch_norm(feats, momentum, eps, name=name + '_bn2') feats = batch_norm(feats, momentum, eps, name=name + '_bn2')
if stride == 1 and input_filters == output_filters: # Skip connection and drop connect
if block_arg.id_skip and block_arg.stride == 1 and input_filters == output_filters:
if drop_connect_rate:
feats = _drop_connect(feats, drop_connect_rate, mode)
feats = fluid.layers.elementwise_add(feats, inputs) feats = fluid.layers.elementwise_add(feats, inputs)
return feats return feats
...@@ -227,7 +268,10 @@ class EfficientNet(object): ...@@ -227,7 +268,10 @@ class EfficientNet(object):
""" """
__shared__ = ['norm_type'] __shared__ = ['norm_type']
def __init__(self, scale='b0', use_se=True, norm_type='bn'): def __init__(self,
scale='b0',
use_se=True,
norm_type='bn'):
assert scale in ['b' + str(i) for i in range(8)], \ assert scale in ['b' + str(i) for i in range(8)], \
"valid scales are b0 - b7" "valid scales are b0 - b7"
assert norm_type in ['bn', 'sync_bn'], \ assert norm_type in ['bn', 'sync_bn'], \
...@@ -238,54 +282,80 @@ class EfficientNet(object): ...@@ -238,54 +282,80 @@ class EfficientNet(object):
self.scale = scale self.scale = scale
self.use_se = use_se self.use_se = use_se
def __call__(self, inputs): def __call__(self, inputs, mode):
assert mode in ['train', 'test'], \
"only 'train' and 'test' mode are supported"
blocks_args, global_params = get_model_params(self.scale) blocks_args, global_params = get_model_params(self.scale)
momentum = global_params.batch_norm_momentum momentum = global_params.batch_norm_momentum
eps = global_params.batch_norm_epsilon eps = global_params.batch_norm_epsilon
num_filters = round_filters(32, global_params) # Stem part.
feats = conv2d( num_filters = round_filters(blocks_args[0].input_filters, global_params, global_params.fix_head_stem)
inputs, feats = conv2d(inputs, num_filters=num_filters, filter_size=3, stride=2, name='_conv_stem')
num_filters=num_filters,
filter_size=3,
stride=2,
name='_conv_stem')
feats = batch_norm(feats, momentum=momentum, eps=eps, name='_bn0') feats = batch_norm(feats, momentum=momentum, eps=eps, name='_bn0')
feats = fluid.layers.swish(feats) feats = fluid.layers.swish(feats)
layer_count = 0 # Builds blocks.
feature_maps = [] feature_maps = []
layer_count = 0
for b, block_arg in enumerate(blocks_args): num_blocks = sum([block_arg.num_repeat for block_arg in blocks_args])
for r in range(block_arg.num_repeat):
input_filters = round_filters(block_arg.input_filters, for block_arg in blocks_args:
global_params) # Update block input and output filters based on depth multiplier.
output_filters = round_filters(block_arg.output_filters, block_arg = block_arg._replace(
global_params) input_filters=round_filters(block_arg.input_filters,
kernel_size = block_arg.kernel_size global_params),
stride = block_arg.stride output_filters=round_filters(block_arg.output_filters,
se_ratio = None global_params),
if self.use_se: num_repeat=round_repeats(block_arg.num_repeat,
se_ratio = block_arg.se_ratio global_params))
if r > 0: # The first block needs to take care of stride,
input_filters = output_filters # and filter size increase.
stride = 1 drop_connect_rate = global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(layer_count) / num_blocks
feats = mb_conv_block(
feats,
block_arg.input_filters,
block_arg.output_filters,
block_arg.expand_ratio,
block_arg.kernel_size,
block_arg.stride,
momentum,
eps,
block_arg,
drop_connect_rate,
mode,
se_ratio=block_arg.se_ratio,
name='_blocks.{}.'.format(layer_count))
layer_count += 1
# Other block
if block_arg.num_repeat > 1:
block_arg = block_arg._replace(input_filters=block_arg.output_filters, stride=1)
for _ in range(block_arg.num_repeat - 1):
drop_connect_rate = global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(layer_count) / num_blocks
feats = mb_conv_block( feats = mb_conv_block(
feats, feats,
input_filters, block_arg.input_filters,
output_filters, block_arg.output_filters,
block_arg.expand_ratio, block_arg.expand_ratio,
kernel_size, block_arg.kernel_size,
stride, block_arg.stride,
momentum, momentum,
eps, eps,
se_ratio=se_ratio, block_arg,
drop_connect_rate,
mode,
se_ratio=block_arg.se_ratio,
name='_blocks.{}.'.format(layer_count)) name='_blocks.{}.'.format(layer_count))
layer_count += 1 layer_count += 1
feature_maps.append(feats) feature_maps.append(feats)
return list(feature_maps[i] for i in [2, 4, 6]) return list(feature_maps[i] for i in [2, 4, 6]) # 1/8, 1/16, 1/32
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册