提交 34edfa05 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add dpn, densenet and hrnet dygraph model

上级 3b93ffa0
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np import numpy as np
import time import argparse
import sys import ast
import math import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import math
import sys
import time
__all__ = [
"DPN",
"DPN68",
"DPN92",
"DPN98",
"DPN107",
"DPN131",
]
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
pad=0,
groups=1,
act="relu",
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=pad,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=name + '_bn_scale'),
bias_attr=ParamAttr(name + '_bn_offset'),
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
def forward(self, input):
y = self._conv(input)
y = self._batch_norm(y)
return y
class BNACConvLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
pad=0,
groups=1,
act="relu",
name=None):
super(BNACConvLayer, self).__init__()
self.num_channels = num_channels
self.name = name
self._batch_norm = BatchNorm(
num_channels,
act=act,
param_attr=ParamAttr(name=name + '_bn_scale'),
bias_attr=ParamAttr(name + '_bn_offset'),
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=pad,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
def forward(self, input):
y = self._batch_norm(input)
y = self._conv(y)
return y
class DualPathFactory(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_1x1_a,
num_3x3_b,
num_1x1_c,
inc,
G,
_type='normal',
name=None):
super(DualPathFactory, self).__init__()
self.num_1x1_c = num_1x1_c
self.inc = inc
self.name = name
kw = 3
kh = 3
pw = (kw - 1) // 2
ph = (kh - 1) // 2
# type
if _type == 'proj':
key_stride = 1
self.has_proj = True
elif _type == 'down':
key_stride = 2
self.has_proj = True
elif _type == 'normal':
key_stride = 1
self.has_proj = False
else:
print("not implemented now!!!")
sys.exit(1)
data_in_ch = sum(num_channels) if isinstance(num_channels,
list) else num_channels
if self.has_proj:
self.c1x1_w_func = BNACConvLayer(
num_channels=data_in_ch,
num_filters=num_1x1_c + 2 * inc,
filter_size=(1, 1),
pad=(0, 0),
stride=(key_stride, key_stride),
name=name + "_match")
self.c1x1_a_func = BNACConvLayer(
num_channels=data_in_ch,
num_filters=num_1x1_a,
filter_size=(1, 1),
pad=(0, 0),
name=name + "_conv1")
self.c3x3_b_func = BNACConvLayer(
num_channels=num_1x1_a,
num_filters=num_3x3_b,
filter_size=(kw, kh),
pad=(pw, ph),
stride=(key_stride, key_stride),
groups=G,
name=name + "_conv2")
self.c1x1_c_func = BNACConvLayer(
num_channels=num_3x3_b,
num_filters=num_1x1_c + inc,
filter_size=(1, 1),
pad=(0, 0),
name=name + "_conv3")
def forward(self, input):
# PROJ
if isinstance(input, list):
data_in = fluid.layers.concat([input[0], input[1]], axis=1)
else:
data_in = input
if self.has_proj:
c1x1_w = self.c1x1_w_func(data_in)
data_o1, data_o2 = fluid.layers.split(
c1x1_w, num_or_sections=[self.num_1x1_c, 2 * self.inc], dim=1)
else:
data_o1 = input[0]
data_o2 = input[1]
c1x1_a = self.c1x1_a_func(data_in)
c3x3_b = self.c3x3_b_func(c1x1_a)
c1x1_c = self.c1x1_c_func(c3x3_b)
c1x1_c1, c1x1_c2 = fluid.layers.split(
c1x1_c, num_or_sections=[self.num_1x1_c, self.inc], dim=1)
# OUTPUTS
summ = fluid.layers.elementwise_add(x=data_o1, y=c1x1_c1)
dense = fluid.layers.concat([data_o2, c1x1_c2], axis=1)
# tensor, channels
return [summ, dense]
__all__ = ["DPN", "DPN68", "DPN92", "DPN98", "DPN107", "DPN131"]
class DPN(fluid.dygraph.Layer):
def __init__(self, layers=60, class_dim=1000):
super(DPN, self).__init__()
class DPN(object): self._class_dim = class_dim
def __init__(self, layers=68):
self.layers = layers
def net(self, input, class_dim=1000): args = self.get_net_args(layers)
# get network args
args = self.get_net_args(self.layers)
bws = args['bw'] bws = args['bw']
inc_sec = args['inc_sec'] inc_sec = args['inc_sec']
rs = args['r'] rs = args['r']
...@@ -45,39 +215,23 @@ class DPN(object): ...@@ -45,39 +215,23 @@ class DPN(object):
init_filter_size = args['init_filter_size'] init_filter_size = args['init_filter_size']
init_padding = args['init_padding'] init_padding = args['init_padding']
## define Dual Path Network self.k_sec = k_sec
# conv1 self.conv1_x_1_func = ConvBNLayer(
conv1_x_1 = fluid.layers.conv2d( num_channels=3,
input=input,
num_filters=init_num_filter, num_filters=init_num_filter,
filter_size=init_filter_size, filter_size=3,
stride=2, stride=2,
padding=init_padding, pad=1,
groups=1,
act=None,
bias_attr=False,
name="conv1",
param_attr=ParamAttr(name="conv1_weights"), )
conv1_x_1 = fluid.layers.batch_norm(
input=conv1_x_1,
act='relu', act='relu',
is_test=False, name="conv1")
name="conv1_bn",
param_attr=ParamAttr(name='conv1_bn_scale'), self.pool2d_max = Pool2D(
bias_attr=ParamAttr('conv1_bn_offset'), pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
moving_mean_name='conv1_bn_mean',
moving_variance_name='conv1_bn_variance', ) num_channel_dpn = init_num_filter
convX_x_x = fluid.layers.pool2d(
input=conv1_x_1,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
name="pool1")
self.dpn_func_list = []
#conv2 - conv5 #conv2 - conv5
match_list, num = [], 0 match_list, num = [], 0
for gc in range(4): for gc in range(4):
...@@ -93,43 +247,82 @@ class DPN(object): ...@@ -93,43 +247,82 @@ class DPN(object):
_type2 = 'normal' _type2 = 'normal'
match = match + k_sec[gc - 1] match = match + k_sec[gc - 1]
match_list.append(match) match_list.append(match)
self.dpn_func_list.append(
self.add_sublayer(
"dpn{}".format(match),
DualPathFactory(
num_channels=num_channel_dpn,
num_1x1_a=R,
num_3x3_b=R,
num_1x1_c=bw,
inc=inc,
G=G,
_type=_type1,
name="dpn" + str(match))))
num_channel_dpn = [bw, 3 * inc]
convX_x_x = self.dual_path_factory(
convX_x_x, R, R, bw, inc, G, _type1, name="dpn" + str(match))
for i_ly in range(2, k_sec[gc] + 1): for i_ly in range(2, k_sec[gc] + 1):
num += 1 num += 1
if num in match_list: if num in match_list:
num += 1 num += 1
convX_x_x = self.dual_path_factory( self.dpn_func_list.append(
convX_x_x, R, R, bw, inc, G, _type2, name="dpn" + str(num)) self.add_sublayer(
"dpn{}".format(num),
conv5_x_x = fluid.layers.concat(convX_x_x, axis=1) DualPathFactory(
conv5_x_x = fluid.layers.batch_norm( num_channels=num_channel_dpn,
input=conv5_x_x, num_1x1_a=R,
act='relu', num_3x3_b=R,
is_test=False, num_1x1_c=bw,
name="final_concat_bn", inc=inc,
G=G,
_type=_type2,
name="dpn" + str(num))))
num_channel_dpn = [
num_channel_dpn[0], num_channel_dpn[1] + inc
]
out_channel = sum(num_channel_dpn)
self.conv5_x_x_bn = BatchNorm(
num_channels=sum(num_channel_dpn),
act="relu",
param_attr=ParamAttr(name='final_concat_bn_scale'), param_attr=ParamAttr(name='final_concat_bn_scale'),
bias_attr=ParamAttr('final_concat_bn_offset'), bias_attr=ParamAttr('final_concat_bn_offset'),
moving_mean_name='final_concat_bn_mean', moving_mean_name='final_concat_bn_mean',
moving_variance_name='final_concat_bn_variance', ) moving_variance_name='final_concat_bn_variance')
pool5 = fluid.layers.pool2d(
input=conv5_x_x, self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
pool_size=7,
pool_stride=1,
pool_padding=0,
pool_type='avg', )
stdv = 0.01 stdv = 0.01
fc6 = fluid.layers.fc(
input=pool5, self.out = Linear(
size=class_dim, out_channel,
class_dim,
param_attr=ParamAttr( param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc_weights'), name="fc_weights"),
bias_attr=ParamAttr(name='fc_offset')) bias_attr=ParamAttr(name="fc_offset"))
def forward(self, input):
conv1_x_1 = self.conv1_x_1_func(input)
convX_x_x = self.pool2d_max(conv1_x_1)
return fc6 dpn_idx = 0
for gc in range(4):
convX_x_x = self.dpn_func_list[dpn_idx](convX_x_x)
dpn_idx += 1
for i_ly in range(2, self.k_sec[gc] + 1):
convX_x_x = self.dpn_func_list[dpn_idx](convX_x_x)
dpn_idx += 1
conv5_x_x = fluid.layers.concat(convX_x_x, axis=1)
conv5_x_x = self.conv5_x_x_bn(conv5_x_x)
y = self.pool2d_avg(conv5_x_x)
y = fluid.layers.reshape(y, shape=[0, -1])
y = self.out(y)
return y
def get_net_args(self, layers): def get_net_args(self, layers):
if layers == 68: if layers == 68:
...@@ -198,119 +391,6 @@ class DPN(object): ...@@ -198,119 +391,6 @@ class DPN(object):
return net_arg return net_arg
def dual_path_factory(self,
data,
num_1x1_a,
num_3x3_b,
num_1x1_c,
inc,
G,
_type='normal',
name=None):
kw = 3
kh = 3
pw = (kw - 1) // 2
ph = (kh - 1) // 2
# type
if _type is 'proj':
key_stride = 1
has_proj = True
if _type is 'down':
key_stride = 2
has_proj = True
if _type is 'normal':
key_stride = 1
has_proj = False
# PROJ
if type(data) is list:
data_in = fluid.layers.concat([data[0], data[1]], axis=1)
else:
data_in = data
if has_proj:
c1x1_w = self.bn_ac_conv(
data=data_in,
num_filter=(num_1x1_c + 2 * inc),
kernel=(1, 1),
pad=(0, 0),
stride=(key_stride, key_stride),
name=name + "_match")
data_o1, data_o2 = fluid.layers.split(
c1x1_w,
num_or_sections=[num_1x1_c, 2 * inc],
dim=1,
name=name + "_match_conv_Slice")
else:
data_o1 = data[0]
data_o2 = data[1]
# MAIN
c1x1_a = self.bn_ac_conv(
data=data_in,
num_filter=num_1x1_a,
kernel=(1, 1),
pad=(0, 0),
name=name + "_conv1")
c3x3_b = self.bn_ac_conv(
data=c1x1_a,
num_filter=num_3x3_b,
kernel=(kw, kh),
pad=(pw, ph),
stride=(key_stride, key_stride),
num_group=G,
name=name + "_conv2")
c1x1_c = self.bn_ac_conv(
data=c3x3_b,
num_filter=(num_1x1_c + inc),
kernel=(1, 1),
pad=(0, 0),
name=name + "_conv3")
c1x1_c1, c1x1_c2 = fluid.layers.split(
c1x1_c,
num_or_sections=[num_1x1_c, inc],
dim=1,
name=name + "_conv3_Slice")
# OUTPUTS
summ = fluid.layers.elementwise_add(
x=data_o1, y=c1x1_c1, name=name + "_elewise")
dense = fluid.layers.concat(
[data_o2, c1x1_c2], axis=1, name=name + "_concat")
return [summ, dense]
def bn_ac_conv(self,
data,
num_filter,
kernel,
pad,
stride=(1, 1),
num_group=1,
name=None):
bn_ac = fluid.layers.batch_norm(
input=data,
act='relu',
is_test=False,
name=name + '.output.1',
param_attr=ParamAttr(name=name + '_bn_scale'),
bias_attr=ParamAttr(name + '_bn_offset'),
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance', )
bn_ac_conv = fluid.layers.conv2d(
input=bn_ac,
num_filters=num_filter,
filter_size=kernel,
stride=stride,
padding=pad,
groups=num_group,
act=None,
bias_attr=False,
param_attr=ParamAttr(name=name + "_weights"))
return bn_ac_conv
def DPN68(): def DPN68():
model = DPN(layers=68) model = DPN(layers=68)
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. import numpy as np
# import argparse
#Licensed under the Apache License, Version 2.0 (the "License"); import ast
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import math
import sys
import time
__all__ = [ __all__ = [
"HRNet", "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C", "HRNet_W18_C",
"HRNet_W44_C", "HRNet_W48_C", "HRNet_W60_C", "HRNet_W64_C", "HRNet_W30_C",
"SE_HRNet_W18_C", "SE_HRNet_W30_C", "SE_HRNet_W32_C", "SE_HRNet_W40_C", "HRNet_W32_C",
"SE_HRNet_W44_C", "SE_HRNet_W48_C", "SE_HRNet_W60_C", "SE_HRNet_W64_C" "HRNet_W40_C",
"HRNet_W44_C",
"HRNet_W48_C",
"HRNet_W60_C",
"HRNet_W64_C",
"SE_HRNet_W18_C",
"SE_HRNet_W30_C",
"SE_HRNet_W32_C",
"SE_HRNet_W40_C",
"SE_HRNet_W44_C",
"SE_HRNet_W48_C",
"SE_HRNet_W60_C",
"SE_HRNet_W64_C",
] ]
class HRNet(): class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self, width=18, has_se=False): def __init__(self,
self.width = width num_channels,
self.has_se = has_se num_filters,
self.channels = { filter_size,
18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]], stride=1,
30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]], groups=1,
32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]], act="relu",
40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]], name=None):
44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]], super(ConvBNLayer, self).__init__()
48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
}
def net(self, input, class_dim=1000): self._conv = Conv2D(
width = self.width num_channels=num_channels,
channels_2, channels_3, channels_4 = self.channels[width] num_filters=num_filters,
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3 filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
bn_name = name + '_bn'
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
x = self.conv_bn_layer( def forward(self, input):
input=input, y = self._conv(input)
filter_size=3, y = self._batch_norm(y)
num_filters=64, return y
stride=2,
if_act=True,
name='layer1_1')
x = self.conv_bn_layer(
input=x,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_2')
la1 = self.layer1(x, name='layer2')
tr1 = self.transition_layer([la1], [256], channels_2, name='tr1')
st2 = self.stage(tr1, num_modules_2, channels_2, name='st2')
tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2')
st3 = self.stage(tr2, num_modules_3, channels_3, name='st3')
tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3')
st4 = self.stage(tr3, num_modules_4, channels_4, name='st4')
#classification
last_cls = self.last_cls_out(x=st4, name='cls_head')
y = last_cls[0]
last_num_filters = [256, 512, 1024]
for i in range(3):
y = fluid.layers.elementwise_add(
last_cls[i + 1],
self.conv_bn_layer(
input=y,
filter_size=3,
num_filters=last_num_filters[i],
stride=2,
name='cls_head_add' + str(i + 1)))
y = self.conv_bn_layer(
input=y,
filter_size=1,
num_filters=2048,
stride=1,
name='cls_head_last_conv')
pool = fluid.layers.pool2d(
input=y, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=ParamAttr(
name='fc_weights',
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name='fc_offset'))
return out
def layer1(self, input, name=None): class Layer1(fluid.dygraph.Layer):
conv = input def __init__(self, num_channels, has_se=False, name=None):
super(Layer1, self).__init__()
self.bottleneck_block_list = []
for i in range(4): for i in range(4):
conv = self.bottleneck_block( bottleneck_block = self.add_sublayer(
conv, "bb_{}_{}".format(name, i + 1),
BottleneckBlock(
num_channels=num_channels if i == 0 else 256,
num_filters=64, num_filters=64,
has_se=has_se,
stride=1,
downsample=True if i == 0 else False, downsample=True if i == 0 else False,
name=name + '_' + str(i + 1)) name=name + '_' + str(i + 1)))
self.bottleneck_block_list.append(bottleneck_block)
def forward(self, input):
conv = input
for block_func in self.bottleneck_block_list:
conv = block_func(conv)
return conv return conv
def transition_layer(self, x, in_channels, out_channels, name=None):
class TransitionLayer(fluid.dygraph.Layer):
def __init__(self, in_channels, out_channels, name=None):
super(TransitionLayer, self).__init__()
num_in = len(in_channels) num_in = len(in_channels)
num_out = len(out_channels) num_out = len(out_channels)
out = [] out = []
self.conv_bn_func_list = []
for i in range(num_out): for i in range(num_out):
residual = None
if i < num_in: if i < num_in:
if in_channels[i] != out_channels[i]: if in_channels[i] != out_channels[i]:
residual = self.conv_bn_layer( residual = self.add_sublayer(
x[i], "transition_{}_layer_{}".format(name, i + 1),
filter_size=3, ConvBNLayer(
num_channels=in_channels[i],
num_filters=out_channels[i], num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual)
else:
out.append(x[i])
else:
residual = self.conv_bn_layer(
x[-1],
filter_size=3, filter_size=3,
name=name + '_layer_' + str(i + 1)))
else:
residual = self.add_sublayer(
"transition_{}_layer_{}".format(name, i + 1),
ConvBNLayer(
num_channels=in_channels[-1],
num_filters=out_channels[i], num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual)
return out
def branches(self, x, block_num, channels, name=None):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num):
residual = self.basic_block(
residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1))
out.append(residual)
return out
def fuse_layers(self, x, channels, multi_scale_output=True, name=None):
out = []
for i in range(len(channels) if multi_scale_output else 1):
residual = x[i]
for j in range(len(channels)):
if j > i:
y = self.conv_bn_layer(
x[j],
filter_size=1,
num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_nearest(input=y, scale=2**(j - i))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = x[j]
for k in range(i - j):
if k == i - j - 1:
y = self.conv_bn_layer(
y,
filter_size=3, filter_size=3,
num_filters=channels[i],
stride=2, stride=2,
if_act=False, name=name + '_layer_' + str(i + 1)))
name=name + '_layer_' + str(i + 1) + '_' + self.conv_bn_func_list.append(residual)
str(j + 1) + '_' + str(k + 1))
def forward(self, input):
outs = []
for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
if conv_bn_func is None:
outs.append(input[idx])
else: else:
y = self.conv_bn_layer( if idx < len(input):
y, outs.append(conv_bn_func(input[idx]))
filter_size=3, else:
num_filters=channels[j], outs.append(conv_bn_func(input[-1]))
stride=2, return outs
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
residual = fluid.layers.relu(residual)
out.append(residual)
return out
def high_resolution_module(self, class Branches(fluid.dygraph.Layer):
x, def __init__(self,
channels, block_num,
multi_scale_output=True, in_channels,
out_channels,
has_se=False,
name=None): name=None):
residual = self.branches(x, 4, channels, name=name) super(Branches, self).__init__()
out = self.fuse_layers(
residual,
channels,
multi_scale_output=multi_scale_output,
name=name)
return out
def stage(self, self.basic_block_list = []
x,
num_modules,
channels,
multi_scale_output=True,
name=None):
out = x
for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False:
out = self.high_resolution_module(
out,
channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else:
out = self.high_resolution_module(
out, channels, name=name + '_' + str(i + 1))
return out for i in range(len(out_channels)):
self.basic_block_list.append([])
for j in range(block_num):
in_ch = in_channels[i] if j == 0 else out_channels[i]
basic_block_func = self.add_sublayer(
"bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
BasicBlock(
num_channels=in_ch,
num_filters=out_channels[i],
has_se=has_se,
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1)))
self.basic_block_list[i].append(basic_block_func)
def last_cls_out(self, x, name=None): def forward(self, inputs):
out = [] outs = []
num_filters_list = [32, 64, 128, 256] for idx, input in enumerate(inputs):
for i in range(len(x)): conv = input
out.append( for basic_block_func in self.basic_block_list[idx]:
self.bottleneck_block( conv = basic_block_func(conv)
input=x[i], outs.append(conv)
num_filters=num_filters_list[i], return outs
name=name + 'conv_' + str(i + 1),
downsample=True))
return out
def basic_block(self, class BottleneckBlock(fluid.dygraph.Layer):
input, def __init__(self,
num_channels,
num_filters, num_filters,
has_se,
stride=1, stride=1,
downsample=False, downsample=False,
name=None): name=None):
residual = input super(BottleneckBlock, self).__init__()
conv = self.conv_bn_layer(
input=input, self.has_se = has_se
filter_size=3, self.downsample = downsample
num_filters=num_filters,
stride=stride, self.conv1 = ConvBNLayer(
name=name + '_conv1') num_channels=num_channels,
conv = self.conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters, num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample:
residual = self.conv_bn_layer(
input=input,
filter_size=1, filter_size=1,
act="relu",
name=name + "_conv1", )
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters, num_filters=num_filters,
if_act=False, filter_size=3,
name=name + '_downsample') stride=stride,
if self.has_se: act="relu",
conv = self.squeeze_excitation( name=name + "_conv2")
input=conv, self.conv3 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_conv3")
if self.downsample:
self.conv_down = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_downsample")
if self.has_se:
self.se = SELayer(
num_channels=num_filters * 4,
num_filters=num_filters * 4,
reduction_ratio=16, reduction_ratio=16,
name=name + '_fc') name='fc' + name)
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def bottleneck_block(self, def forward(self, input):
input, residual = input
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
if self.downsample:
residual = self.conv_down(input)
if self.has_se:
conv3 = self.se(conv3)
y = fluid.layers.elementwise_add(x=conv3, y=residual, act="relu")
return y
class BasicBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters, num_filters,
stride=1, stride=1,
has_se=False,
downsample=False, downsample=False,
name=None): name=None):
residual = input super(BasicBlock, self).__init__()
conv = self.conv_bn_layer(
input=input, self.has_se = has_se
filter_size=1, self.downsample = downsample
self.conv1 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
name=name + '_conv1')
conv = self.conv_bn_layer(
input=conv,
filter_size=3, filter_size=3,
num_filters=num_filters,
stride=stride, stride=stride,
name=name + '_conv2') act="relu",
conv = self.conv_bn_layer( name=name + "_conv1")
input=conv, self.conv2 = ConvBNLayer(
filter_size=1, num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=1,
act=None,
name=name + "_conv2")
if self.downsample:
self.conv_down = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4, num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample:
residual = self.conv_bn_layer(
input=input,
filter_size=1, filter_size=1,
num_filters=num_filters * 4, act="relu",
if_act=False, name=name + "_downsample")
name=name + '_downsample')
if self.has_se: if self.has_se:
conv = self.squeeze_excitation( self.se = SELayer(
input=conv, num_channels=num_filters,
num_channels=num_filters * 4, num_filters=num_filters,
reduction_ratio=16, reduction_ratio=16,
name=name + '_fc') name='fc' + name)
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def forward(self, input):
residual = input
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
def squeeze_excitation(self, if self.downsample:
input, residual = self.conv_down(input)
if self.has_se:
conv2 = self.se(conv2)
y = fluid.layers.elementwise_add(x=conv2, y=residual, act="relu")
return y
class SELayer(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
super(SELayer, self).__init__()
self.pool2d_gap = Pool2D(pool_type='avg', global_pooling=True)
self._num_channels = num_channels
med_ch = int(num_channels / reduction_ratio)
stdv = 1.0 / math.sqrt(num_channels * 1.0)
self.squeeze = Linear(
num_channels, num_channels,
reduction_ratio, med_ch,
name=None): act="relu",
pool = fluid.layers.pool2d( param_attr=ParamAttr(
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(
input=pool,
size=num_channels / reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_sqz_weights'), name=name + "_sqz_weights"),
bias_attr=ParamAttr(name=name + '_sqz_offset')) bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc( stdv = 1.0 / math.sqrt(med_ch * 1.0)
input=squeeze, self.excitation = Linear(
size=num_channels, med_ch,
act='sigmoid', num_filters,
param_attr=fluid.param_attr.ParamAttr( act="sigmoid",
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_exc_weights'), name=name + "_exc_weights"),
bias_attr=ParamAttr(name=name + '_exc_offset')) bias_attr=ParamAttr(name=name + '_exc_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
def conv_bn_layer(self, def forward(self, input):
input, pool = self.pool2d_gap(input)
filter_size, pool = fluid.layers.reshape(pool, shape=[-1, self._num_channels])
squeeze = self.squeeze(pool)
excitation = self.excitation(squeeze)
excitation = fluid.layers.reshape(
excitation, shape=[-1, self._num_channels, 1, 1])
out = input * excitation
return out
class Stage(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_modules,
num_filters, num_filters,
stride=1, has_se=False,
padding=1, multi_scale_output=True,
num_groups=1,
if_act=True,
name=None): name=None):
conv = fluid.layers.conv2d( super(Stage, self).__init__()
input=input,
self._num_modules = num_modules
self.stage_func_list = []
for i in range(num_modules):
if i == num_modules - 1 and not multi_scale_output:
stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1),
HighResolutionModule(
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, has_se=has_se,
stride=stride, multi_scale_output=False,
padding=(filter_size - 1) // 2, name=name + '_' + str(i + 1)))
groups=num_groups, else:
stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1),
HighResolutionModule(
num_channels=num_channels,
num_filters=num_filters,
has_se=has_se,
name=name + '_' + str(i + 1)))
self.stage_func_list.append(stage_func)
def forward(self, input):
out = input
for idx in range(self._num_modules):
out = self.stage_func_list[idx](out)
return out
class HighResolutionModule(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
has_se=False,
multi_scale_output=True,
name=None):
super(HighResolutionModule, self).__init__()
self.branches_func = Branches(
block_num=4,
in_channels=num_channels,
out_channels=num_filters,
has_se=has_se,
name=name)
self.fuse_func = FuseLayers(
in_channels=num_filters,
out_channels=num_filters,
multi_scale_output=multi_scale_output,
name=name)
def forward(self, input):
out = self.branches_func(input)
out = self.fuse_func(out)
return out
class FuseLayers(fluid.dygraph.Layer):
def __init__(self,
in_channels,
out_channels,
multi_scale_output=True,
name=None):
super(FuseLayers, self).__init__()
self._actual_ch = len(in_channels) if multi_scale_output else 1
self._in_channels = in_channels
self.residual_func_list = []
for i in range(self._actual_ch):
for j in range(len(in_channels)):
residual_func = None
if j > i:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
ConvBNLayer(
num_channels=in_channels[j],
num_filters=out_channels[i],
filter_size=1,
stride=1,
act=None, act=None,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1)))
self.residual_func_list.append(residual_func)
elif j < i:
pre_num_filters = in_channels[j]
for k in range(i - j):
if k == i - j - 1:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}_{}".format(
name, i + 1, j + 1, k + 1),
ConvBNLayer(
num_channels=pre_num_filters,
num_filters=out_channels[i],
filter_size=3,
stride=2,
act=None,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1)))
pre_num_filters = out_channels[i]
else:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}_{}".format(
name, i + 1, j + 1, k + 1),
ConvBNLayer(
num_channels=pre_num_filters,
num_filters=out_channels[j],
filter_size=3,
stride=2,
act="relu",
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1)))
pre_num_filters = out_channels[j]
self.residual_func_list.append(residual_func)
def forward(self, input):
outs = []
residual_func_idx = 0
for i in range(self._actual_ch):
residual = input[i]
for j in range(len(self._in_channels)):
if j > i:
y = self.residual_func_list[residual_func_idx](input[j])
residual_func_idx += 1
y = fluid.layers.resize_nearest(input=y, scale=2**(j - i))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = input[j]
for k in range(i - j):
y = self.residual_func_list[residual_func_idx](y)
residual_func_idx += 1
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
layer_helper = LayerHelper(self.full_name(), act='relu')
residual = layer_helper.append_activation(residual)
outs.append(residual)
return outs
class LastClsOut(fluid.dygraph.Layer):
def __init__(self,
num_channel_list,
has_se,
num_filters_list=[32, 64, 128, 256],
name=None):
super(LastClsOut, self).__init__()
self.func_list = []
for idx in range(len(num_channel_list)):
func = self.add_sublayer(
"conv_{}_conv_{}".format(name, idx + 1),
BottleneckBlock(
num_channels=num_channel_list[idx],
num_filters=num_filters_list[idx],
has_se=has_se,
downsample=True,
name=name + 'conv_' + str(idx + 1)))
self.func_list.append(func)
def forward(self, inputs):
outs = []
for idx, input in enumerate(inputs):
out = self.func_list[idx](input)
outs.append(out)
return outs
class HRNet(fluid.dygraph.Layer):
def __init__(self, width=18, has_se=False, class_dim=1000):
super(HRNet, self).__init__()
self.width = width
self.has_se = has_se
self.channels = {
18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
}
self._class_dim = class_dim
channels_2, channels_3, channels_4 = self.channels[width]
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
self.conv_layer1_1 = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=3,
stride=2,
act='relu',
name="layer1_1")
self.conv_layer1_2 = ConvBNLayer(
num_channels=64,
num_filters=64,
filter_size=3,
stride=2,
act='relu',
name="layer1_2")
self.la1 = Layer1(num_channels=64, has_se=has_se, name="layer2")
self.tr1 = TransitionLayer(
in_channels=[256], out_channels=channels_2, name="tr1")
self.st2 = Stage(
num_channels=channels_2,
num_modules=num_modules_2,
num_filters=channels_2,
has_se=self.has_se,
name="st2")
self.tr2 = TransitionLayer(
in_channels=channels_2, out_channels=channels_3, name="tr2")
self.st3 = Stage(
num_channels=channels_3,
num_modules=num_modules_3,
num_filters=channels_3,
has_se=self.has_se,
name="st3")
self.tr3 = TransitionLayer(
in_channels=channels_3, out_channels=channels_4, name="tr3")
self.st4 = Stage(
num_channels=channels_4,
num_modules=num_modules_4,
num_filters=channels_4,
has_se=self.has_se,
name="st4")
# classification
num_filters_list = [32, 64, 128, 256]
self.last_cls = LastClsOut(
num_channel_list=channels_4,
has_se=self.has_se,
num_filters_list=num_filters_list,
name="cls_head", )
last_num_filters = [256, 512, 1024]
self.cls_head_conv_list = []
for idx in range(3):
self.cls_head_conv_list.append(
self.add_sublayer(
"cls_head_add{}".format(idx + 1),
ConvBNLayer(
num_channels=num_filters_list[idx] * 4,
num_filters=last_num_filters[idx],
filter_size=3,
stride=2,
name="cls_head_add" + str(idx + 1))))
self.conv_last = ConvBNLayer(
num_channels=1024,
num_filters=2048,
filter_size=1,
stride=1,
name="cls_head_last_conv")
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.out = Linear(
2048,
class_dim,
param_attr=ParamAttr( param_attr=ParamAttr(
initializer=MSRA(), name=name + '_weights'), initializer=fluid.initializer.Uniform(-stdv, stdv),
bias_attr=False) name="fc_weights"),
bn_name = name + '_bn' bias_attr=ParamAttr(name="fc_offset"))
bn = fluid.layers.batch_norm(
input=conv, def forward(self, input):
param_attr=ParamAttr( conv1 = self.conv_layer1_1(input)
name=bn_name + "_scale", conv2 = self.conv_layer1_2(conv1)
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr( la1 = self.la1(conv2)
name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)), tr1 = self.tr1([la1])
moving_mean_name=bn_name + '_mean', st2 = self.st2(tr1)
moving_variance_name=bn_name + '_variance')
if if_act: tr2 = self.tr2(st2)
bn = fluid.layers.relu(bn) st3 = self.st3(tr2)
return bn
tr3 = self.tr3(st3)
st4 = self.st4(tr3)
last_cls = self.last_cls(st4)
y = last_cls[0]
for idx in range(3):
y = last_cls[idx + 1] + self.cls_head_conv_list[idx](y)
y = self.conv_last(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[0, -1])
y = self.out(y)
return y
def HRNet_W18_C(): def HRNet_W18_C():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册