提交 4ffcb3b1 编写于 作者: Z zhongpu 提交者: hong

fix se_resnet model for dygraph incompatible upgrade (#4111)

* fix se_resnet model for Optimizer and Linear upgrade, test=develop

* polish code style for se_resnet, test=develop
上级 f9243e6a
......@@ -21,7 +21,7 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
import sys
import math
......@@ -59,7 +59,7 @@ momentum_rate = 0.9
l2_decay = 1.2e-4
def optimizer_setting(params):
def optimizer_setting(params, parameter_list):
ls = params["learning_strategy"]
if "total_images" not in params:
total_images = 6149
......@@ -75,23 +75,24 @@ def optimizer_setting(params):
learning_rate=fluid.layers.cosine_decay(
learning_rate=lr, step_each_epoch=step, epochs=num_epochs),
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
regularization=fluid.regularizer.L2Decay(l2_decay),
parameter_list=parameter_list)
return optimizer
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__(name_scope)
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
"conv2d",
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
......@@ -101,7 +102,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
bias_attr=False,
param_attr=fluid.ParamAttr(name="weights"))
self._batch_norm = BatchNorm(self.full_name(), num_filters, act=act)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
y = self._conv(inputs)
......@@ -111,29 +112,30 @@ class ConvBNLayer(fluid.dygraph.Layer):
class SqueezeExcitation(fluid.dygraph.Layer):
def __init__(self, name_scope, num_channels, reduction_ratio):
def __init__(self, num_channels, reduction_ratio):
super(SqueezeExcitation, self).__init__(name_scope)
self._pool = Pool2D(
self.full_name(), pool_size=0, pool_type='avg', global_pooling=True)
super(SqueezeExcitation, self).__init__()
self._num_channels = num_channels
self._pool = Pool2D(pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(num_channels * 1.0)
self._squeeze = FC(
self.full_name(),
size=num_channels // reduction_ratio,
self._fc = Linear(
num_channels,
num_channels // reduction_ratio,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)),
act='relu')
stdv = 1.0 / math.sqrt(num_channels / 16.0 * 1.0)
self._excitation = FC(
self.full_name(),
size=num_channels,
self._excitation = Linear(
num_channels // reduction_ratio,
num_channels,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)),
act='sigmoid')
def forward(self, input):
y = self._pool(input)
y = self._squeeze(y)
y = fluid.layers.reshape(y, shape=[-1, self._num_channels])
y = self._fc(y)
y = self._excitation(y)
y = fluid.layers.elementwise_mul(x=input, y=y, axis=0)
return y
......@@ -141,41 +143,39 @@ class SqueezeExcitation(fluid.dygraph.Layer):
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
stride,
cardinality,
reduction_ratio,
shortcut=True):
super(BottleneckBlock, self).__init__(name_scope)
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act="relu")
self.conv1 = ConvBNLayer(
self.full_name(),
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
groups=cardinality,
act="relu")
self.conv2 = ConvBNLayer(
self.full_name(),
num_channels=num_filters,
num_filters=num_filters * 2,
filter_size=1,
act=None)
self.scale = SqueezeExcitation(
self.full_name(),
num_channels=num_filters * 2,
reduction_ratio=reduction_ratio)
if not shortcut:
self.short = ConvBNLayer(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters * 2,
filter_size=1,
stride=stride)
......@@ -200,8 +200,8 @@ class BottleneckBlock(fluid.dygraph.Layer):
class SeResNeXt(fluid.dygraph.Layer):
def __init__(self, name_scope, layers=50, class_dim=102):
super(SeResNeXt, self).__init__(name_scope)
def __init__(self, layers=50, class_dim=102):
super(SeResNeXt, self).__init__()
self.layers = layers
supported_layers = [50, 101, 152]
......@@ -214,13 +214,12 @@ class SeResNeXt(fluid.dygraph.Layer):
depth = [3, 4, 6, 3]
num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
self.pool = Pool2D(
self.full_name(),
pool_size=3,
pool_stride=2,
pool_padding=1,
......@@ -231,13 +230,12 @@ class SeResNeXt(fluid.dygraph.Layer):
depth = [3, 4, 23, 3]
num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
self.pool = Pool2D(
self.full_name(),
pool_size=3,
pool_stride=2,
pool_padding=1,
......@@ -248,25 +246,24 @@ class SeResNeXt(fluid.dygraph.Layer):
depth = [3, 8, 36, 3]
num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=3,
num_filters=64,
filter_size=3,
stride=2,
act='relu')
self.conv1 = ConvBNLayer(
self.full_name(),
num_channels=64,
num_filters=64,
filter_size=3,
stride=1,
act='relu')
self.conv2 = ConvBNLayer(
self.full_name(),
num_channels=64,
num_filters=128,
filter_size=3,
stride=1,
act='relu')
self.pool = Pool2D(
self.full_name(),
pool_size=3,
pool_stride=2,
pool_padding=1,
......@@ -274,13 +271,14 @@ class SeResNeXt(fluid.dygraph.Layer):
self.bottleneck_block_list = []
num_channels = 64
if layers == 152:
num_channels = 128
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
......@@ -292,11 +290,13 @@ class SeResNeXt(fluid.dygraph.Layer):
shortcut = True
self.pool2d_avg = Pool2D(
self.full_name(), pool_size=7, pool_type='avg', global_pooling=True)
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.out = FC(self.full_name(),
size=class_dim,
self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 2 * 1 * 1
self.out = Linear(self.pool2d_avg_output,
class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
......@@ -306,14 +306,15 @@ class SeResNeXt(fluid.dygraph.Layer):
y = self.pool(y)
elif self.layers == 152:
y = self.conv0(inputs)
y = self.conv1(inputs)
y = self.conv2(inputs)
y = self.conv1(y)
y = self.conv2(y)
y = self.pool(y)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = fluid.layers.dropout(y, dropout_prob=0.5, seed=100)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
y = self.out(y)
return y
......@@ -383,8 +384,8 @@ def train():
fluid.default_main_program().random_seed = seed
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
se_resnext = SeResNeXt("se_resnext")
optimizer = optimizer_setting(train_parameters)
se_resnext = SeResNeXt()
optimizer = optimizer_setting(train_parameters, se_resnext.parameters())
if args.use_data_parallel:
se_resnext = fluid.dygraph.parallel.DataParallel(se_resnext,
strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册