提交 9ecc07df 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix squeezenet, vgg and shufflenet

上级 251e47c1
...@@ -35,5 +35,8 @@ from .alexnet import AlexNet ...@@ -35,5 +35,8 @@ from .alexnet import AlexNet
from .inception_v4 import InceptionV4 from .inception_v4 import InceptionV4
from .xception_deeplab import Xception41_deeplab, Xception65_deeplab, Xception71_deeplab from .xception_deeplab import Xception41_deeplab, Xception65_deeplab, Xception71_deeplab
from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish
from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
from .vgg import VGG11, VGG13, VGG16, VGG19
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0 from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0
...@@ -18,15 +18,17 @@ from __future__ import print_function ...@@ -18,15 +18,17 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.param_attr import ParamAttr import paddle.nn as nn
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout import paddle.nn.functional as F
from paddle.fluid.initializer import MSRA from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle.nn.initializer import MSRA
import math import math
__all__ = [ __all__ = [
"ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33", "ShuffleNetV2_x0_5", "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33", "ShuffleNetV2_x0_5",
"ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5", "ShuffleNetV2_x2_0", "ShuffleNetV2", "ShuffleNetV2_x1_5", "ShuffleNetV2_x2_0",
"ShuffleNetV2_swish" "ShuffleNetV2_swish"
] ]
...@@ -37,17 +39,16 @@ def channel_shuffle(x, groups): ...@@ -37,17 +39,16 @@ def channel_shuffle(x, groups):
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
# reshape # reshape
x = fluid.layers.reshape( x = paddle.reshape(
x=x, shape=[batchsize, groups, channels_per_group, height, width]) x=x, shape=[batchsize, groups, channels_per_group, height, width])
x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3, 4]) x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4])
# flatten # flatten
x = fluid.layers.reshape( x = paddle.reshape(x=x, shape=[batchsize, num_channels, height, width])
x=x, shape=[batchsize, num_channels, height, width])
return x return x
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(nn.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
filter_size, filter_size,
...@@ -58,24 +59,21 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -58,24 +59,21 @@ class ConvBNLayer(fluid.dygraph.Layer):
num_groups=1, num_groups=1,
if_act=True, if_act=True,
act='relu', act='relu',
name=None, name=None):
use_cudnn=True):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._if_act = if_act self._if_act = if_act
assert act in ['relu', 'swish'], \ assert act in ['relu', 'swish'], \
"supported act are {} but your act is {}".format( "supported act are {} but your act is {}".format(
['relu', 'swish'], act) ['relu', 'swish'], act)
self._act = act self._act = act
self._conv = Conv2D( self._conv = Conv2d(
num_channels=num_channels, in_channels=num_channels,
num_filters=num_filters, out_channels=num_filters,
filter_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
groups=num_groups, groups=num_groups,
act=None, weight_attr=ParamAttr(
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=name + "_weights"), initializer=MSRA(), name=name + "_weights"),
bias_attr=False) bias_attr=False)
...@@ -90,12 +88,11 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -90,12 +88,11 @@ class ConvBNLayer(fluid.dygraph.Layer):
y = self._conv(inputs) y = self._conv(inputs)
y = self._batch_norm(y) y = self._batch_norm(y)
if self._if_act: if self._if_act:
y = fluid.layers.relu( y = F.relu(y) if self._act == 'relu' else F.swish(y)
y) if self._act == 'relu' else fluid.layers.swish(y)
return y return y
class InvertedResidualUnit(fluid.dygraph.Layer): class InvertedResidualUnit(nn.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
...@@ -130,7 +127,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer): ...@@ -130,7 +127,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
num_groups=oup_inc, num_groups=oup_inc,
if_act=False, if_act=False,
act=act, act=act,
use_cudnn=False,
name='stage_' + name + '_conv2') name='stage_' + name + '_conv2')
self._conv_linear = ConvBNLayer( self._conv_linear = ConvBNLayer(
num_channels=oup_inc, num_channels=oup_inc,
...@@ -153,7 +149,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer): ...@@ -153,7 +149,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
num_groups=inp, num_groups=inp,
if_act=False, if_act=False,
act=act, act=act,
use_cudnn=False,
name='stage_' + name + '_conv4') name='stage_' + name + '_conv4')
self._conv_linear_1 = ConvBNLayer( self._conv_linear_1 = ConvBNLayer(
num_channels=inp, num_channels=inp,
...@@ -185,7 +180,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer): ...@@ -185,7 +180,6 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
num_groups=oup_inc, num_groups=oup_inc,
if_act=False, if_act=False,
act=act, act=act,
use_cudnn=False,
name='stage_' + name + '_conv2') name='stage_' + name + '_conv2')
self._conv_linear_2 = ConvBNLayer( self._conv_linear_2 = ConvBNLayer(
num_channels=oup_inc, num_channels=oup_inc,
...@@ -200,14 +194,14 @@ class InvertedResidualUnit(fluid.dygraph.Layer): ...@@ -200,14 +194,14 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
def forward(self, inputs): def forward(self, inputs):
if self.benchmodel == 1: if self.benchmodel == 1:
x1, x2 = fluid.layers.split( x1, x2 = paddle.split(
inputs, inputs,
num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2],
dim=1) axis=1)
x2 = self._conv_pw(x2) x2 = self._conv_pw(x2)
x2 = self._conv_dw(x2) x2 = self._conv_dw(x2)
x2 = self._conv_linear(x2) x2 = self._conv_linear(x2)
out = fluid.layers.concat([x1, x2], axis=1) out = paddle.concat([x1, x2], axis=1)
else: else:
x1 = self._conv_dw_1(inputs) x1 = self._conv_dw_1(inputs)
x1 = self._conv_linear_1(x1) x1 = self._conv_linear_1(x1)
...@@ -215,12 +209,12 @@ class InvertedResidualUnit(fluid.dygraph.Layer): ...@@ -215,12 +209,12 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
x2 = self._conv_pw_2(inputs) x2 = self._conv_pw_2(inputs)
x2 = self._conv_dw_2(x2) x2 = self._conv_dw_2(x2)
x2 = self._conv_linear_2(x2) x2 = self._conv_linear_2(x2)
out = fluid.layers.concat([x1, x2], axis=1) out = paddle.concat([x1, x2], axis=1)
return channel_shuffle(out, 2) return channel_shuffle(out, 2)
class ShuffleNet(fluid.dygraph.Layer): class ShuffleNet(nn.Layer):
def __init__(self, class_dim=1000, scale=1.0, act='relu'): def __init__(self, class_dim=1000, scale=1.0, act='relu'):
super(ShuffleNet, self).__init__() super(ShuffleNet, self).__init__()
self.scale = scale self.scale = scale
...@@ -252,8 +246,7 @@ class ShuffleNet(fluid.dygraph.Layer): ...@@ -252,8 +246,7 @@ class ShuffleNet(fluid.dygraph.Layer):
if_act=True, if_act=True,
act=act, act=act,
name='stage1_conv') name='stage1_conv')
self._max_pool = Pool2D( self._max_pool = MaxPool2d(kernel_size=3, stride=2, padding=1)
pool_type='max', pool_size=3, pool_stride=2, pool_padding=1)
# 2. bottleneck sequences # 2. bottleneck sequences
self._block_list = [] self._block_list = []
...@@ -298,13 +291,13 @@ class ShuffleNet(fluid.dygraph.Layer): ...@@ -298,13 +291,13 @@ class ShuffleNet(fluid.dygraph.Layer):
name='conv5') name='conv5')
# 4. pool # 4. pool
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True) self._pool2d_avg = AdaptiveAvgPool2d(1)
self._out_c = stage_out_channels[-1] self._out_c = stage_out_channels[-1]
# 5. fc # 5. fc
self._fc = Linear( self._fc = Linear(
stage_out_channels[-1], stage_out_channels[-1],
class_dim, class_dim,
param_attr=ParamAttr(name='fc6_weights'), weight_attr=ParamAttr(name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset')) bias_attr=ParamAttr(name='fc6_offset'))
def forward(self, inputs): def forward(self, inputs):
...@@ -314,7 +307,7 @@ class ShuffleNet(fluid.dygraph.Layer): ...@@ -314,7 +307,7 @@ class ShuffleNet(fluid.dygraph.Layer):
y = inv(y) y = inv(y)
y = self._last_conv(y) y = self._last_conv(y)
y = self._pool2d_avg(y) y = self._pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self._out_c]) y = paddle.reshape(y, shape=[-1, self._out_c])
y = self._fc(y) y = self._fc(y)
return y return y
......
import paddle import paddle
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.param_attr import ParamAttr import paddle.nn as nn
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
__all__ = ["SqueezeNet1_0", "SqueezeNet1_1"] __all__ = ["SqueezeNet1_0", "SqueezeNet1_1"]
class MakeFireConv(fluid.dygraph.Layer):
def __init__(self, class MakeFireConv(nn.Layer):
input_channels, def __init__(self,
output_channels, input_channels,
filter_size, output_channels,
padding=0, filter_size,
name=None): padding=0,
name=None):
super(MakeFireConv, self).__init__() super(MakeFireConv, self).__init__()
self._conv = Conv2D(input_channels, self._conv = Conv2d(
output_channels, input_channels,
filter_size, output_channels,
padding=padding, filter_size,
act="relu", padding=padding,
param_attr=ParamAttr(name=name + "_weights"), weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=ParamAttr(name=name + "_offset")) bias_attr=ParamAttr(name=name + "_offset"))
def forward(self, x):
x = self._conv(x)
x = F.relu(x)
return x
def forward(self, inputs):
return self._conv(inputs)
class MakeFire(fluid.dygraph.Layer): class MakeFire(nn.Layer):
def __init__(self, def __init__(self,
input_channels, input_channels,
squeeze_channels, squeeze_channels,
expand1x1_channels, expand1x1_channels,
expand3x3_channels, expand3x3_channels,
name=None): name=None):
super(MakeFire, self).__init__() super(MakeFire, self).__init__()
self._conv = MakeFireConv(input_channels, self._conv = MakeFireConv(
squeeze_channels, input_channels, squeeze_channels, 1, name=name + "_squeeze1x1")
1, self._conv_path1 = MakeFireConv(
name=name + "_squeeze1x1") squeeze_channels, expand1x1_channels, 1, name=name + "_expand1x1")
self._conv_path1 = MakeFireConv(squeeze_channels, self._conv_path2 = MakeFireConv(
expand1x1_channels, squeeze_channels,
1, expand3x3_channels,
name=name + "_expand1x1") 3,
self._conv_path2 = MakeFireConv(squeeze_channels, padding=1,
expand3x3_channels, name=name + "_expand3x3")
3,
padding=1,
name=name + "_expand3x3")
def forward(self, inputs): def forward(self, inputs):
x = self._conv(inputs) x = self._conv(inputs)
x1 = self._conv_path1(x) x1 = self._conv_path1(x)
x2 = self._conv_path2(x) x2 = self._conv_path2(x)
return fluid.layers.concat([x1, x2], axis=1) return paddle.concat([x1, x2], axis=1)
class SqueezeNet(fluid.dygraph.Layer):
class SqueezeNet(nn.Layer):
def __init__(self, version, class_dim=1000): def __init__(self, version, class_dim=1000):
super(SqueezeNet, self).__init__() super(SqueezeNet, self).__init__()
self.version = version self.version = version
if self.version == "1.0": if self.version == "1.0":
self._conv = Conv2D(3, self._conv = Conv2d(
96, 3,
7, 96,
stride=2, 7,
act="relu", stride=2,
param_attr=ParamAttr(name="conv1_weights"), weight_attr=ParamAttr(name="conv1_weights"),
bias_attr=ParamAttr(name="conv1_offset")) bias_attr=ParamAttr(name="conv1_offset"))
self._pool = Pool2D(pool_size=3, self._pool = MaxPool2d(kernel_size=3, stride=2, padding=0)
pool_stride=2,
pool_type="max")
self._conv1 = MakeFire(96, 16, 64, 64, name="fire2") self._conv1 = MakeFire(96, 16, 64, 64, name="fire2")
self._conv2 = MakeFire(128, 16, 64, 64, name="fire3") self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
self._conv3 = MakeFire(128, 32, 128, 128, name="fire4") self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
...@@ -79,17 +81,15 @@ class SqueezeNet(fluid.dygraph.Layer): ...@@ -79,17 +81,15 @@ class SqueezeNet(fluid.dygraph.Layer):
self._conv8 = MakeFire(512, 64, 256, 256, name="fire9") self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
else: else:
self._conv = Conv2D(3, self._conv = Conv2d(
64, 3,
3, 64,
stride=2, 3,
padding=1, stride=2,
act="relu", padding=1,
param_attr=ParamAttr(name="conv1_weights"), weight_attr=ParamAttr(name="conv1_weights"),
bias_attr=ParamAttr(name="conv1_offset")) bias_attr=ParamAttr(name="conv1_offset"))
self._pool = Pool2D(pool_size=3, self._pool = MaxPool2d(kernel_size=3, stride=2, padding=0)
pool_stride=2,
pool_type="max")
self._conv1 = MakeFire(64, 16, 64, 64, name="fire2") self._conv1 = MakeFire(64, 16, 64, 64, name="fire2")
self._conv2 = MakeFire(128, 16, 64, 64, name="fire3") self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
...@@ -102,19 +102,19 @@ class SqueezeNet(fluid.dygraph.Layer): ...@@ -102,19 +102,19 @@ class SqueezeNet(fluid.dygraph.Layer):
self._conv8 = MakeFire(512, 64, 256, 256, name="fire9") self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
self._drop = Dropout(p=0.5) self._drop = Dropout(p=0.5)
self._conv9 = Conv2D(512, self._conv9 = Conv2d(
class_dim, 512,
1, class_dim,
act="relu", 1,
param_attr=ParamAttr(name="conv10_weights"), weight_attr=ParamAttr(name="conv10_weights"),
bias_attr=ParamAttr(name="conv10_offset")) bias_attr=ParamAttr(name="conv10_offset"))
self._avg_pool = Pool2D(pool_type="avg", self._avg_pool = AdaptiveAvgPool2d(1)
global_pooling=True)
def forward(self, inputs): def forward(self, inputs):
x = self._conv(inputs) x = self._conv(inputs)
x = F.relu(x)
x = self._pool(x) x = self._pool(x)
if self.version=="1.0": if self.version == "1.0":
x = self._conv1(x) x = self._conv1(x)
x = self._conv2(x) x = self._conv2(x)
x = self._conv3(x) x = self._conv3(x)
...@@ -138,14 +138,17 @@ class SqueezeNet(fluid.dygraph.Layer): ...@@ -138,14 +138,17 @@ class SqueezeNet(fluid.dygraph.Layer):
x = self._conv8(x) x = self._conv8(x)
x = self._drop(x) x = self._drop(x)
x = self._conv9(x) x = self._conv9(x)
x = F.relu(x)
x = self._avg_pool(x) x = self._avg_pool(x)
x = fluid.layers.squeeze(x, axes=[2,3]) x = paddle.squeeze(x, axis=[2, 3])
return x return x
def SqueezeNet1_0(**args): def SqueezeNet1_0(**args):
model = SqueezeNet(version="1.0", **args) model = SqueezeNet(version="1.0", **args)
return model return model
def SqueezeNet1_1(**args): def SqueezeNet1_1(**args):
model = SqueezeNet(version="1.1", **args) model = SqueezeNet(version="1.1", **args)
return model return model
\ No newline at end of file
import paddle import paddle
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.param_attr import ParamAttr import paddle.nn as nn
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"] __all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]
class ConvBlock(fluid.dygraph.Layer):
def __init__(self, class ConvBlock(nn.Layer):
input_channels, def __init__(self, input_channels, output_channels, groups, name=None):
output_channels,
groups,
name=None):
super(ConvBlock, self).__init__() super(ConvBlock, self).__init__()
self.groups = groups self.groups = groups
self._conv_1 = Conv2D(num_channels=input_channels, self._conv_1 = Conv2d(
num_filters=output_channels, in_channels=input_channels,
filter_size=3, out_channels=output_channels,
stride=1, kernel_size=3,
padding=1, stride=1,
act="relu", padding=1,
param_attr=ParamAttr(name=name + "1_weights"), weight_attr=ParamAttr(name=name + "1_weights"),
bias_attr=False) bias_attr=False)
if groups == 2 or groups == 3 or groups == 4: if groups == 2 or groups == 3 or groups == 4:
self._conv_2 = Conv2D(num_channels=output_channels, self._conv_2 = Conv2d(
num_filters=output_channels, in_channels=output_channels,
filter_size=3, out_channels=output_channels,
stride=1, kernel_size=3,
padding=1, stride=1,
act="relu", padding=1,
param_attr=ParamAttr(name=name + "2_weights"), weight_attr=ParamAttr(name=name + "2_weights"),
bias_attr=False) bias_attr=False)
if groups == 3 or groups == 4: if groups == 3 or groups == 4:
self._conv_3 = Conv2D(num_channels=output_channels, self._conv_3 = Conv2d(
num_filters=output_channels, in_channels=output_channels,
filter_size=3, out_channels=output_channels,
stride=1, kernel_size=3,
padding=1, stride=1,
act="relu", padding=1,
param_attr=ParamAttr(name=name + "3_weights"), weight_attr=ParamAttr(name=name + "3_weights"),
bias_attr=False) bias_attr=False)
if groups == 4: if groups == 4:
self._conv_4 = Conv2D(num_channels=output_channels, self._conv_4 = Conv2d(
num_filters=output_channels, in_channels=output_channels,
filter_size=3, out_channels=output_channels,
stride=1, kernel_size=3,
padding=1, stride=1,
act="relu", padding=1,
param_attr=ParamAttr(name=name + "4_weights"), weight_attr=ParamAttr(name=name + "4_weights"),
bias_attr=False) bias_attr=False)
self._pool = Pool2D(pool_size=2,
pool_type="max", self._pool = MaxPool2d(kernel_size=2, stride=2, padding=0)
pool_stride=2)
def forward(self, inputs): def forward(self, inputs):
x = self._conv_1(inputs) x = self._conv_1(inputs)
x = F.relu(x)
if self.groups == 2 or self.groups == 3 or self.groups == 4: if self.groups == 2 or self.groups == 3 or self.groups == 4:
x = self._conv_2(x) x = self._conv_2(x)
if self.groups == 3 or self.groups == 4 : x = F.relu(x)
if self.groups == 3 or self.groups == 4:
x = self._conv_3(x) x = self._conv_3(x)
x = F.relu(x)
if self.groups == 4: if self.groups == 4:
x = self._conv_4(x) x = self._conv_4(x)
x = F.relu(x)
x = self._pool(x) x = self._pool(x)
return x return x
class VGGNet(fluid.dygraph.Layer):
class VGGNet(nn.Layer):
def __init__(self, layers=11, class_dim=1000): def __init__(self, layers=11, class_dim=1000):
super(VGGNet, self).__init__() super(VGGNet, self).__init__()
self.layers = layers self.layers = layers
self.vgg_configure = {11: [1, 1, 2, 2, 2], self.vgg_configure = {
13: [2, 2, 2, 2, 2], 11: [1, 1, 2, 2, 2],
16: [2, 2, 3, 3, 3], 13: [2, 2, 2, 2, 2],
19: [2, 2, 4, 4, 4]} 16: [2, 2, 3, 3, 3],
19: [2, 2, 4, 4, 4]
}
assert self.layers in self.vgg_configure.keys(), \ assert self.layers in self.vgg_configure.keys(), \
"supported layers are {} but input layer is {}".format(vgg_configure.keys(), layers) "supported layers are {} but input layer is {}".format(
vgg_configure.keys(), layers)
self.groups = self.vgg_configure[self.layers] self.groups = self.vgg_configure[self.layers]
self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_") self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")
...@@ -83,21 +89,22 @@ class VGGNet(fluid.dygraph.Layer): ...@@ -83,21 +89,22 @@ class VGGNet(fluid.dygraph.Layer):
self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_") self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_") self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
self._drop = fluid.dygraph.Dropout(p=0.5) self._drop = Dropout(p=0.5)
self._fc1 = Linear(input_dim=7*7*512, self._fc1 = Linear(
output_dim=4096, 7 * 7 * 512,
act="relu", 4096,
param_attr=ParamAttr(name="fc6_weights"), weight_attr=ParamAttr(name="fc6_weights"),
bias_attr=ParamAttr(name="fc6_offset")) bias_attr=ParamAttr(name="fc6_offset"))
self._fc2 = Linear(input_dim=4096, self._fc2 = Linear(
output_dim=4096, 4096,
act="relu", 4096,
param_attr=ParamAttr(name="fc7_weights"), weight_attr=ParamAttr(name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset")) bias_attr=ParamAttr(name="fc7_offset"))
self._out = Linear(input_dim=4096, self._out = Linear(
output_dim=class_dim, 4096,
param_attr=ParamAttr(name="fc8_weights"), class_dim,
bias_attr=ParamAttr(name="fc8_offset")) weight_attr=ParamAttr(name="fc8_weights"),
bias_attr=ParamAttr(name="fc8_offset"))
def forward(self, inputs): def forward(self, inputs):
x = self._conv_block_1(inputs) x = self._conv_block_1(inputs)
...@@ -106,26 +113,32 @@ class VGGNet(fluid.dygraph.Layer): ...@@ -106,26 +113,32 @@ class VGGNet(fluid.dygraph.Layer):
x = self._conv_block_4(x) x = self._conv_block_4(x)
x = self._conv_block_5(x) x = self._conv_block_5(x)
x = fluid.layers.reshape(x, [0,-1]) x = paddle.reshape(x, [0, -1])
x = self._fc1(x) x = self._fc1(x)
x = F.relu(x)
x = self._drop(x) x = self._drop(x)
x = self._fc2(x) x = self._fc2(x)
x = F.relu(x)
x = self._drop(x) x = self._drop(x)
x = self._out(x) x = self._out(x)
return x return x
def VGG11(**args): def VGG11(**args):
model = VGGNet(layers=11, **args) model = VGGNet(layers=11, **args)
return model return model
def VGG13(**args): def VGG13(**args):
model = VGGNet(layers=13, **args) model = VGGNet(layers=13, **args)
return model return model
def VGG16(**args): def VGG16(**args):
model = VGGNet(layers=16, **args) model = VGGNet(layers=16, **args)
return model return model
def VGG19(**args): def VGG19(**args):
model = VGGNet(layers=19, **args) model = VGGNet(layers=19, **args)
return model return model
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册