提交 58e558fb 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix effne

上级 8e42f786
...@@ -2,7 +2,6 @@ mode: 'train' ...@@ -2,7 +2,6 @@ mode: 'train'
ARCHITECTURE: ARCHITECTURE:
name: "EfficientNetB0" name: "EfficientNetB0"
params: params:
is_test: False
padding_type : "SAME" padding_type : "SAME"
override_params: override_params:
drop_connect_rate: 0.1 drop_connect_rate: 0.1
......
...@@ -518,7 +518,6 @@ class MbConvBlock(nn.Layer): ...@@ -518,7 +518,6 @@ class MbConvBlock(nn.Layer):
use_se, use_se,
name=None, name=None,
drop_connect_rate=None, drop_connect_rate=None,
is_test=False,
model_name=None, model_name=None,
cur_stage=None): cur_stage=None):
super(MbConvBlock, self).__init__() super(MbConvBlock, self).__init__()
...@@ -530,7 +529,6 @@ class MbConvBlock(nn.Layer): ...@@ -530,7 +529,6 @@ class MbConvBlock(nn.Layer):
self.id_skip = block_args.id_skip self.id_skip = block_args.id_skip
self.expand_ratio = block_args.expand_ratio self.expand_ratio = block_args.expand_ratio
self.drop_connect_rate = drop_connect_rate self.drop_connect_rate = drop_connect_rate
self.is_test = is_test
if self.expand_ratio != 1: if self.expand_ratio != 1:
self._ecn = ExpandConvNorm( self._ecn = ExpandConvNorm(
...@@ -583,7 +581,7 @@ class MbConvBlock(nn.Layer): ...@@ -583,7 +581,7 @@ class MbConvBlock(nn.Layer):
self.block_args.stride == 1 and \ self.block_args.stride == 1 and \
self.block_args.input_filters == self.block_args.output_filters: self.block_args.input_filters == self.block_args.output_filters:
if self.drop_connect_rate: if self.drop_connect_rate:
x = _drop_connect(x, self.drop_connect_rate, self.is_test) x = _drop_connect(x, self.drop_connect_rate, not self.training)
x = paddle.elementwise_add(x, inputs) x = paddle.elementwise_add(x, inputs)
return x return x
...@@ -623,7 +621,6 @@ class ExtractFeatures(nn.Layer): ...@@ -623,7 +621,6 @@ class ExtractFeatures(nn.Layer):
_global_params, _global_params,
padding_type, padding_type,
use_se, use_se,
is_test,
model_name=None): model_name=None):
super(ExtractFeatures, self).__init__() super(ExtractFeatures, self).__init__()
...@@ -661,7 +658,7 @@ class ExtractFeatures(nn.Layer): ...@@ -661,7 +658,7 @@ class ExtractFeatures(nn.Layer):
num_repeat=round_repeats(block_args.num_repeat, num_repeat=round_repeats(block_args.num_repeat,
_global_params)) _global_params))
drop_connect_rate = self._global_params.drop_connect_rate if not is_test else 0 drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate: if drop_connect_rate:
drop_connect_rate *= float(idx) / block_size drop_connect_rate *= float(idx) / block_size
...@@ -682,7 +679,7 @@ class ExtractFeatures(nn.Layer): ...@@ -682,7 +679,7 @@ class ExtractFeatures(nn.Layer):
block_args = block_args._replace( block_args = block_args._replace(
input_filters=block_args.output_filters, stride=1) input_filters=block_args.output_filters, stride=1)
for _ in range(block_args.num_repeat - 1): for _ in range(block_args.num_repeat - 1):
drop_connect_rate = self._global_params.drop_connect_rate if not is_test else 0 drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate: if drop_connect_rate:
drop_connect_rate *= float(idx) / block_size drop_connect_rate *= float(idx) / block_size
_mc_block = self.add_sublayer( _mc_block = self.add_sublayer(
...@@ -711,7 +708,6 @@ class ExtractFeatures(nn.Layer): ...@@ -711,7 +708,6 @@ class ExtractFeatures(nn.Layer):
class EfficientNet(nn.Layer): class EfficientNet(nn.Layer):
def __init__(self, def __init__(self,
name="b0", name="b0",
is_test=True,
padding_type="SAME", padding_type="SAME",
override_params=None, override_params=None,
use_se=True, use_se=True,
...@@ -724,7 +720,6 @@ class EfficientNet(nn.Layer): ...@@ -724,7 +720,6 @@ class EfficientNet(nn.Layer):
model_name, override_params) model_name, override_params)
self.padding_type = padding_type self.padding_type = padding_type
self.use_se = use_se self.use_se = use_se
self.is_test = is_test
self._ef = ExtractFeatures( self._ef = ExtractFeatures(
3, 3,
...@@ -732,7 +727,6 @@ class EfficientNet(nn.Layer): ...@@ -732,7 +727,6 @@ class EfficientNet(nn.Layer):
self._global_params, self._global_params,
self.padding_type, self.padding_type,
self.use_se, self.use_se,
self.is_test,
model_name=self.name) model_name=self.name)
output_channels = round_filters(1280, self._global_params) output_channels = round_filters(1280, self._global_params)
...@@ -785,14 +779,12 @@ class EfficientNet(nn.Layer): ...@@ -785,14 +779,12 @@ class EfficientNet(nn.Layer):
return x return x
def EfficientNetB0_small(is_test=True, def EfficientNetB0_small(padding_type='DYNAMIC',
padding_type='DYNAMIC',
override_params=None, override_params=None,
use_se=False, use_se=False,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b0', name='b0',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
...@@ -800,14 +792,12 @@ def EfficientNetB0_small(is_test=True, ...@@ -800,14 +792,12 @@ def EfficientNetB0_small(is_test=True,
return model return model
def EfficientNetB0(is_test=False, def EfficientNetB0(padding_type='SAME',
padding_type='SAME',
override_params=None, override_params=None,
use_se=True, use_se=True,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b0', name='b0',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
...@@ -815,14 +805,12 @@ def EfficientNetB0(is_test=False, ...@@ -815,14 +805,12 @@ def EfficientNetB0(is_test=False,
return model return model
def EfficientNetB1(is_test=False, def EfficientNetB1(padding_type='SAME',
padding_type='SAME',
override_params=None, override_params=None,
use_se=True, use_se=True,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b1', name='b1',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
...@@ -830,14 +818,12 @@ def EfficientNetB1(is_test=False, ...@@ -830,14 +818,12 @@ def EfficientNetB1(is_test=False,
return model return model
def EfficientNetB2(is_test=False, def EfficientNetB2(padding_type='SAME',
padding_type='SAME',
override_params=None, override_params=None,
use_se=True, use_se=True,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b2', name='b2',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
...@@ -845,14 +831,12 @@ def EfficientNetB2(is_test=False, ...@@ -845,14 +831,12 @@ def EfficientNetB2(is_test=False,
return model return model
def EfficientNetB3(is_test=False, def EfficientNetB3(padding_type='SAME',
padding_type='SAME',
override_params=None, override_params=None,
use_se=True, use_se=True,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b3', name='b3',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
...@@ -860,14 +844,12 @@ def EfficientNetB3(is_test=False, ...@@ -860,14 +844,12 @@ def EfficientNetB3(is_test=False,
return model return model
def EfficientNetB4(is_test=False, def EfficientNetB4(padding_type='SAME',
padding_type='SAME',
override_params=None, override_params=None,
use_se=True, use_se=True,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b4', name='b4',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
...@@ -875,14 +857,12 @@ def EfficientNetB4(is_test=False, ...@@ -875,14 +857,12 @@ def EfficientNetB4(is_test=False,
return model return model
def EfficientNetB5(is_test=False, def EfficientNetB5(padding_type='SAME',
padding_type='SAME',
override_params=None, override_params=None,
use_se=True, use_se=True,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b5', name='b5',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
...@@ -890,14 +870,12 @@ def EfficientNetB5(is_test=False, ...@@ -890,14 +870,12 @@ def EfficientNetB5(is_test=False,
return model return model
def EfficientNetB6(is_test=False, def EfficientNetB6(padding_type='SAME',
padding_type='SAME',
override_params=None, override_params=None,
use_se=True, use_se=True,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b6', name='b6',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
...@@ -905,14 +883,12 @@ def EfficientNetB6(is_test=False, ...@@ -905,14 +883,12 @@ def EfficientNetB6(is_test=False,
return model return model
def EfficientNetB7(is_test=False, def EfficientNetB7(padding_type='SAME',
padding_type='SAME',
override_params=None, override_params=None,
use_se=True, use_se=True,
**args): **args):
model = EfficientNet( model = EfficientNet(
name='b7', name='b7',
is_test=is_test,
padding_type=padding_type, padding_type=padding_type,
override_params=override_params, override_params=override_params,
use_se=use_se, use_se=use_se,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册