提交 58e558fb 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix effne

上级 8e42f786
......@@ -2,7 +2,6 @@ mode: 'train'
ARCHITECTURE:
name: "EfficientNetB0"
params:
is_test: False
padding_type : "SAME"
override_params:
drop_connect_rate: 0.1
......
......@@ -518,7 +518,6 @@ class MbConvBlock(nn.Layer):
use_se,
name=None,
drop_connect_rate=None,
is_test=False,
model_name=None,
cur_stage=None):
super(MbConvBlock, self).__init__()
......@@ -530,7 +529,6 @@ class MbConvBlock(nn.Layer):
self.id_skip = block_args.id_skip
self.expand_ratio = block_args.expand_ratio
self.drop_connect_rate = drop_connect_rate
self.is_test = is_test
if self.expand_ratio != 1:
self._ecn = ExpandConvNorm(
......@@ -583,7 +581,7 @@ class MbConvBlock(nn.Layer):
self.block_args.stride == 1 and \
self.block_args.input_filters == self.block_args.output_filters:
if self.drop_connect_rate:
x = _drop_connect(x, self.drop_connect_rate, self.is_test)
x = _drop_connect(x, self.drop_connect_rate, not self.training)
x = paddle.elementwise_add(x, inputs)
return x
......@@ -623,7 +621,6 @@ class ExtractFeatures(nn.Layer):
_global_params,
padding_type,
use_se,
is_test,
model_name=None):
super(ExtractFeatures, self).__init__()
......@@ -661,7 +658,7 @@ class ExtractFeatures(nn.Layer):
num_repeat=round_repeats(block_args.num_repeat,
_global_params))
drop_connect_rate = self._global_params.drop_connect_rate if not is_test else 0
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / block_size
......@@ -682,7 +679,7 @@ class ExtractFeatures(nn.Layer):
block_args = block_args._replace(
input_filters=block_args.output_filters, stride=1)
for _ in range(block_args.num_repeat - 1):
drop_connect_rate = self._global_params.drop_connect_rate if not is_test else 0
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / block_size
_mc_block = self.add_sublayer(
......@@ -711,7 +708,6 @@ class ExtractFeatures(nn.Layer):
class EfficientNet(nn.Layer):
def __init__(self,
name="b0",
is_test=True,
padding_type="SAME",
override_params=None,
use_se=True,
......@@ -724,7 +720,6 @@ class EfficientNet(nn.Layer):
model_name, override_params)
self.padding_type = padding_type
self.use_se = use_se
self.is_test = is_test
self._ef = ExtractFeatures(
3,
......@@ -732,7 +727,6 @@ class EfficientNet(nn.Layer):
self._global_params,
self.padding_type,
self.use_se,
self.is_test,
model_name=self.name)
output_channels = round_filters(1280, self._global_params)
......@@ -785,14 +779,12 @@ class EfficientNet(nn.Layer):
return x
def EfficientNetB0_small(is_test=True,
padding_type='DYNAMIC',
def EfficientNetB0_small(padding_type='DYNAMIC',
override_params=None,
use_se=False,
**args):
model = EfficientNet(
name='b0',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......@@ -800,14 +792,12 @@ def EfficientNetB0_small(is_test=True,
return model
def EfficientNetB0(is_test=False,
padding_type='SAME',
def EfficientNetB0(padding_type='SAME',
override_params=None,
use_se=True,
**args):
model = EfficientNet(
name='b0',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......@@ -815,14 +805,12 @@ def EfficientNetB0(is_test=False,
return model
def EfficientNetB1(is_test=False,
padding_type='SAME',
def EfficientNetB1(padding_type='SAME',
override_params=None,
use_se=True,
**args):
model = EfficientNet(
name='b1',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......@@ -830,14 +818,12 @@ def EfficientNetB1(is_test=False,
return model
def EfficientNetB2(is_test=False,
padding_type='SAME',
def EfficientNetB2(padding_type='SAME',
override_params=None,
use_se=True,
**args):
model = EfficientNet(
name='b2',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......@@ -845,14 +831,12 @@ def EfficientNetB2(is_test=False,
return model
def EfficientNetB3(is_test=False,
padding_type='SAME',
def EfficientNetB3(padding_type='SAME',
override_params=None,
use_se=True,
**args):
model = EfficientNet(
name='b3',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......@@ -860,14 +844,12 @@ def EfficientNetB3(is_test=False,
return model
def EfficientNetB4(is_test=False,
padding_type='SAME',
def EfficientNetB4(padding_type='SAME',
override_params=None,
use_se=True,
**args):
model = EfficientNet(
name='b4',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......@@ -875,14 +857,12 @@ def EfficientNetB4(is_test=False,
return model
def EfficientNetB5(is_test=False,
padding_type='SAME',
def EfficientNetB5(padding_type='SAME',
override_params=None,
use_se=True,
**args):
model = EfficientNet(
name='b5',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......@@ -890,14 +870,12 @@ def EfficientNetB5(is_test=False,
return model
def EfficientNetB6(is_test=False,
padding_type='SAME',
def EfficientNetB6(padding_type='SAME',
override_params=None,
use_se=True,
**args):
model = EfficientNet(
name='b6',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......@@ -905,14 +883,12 @@ def EfficientNetB6(is_test=False,
return model
def EfficientNetB7(is_test=False,
padding_type='SAME',
def EfficientNetB7(padding_type='SAME',
override_params=None,
use_se=True,
**args):
model = EfficientNet(
name='b7',
is_test=is_test,
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册