提交 9d054b2a 编写于 作者: G gaotingquan 提交者: Tingquan Gao

support NHWC data format

上级 16968200
...@@ -49,7 +49,8 @@ class ConvLayer(nn.Layer): ...@@ -49,7 +49,8 @@ class ConvLayer(nn.Layer):
stride=1, stride=1,
groups=1, groups=1,
act=None, act=None,
name=None): name=None,
data_format="NCHW"):
super(ConvLayer, self).__init__() super(ConvLayer, self).__init__()
self._conv = Conv2D( self._conv = Conv2D(
...@@ -60,7 +61,8 @@ class ConvLayer(nn.Layer): ...@@ -60,7 +61,8 @@ class ConvLayer(nn.Layer):
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + "_weights"), weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False) bias_attr=False,
data_format=data_format)
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
...@@ -77,29 +79,50 @@ class Inception(nn.Layer): ...@@ -77,29 +79,50 @@ class Inception(nn.Layer):
filter5R, filter5R,
filter5, filter5,
proj, proj,
name=None): name=None,
data_format="NCHW"):
super(Inception, self).__init__() super(Inception, self).__init__()
self.data_format = data_format
self._conv1 = ConvLayer( self._conv1 = ConvLayer(
input_channels, filter1, 1, name="inception_" + name + "_1x1") input_channels,
filter1,
1,
name="inception_" + name + "_1x1",
data_format=data_format)
self._conv3r = ConvLayer( self._conv3r = ConvLayer(
input_channels, input_channels,
filter3R, filter3R,
1, 1,
name="inception_" + name + "_3x3_reduce") name="inception_" + name + "_3x3_reduce",
data_format=data_format)
self._conv3 = ConvLayer( self._conv3 = ConvLayer(
filter3R, filter3, 3, name="inception_" + name + "_3x3") filter3R,
filter3,
3,
name="inception_" + name + "_3x3",
data_format=data_format)
self._conv5r = ConvLayer( self._conv5r = ConvLayer(
input_channels, input_channels,
filter5R, filter5R,
1, 1,
name="inception_" + name + "_5x5_reduce") name="inception_" + name + "_5x5_reduce",
data_format=data_format)
self._conv5 = ConvLayer( self._conv5 = ConvLayer(
filter5R, filter5, 5, name="inception_" + name + "_5x5") filter5R,
self._pool = MaxPool2D(kernel_size=3, stride=1, padding=1) filter5,
5,
name="inception_" + name + "_5x5",
data_format=data_format)
self._pool = MaxPool2D(
kernel_size=3, stride=1, padding=1, data_format=data_format)
self._convprj = ConvLayer( self._convprj = ConvLayer(
input_channels, proj, 1, name="inception_" + name + "_3x3_proj") input_channels,
proj,
1,
name="inception_" + name + "_3x3_proj",
data_format=data_format)
def forward(self, inputs): def forward(self, inputs):
conv1 = self._conv1(inputs) conv1 = self._conv1(inputs)
...@@ -113,50 +136,142 @@ class Inception(nn.Layer): ...@@ -113,50 +136,142 @@ class Inception(nn.Layer):
pool = self._pool(inputs) pool = self._pool(inputs)
convprj = self._convprj(pool) convprj = self._convprj(pool)
cat = paddle.concat([conv1, conv3, conv5, convprj], axis=1) if self.data_format == "NHWC":
cat = paddle.concat([conv1, conv3, conv5, convprj], axis=3)
else:
cat = paddle.concat([conv1, conv3, conv5, convprj], axis=1)
cat = F.relu(cat) cat = F.relu(cat)
return cat return cat
class GoogLeNetDY(nn.Layer): class GoogLeNetDY(nn.Layer):
def __init__(self, class_num=1000): def __init__(self, class_num=1000, data_format="NCHW"):
super(GoogLeNetDY, self).__init__() super(GoogLeNetDY, self).__init__()
self._conv = ConvLayer(3, 64, 7, 2, name="conv1") self.data_format = data_format
self._pool = MaxPool2D(kernel_size=3, stride=2) self._conv = ConvLayer(
self._conv_1 = ConvLayer(64, 64, 1, name="conv2_1x1") 3, 64, 7, 2, name="conv1", data_format=data_format)
self._conv_2 = ConvLayer(64, 192, 3, name="conv2_3x3") self._pool = MaxPool2D(
kernel_size=3, stride=2, data_format=data_format)
self._conv_1 = ConvLayer(
64, 64, 1, name="conv2_1x1", data_format=data_format)
self._conv_2 = ConvLayer(
64, 192, 3, name="conv2_3x3", data_format=data_format)
self._ince3a = Inception( self._ince3a = Inception(
192, 192, 64, 96, 128, 16, 32, 32, name="ince3a") 192,
192,
64,
96,
128,
16,
32,
32,
name="ince3a",
data_format=data_format)
self._ince3b = Inception( self._ince3b = Inception(
256, 256, 128, 128, 192, 32, 96, 64, name="ince3b") 256,
256,
128,
128,
192,
32,
96,
64,
name="ince3b",
data_format=data_format)
self._ince4a = Inception( self._ince4a = Inception(
480, 480, 192, 96, 208, 16, 48, 64, name="ince4a") 480,
480,
192,
96,
208,
16,
48,
64,
name="ince4a",
data_format=data_format)
self._ince4b = Inception( self._ince4b = Inception(
512, 512, 160, 112, 224, 24, 64, 64, name="ince4b") 512,
512,
160,
112,
224,
24,
64,
64,
name="ince4b",
data_format=data_format)
self._ince4c = Inception( self._ince4c = Inception(
512, 512, 128, 128, 256, 24, 64, 64, name="ince4c") 512,
512,
128,
128,
256,
24,
64,
64,
name="ince4c",
data_format=data_format)
self._ince4d = Inception( self._ince4d = Inception(
512, 512, 112, 144, 288, 32, 64, 64, name="ince4d") 512,
512,
112,
144,
288,
32,
64,
64,
name="ince4d",
data_format=data_format)
self._ince4e = Inception( self._ince4e = Inception(
528, 528, 256, 160, 320, 32, 128, 128, name="ince4e") 528,
528,
256,
160,
320,
32,
128,
128,
name="ince4e",
data_format=data_format)
self._ince5a = Inception( self._ince5a = Inception(
832, 832, 256, 160, 320, 32, 128, 128, name="ince5a") 832,
832,
256,
160,
320,
32,
128,
128,
name="ince5a",
data_format=data_format)
self._ince5b = Inception( self._ince5b = Inception(
832, 832, 384, 192, 384, 48, 128, 128, name="ince5b") 832,
832,
self._pool_5 = AdaptiveAvgPool2D(1) 384,
192,
384,
48,
128,
128,
name="ince5b",
data_format=data_format)
self._pool_5 = AdaptiveAvgPool2D(1, data_format=data_format)
self._drop = Dropout(p=0.4, mode="downscale_in_infer") self._drop = Dropout(p=0.4, mode="downscale_in_infer")
self.flatten = nn.Flatten()
self._fc_out = Linear( self._fc_out = Linear(
1024, 1024,
class_num, class_num,
weight_attr=xavier(1024, 1, "out"), weight_attr=xavier(1024, 1, "out"),
bias_attr=ParamAttr(name="out_offset")) bias_attr=ParamAttr(name="out_offset"))
self._pool_o1 = AvgPool2D(kernel_size=5, stride=3) self._pool_o1 = AvgPool2D(
self._conv_o1 = ConvLayer(512, 128, 1, name="conv_o1") kernel_size=5, stride=3, data_format=data_format)
self._conv_o1 = ConvLayer(
512, 128, 1, name="conv_o1", data_format=data_format)
self._fc_o1 = Linear( self._fc_o1 = Linear(
1152, 1152,
1024, 1024,
...@@ -168,8 +283,10 @@ class GoogLeNetDY(nn.Layer): ...@@ -168,8 +283,10 @@ class GoogLeNetDY(nn.Layer):
class_num, class_num,
weight_attr=xavier(1024, 1, "out1"), weight_attr=xavier(1024, 1, "out1"),
bias_attr=ParamAttr(name="out1_offset")) bias_attr=ParamAttr(name="out1_offset"))
self._pool_o2 = AvgPool2D(kernel_size=5, stride=3) self._pool_o2 = AvgPool2D(
self._conv_o2 = ConvLayer(528, 128, 1, name="conv_o2") kernel_size=5, stride=3, data_format=data_format)
self._conv_o2 = ConvLayer(
528, 128, 1, name="conv_o2", data_format=data_format)
self._fc_o2 = Linear( self._fc_o2 = Linear(
1152, 1152,
1024, 1024,
...@@ -183,6 +300,9 @@ class GoogLeNetDY(nn.Layer): ...@@ -183,6 +300,9 @@ class GoogLeNetDY(nn.Layer):
bias_attr=ParamAttr(name="out2_offset")) bias_attr=ParamAttr(name="out2_offset"))
def forward(self, inputs): def forward(self, inputs):
if self.data_format == "NHWC":
inputs = paddle.transpose(inputs, [0, 2, 3, 1])
inputs.stop_gradient = True
x = self._conv(inputs) x = self._conv(inputs)
x = self._pool(x) x = self._pool(x)
x = self._conv_1(x) x = self._conv_1(x)
...@@ -205,12 +325,12 @@ class GoogLeNetDY(nn.Layer): ...@@ -205,12 +325,12 @@ class GoogLeNetDY(nn.Layer):
x = self._pool_5(ince5b) x = self._pool_5(ince5b)
x = self._drop(x) x = self._drop(x)
x = paddle.squeeze(x, axis=[2, 3]) x = self.flatten(x)
out = self._fc_out(x) out = self._fc_out(x)
x = self._pool_o1(ince4a) x = self._pool_o1(ince4a)
x = self._conv_o1(x) x = self._conv_o1(x)
x = paddle.flatten(x, start_axis=1, stop_axis=-1) x = self.flatten(x)
x = self._fc_o1(x) x = self._fc_o1(x)
x = F.relu(x) x = F.relu(x)
x = self._drop_o1(x) x = self._drop_o1(x)
...@@ -218,7 +338,7 @@ class GoogLeNetDY(nn.Layer): ...@@ -218,7 +338,7 @@ class GoogLeNetDY(nn.Layer):
x = self._pool_o2(ince4d) x = self._pool_o2(ince4d)
x = self._conv_o2(x) x = self._conv_o2(x)
x = paddle.flatten(x, start_axis=1, stop_axis=-1) x = self.flatten(x)
x = self._fc_o2(x) x = self._fc_o2(x)
x = self._drop_o2(x) x = self._drop_o2(x)
out2 = self._out2(x) out2 = self._out2(x)
......
...@@ -58,7 +58,8 @@ class ConvBNLayer(nn.Layer): ...@@ -58,7 +58,8 @@ class ConvBNLayer(nn.Layer):
channels=None, channels=None,
num_groups=1, num_groups=1,
name=None, name=None,
use_cudnn=True): use_cudnn=True,
data_format="NCHW"):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._conv = Conv2D( self._conv = Conv2D(
...@@ -69,14 +70,16 @@ class ConvBNLayer(nn.Layer): ...@@ -69,14 +70,16 @@ class ConvBNLayer(nn.Layer):
padding=padding, padding=padding,
groups=num_groups, groups=num_groups,
weight_attr=ParamAttr(name=name + "_weights"), weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False) bias_attr=False,
data_format=data_format)
self._batch_norm = BatchNorm( self._batch_norm = BatchNorm(
num_filters, num_filters,
param_attr=ParamAttr(name=name + "_bn_scale"), param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"), bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean", moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance") moving_variance_name=name + "_bn_variance",
data_layout=data_format)
def forward(self, inputs, if_act=True): def forward(self, inputs, if_act=True):
y = self._conv(inputs) y = self._conv(inputs)
...@@ -87,8 +90,16 @@ class ConvBNLayer(nn.Layer): ...@@ -87,8 +90,16 @@ class ConvBNLayer(nn.Layer):
class InvertedResidualUnit(nn.Layer): class InvertedResidualUnit(nn.Layer):
def __init__(self, num_channels, num_in_filter, num_filters, stride, def __init__(self,
filter_size, padding, expansion_factor, name): num_channels,
num_in_filter,
num_filters,
stride,
filter_size,
padding,
expansion_factor,
name,
data_format="NCHW"):
super(InvertedResidualUnit, self).__init__() super(InvertedResidualUnit, self).__init__()
num_expfilter = int(round(num_in_filter * expansion_factor)) num_expfilter = int(round(num_in_filter * expansion_factor))
self._expand_conv = ConvBNLayer( self._expand_conv = ConvBNLayer(
...@@ -98,7 +109,8 @@ class InvertedResidualUnit(nn.Layer): ...@@ -98,7 +109,8 @@ class InvertedResidualUnit(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
num_groups=1, num_groups=1,
name=name + "_expand") name=name + "_expand",
data_format=data_format)
self._bottleneck_conv = ConvBNLayer( self._bottleneck_conv = ConvBNLayer(
num_channels=num_expfilter, num_channels=num_expfilter,
...@@ -108,7 +120,8 @@ class InvertedResidualUnit(nn.Layer): ...@@ -108,7 +120,8 @@ class InvertedResidualUnit(nn.Layer):
padding=padding, padding=padding,
num_groups=num_expfilter, num_groups=num_expfilter,
use_cudnn=False, use_cudnn=False,
name=name + "_dwise") name=name + "_dwise",
data_format=data_format)
self._linear_conv = ConvBNLayer( self._linear_conv = ConvBNLayer(
num_channels=num_expfilter, num_channels=num_expfilter,
...@@ -117,7 +130,8 @@ class InvertedResidualUnit(nn.Layer): ...@@ -117,7 +130,8 @@ class InvertedResidualUnit(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
num_groups=1, num_groups=1,
name=name + "_linear") name=name + "_linear",
data_format=data_format)
def forward(self, inputs, ifshortcut): def forward(self, inputs, ifshortcut):
y = self._expand_conv(inputs, if_act=True) y = self._expand_conv(inputs, if_act=True)
...@@ -129,7 +143,7 @@ class InvertedResidualUnit(nn.Layer): ...@@ -129,7 +143,7 @@ class InvertedResidualUnit(nn.Layer):
class InvresiBlocks(nn.Layer): class InvresiBlocks(nn.Layer):
def __init__(self, in_c, t, c, n, s, name): def __init__(self, in_c, t, c, n, s, name, data_format="NCHW"):
super(InvresiBlocks, self).__init__() super(InvresiBlocks, self).__init__()
self._first_block = InvertedResidualUnit( self._first_block = InvertedResidualUnit(
...@@ -140,7 +154,8 @@ class InvresiBlocks(nn.Layer): ...@@ -140,7 +154,8 @@ class InvresiBlocks(nn.Layer):
filter_size=3, filter_size=3,
padding=1, padding=1,
expansion_factor=t, expansion_factor=t,
name=name + "_1") name=name + "_1",
data_format=data_format)
self._block_list = [] self._block_list = []
for i in range(1, n): for i in range(1, n):
...@@ -154,7 +169,8 @@ class InvresiBlocks(nn.Layer): ...@@ -154,7 +169,8 @@ class InvresiBlocks(nn.Layer):
filter_size=3, filter_size=3,
padding=1, padding=1,
expansion_factor=t, expansion_factor=t,
name=name + "_" + str(i + 1))) name=name + "_" + str(i + 1),
data_format=data_format))
self._block_list.append(block) self._block_list.append(block)
def forward(self, inputs): def forward(self, inputs):
...@@ -165,10 +181,15 @@ class InvresiBlocks(nn.Layer): ...@@ -165,10 +181,15 @@ class InvresiBlocks(nn.Layer):
class MobileNet(nn.Layer): class MobileNet(nn.Layer):
def __init__(self, class_num=1000, scale=1.0, prefix_name=""): def __init__(self,
class_num=1000,
scale=1.0,
prefix_name="",
data_format="NCHW"):
super(MobileNet, self).__init__() super(MobileNet, self).__init__()
self.scale = scale self.scale = scale
self.class_num = class_num self.class_num = class_num
self.data_format = data_format
bottleneck_params_list = [ bottleneck_params_list = [
(1, 16, 1, 1), (1, 16, 1, 1),
...@@ -186,7 +207,8 @@ class MobileNet(nn.Layer): ...@@ -186,7 +207,8 @@ class MobileNet(nn.Layer):
filter_size=3, filter_size=3,
stride=2, stride=2,
padding=1, padding=1,
name=prefix_name + "conv1_1") name=prefix_name + "conv1_1",
data_format=data_format)
self.block_list = [] self.block_list = []
i = 1 i = 1
...@@ -202,7 +224,8 @@ class MobileNet(nn.Layer): ...@@ -202,7 +224,8 @@ class MobileNet(nn.Layer):
c=int(c * scale), c=int(c * scale),
n=n, n=n,
s=s, s=s,
name=prefix_name + "conv" + str(i))) name=prefix_name + "conv" + str(i),
data_format=data_format))
self.block_list.append(block) self.block_list.append(block)
in_c = int(c * scale) in_c = int(c * scale)
...@@ -213,9 +236,10 @@ class MobileNet(nn.Layer): ...@@ -213,9 +236,10 @@ class MobileNet(nn.Layer):
filter_size=1, filter_size=1,
stride=1, stride=1,
padding=0, padding=0,
name=prefix_name + "conv9") name=prefix_name + "conv9",
data_format=data_format)
self.pool2d_avg = AdaptiveAvgPool2D(1) self.pool2d_avg = AdaptiveAvgPool2D(1, data_format=data_format)
self.out = Linear( self.out = Linear(
self.out_c, self.out_c,
...@@ -224,6 +248,9 @@ class MobileNet(nn.Layer): ...@@ -224,6 +248,9 @@ class MobileNet(nn.Layer):
bias_attr=ParamAttr(name=prefix_name + "fc10_offset")) bias_attr=ParamAttr(name=prefix_name + "fc10_offset"))
def forward(self, inputs): def forward(self, inputs):
if self.data_format == "NHWC":
inputs = paddle.transpose(inputs, [0, 2, 3, 1])
inputs.stop_gradient = True
y = self.conv1(inputs, if_act=True) y = self.conv1(inputs, if_act=True)
for block in self.block_list: for block in self.block_list:
y = block(y) y = block(y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册