未验证 提交 c933dcd8 编写于 作者: L littletomatodonkey 提交者: GitHub

minor fix init (#384)

* add densent init

* fix export model
上级 00a0f7fb
...@@ -23,7 +23,7 @@ from .se_resnet_vd import SE_ResNet18_vd, SE_ResNet34_vd, SE_ResNet50_vd, SE_Res ...@@ -23,7 +23,7 @@ from .se_resnet_vd import SE_ResNet18_vd, SE_ResNet34_vd, SE_ResNet50_vd, SE_Res
from .se_resnext_vd import SE_ResNeXt50_vd_32x4d, SE_ResNeXt50_vd_32x4d, SENet154_vd from .se_resnext_vd import SE_ResNeXt50_vd_32x4d, SE_ResNeXt50_vd_32x4d, SENet154_vd
from .se_resnext import SE_ResNeXt50_32x4d, SE_ResNeXt101_32x4d, SE_ResNeXt152_64x4d from .se_resnext import SE_ResNeXt50_32x4d, SE_ResNeXt101_32x4d, SE_ResNeXt152_64x4d
from .dpn import DPN68, DPN92, DPN98, DPN107, DPN131 from .dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from .densenet import DenseNet121 from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7 from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50 from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50
......
...@@ -243,11 +243,12 @@ inp_shape = { ...@@ -243,11 +243,12 @@ inp_shape = {
def _drop_connect(inputs, prob, is_test): def _drop_connect(inputs, prob, is_test):
if is_test: if is_test:
return inputs return inputs
keep_prob = 1.0 - prob keep_prob = 1.0 - prob
inputs_shape = paddle.shape(inputs) inputs_shape = paddle.shape(inputs)
random_tensor = keep_prob + paddle.rand(shape=[inputs_shape[0], 1, 1, 1]) random_tensor = keep_prob + paddle.rand(shape=[inputs_shape[0], 1, 1, 1])
binary_tensor = paddle.floor(random_tensor) binary_tensor = paddle.floor(random_tensor)
output = inputs / keep_prob * binary_tensor output = paddle.multiply(inputs, binary_tensor) / keep_prob
return output return output
...@@ -507,7 +508,8 @@ class SEBlock(nn.Layer): ...@@ -507,7 +508,8 @@ class SEBlock(nn.Layer):
x = self._pool(inputs) x = self._pool(inputs)
x = self._conv1(x) x = self._conv1(x)
x = self._conv2(x) x = self._conv2(x)
return paddle.multiply(inputs, x) out = paddle.multiply(inputs, x)
return out
class MbConvBlock(nn.Layer): class MbConvBlock(nn.Layer):
...@@ -572,11 +574,13 @@ class MbConvBlock(nn.Layer): ...@@ -572,11 +574,13 @@ class MbConvBlock(nn.Layer):
if self.expand_ratio != 1: if self.expand_ratio != 1:
x = self._ecn(x) x = self._ecn(x)
x = F.swish(x) x = F.swish(x)
x = self._dcn(x) x = self._dcn(x)
x = F.swish(x) x = F.swish(x)
if self.has_se: if self.has_se:
x = self._se(x) x = self._se(x)
x = self._pcn(x) x = self._pcn(x)
if self.id_skip and \ if self.id_skip and \
self.block_args.stride == 1 and \ self.block_args.stride == 1 and \
self.block_args.input_filters == self.block_args.output_filters: self.block_args.input_filters == self.block_args.output_filters:
......
...@@ -65,11 +65,12 @@ def main(): ...@@ -65,11 +65,12 @@ def main():
net = architectures.__dict__[args.model] net = architectures.__dict__[args.model]
model = Net(net, to_static, args.class_dim) model = Net(net, to_static, args.class_dim)
load_dygraph_pretrain( load_dygraph_pretrain(
model.pre_net, model.pre_net,
path=args.pretrained_model, path=args.pretrained_model,
load_static_weights=args.load_static_weights) load_static_weights=args.load_static_weights)
model.eval()
paddle.jit.save(model, args.output_path) paddle.jit.save(model, args.output_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册