提交 a0ed3fef 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix res2net

上级 1921935e
...@@ -18,9 +18,12 @@ from __future__ import print_function ...@@ -18,9 +18,12 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.param_attr import ParamAttr import paddle.nn as nn
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle.nn.initializer import Uniform
import math import math
...@@ -31,7 +34,7 @@ __all__ = [ ...@@ -31,7 +34,7 @@ __all__ = [
] ]
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(nn.Layer):
def __init__( def __init__(
self, self,
num_channels, num_channels,
...@@ -45,21 +48,17 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -45,21 +48,17 @@ class ConvBNLayer(fluid.dygraph.Layer):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode self.is_vd_mode = is_vd_mode
self._pool2d_avg = Pool2D( self._pool2d_avg = AvgPool2d(
pool_size=2, kernel_size=2, stride=2, padding=0, ceil_mode=True)
pool_stride=2, self._conv = Conv2d(
pool_padding=0, in_channels=num_channels,
pool_type='avg', out_channels=num_filters,
ceil_mode=True) kernel_size=filter_size,
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False) bias_attr=False)
if name == "conv1": if name == "conv1":
bn_name = "bn_" + name bn_name = "bn_" + name
...@@ -81,7 +80,7 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -81,7 +80,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return y return y
class BottleneckBlock(fluid.dygraph.Layer): class BottleneckBlock(nn.Layer):
def __init__(self, def __init__(self,
num_channels1, num_channels1,
num_channels2, num_channels2,
...@@ -112,8 +111,8 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -112,8 +111,8 @@ class BottleneckBlock(fluid.dygraph.Layer):
act='relu', act='relu',
name=name + '_branch2b_' + str(s + 1))) name=name + '_branch2b_' + str(s + 1)))
self.conv1_list.append(conv1) self.conv1_list.append(conv1)
self.pool2d_avg = Pool2D( self.pool2d_avg = AvgPool2d(
pool_size=3, pool_stride=stride, pool_padding=1, pool_type='avg') kernel_size=3, stride=stride, padding=1, ceil_mode=True)
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
...@@ -135,7 +134,7 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -135,7 +134,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
def forward(self, inputs): def forward(self, inputs):
y = self.conv0(inputs) y = self.conv0(inputs)
xs = fluid.layers.split(y, self.scales, 1) xs = paddle.split(y, self.scales, 1)
ys = [] ys = []
for s, conv1 in enumerate(self.conv1_list): for s, conv1 in enumerate(self.conv1_list):
if s == 0 or self.stride == 2: if s == 0 or self.stride == 2:
...@@ -146,18 +145,18 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -146,18 +145,18 @@ class BottleneckBlock(fluid.dygraph.Layer):
ys.append(xs[-1]) ys.append(xs[-1])
else: else:
ys.append(self.pool2d_avg(xs[-1])) ys.append(self.pool2d_avg(xs[-1]))
conv1 = fluid.layers.concat(ys, axis=1) conv1 = paddle.concat(ys, axis=1)
conv2 = self.conv2(conv1) conv2 = self.conv2(conv1)
if self.shortcut: if self.shortcut:
short = inputs short = inputs
else: else:
short = self.short(inputs) short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2, act='relu') y = paddle.elementwise_add(x=short, y=conv2, act='relu')
return y return y
class Res2Net_vd(fluid.dygraph.Layer): class Res2Net_vd(nn.Layer):
def __init__(self, layers=50, scales=4, width=26, class_dim=1000): def __init__(self, layers=50, scales=4, width=26, class_dim=1000):
super(Res2Net_vd, self).__init__() super(Res2Net_vd, self).__init__()
...@@ -203,8 +202,8 @@ class Res2Net_vd(fluid.dygraph.Layer): ...@@ -203,8 +202,8 @@ class Res2Net_vd(fluid.dygraph.Layer):
stride=1, stride=1,
act='relu', act='relu',
name="conv1_3") name="conv1_3")
self.pool2d_max = Pool2D( self.pool2d_max = MaxPool2d(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') kernel_size=3, stride=2, padding=1, ceil_mode=True)
self.block_list = [] self.block_list = []
for block in range(len(depth)): for block in range(len(depth)):
...@@ -232,8 +231,7 @@ class Res2Net_vd(fluid.dygraph.Layer): ...@@ -232,8 +231,7 @@ class Res2Net_vd(fluid.dygraph.Layer):
self.block_list.append(bottleneck_block) self.block_list.append(bottleneck_block)
shortcut = True shortcut = True
self.pool2d_avg = Pool2D( self.pool2d_avg = AdaptiveAvgPool2d(1)
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg_channels = num_channels[-1] * 2 self.pool2d_avg_channels = num_channels[-1] * 2
...@@ -242,9 +240,8 @@ class Res2Net_vd(fluid.dygraph.Layer): ...@@ -242,9 +240,8 @@ class Res2Net_vd(fluid.dygraph.Layer):
self.out = Linear( self.out = Linear(
self.pool2d_avg_channels, self.pool2d_avg_channels,
class_dim, class_dim,
param_attr=ParamAttr( weight_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), initializer=Uniform(-stdv, stdv), name="fc_weights"),
name="fc_weights"),
bias_attr=ParamAttr(name="fc_offset")) bias_attr=ParamAttr(name="fc_offset"))
def forward(self, inputs): def forward(self, inputs):
...@@ -255,7 +252,7 @@ class Res2Net_vd(fluid.dygraph.Layer): ...@@ -255,7 +252,7 @@ class Res2Net_vd(fluid.dygraph.Layer):
for block in self.block_list: for block in self.block_list:
y = block(y) y = block(y)
y = self.pool2d_avg(y) y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_channels]) y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y) y = self.out(y)
return y return y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册