“40162227bbd88e4f60644058ae5209e0eaa62665”上不存在“test/dygraph_to_static/test_save_inference_model.py”
提交 34edfa05 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add dpn, densenet and hrnet dygraph model

上级 3b93ffa0
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import time
import sys
import math
import argparse
import ast
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import math
import sys
import time
__all__ = [
"DPN",
"DPN68",
"DPN92",
"DPN98",
"DPN107",
"DPN131",
]
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
pad=0,
groups=1,
act="relu",
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=pad,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=name + '_bn_scale'),
bias_attr=ParamAttr(name + '_bn_offset'),
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
def forward(self, input):
y = self._conv(input)
y = self._batch_norm(y)
return y
class BNACConvLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
pad=0,
groups=1,
act="relu",
name=None):
super(BNACConvLayer, self).__init__()
self.num_channels = num_channels
self.name = name
self._batch_norm = BatchNorm(
num_channels,
act=act,
param_attr=ParamAttr(name=name + '_bn_scale'),
bias_attr=ParamAttr(name + '_bn_offset'),
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=pad,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
def forward(self, input):
y = self._batch_norm(input)
y = self._conv(y)
return y
class DualPathFactory(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_1x1_a,
num_3x3_b,
num_1x1_c,
inc,
G,
_type='normal',
name=None):
super(DualPathFactory, self).__init__()
self.num_1x1_c = num_1x1_c
self.inc = inc
self.name = name
kw = 3
kh = 3
pw = (kw - 1) // 2
ph = (kh - 1) // 2
# type
if _type == 'proj':
key_stride = 1
self.has_proj = True
elif _type == 'down':
key_stride = 2
self.has_proj = True
elif _type == 'normal':
key_stride = 1
self.has_proj = False
else:
print("not implemented now!!!")
sys.exit(1)
data_in_ch = sum(num_channels) if isinstance(num_channels,
list) else num_channels
if self.has_proj:
self.c1x1_w_func = BNACConvLayer(
num_channels=data_in_ch,
num_filters=num_1x1_c + 2 * inc,
filter_size=(1, 1),
pad=(0, 0),
stride=(key_stride, key_stride),
name=name + "_match")
self.c1x1_a_func = BNACConvLayer(
num_channels=data_in_ch,
num_filters=num_1x1_a,
filter_size=(1, 1),
pad=(0, 0),
name=name + "_conv1")
self.c3x3_b_func = BNACConvLayer(
num_channels=num_1x1_a,
num_filters=num_3x3_b,
filter_size=(kw, kh),
pad=(pw, ph),
stride=(key_stride, key_stride),
groups=G,
name=name + "_conv2")
self.c1x1_c_func = BNACConvLayer(
num_channels=num_3x3_b,
num_filters=num_1x1_c + inc,
filter_size=(1, 1),
pad=(0, 0),
name=name + "_conv3")
def forward(self, input):
# PROJ
if isinstance(input, list):
data_in = fluid.layers.concat([input[0], input[1]], axis=1)
else:
data_in = input
if self.has_proj:
c1x1_w = self.c1x1_w_func(data_in)
data_o1, data_o2 = fluid.layers.split(
c1x1_w, num_or_sections=[self.num_1x1_c, 2 * self.inc], dim=1)
else:
data_o1 = input[0]
data_o2 = input[1]
c1x1_a = self.c1x1_a_func(data_in)
c3x3_b = self.c3x3_b_func(c1x1_a)
c1x1_c = self.c1x1_c_func(c3x3_b)
c1x1_c1, c1x1_c2 = fluid.layers.split(
c1x1_c, num_or_sections=[self.num_1x1_c, self.inc], dim=1)
# OUTPUTS
summ = fluid.layers.elementwise_add(x=data_o1, y=c1x1_c1)
dense = fluid.layers.concat([data_o2, c1x1_c2], axis=1)
# tensor, channels
return [summ, dense]
__all__ = ["DPN", "DPN68", "DPN92", "DPN98", "DPN107", "DPN131"]
class DPN(fluid.dygraph.Layer):
def __init__(self, layers=60, class_dim=1000):
super(DPN, self).__init__()
class DPN(object):
def __init__(self, layers=68):
self.layers = layers
self._class_dim = class_dim
def net(self, input, class_dim=1000):
# get network args
args = self.get_net_args(self.layers)
args = self.get_net_args(layers)
bws = args['bw']
inc_sec = args['inc_sec']
rs = args['r']
......@@ -45,39 +215,23 @@ class DPN(object):
init_filter_size = args['init_filter_size']
init_padding = args['init_padding']
## define Dual Path Network
self.k_sec = k_sec
# conv1
conv1_x_1 = fluid.layers.conv2d(
input=input,
self.conv1_x_1_func = ConvBNLayer(
num_channels=3,
num_filters=init_num_filter,
filter_size=init_filter_size,
filter_size=3,
stride=2,
padding=init_padding,
groups=1,
act=None,
bias_attr=False,
name="conv1",
param_attr=ParamAttr(name="conv1_weights"), )
conv1_x_1 = fluid.layers.batch_norm(
input=conv1_x_1,
pad=1,
act='relu',
is_test=False,
name="conv1_bn",
param_attr=ParamAttr(name='conv1_bn_scale'),
bias_attr=ParamAttr('conv1_bn_offset'),
moving_mean_name='conv1_bn_mean',
moving_variance_name='conv1_bn_variance', )
convX_x_x = fluid.layers.pool2d(
input=conv1_x_1,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
name="pool1")
name="conv1")
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
num_channel_dpn = init_num_filter
self.dpn_func_list = []
#conv2 - conv5
match_list, num = [], 0
for gc in range(4):
......@@ -93,43 +247,82 @@ class DPN(object):
_type2 = 'normal'
match = match + k_sec[gc - 1]
match_list.append(match)
self.dpn_func_list.append(
self.add_sublayer(
"dpn{}".format(match),
DualPathFactory(
num_channels=num_channel_dpn,
num_1x1_a=R,
num_3x3_b=R,
num_1x1_c=bw,
inc=inc,
G=G,
_type=_type1,
name="dpn" + str(match))))
num_channel_dpn = [bw, 3 * inc]
convX_x_x = self.dual_path_factory(
convX_x_x, R, R, bw, inc, G, _type1, name="dpn" + str(match))
for i_ly in range(2, k_sec[gc] + 1):
num += 1
if num in match_list:
num += 1
convX_x_x = self.dual_path_factory(
convX_x_x, R, R, bw, inc, G, _type2, name="dpn" + str(num))
conv5_x_x = fluid.layers.concat(convX_x_x, axis=1)
conv5_x_x = fluid.layers.batch_norm(
input=conv5_x_x,
act='relu',
is_test=False,
name="final_concat_bn",
self.dpn_func_list.append(
self.add_sublayer(
"dpn{}".format(num),
DualPathFactory(
num_channels=num_channel_dpn,
num_1x1_a=R,
num_3x3_b=R,
num_1x1_c=bw,
inc=inc,
G=G,
_type=_type2,
name="dpn" + str(num))))
num_channel_dpn = [
num_channel_dpn[0], num_channel_dpn[1] + inc
]
out_channel = sum(num_channel_dpn)
self.conv5_x_x_bn = BatchNorm(
num_channels=sum(num_channel_dpn),
act="relu",
param_attr=ParamAttr(name='final_concat_bn_scale'),
bias_attr=ParamAttr('final_concat_bn_offset'),
moving_mean_name='final_concat_bn_mean',
moving_variance_name='final_concat_bn_variance', )
pool5 = fluid.layers.pool2d(
input=conv5_x_x,
pool_size=7,
pool_stride=1,
pool_padding=0,
pool_type='avg', )
moving_variance_name='final_concat_bn_variance')
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
stdv = 0.01
fc6 = fluid.layers.fc(
input=pool5,
size=class_dim,
self.out = Linear(
out_channel,
class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
name="fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
def forward(self, input):
conv1_x_1 = self.conv1_x_1_func(input)
convX_x_x = self.pool2d_max(conv1_x_1)
return fc6
dpn_idx = 0
for gc in range(4):
convX_x_x = self.dpn_func_list[dpn_idx](convX_x_x)
dpn_idx += 1
for i_ly in range(2, self.k_sec[gc] + 1):
convX_x_x = self.dpn_func_list[dpn_idx](convX_x_x)
dpn_idx += 1
conv5_x_x = fluid.layers.concat(convX_x_x, axis=1)
conv5_x_x = self.conv5_x_x_bn(conv5_x_x)
y = self.pool2d_avg(conv5_x_x)
y = fluid.layers.reshape(y, shape=[0, -1])
y = self.out(y)
return y
def get_net_args(self, layers):
if layers == 68:
......@@ -198,119 +391,6 @@ class DPN(object):
return net_arg
def dual_path_factory(self,
data,
num_1x1_a,
num_3x3_b,
num_1x1_c,
inc,
G,
_type='normal',
name=None):
kw = 3
kh = 3
pw = (kw - 1) // 2
ph = (kh - 1) // 2
# type
if _type is 'proj':
key_stride = 1
has_proj = True
if _type is 'down':
key_stride = 2
has_proj = True
if _type is 'normal':
key_stride = 1
has_proj = False
# PROJ
if type(data) is list:
data_in = fluid.layers.concat([data[0], data[1]], axis=1)
else:
data_in = data
if has_proj:
c1x1_w = self.bn_ac_conv(
data=data_in,
num_filter=(num_1x1_c + 2 * inc),
kernel=(1, 1),
pad=(0, 0),
stride=(key_stride, key_stride),
name=name + "_match")
data_o1, data_o2 = fluid.layers.split(
c1x1_w,
num_or_sections=[num_1x1_c, 2 * inc],
dim=1,
name=name + "_match_conv_Slice")
else:
data_o1 = data[0]
data_o2 = data[1]
# MAIN
c1x1_a = self.bn_ac_conv(
data=data_in,
num_filter=num_1x1_a,
kernel=(1, 1),
pad=(0, 0),
name=name + "_conv1")
c3x3_b = self.bn_ac_conv(
data=c1x1_a,
num_filter=num_3x3_b,
kernel=(kw, kh),
pad=(pw, ph),
stride=(key_stride, key_stride),
num_group=G,
name=name + "_conv2")
c1x1_c = self.bn_ac_conv(
data=c3x3_b,
num_filter=(num_1x1_c + inc),
kernel=(1, 1),
pad=(0, 0),
name=name + "_conv3")
c1x1_c1, c1x1_c2 = fluid.layers.split(
c1x1_c,
num_or_sections=[num_1x1_c, inc],
dim=1,
name=name + "_conv3_Slice")
# OUTPUTS
summ = fluid.layers.elementwise_add(
x=data_o1, y=c1x1_c1, name=name + "_elewise")
dense = fluid.layers.concat(
[data_o2, c1x1_c2], axis=1, name=name + "_concat")
return [summ, dense]
def bn_ac_conv(self,
data,
num_filter,
kernel,
pad,
stride=(1, 1),
num_group=1,
name=None):
bn_ac = fluid.layers.batch_norm(
input=data,
act='relu',
is_test=False,
name=name + '.output.1',
param_attr=ParamAttr(name=name + '_bn_scale'),
bias_attr=ParamAttr(name + '_bn_offset'),
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance', )
bn_ac_conv = fluid.layers.conv2d(
input=bn_ac,
num_filters=num_filter,
filter_size=kernel,
stride=stride,
padding=pad,
groups=num_group,
act=None,
bias_attr=False,
param_attr=ParamAttr(name=name + "_weights"))
return bn_ac_conv
def DPN68():
model = DPN(layers=68)
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import argparse
import ast
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import math
import sys
import time
__all__ = [
"HRNet", "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C",
"HRNet_W44_C", "HRNet_W48_C", "HRNet_W60_C", "HRNet_W64_C",
"SE_HRNet_W18_C", "SE_HRNet_W30_C", "SE_HRNet_W32_C", "SE_HRNet_W40_C",
"SE_HRNet_W44_C", "SE_HRNet_W48_C", "SE_HRNet_W60_C", "SE_HRNet_W64_C"
"HRNet_W18_C",
"HRNet_W30_C",
"HRNet_W32_C",
"HRNet_W40_C",
"HRNet_W44_C",
"HRNet_W48_C",
"HRNet_W60_C",
"HRNet_W64_C",
"SE_HRNet_W18_C",
"SE_HRNet_W30_C",
"SE_HRNet_W32_C",
"SE_HRNet_W40_C",
"SE_HRNet_W44_C",
"SE_HRNet_W48_C",
"SE_HRNet_W60_C",
"SE_HRNet_W64_C",
]
class HRNet():
def __init__(self, width=18, has_se=False):
self.width = width
self.has_se = has_se
self.channels = {
18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act="relu",
name=None):
super(ConvBNLayer, self).__init__()
def net(self, input, class_dim=1000):
width = self.width
channels_2, channels_3, channels_4 = self.channels[width]
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
bn_name = name + '_bn'
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
x = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_1')
x = self.conv_bn_layer(
input=x,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_2')
la1 = self.layer1(x, name='layer2')
tr1 = self.transition_layer([la1], [256], channels_2, name='tr1')
st2 = self.stage(tr1, num_modules_2, channels_2, name='st2')
tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2')
st3 = self.stage(tr2, num_modules_3, channels_3, name='st3')
tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3')
st4 = self.stage(tr3, num_modules_4, channels_4, name='st4')
#classification
last_cls = self.last_cls_out(x=st4, name='cls_head')
y = last_cls[0]
last_num_filters = [256, 512, 1024]
for i in range(3):
y = fluid.layers.elementwise_add(
last_cls[i + 1],
self.conv_bn_layer(
input=y,
filter_size=3,
num_filters=last_num_filters[i],
stride=2,
name='cls_head_add' + str(i + 1)))
def forward(self, input):
y = self._conv(input)
y = self._batch_norm(y)
return y
y = self.conv_bn_layer(
input=y,
filter_size=1,
num_filters=2048,
stride=1,
name='cls_head_last_conv')
pool = fluid.layers.pool2d(
input=y, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=ParamAttr(
name='fc_weights',
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name='fc_offset'))
return out
def layer1(self, input, name=None):
conv = input
class Layer1(fluid.dygraph.Layer):
def __init__(self, num_channels, has_se=False, name=None):
super(Layer1, self).__init__()
self.bottleneck_block_list = []
for i in range(4):
conv = self.bottleneck_block(
conv,
bottleneck_block = self.add_sublayer(
"bb_{}_{}".format(name, i + 1),
BottleneckBlock(
num_channels=num_channels if i == 0 else 256,
num_filters=64,
has_se=has_se,
stride=1,
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
name=name + '_' + str(i + 1)))
self.bottleneck_block_list.append(bottleneck_block)
def forward(self, input):
conv = input
for block_func in self.bottleneck_block_list:
conv = block_func(conv)
return conv
def transition_layer(self, x, in_channels, out_channels, name=None):
class TransitionLayer(fluid.dygraph.Layer):
def __init__(self, in_channels, out_channels, name=None):
super(TransitionLayer, self).__init__()
num_in = len(in_channels)
num_out = len(out_channels)
out = []
self.conv_bn_func_list = []
for i in range(num_out):
residual = None
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = self.conv_bn_layer(
x[i],
filter_size=3,
residual = self.add_sublayer(
"transition_{}_layer_{}".format(name, i + 1),
ConvBNLayer(
num_channels=in_channels[i],
num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual)
else:
out.append(x[i])
else:
residual = self.conv_bn_layer(
x[-1],
filter_size=3,
name=name + '_layer_' + str(i + 1)))
else:
residual = self.add_sublayer(
"transition_{}_layer_{}".format(name, i + 1),
ConvBNLayer(
num_channels=in_channels[-1],
num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual)
return out
def branches(self, x, block_num, channels, name=None):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num):
residual = self.basic_block(
residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1))
out.append(residual)
return out
def fuse_layers(self, x, channels, multi_scale_output=True, name=None):
out = []
for i in range(len(channels) if multi_scale_output else 1):
residual = x[i]
for j in range(len(channels)):
if j > i:
y = self.conv_bn_layer(
x[j],
filter_size=1,
num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_nearest(input=y, scale=2**(j - i))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = x[j]
for k in range(i - j):
if k == i - j - 1:
y = self.conv_bn_layer(
y,
filter_size=3,
num_filters=channels[i],
stride=2,
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
name=name + '_layer_' + str(i + 1)))
self.conv_bn_func_list.append(residual)
def forward(self, input):
outs = []
for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
if conv_bn_func is None:
outs.append(input[idx])
else:
y = self.conv_bn_layer(
y,
filter_size=3,
num_filters=channels[j],
stride=2,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
if idx < len(input):
outs.append(conv_bn_func(input[idx]))
else:
outs.append(conv_bn_func(input[-1]))
return outs
residual = fluid.layers.relu(residual)
out.append(residual)
return out
def high_resolution_module(self,
x,
channels,
multi_scale_output=True,
class Branches(fluid.dygraph.Layer):
def __init__(self,
block_num,
in_channels,
out_channels,
has_se=False,
name=None):
residual = self.branches(x, 4, channels, name=name)
out = self.fuse_layers(
residual,
channels,
multi_scale_output=multi_scale_output,
name=name)
return out
super(Branches, self).__init__()
def stage(self,
x,
num_modules,
channels,
multi_scale_output=True,
name=None):
out = x
for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False:
out = self.high_resolution_module(
out,
channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else:
out = self.high_resolution_module(
out, channels, name=name + '_' + str(i + 1))
self.basic_block_list = []
return out
for i in range(len(out_channels)):
self.basic_block_list.append([])
for j in range(block_num):
in_ch = in_channels[i] if j == 0 else out_channels[i]
basic_block_func = self.add_sublayer(
"bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
BasicBlock(
num_channels=in_ch,
num_filters=out_channels[i],
has_se=has_se,
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1)))
self.basic_block_list[i].append(basic_block_func)
def last_cls_out(self, x, name=None):
out = []
num_filters_list = [32, 64, 128, 256]
for i in range(len(x)):
out.append(
self.bottleneck_block(
input=x[i],
num_filters=num_filters_list[i],
name=name + 'conv_' + str(i + 1),
downsample=True))
def forward(self, inputs):
outs = []
for idx, input in enumerate(inputs):
conv = input
for basic_block_func in self.basic_block_list[idx]:
conv = basic_block_func(conv)
outs.append(conv)
return outs
return out
def basic_block(self,
input,
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
has_se,
stride=1,
downsample=False,
name=None):
residual = input
conv = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv1')
conv = self.conv_bn_layer(
input=conv,
filter_size=3,
super(BottleneckBlock, self).__init__()
self.has_se = has_se
self.downsample = downsample
self.conv1 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample:
residual = self.conv_bn_layer(
input=input,
filter_size=1,
act="relu",
name=name + "_conv1", )
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
if_act=False,
name=name + '_downsample')
if self.has_se:
conv = self.squeeze_excitation(
input=conv,
filter_size=3,
stride=stride,
act="relu",
name=name + "_conv2")
self.conv3 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_conv3")
if self.downsample:
self.conv_down = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_downsample")
if self.has_se:
self.se = SELayer(
num_channels=num_filters * 4,
num_filters=num_filters * 4,
reduction_ratio=16,
name=name + '_fc')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
name='fc' + name)
def bottleneck_block(self,
input,
def forward(self, input):
residual = input
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
if self.downsample:
residual = self.conv_down(input)
if self.has_se:
conv3 = self.se(conv3)
y = fluid.layers.elementwise_add(x=conv3, y=residual, act="relu")
return y
class BasicBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride=1,
has_se=False,
downsample=False,
name=None):
residual = input
conv = self.conv_bn_layer(
input=input,
filter_size=1,
super(BasicBlock, self).__init__()
self.has_se = has_se
self.downsample = downsample
self.conv1 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
name=name + '_conv1')
conv = self.conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv2')
conv = self.conv_bn_layer(
input=conv,
filter_size=1,
act="relu",
name=name + "_conv1")
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=1,
act=None,
name=name + "_conv2")
if self.downsample:
self.conv_down = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample:
residual = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_downsample')
act="relu",
name=name + "_downsample")
if self.has_se:
conv = self.squeeze_excitation(
input=conv,
num_channels=num_filters * 4,
self.se = SELayer(
num_channels=num_filters,
num_filters=num_filters,
reduction_ratio=16,
name=name + '_fc')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
name='fc' + name)
def forward(self, input):
residual = input
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
def squeeze_excitation(self,
input,
if self.downsample:
residual = self.conv_down(input)
if self.has_se:
conv2 = self.se(conv2)
y = fluid.layers.elementwise_add(x=conv2, y=residual, act="relu")
return y
class SELayer(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
super(SELayer, self).__init__()
self.pool2d_gap = Pool2D(pool_type='avg', global_pooling=True)
self._num_channels = num_channels
med_ch = int(num_channels / reduction_ratio)
stdv = 1.0 / math.sqrt(num_channels * 1.0)
self.squeeze = Linear(
num_channels,
reduction_ratio,
name=None):
pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(
input=pool,
size=num_channels / reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
med_ch,
act="relu",
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_sqz_weights'),
name=name + "_sqz_weights"),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(
input=squeeze,
size=num_channels,
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
stdv = 1.0 / math.sqrt(med_ch * 1.0)
self.excitation = Linear(
med_ch,
num_filters,
act="sigmoid",
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_exc_weights'),
name=name + "_exc_weights"),
bias_attr=ParamAttr(name=name + '_exc_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
def conv_bn_layer(self,
input,
filter_size,
def forward(self, input):
pool = self.pool2d_gap(input)
pool = fluid.layers.reshape(pool, shape=[-1, self._num_channels])
squeeze = self.squeeze(pool)
excitation = self.excitation(squeeze)
excitation = fluid.layers.reshape(
excitation, shape=[-1, self._num_channels, 1, 1])
out = input * excitation
return out
class Stage(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_modules,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
has_se=False,
multi_scale_output=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
super(Stage, self).__init__()
self._num_modules = num_modules
self.stage_func_list = []
for i in range(num_modules):
if i == num_modules - 1 and not multi_scale_output:
stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1),
HighResolutionModule(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
has_se=has_se,
multi_scale_output=False,
name=name + '_' + str(i + 1)))
else:
stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1),
HighResolutionModule(
num_channels=num_channels,
num_filters=num_filters,
has_se=has_se,
name=name + '_' + str(i + 1)))
self.stage_func_list.append(stage_func)
def forward(self, input):
out = input
for idx in range(self._num_modules):
out = self.stage_func_list[idx](out)
return out
class HighResolutionModule(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
has_se=False,
multi_scale_output=True,
name=None):
super(HighResolutionModule, self).__init__()
self.branches_func = Branches(
block_num=4,
in_channels=num_channels,
out_channels=num_filters,
has_se=has_se,
name=name)
self.fuse_func = FuseLayers(
in_channels=num_filters,
out_channels=num_filters,
multi_scale_output=multi_scale_output,
name=name)
def forward(self, input):
out = self.branches_func(input)
out = self.fuse_func(out)
return out
class FuseLayers(fluid.dygraph.Layer):
def __init__(self,
in_channels,
out_channels,
multi_scale_output=True,
name=None):
super(FuseLayers, self).__init__()
self._actual_ch = len(in_channels) if multi_scale_output else 1
self._in_channels = in_channels
self.residual_func_list = []
for i in range(self._actual_ch):
for j in range(len(in_channels)):
residual_func = None
if j > i:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
ConvBNLayer(
num_channels=in_channels[j],
num_filters=out_channels[i],
filter_size=1,
stride=1,
act=None,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1)))
self.residual_func_list.append(residual_func)
elif j < i:
pre_num_filters = in_channels[j]
for k in range(i - j):
if k == i - j - 1:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}_{}".format(
name, i + 1, j + 1, k + 1),
ConvBNLayer(
num_channels=pre_num_filters,
num_filters=out_channels[i],
filter_size=3,
stride=2,
act=None,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1)))
pre_num_filters = out_channels[i]
else:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}_{}".format(
name, i + 1, j + 1, k + 1),
ConvBNLayer(
num_channels=pre_num_filters,
num_filters=out_channels[j],
filter_size=3,
stride=2,
act="relu",
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1)))
pre_num_filters = out_channels[j]
self.residual_func_list.append(residual_func)
def forward(self, input):
outs = []
residual_func_idx = 0
for i in range(self._actual_ch):
residual = input[i]
for j in range(len(self._in_channels)):
if j > i:
y = self.residual_func_list[residual_func_idx](input[j])
residual_func_idx += 1
y = fluid.layers.resize_nearest(input=y, scale=2**(j - i))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = input[j]
for k in range(i - j):
y = self.residual_func_list[residual_func_idx](y)
residual_func_idx += 1
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
layer_helper = LayerHelper(self.full_name(), act='relu')
residual = layer_helper.append_activation(residual)
outs.append(residual)
return outs
class LastClsOut(fluid.dygraph.Layer):
def __init__(self,
num_channel_list,
has_se,
num_filters_list=[32, 64, 128, 256],
name=None):
super(LastClsOut, self).__init__()
self.func_list = []
for idx in range(len(num_channel_list)):
func = self.add_sublayer(
"conv_{}_conv_{}".format(name, idx + 1),
BottleneckBlock(
num_channels=num_channel_list[idx],
num_filters=num_filters_list[idx],
has_se=has_se,
downsample=True,
name=name + 'conv_' + str(idx + 1)))
self.func_list.append(func)
def forward(self, inputs):
outs = []
for idx, input in enumerate(inputs):
out = self.func_list[idx](input)
outs.append(out)
return outs
class HRNet(fluid.dygraph.Layer):
def __init__(self, width=18, has_se=False, class_dim=1000):
super(HRNet, self).__init__()
self.width = width
self.has_se = has_se
self.channels = {
18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
}
self._class_dim = class_dim
channels_2, channels_3, channels_4 = self.channels[width]
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
self.conv_layer1_1 = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=3,
stride=2,
act='relu',
name="layer1_1")
self.conv_layer1_2 = ConvBNLayer(
num_channels=64,
num_filters=64,
filter_size=3,
stride=2,
act='relu',
name="layer1_2")
self.la1 = Layer1(num_channels=64, has_se=has_se, name="layer2")
self.tr1 = TransitionLayer(
in_channels=[256], out_channels=channels_2, name="tr1")
self.st2 = Stage(
num_channels=channels_2,
num_modules=num_modules_2,
num_filters=channels_2,
has_se=self.has_se,
name="st2")
self.tr2 = TransitionLayer(
in_channels=channels_2, out_channels=channels_3, name="tr2")
self.st3 = Stage(
num_channels=channels_3,
num_modules=num_modules_3,
num_filters=channels_3,
has_se=self.has_se,
name="st3")
self.tr3 = TransitionLayer(
in_channels=channels_3, out_channels=channels_4, name="tr3")
self.st4 = Stage(
num_channels=channels_4,
num_modules=num_modules_4,
num_filters=channels_4,
has_se=self.has_se,
name="st4")
# classification
num_filters_list = [32, 64, 128, 256]
self.last_cls = LastClsOut(
num_channel_list=channels_4,
has_se=self.has_se,
num_filters_list=num_filters_list,
name="cls_head", )
last_num_filters = [256, 512, 1024]
self.cls_head_conv_list = []
for idx in range(3):
self.cls_head_conv_list.append(
self.add_sublayer(
"cls_head_add{}".format(idx + 1),
ConvBNLayer(
num_channels=num_filters_list[idx] * 4,
num_filters=last_num_filters[idx],
filter_size=3,
stride=2,
name="cls_head_add" + str(idx + 1))))
self.conv_last = ConvBNLayer(
num_channels=1024,
num_filters=2048,
filter_size=1,
stride=1,
name="cls_head_last_conv")
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.out = Linear(
2048,
class_dim,
param_attr=ParamAttr(
initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
bn = fluid.layers.relu(bn)
return bn
initializer=fluid.initializer.Uniform(-stdv, stdv),
name="fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
def forward(self, input):
conv1 = self.conv_layer1_1(input)
conv2 = self.conv_layer1_2(conv1)
la1 = self.la1(conv2)
tr1 = self.tr1([la1])
st2 = self.st2(tr1)
tr2 = self.tr2(st2)
st3 = self.st3(tr2)
tr3 = self.tr3(st3)
st4 = self.st4(tr3)
last_cls = self.last_cls(st4)
y = last_cls[0]
for idx in range(3):
y = last_cls[idx + 1] + self.cls_head_conv_list[idx](y)
y = self.conv_last(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[0, -1])
y = self.out(y)
return y
def HRNet_W18_C():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册