未验证 提交 49176dc0 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #342 from michaelowenliu/develop

implement deeplabv3p with resnet50/101 and xception65
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import fluid
from paddle.fluid import dygraph
from paddle.fluid.dygraph import Conv2D
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
import cv2
import os
import sys
class ConvBnRelu(dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
using_sep_conv=False,
**kwargs):
super(ConvBnRelu, self).__init__()
if using_sep_conv:
self.conv = DepthwiseConvBnRelu(num_channels,
num_filters,
filter_size,
**kwargs)
else:
self.conv = Conv2D(num_channels,
num_filters,
filter_size,
**kwargs)
self.batch_norm = BatchNorm(num_filters)
def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
x = fluid.layers.relu(x)
return x
class ConvBn(dygraph.Layer):
def __init__(self, num_channels, num_filters, filter_size, **kwargs):
super(ConvBn, self).__init__()
self.conv = Conv2D(num_channels,
num_filters,
filter_size,
**kwargs)
self.batch_norm = BatchNorm(num_filters)
def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
return x
class ConvReluPool(dygraph.Layer):
def __init__(self, num_channels, num_filters):
super(ConvReluPool, self).__init__()
self.conv = Conv2D(num_channels,
num_filters,
filter_size=3,
stride=1,
padding=1,
dilation=1)
def forward(self, x):
x = self.conv(x)
x = fluid.layers.relu(x)
x = fluid.layers.pool2d(x, pool_size=2, pool_type="max", pool_stride=2)
return x
class ConvBnReluUpsample(dygraph.Layer):
def __init__(self, num_channels, num_filters):
super(ConvBnReluUpsample, self).__init__()
self.conv_bn_relu = ConvBnRelu(num_channels, num_filters)
def forward(self, x, upsample_scale=2):
x = self.conv_bn_relu(x)
new_shape = [x.shape[2] * upsample_scale, x.shape[3] * upsample_scale]
x = fluid.layers.resize_bilinear(x, new_shape)
return x
class DepthwiseConvBnRelu(dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
**kwargs):
super(DepthwiseConvBnRelu, self).__init__()
self.depthwise_conv = ConvBn(num_channels,
num_filters=num_channels,
filter_size=filter_size,
groups=num_channels,
use_cudnn=False,
**kwargs)
self.piontwise_conv = ConvBnRelu(num_channels,
num_filters,
filter_size=1,
groups=1)
def forward(self, x):
x = self.depthwise_conv(x)
x = self.piontwise_conv(x)
return x
def compute_loss(logits, label, ignore_index=255):
mask = label != ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logits,
label,
ignore_index=ignore_index,
return_softmax=True,
axis=1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + 1e-5)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
from dygraph.utils import utils
__all__ = [
"ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
]
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(
self,
num_channels,
num_filters,
filter_size,
stride=1,
dilation=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = Pool2D(
pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True)
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2 if dilation ==1 else 0,
dilation=dilation,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
if_first=False,
dilation=1,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
self.dilation = dilation
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
dilation=dilation,
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=1,
is_vd_mode=False if if_first or stride==1 else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
####################################################################
# If given dilation rate > 1, using corresponding padding
if self.dilation > 1:
padding = self.dilation
y = fluid.layers.pad(y, [0,0,0,0,padding,padding,padding,padding])
#####################################################################
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(y)
class BasicBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv1)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(y)
class ResNet_vd(fluid.dygraph.Layer):
def __init__(self, layers=50, class_dim=1000, dilation_dict=None, multi_grid=(1, 2, 4), **kwargs):
super(ResNet_vd, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.conv1_1 = ConvBNLayer(
num_channels=3,
num_filters=32,
filter_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = ConvBNLayer(
num_channels=32,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = ConvBNLayer(
num_channels=32,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
# self.block_list = []
self.stage_list = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
block_list=[]
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
###############################################################################
# Add dilation rate for some segmentation tasks, if dilation_dict is not None.
dilation_rate = dilation_dict[block] if dilation_dict and block in dilation_dict else 1
# Actually block here is 'stage', and i is 'block' in 'stage'
# At the stage 4, expand the the dilation_rate using multi_grid, default (1, 2, 4)
if block == 3:
dilation_rate = dilation_rate * multi_grid[i]
#print("stage {}, block {}: dilation rate".format(block, i), dilation_rate)
###############################################################################
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
num_channels=num_channels[block] if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 and dilation_rate == 1 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name,
dilation=dilation_rate))
block_list.append(bottleneck_block)
shortcut = True
self.stage_list.append(block_list)
else:
for block in range(len(depth)):
shortcut = False
block_list=[]
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
block_list.append(basic_block)
shortcut = True
self.stage_list.append(block_list)
self.pool2d_avg = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
# A feature list saves the output feature map of each stage.
feat_list = []
for i, stage in enumerate(self.stage_list):
for j, block in enumerate(stage):
y = block(y)
#print("stage {} block {}".format(i+1, j+1), y.shape)
feat_list.append(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
return y, feat_list
# def init_weight(self, pretrained_model=None):
# if pretrained_model is not None:
# if os.path.exists(pretrained_model):
# utils.load_pretrained_model(self, pretrained_model)
def ResNet18_vd(**args):
model = ResNet_vd(layers=18, **args)
return model
def ResNet34_vd(**args):
model = ResNet_vd(layers=34, **args)
return model
def ResNet50_vd(**args):
model = ResNet_vd(layers=50, **args)
return model
def ResNet101_vd(**args):
model = ResNet_vd(layers=101, **args)
return model
def ResNet152_vd(**args):
model = ResNet_vd(layers=152, **args)
return model
def ResNet200_vd(**args):
model = ResNet_vd(layers=200, **args)
return model
\ No newline at end of file
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
__all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
def check_data(data, number):
if type(data) == int:
return [data] * number
assert len(data) == number
return data
def check_stride(s, os):
if s <= os:
return True
else:
return False
def check_points(count, points):
if points is None:
return False
else:
if isinstance(points, list):
return (True if count in points else False)
else:
return (True if count == points else False)
def gen_bottleneck_params(backbone='xception_65'):
if backbone == 'xception_65':
bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
"middle_flow": (16, 1, 728),
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
}
elif backbone == 'xception_41':
bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
"middle_flow": (8, 1, 728),
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
}
elif backbone == 'xception_71':
bottleneck_params = {
"entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
"middle_flow": (16, 1, 728),
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
}
else:
raise Exception(
"xception backbont only support xception_41/xception_65/xception_71"
)
return bottleneck_params
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
input_channels,
output_channels,
filter_size,
stride=1,
padding=0,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=input_channels,
num_filters=output_channels,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=ParamAttr(name=name + "/weights"),
bias_attr=False)
self._bn = BatchNorm(
num_channels=output_channels,
act=act,
epsilon=1e-3,
momentum=0.99,
param_attr=ParamAttr(name=name + "/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/BatchNorm/beta"),
moving_mean_name=name + "/BatchNorm/moving_mean",
moving_variance_name=name + "/BatchNorm/moving_variance")
def forward(self, inputs):
return self._bn(self._conv(inputs))
class Seperate_Conv(fluid.dygraph.Layer):
def __init__(self,
input_channels,
output_channels,
stride,
filter,
dilation=1,
act=None,
name=None):
super(Seperate_Conv, self).__init__()
self._conv1 = Conv2D(
num_channels=input_channels,
num_filters=input_channels,
filter_size=filter,
stride=stride,
groups=input_channels,
padding=(filter) // 2 * dilation,
dilation=dilation,
param_attr=ParamAttr(name=name + "/depthwise/weights"),
bias_attr=False)
self._bn1 = BatchNorm(
input_channels,
act=act,
epsilon=1e-3,
momentum=0.99,
param_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"),
moving_mean_name=name + "/depthwise/BatchNorm/moving_mean",
moving_variance_name=name + "/depthwise/BatchNorm/moving_variance")
self._conv2 = Conv2D(
input_channels,
output_channels,
1,
stride=1,
groups=1,
padding=0,
param_attr=ParamAttr(name=name + "/pointwise/weights"),
bias_attr=False)
self._bn2 = BatchNorm(
output_channels,
act=act,
epsilon=1e-3,
momentum=0.99,
param_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta"),
moving_mean_name=name + "/pointwise/BatchNorm/moving_mean",
moving_variance_name=name + "/pointwise/BatchNorm/moving_variance")
def forward(self, inputs):
x = self._conv1(inputs)
x = self._bn1(x)
x = self._conv2(x)
x = self._bn2(x)
return x
class Xception_Block(fluid.dygraph.Layer):
def __init__(self,
input_channels,
output_channels,
strides=1,
filter_size=3,
dilation=1,
skip_conv=True,
has_skip=True,
activation_fn_in_separable_conv=False,
name=None):
super(Xception_Block, self).__init__()
repeat_number = 3
output_channels = check_data(output_channels, repeat_number)
filter_size = check_data(filter_size, repeat_number)
strides = check_data(strides, repeat_number)
self.has_skip = has_skip
self.skip_conv = skip_conv
self.activation_fn_in_separable_conv = activation_fn_in_separable_conv
if not activation_fn_in_separable_conv:
self._conv1 = Seperate_Conv(
input_channels,
output_channels[0],
stride=strides[0],
filter=filter_size[0],
dilation=dilation,
name=name + "/separable_conv1")
self._conv2 = Seperate_Conv(
output_channels[0],
output_channels[1],
stride=strides[1],
filter=filter_size[1],
dilation=dilation,
name=name + "/separable_conv2")
self._conv3 = Seperate_Conv(
output_channels[1],
output_channels[2],
stride=strides[2],
filter=filter_size[2],
dilation=dilation,
name=name + "/separable_conv3")
else:
self._conv1 = Seperate_Conv(
input_channels,
output_channels[0],
stride=strides[0],
filter=filter_size[0],
act="relu",
dilation=dilation,
name=name + "/separable_conv1")
self._conv2 = Seperate_Conv(
output_channels[0],
output_channels[1],
stride=strides[1],
filter=filter_size[1],
act="relu",
dilation=dilation,
name=name + "/separable_conv2")
self._conv3 = Seperate_Conv(
output_channels[1],
output_channels[2],
stride=strides[2],
filter=filter_size[2],
act="relu",
dilation=dilation,
name=name + "/separable_conv3")
if has_skip and skip_conv:
self._short = ConvBNLayer(
input_channels,
output_channels[-1],
1,
stride=strides[-1],
padding=0,
name=name + "/shortcut")
def forward(self, inputs):
layer_helper = LayerHelper(self.full_name(), act='relu')
if not self.activation_fn_in_separable_conv:
x = layer_helper.append_activation(inputs)
x = self._conv1(x)
x = layer_helper.append_activation(x)
x = self._conv2(x)
x = layer_helper.append_activation(x)
x = self._conv3(x)
else:
x = self._conv1(inputs)
x = self._conv2(x)
x = self._conv3(x)
if self.has_skip is False:
return x
if self.skip_conv:
skip = self._short(inputs)
else:
skip = inputs
return fluid.layers.elementwise_add(x, skip)
class XceptionDeeplab(fluid.dygraph.Layer):
#def __init__(self, backbone, class_dim=1000):
# add output_stride
def __init__(self, backbone, output_stride=16, class_dim=1000, **kwargs):
super(XceptionDeeplab, self).__init__()
bottleneck_params = gen_bottleneck_params(backbone)
self.backbone = backbone
self._conv1 = ConvBNLayer(
3,
32,
3,
stride=2,
padding=1,
act="relu",
name=self.backbone + "/entry_flow/conv1")
self._conv2 = ConvBNLayer(
32,
64,
3,
stride=1,
padding=1,
act="relu",
name=self.backbone + "/entry_flow/conv2")
"""
bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
"middle_flow": (16, 1, 728),
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
}
if output_stride == 16:
entry_block3_stride = 2
middle_block_dilation = 1
exit_block_dilations = (1, 2)
elif output_stride == 8:
entry_block3_stride = 1
middle_block_dilation = 2
exit_block_dilations = (2, 4)
"""
self.block_num = bottleneck_params["entry_flow"][0]
self.strides = bottleneck_params["entry_flow"][1]
self.chns = bottleneck_params["entry_flow"][2]
self.strides = check_data(self.strides, self.block_num)
self.chns = check_data(self.chns, self.block_num)
self.entry_flow = []
self.middle_flow = []
self.stride = 2
self.output_stride = output_stride
s = self.stride
for i in range(self.block_num):
stride = self.strides[i] if check_stride(s * self.strides[i],
self.output_stride) else 1
xception_block = self.add_sublayer(
self.backbone + "/entry_flow/block" + str(i + 1),
Xception_Block(
input_channels=64 if i == 0 else self.chns[i - 1],
output_channels=self.chns[i],
strides=[1, 1, self.stride],
name=self.backbone + "/entry_flow/block" + str(i + 1)))
self.entry_flow.append(xception_block)
s = s * stride
self.stride = s
self.block_num = bottleneck_params["middle_flow"][0]
self.strides = bottleneck_params["middle_flow"][1]
self.chns = bottleneck_params["middle_flow"][2]
self.strides = check_data(self.strides, self.block_num)
self.chns = check_data(self.chns, self.block_num)
s = self.stride
for i in range(self.block_num):
stride = self.strides[i] if check_stride(s * self.strides[i],
self.output_stride) else 1
xception_block = self.add_sublayer(
self.backbone + "/middle_flow/block" + str(i + 1),
Xception_Block(
input_channels=728,
output_channels=728,
strides=[1, 1, self.strides[i]],
skip_conv=False,
name=self.backbone + "/middle_flow/block" + str(i + 1)))
self.middle_flow.append(xception_block)
s = s * stride
self.stride = s
self.block_num = bottleneck_params["exit_flow"][0]
self.strides = bottleneck_params["exit_flow"][1]
self.chns = bottleneck_params["exit_flow"][2]
self.strides = check_data(self.strides, self.block_num)
self.chns = check_data(self.chns, self.block_num)
s = self.stride
stride = self.strides[0] if check_stride(s * self.strides[0],
self.output_stride) else 1
self._exit_flow_1 = Xception_Block(
728,
self.chns[0], [1, 1, stride],
name=self.backbone + "/exit_flow/block1")
s = s * stride
stride = self.strides[1] if check_stride(s * self.strides[1],
self.output_stride) else 1
self._exit_flow_2 = Xception_Block(
self.chns[0][-1],
self.chns[1], [1, 1, stride],
dilation=2,
has_skip=False,
activation_fn_in_separable_conv=True,
name=self.backbone + "/exit_flow/block2")
s = s * stride
self.stride = s
self._drop = Dropout(p=0.5)
self._pool = Pool2D(pool_type="avg", global_pooling=True)
self._fc = Linear(
self.chns[1][-1],
class_dim,
param_attr=ParamAttr(name="fc_weights"),
bias_attr=ParamAttr(name="fc_bias"))
def forward(self, inputs):
x = self._conv1(inputs)
x = self._conv2(x)
feat_list = []
for i, ef in enumerate(self.entry_flow):
x = ef(x)
if i == 0:
feat_list.append(x)
for mf in self.middle_flow:
x = mf(x)
x = self._exit_flow_1(x)
x = self._exit_flow_2(x)
feat_list.append(x)
x = self._drop(x)
x = self._pool(x)
x = fluid.layers.squeeze(x, axes=[2, 3])
x = self._fc(x)
return x, feat_list
def Xception41_deeplab(**args):
model = XceptionDeeplab('xception_41', **args)
return model
def Xception65_deeplab(**args):
model = XceptionDeeplab("xception_65", **args)
return model
def Xception71_deeplab(**args):
model = XceptionDeeplab("xception_71", **args)
return model
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import dygraph
from paddle.fluid.dygraph import Conv2D
from .architectures import layer_utils, xception_deeplab, resnet_vd
from dygraph.utils import utils
__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8",
"deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8",
"deeplabv3p_xception65_deeplab"]
class ImageAverage(dygraph.Layer):
"""
Global average pooling
Args:
num_channels (int): the number of input channels.
"""
def __init__(self, num_channels):
super(ImageAverage, self).__init__()
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels,
num_filters=256,
filter_size=1)
def forward(self, input):
x = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True)
x = self.conv_bn_relu(x)
x = fluid.layers.resize_bilinear(x, out_shape=input.shape[2:])
return x
class ASPP(dygraph.Layer):
"""
Decoder module of DeepLabV3P model
Args:
output_stride (int): the ratio of input size and final feature size. Support 16 or 8.
in_channels (int): the number of input channels in decoder module.
using_sep_conv (bool): whether use separable conv or not. Default to True.
"""
def __init__(self, output_stride, in_channels, using_sep_conv=True):
super(ASPP, self).__init__()
if output_stride == 16:
aspp_ratios = (6, 12, 18)
elif output_stride == 8:
aspp_ratios = (12, 24, 36)
else:
raise NotImplementedError("Only support output_stride is 8 or 16, but received{}".format(output_stride))
self.image_average = ImageAverage(num_channels=in_channels)
# The first aspp using 1*1 conv
self.aspp1 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=256,
filter_size=1,
using_sep_conv=False)
# The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0]
self.aspp2 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[0],
padding=aspp_ratios[0])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1]
self.aspp3 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[1],
padding=aspp_ratios[1])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2]
self.aspp4 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[2],
padding=aspp_ratios[2])
# After concat op, using 1*1 conv
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=1280,
num_filters=256,
filter_size=1)
def forward(self, x):
x1 = self.image_average(x)
x2 = self.aspp1(x)
x3 = self.aspp2(x)
x4 = self.aspp3(x)
x5 = self.aspp4(x)
x = fluid.layers.concat([x1, x2, x3, x4, x5], axis=1)
x = self.conv_bn_relu(x)
x = fluid.layers.dropout(x, dropout_prob=0.1)
return x
class Decoder(dygraph.Layer):
"""
Decoder module of DeepLabV3P model
Args:
num_classes (int): the number of classes.
in_channels (int): the number of input channels in decoder module.
using_sep_conv (bool): whether use separable conv or not. Default to True.
"""
def __init__(self, num_classes, in_channels, using_sep_conv=True):
super(Decoder, self).__init__()
self.conv_bn_relu1 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=48,
filter_size=1)
self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=304,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
padding=1)
self.conv_bn_relu3 = layer_utils.ConvBnRelu(num_channels=256,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
padding=1)
self.conv = Conv2D(num_channels=256,
num_filters=num_classes,
filter_size=1)
def forward(self, x, low_level_feat):
low_level_feat = self.conv_bn_relu1(low_level_feat)
x = fluid.layers.resize_bilinear(x, low_level_feat.shape[2:])
x = fluid.layers.concat([x, low_level_feat], axis=1)
x = self.conv_bn_relu2(x)
x = self.conv_bn_relu3(x)
x = self.conv(x)
return x
class DeepLabV3P(dygraph.Layer):
"""
The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder
The orginal artile refers to
"Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam.
(https://arxiv.org/abs/1802.02611)
Args:
backbone (str): backbone name, currently support Xception65, Resnet101_vd. Default Resnet101_vd.
num_classes (int): the unique number of target classes. Default 2.
output_stride (int): the ratio of input size and final feature size. Default 16.
backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
the first index will be taken as a low-level feature in Deconder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage, so we set default (0, 3), which means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default 255.
using_sep_conv (bool): a bool value indicates whether using separable convolutions
in ASPP and Decoder components. Default True.
pretrained_model (str): the pretrained_model path of backbone.
"""
def __init__(self,
backbone,
num_classes=2,
output_stride=16,
backbone_indices=(0,3),
backbone_channels=(256, 2048),
ignore_index=255,
using_sep_conv=True,
pretrained_model=None):
super(DeepLabV3P, self).__init__()
self.backbone = build_backbone(backbone, output_stride)
self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv)
self.decoder = Decoder(num_classes, backbone_channels[0], using_sep_conv)
self.ignore_index = ignore_index
self.EPS = 1e-5
self.backbone_indices = backbone_indices
self.init_weight(pretrained_model)
def forward(self, input, label=None, mode='train'):
_, feat_list = self.backbone(input)
low_level_feat = feat_list[self.backbone_indices[0]]
x = feat_list[self.backbone_indices[1]]
x = self.aspp(x)
logit = self.decoder(x, low_level_feat)
logit = fluid.layers.resize_bilinear(logit, input.shape[2:])
if self.training:
return self._get_loss(logit, label)
else:
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
# utils.load_pretrained_model(self, pretrained_model)
# for param in self.backbone.parameters():
# param.stop_gradient = True
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
def build_backbone(backbone, output_stride):
if output_stride == 8:
dilation_dict = {2: 2, 3: 4}
elif output_stride == 16:
dilation_dict = {3: 2}
else:
raise Exception("deeplab only support stride 8 or 16")
model_dict = {"ResNet50_vd":resnet_vd.ResNet50_vd,
"ResNet101_vd":resnet_vd.ResNet101_vd,
"Xception65_deeplab": xception_deeplab.Xception65_deeplab}
model = model_dict[backbone]
return model(dilation_dict=dilation_dict)
def build_aspp(output_stride, using_sep_conv):
return ASPP(output_stride=output_stride, using_sep_conv=using_sep_conv)
def build_decoder(num_classes, using_sep_conv):
return Decoder(num_classes, using_sep_conv=using_sep_conv)
def deeplabv3p_resnet101_vd(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
def deeplabv3p_resnet101_vd_os8(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
def deeplabv3p_resnet50_vd(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
def deeplabv3p_resnet50_vd_os8(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
def deeplabv3p_xception65_deeplab(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='Xception65_deeplab',
pretrained_model=pretrained_model,
backbone_indices=(0,1),
backbone_channels=(128, 2048),
**kwargs)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册