提交 19472373 编写于 作者: S shippingwang

fix

上级 d4b42a38
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
import math
__all__ = [
"ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"
]
__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
class ResNet():
def __init__(self, layers=50):
self.layers = layers
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(bn_name + "_offset"),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act="relu",
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride,
name=name + "_branch1")
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act="relu")
return layer_helper.append_activation(y)
class BasicBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv1)
def net(self, input, class_dim=1000, data_format="NCHW"):
layers = self.layers
layer_helper = LayerHelper(self.full_name(), act="relu")
return layer_helper.append_activation(y)
class ResNet(fluid.dygraph.Layer):
def __init__(self, layers=50, class_dim=1000):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
......@@ -45,25 +189,24 @@ class ResNet():
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input,
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
name="conv1",
data_format=data_format)
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
data_format=data_format)
act="relu",
name="conv1")
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type="max")
self.block_list = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
......@@ -72,169 +215,80 @@ class ResNet():
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
name=conv_name,
data_format=data_format)
bottleneck_block = self.add_sublayer(
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.block_list.append(bottleneck_block)
shortcut = True
else:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
is_first=block == i == 0,
name=conv_name,
data_format=data_format)
pool = fluid.layers.pool2d(
input=conv,
pool_type='avg',
global_pooling=True,
data_format=data_format)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
name="fc_0.w_0",
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_0.b_0"))
return out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None,
data_format='NCHW'):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1',
data_format=data_format)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.block_list.append(basic_block)
shortcut = True
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
data_layout=data_format)
def shortcut(self, input, ch_out, stride, is_first, name, data_format):
if data_format == 'NCHW':
ch_in = input.shape[1]
else:
ch_in = input.shape[-1]
if ch_in != ch_out or stride != 1 or is_first == True:
return self.conv_bn_layer(
input, ch_out, 1, stride, name=name, data_format=data_format)
else:
return input
self.pool2d_avg = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
def bottleneck_block(self, input, num_filters, stride, name, data_format):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a",
data_format=data_format)
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b",
data_format=data_format)
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c",
data_format=data_format)
short = self.shortcut(
input,
num_filters * 4,
stride,
is_first=False,
name=name + "_branch1",
data_format=data_format)
return fluid.layers.elementwise_add(
x=short, y=conv2, act='relu', name=name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name,
data_format):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a",
data_format=data_format)
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b",
data_format=data_format)
short = self.shortcut(
input,
num_filters,
stride,
is_first,
name=name + "_branch1",
data_format=data_format)
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
return y
def ResNet18():
model = ResNet(layers=18)
def ResNet18(**args):
model = ResNet(layers=18, **args)
return model
def ResNet34():
model = ResNet(layers=34)
def ResNet34(**args):
model = ResNet(layers=34, **args)
return model
def ResNet50():
model = ResNet(layers=50)
def ResNet50(**args):
model = ResNet(layers=50, **args)
return model
def ResNet101():
model = ResNet(layers=101)
def ResNet101(**args):
model = ResNet(layers=101, **args)
return model
def ResNet152():
model = ResNet(layers=152)
def ResNet152(**args):
model = ResNet(layers=152, **args)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册