未验证 提交 c933dcd8 编写于 作者: L littletomatodonkey 提交者: GitHub

minor fix init (#384)

* add densent init

* fix export model
上级 00a0f7fb
......@@ -23,7 +23,7 @@ from .se_resnet_vd import SE_ResNet18_vd, SE_ResNet34_vd, SE_ResNet50_vd, SE_Res
from .se_resnext_vd import SE_ResNeXt50_vd_32x4d, SE_ResNeXt50_vd_32x4d, SENet154_vd
from .se_resnext import SE_ResNeXt50_32x4d, SE_ResNeXt101_32x4d, SE_ResNeXt152_64x4d
from .dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from .densenet import DenseNet121
from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50
......
......@@ -243,11 +243,12 @@ inp_shape = {
def _drop_connect(inputs, prob, is_test):
if is_test:
return inputs
keep_prob = 1.0 - prob
inputs_shape = paddle.shape(inputs)
random_tensor = keep_prob + paddle.rand(shape=[inputs_shape[0], 1, 1, 1])
binary_tensor = paddle.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
output = paddle.multiply(inputs, binary_tensor) / keep_prob
return output
......@@ -507,7 +508,8 @@ class SEBlock(nn.Layer):
x = self._pool(inputs)
x = self._conv1(x)
x = self._conv2(x)
return paddle.multiply(inputs, x)
out = paddle.multiply(inputs, x)
return out
class MbConvBlock(nn.Layer):
......@@ -572,11 +574,13 @@ class MbConvBlock(nn.Layer):
if self.expand_ratio != 1:
x = self._ecn(x)
x = F.swish(x)
x = self._dcn(x)
x = F.swish(x)
if self.has_se:
x = self._se(x)
x = self._pcn(x)
if self.id_skip and \
self.block_args.stride == 1 and \
self.block_args.input_filters == self.block_args.output_filters:
......
......@@ -65,11 +65,12 @@ def main():
net = architectures.__dict__[args.model]
model = Net(net, to_static, args.class_dim)
load_dygraph_pretrain(
model.pre_net,
path=args.pretrained_model,
load_static_weights=args.load_static_weights)
model.eval()
paddle.jit.save(model, args.output_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册