未验证 提交 4970bca0 编写于 作者: C cuicheng01 提交者: GitHub

Update se_resnext_vd.py

上级 b7b5a0c3
......@@ -18,16 +18,18 @@ from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle.nn.initializer import Uniform
import math
__all__ = ["SE_ResNeXt50_vd_32x4d", "SE_ResNeXt50_vd_32x4d", "SENet154_vd"]
class ConvBNLayer(fluid.dygraph.Layer):
class ConvBNLayer(nn.Layer):
def __init__(
self,
num_channels,
......@@ -37,21 +39,20 @@ class ConvBNLayer(fluid.dygraph.Layer):
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
name=None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = Pool2D(
pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg')
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
self._pool2d_avg = AvgPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
bn_name = name + '_bn'
self._batch_norm = BatchNorm(
......@@ -70,7 +71,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return y
class BottleneckBlock(fluid.dygraph.Layer):
class BottleneckBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
......@@ -106,7 +107,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
num_channels=num_filters * 2 if cardinality == 32 else num_filters,
num_filters=num_filters * 2 if cardinality == 32 else num_filters,
reduction_ratio=reduction_ratio,
name='fc_' + name)
name='fc' + name)
if not shortcut:
self.short = ConvBNLayer(
......@@ -130,15 +131,15 @@ class BottleneckBlock(fluid.dygraph.Layer):
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=scale, act='relu')
y = paddle.elementwise_add(x=short, y=scale, act='relu')
return y
class SELayer(fluid.dygraph.Layer):
class SELayer(nn.Layer):
def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
super(SELayer, self).__init__()
self.pool2d_gap = Pool2D(pool_type='avg', global_pooling=True)
self.pool2d_gap = AdaptiveAvgPool2d(1)
self._num_channels = num_channels
......@@ -147,34 +148,35 @@ class SELayer(fluid.dygraph.Layer):
self.squeeze = Linear(
num_channels,
med_ch,
act="relu",
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv),
name=name + "_sqz_weights"),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
self.relu = nn.ReLU()
stdv = 1.0 / math.sqrt(med_ch * 1.0)
self.excitation = Linear(
med_ch,
num_filters,
act="sigmoid",
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv),
name=name + "_exc_weights"),
bias_attr=ParamAttr(name=name + '_exc_offset'))
self.sigmoid = nn.Sigmoid()
def forward(self, input):
pool = self.pool2d_gap(input)
pool = fluid.layers.reshape(pool, shape=[-1, self._num_channels])
pool = paddle.reshape(pool, shape=[-1, self._num_channels])
squeeze = self.squeeze(pool)
squeeze = self.relu(squeeze)
excitation = self.excitation(squeeze)
excitation = fluid.layers.reshape(
excitation = self.sigmoid(excitation)
excitation = paddle.reshape(
excitation, shape=[-1, self._num_channels, 1, 1])
out = input * excitation
return out
class ResNeXt(fluid.dygraph.Layer):
class ResNeXt(nn.Layer):
def __init__(self, layers=50, class_dim=1000, cardinality=32):
super(ResNeXt, self).__init__()
......@@ -221,8 +223,7 @@ class ResNeXt(fluid.dygraph.Layer):
act='relu',
name="conv1_3")
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.pool2d_max = MaxPool2d(kernel_size=3, stride=2, padding=1)
self.block_list = []
n = 1 if layers == 50 or layers == 101 else 3
......@@ -245,8 +246,7 @@ class ResNeXt(fluid.dygraph.Layer):
self.block_list.append(bottleneck_block)
shortcut = True
self.pool2d_avg = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg = AdaptiveAvgPool2d(1)
self.pool2d_avg_channels = num_channels[-1] * 2
......@@ -255,8 +255,8 @@ class ResNeXt(fluid.dygraph.Layer):
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv),
name="fc6_weights"),
bias_attr=ParamAttr(name="fc6_offset"))
......@@ -268,7 +268,7 @@ class ResNeXt(fluid.dygraph.Layer):
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
return y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册