未验证 提交 24856d05 编写于 作者: W wuzewu 提交者: GitHub

add mobilenetv3 backbone (#306)

* Add mobilenet v3

* Update config
上级 ea0978eb
EVAL_CROP_SIZE: (2049, 1025) # (width, height), for unpadding rangescaling and stepscaling
TRAIN_CROP_SIZE: (769, 769) # (width, height), for unpadding rangescaling and stepscaling
AUG:
AUG_METHOD: "stepscaling" # choice unpadding rangescaling and stepscaling
MAX_SCALE_FACTOR: 2.0 # for stepscaling
MIN_SCALE_FACTOR: 0.5 # for stepscaling
SCALE_STEP_SIZE: 0.25 # for stepscaling
MIRROR: True
BATCH_SIZE: 32
DATASET:
DATA_DIR: "./dataset/cityscapes/"
IMAGE_TYPE: "rgb" # choice rgb or rgba
NUM_CLASSES: 19
TEST_FILE_LIST: "dataset/cityscapes/val.list"
TRAIN_FILE_LIST: "dataset/cityscapes/train.list"
VAL_FILE_LIST: "dataset/cityscapes/val.list"
IGNORE_INDEX: 255
SEPARATOR: " "
FREEZE:
MODEL_FILENAME: "model"
PARAMS_FILENAME: "params"
MODEL:
DEFAULT_NORM_TYPE: "bn"
MODEL_NAME: "deeplabv3p"
DEEPLAB:
BACKBONE: "mobilenetv3_large"
ASPP_WITH_SEP_CONV: True
DECODER_USE_SEP_CONV: True
ENCODER_WITH_ASPP: True
ENABLE_DECODER: True
OUTPUT_STRIDE: 32
BACKBONE_LR_MULT_LIST: [0.15,0.35,0.65,0.85,1]
ENCODER:
POOLING_STRIDE: (4, 5)
POOLING_CROP_SIZE: (769, 769)
ASPP_WITH_SE: True
SE_USE_QSIGMOID: True
ASPP_CONVS_FILTERS: 128
ASPP_WITH_CONCAT_PROJECTION: False
ADD_IMAGE_LEVEL_FEATURE: False
DECODER:
USE_SUM_MERGE: True
CONV_FILTERS: 19
OUTPUT_IS_LOGITS: True
TRAIN:
PRETRAINED_MODEL_DIR: u"pretrained_model/mobilenetv3-1-0_large_bn_imagenet"
MODEL_SAVE_DIR: "saved_model/deeplabv3p_mobilenetv3_large_cityscapes"
SNAPSHOT_EPOCH: 1
SYNC_BATCH_NORM: True
TEST:
TEST_MODEL: "saved_model/deeplabv3p_mobilenetv3_large_cityscapes/final"
SOLVER:
LR: 0.2
LR_POLICY: "poly"
OPTIMIZER: "sgd"
NUM_EPOCHS: 850
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = [
'MobileNetV3', 'MobileNetV3_small_x0_35', 'MobileNetV3_small_x0_5',
'MobileNetV3_small_x0_75', 'MobileNetV3_small_x1_0',
'MobileNetV3_small_x1_25', 'MobileNetV3_large_x0_35',
'MobileNetV3_large_x0_5', 'MobileNetV3_large_x0_75',
'MobileNetV3_large_x1_0', 'MobileNetV3_large_x1_25'
]
class MobileNetV3():
def __init__(self,
scale=1.0,
model_name='small',
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
output_stride=None):
self.scale = scale
self.inplanes = 16
self.lr_mult_list = lr_mult_list
assert len(self.lr_mult_list) == 5, \
"lr_mult_list length in MobileNetV3 must be 5 but got {}!!".format(
len(self.lr_mult_list))
self.curr_stage = 0
self.decode_point = None
self.end_point = None
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
# The number of channels in the last 4 stages is reduced by a
# factor of 2 compared to the standard implementation.
[5, 336, 80, True, 'hard_swish', 2],
[5, 480, 80, True, 'hard_swish', 1],
[5, 480, 80, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 480
self.cls_ch_expand = 1280
self.lr_interval = 3
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
# The number of channels in the last 4 stages is reduced by a
# factor of 2 compared to the standard implementation.
[5, 144, 48, True, 'hard_swish', 2],
[5, 288, 48, True, 'hard_swish', 1],
[5, 288, 48, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 288
self.cls_ch_expand = 1280
self.lr_interval = 2
else:
raise NotImplementedError(
"mode[{}_model] is not implemented!".format(model_name))
self.modify_bottle_params(output_stride)
def modify_bottle_params(self, output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is None:
return
else:
stride = 2
for i, _cfg in enumerate(self.cfg):
stride = stride * _cfg[-1]
if stride > output_stride:
s = 1
self.cfg[i][-1] = s
def net(self, input, class_dim=1000, end_points=None, decode_points=None):
scale = self.scale
inplanes = self.inplanes
cfg = self.cfg
cls_ch_squeeze = self.cls_ch_squeeze
cls_ch_expand = self.cls_ch_expand
# conv1
conv = self.conv_bn_layer(
input,
filter_size=3,
num_filters=self.make_divisible(inplanes * scale),
stride=2,
padding=1,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv1')
i = 0
inplanes = self.make_divisible(inplanes * scale)
for layer_cfg in cfg:
conv = self.residual_unit(
input=conv,
num_in_filter=inplanes,
num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
num_out_filter=self.make_divisible(scale * layer_cfg[2]),
act=layer_cfg[4],
stride=layer_cfg[5],
filter_size=layer_cfg[0],
use_se=layer_cfg[3],
name='conv' + str(i + 2))
inplanes = self.make_divisible(scale * layer_cfg[2])
i += 1
self.curr_stage = i
conv = self.conv_bn_layer(
input=conv,
filter_size=1,
num_filters=self.make_divisible(scale * cls_ch_squeeze),
stride=1,
padding=0,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv_last')
return conv, self.decode_point
conv = fluid.layers.pool2d(
input=conv, pool_type='avg', global_pooling=True, use_cudnn=False)
conv = fluid.layers.conv2d(
input=conv,
num_filters=cls_ch_expand,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name='last_1x1_conv_weights'),
bias_attr=False)
conv = fluid.layers.hard_swish(conv)
drop = fluid.layers.dropout(x=conv, dropout_prob=0.2)
out = fluid.layers.fc(
input=drop,
size=class_dim,
param_attr=ParamAttr(name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
return out
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
if_act=True,
act=None,
name=None,
use_cudnn=True,
res_last_bn_init=False):
lr_idx = self.curr_stage // self.lr_interval
lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
lr_mult = self.lr_mult_list[lr_idx]
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights', learning_rate=lr_mult),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
if act == 'relu':
bn = fluid.layers.relu(bn)
elif act == 'hard_swish':
bn = fluid.layers.hard_swish(bn)
return bn
def make_divisible(self, v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def se_block(self, input, num_out_filter, ratio=4, name=None):
lr_idx = self.curr_stage // self.lr_interval
lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
lr_mult = self.lr_mult_list[lr_idx]
num_mid_filter = num_out_filter // ratio
pool = fluid.layers.pool2d(
input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act='relu',
param_attr=ParamAttr(
name=name + '_1_weights', learning_rate=lr_mult),
bias_attr=ParamAttr(name=name + '_1_offset', learning_rate=lr_mult))
conv2 = fluid.layers.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(
name=name + '_2_weights', learning_rate=lr_mult),
bias_attr=ParamAttr(name=name + '_2_offset', learning_rate=lr_mult))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
def residual_unit(self,
input,
num_in_filter,
num_mid_filter,
num_out_filter,
stride,
filter_size,
act=None,
use_se=False,
name=None):
conv0 = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_mid_filter,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + '_expand')
conv1 = self.conv_bn_layer(
input=conv0,
filter_size=filter_size,
num_filters=num_mid_filter,
stride=stride,
padding=int((filter_size - 1) // 2),
if_act=True,
act=act,
num_groups=num_mid_filter,
use_cudnn=False,
name=name + '_depthwise')
if self.curr_stage == 5:
self.decode_point = conv1
if use_se:
conv1 = self.se_block(
input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
conv2 = self.conv_bn_layer(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
stride=1,
padding=0,
if_act=False,
name=name + '_linear',
res_last_bn_init=True)
if num_in_filter != num_out_filter or stride != 1:
return conv2
else:
return fluid.layers.elementwise_add(x=input, y=conv2, act=None)
def MobileNetV3_small_x0_35():
model = MobileNetV3(model_name='small', scale=0.35)
return model
def MobileNetV3_small_x0_5():
model = MobileNetV3(model_name='small', scale=0.5)
return model
def MobileNetV3_small_x0_75():
model = MobileNetV3(model_name='small', scale=0.75)
return model
def MobileNetV3_small_x1_0(**args):
model = MobileNetV3(model_name='small', scale=1.0, **args)
return model
def MobileNetV3_small_x1_25():
model = MobileNetV3(model_name='small', scale=1.25)
return model
def MobileNetV3_large_x0_35():
model = MobileNetV3(model_name='large', scale=0.35)
return model
def MobileNetV3_large_x0_5():
model = MobileNetV3(model_name='large', scale=0.5)
return model
def MobileNetV3_large_x0_75():
model = MobileNetV3(model_name='large', scale=0.75)
return model
def MobileNetV3_large_x1_0(**args):
model = MobileNetV3(model_name='large', scale=1.0, **args)
return model
def MobileNetV3_large_x1_25():
model = MobileNetV3(model_name='large', scale=1.25)
return model
......@@ -109,6 +109,10 @@ def bn_relu(data):
return fluid.layers.relu(bn(data))
def qsigmoid(data):
return fluid.layers.relu6(data + 3) * 0.16667
def relu(data):
return fluid.layers.relu(data)
......
......@@ -21,10 +21,11 @@ import paddle
import paddle.fluid as fluid
from utils.config import cfg
from models.libs.model_libs import scope, name_scope
from models.libs.model_libs import bn, bn_relu, relu
from models.libs.model_libs import bn, bn_relu, relu, qsigmoid
from models.libs.model_libs import conv
from models.libs.model_libs import separate_conv
from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_backbone
from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_v2_backbone
from models.backbone.mobilenet_v3 import MobileNetV3 as mobilenet_v3_backbone
from models.backbone.xception import Xception as xception_backbone
from models.backbone.resnet_vd import ResNet as resnet_vd_backbone
......@@ -35,22 +36,42 @@ def encoder(input):
# OUTPUT_STRIDE: 下采样倍数,8或16,决定aspp_ratios大小
# aspp_ratios:ASPP模块空洞卷积的采样率
if cfg.MODEL.DEEPLAB.OUTPUT_STRIDE == 16:
aspp_ratios = [6, 12, 18]
elif cfg.MODEL.DEEPLAB.OUTPUT_STRIDE == 8:
aspp_ratios = [12, 24, 36]
if not cfg.MODEL.DEEPLAB.ENCODER.ASPP_RATIOS:
if cfg.MODEL.DEEPLAB.OUTPUT_STRIDE == 16:
aspp_ratios = [6, 12, 18]
elif cfg.MODEL.DEEPLAB.OUTPUT_STRIDE == 8:
aspp_ratios = [12, 24, 36]
else:
aspp_ratios = []
else:
raise Exception("deeplab only support stride 8 or 16")
aspp_ratios = cfg.MODEL.DEEPLAB.ENCODER.ASPP_RATIOS
param_attr = fluid.ParamAttr(
name=name_scope + 'weights',
regularizer=None,
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.06))
concat_logits = []
with scope('encoder'):
channel = 256
channel = cfg.MODEL.DEEPLAB.ENCODER.ASPP_CONVS_FILTERS
with scope("image_pool"):
image_avg = fluid.layers.reduce_mean(input, [2, 3], keep_dim=True)
image_avg = bn_relu(
if not cfg.MODEL.DEEPLAB.ENCODER.POOLING_CROP_SIZE:
image_avg = fluid.layers.reduce_mean(
input, [2, 3], keep_dim=True)
else:
pool_w = int((cfg.MODEL.DEEPLAB.ENCODER.POOLING_CROP_SIZE[0] -
1.0) / cfg.MODEL.DEEPLAB.OUTPUT_STRIDE + 1.0)
pool_h = int((cfg.MODEL.DEEPLAB.ENCODER.POOLING_CROP_SIZE[1] -
1.0) / cfg.MODEL.DEEPLAB.OUTPUT_STRIDE + 1.0)
image_avg = fluid.layers.pool2d(
input,
pool_size=(pool_h, pool_w),
pool_stride=cfg.MODEL.DEEPLAB.ENCODER.POOLING_STRIDE,
pool_type='avg',
pool_padding='VALID')
act = qsigmoid if cfg.MODEL.DEEPLAB.ENCODER.SE_USE_QSIGMOID else bn_relu
image_avg = act(
conv(
image_avg,
channel,
......@@ -60,6 +81,8 @@ def encoder(input):
padding=0,
param_attr=param_attr))
image_avg = fluid.layers.resize_bilinear(image_avg, input.shape[2:])
if cfg.MODEL.DEEPLAB.ENCODER.ADD_IMAGE_LEVEL_FEATURE:
concat_logits.append(image_avg)
with scope("aspp0"):
aspp0 = bn_relu(
......@@ -71,62 +94,154 @@ def encoder(input):
groups=1,
padding=0,
param_attr=param_attr))
with scope("aspp1"):
if cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV:
aspp1 = separate_conv(
input, channel, 1, 3, dilation=aspp_ratios[0], act=relu)
else:
aspp1 = bn_relu(
conv(
input,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[0],
padding=aspp_ratios[0],
param_attr=param_attr))
with scope("aspp2"):
if cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV:
aspp2 = separate_conv(
input, channel, 1, 3, dilation=aspp_ratios[1], act=relu)
else:
aspp2 = bn_relu(
conv(
input,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[1],
padding=aspp_ratios[1],
param_attr=param_attr))
with scope("aspp3"):
if cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV:
aspp3 = separate_conv(
input, channel, 1, 3, dilation=aspp_ratios[2], act=relu)
else:
aspp3 = bn_relu(
concat_logits.append(aspp0)
if aspp_ratios:
with scope("aspp1"):
if cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV:
aspp1 = separate_conv(
input, channel, 1, 3, dilation=aspp_ratios[0], act=relu)
else:
aspp1 = bn_relu(
conv(
input,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[0],
padding=aspp_ratios[0],
param_attr=param_attr))
concat_logits.append(aspp1)
with scope("aspp2"):
if cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV:
aspp2 = separate_conv(
input, channel, 1, 3, dilation=aspp_ratios[1], act=relu)
else:
aspp2 = bn_relu(
conv(
input,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[1],
padding=aspp_ratios[1],
param_attr=param_attr))
concat_logits.append(aspp2)
with scope("aspp3"):
if cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV:
aspp3 = separate_conv(
input, channel, 1, 3, dilation=aspp_ratios[2], act=relu)
else:
aspp3 = bn_relu(
conv(
input,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[2],
padding=aspp_ratios[2],
param_attr=param_attr))
concat_logits.append(aspp3)
with scope("concat"):
data = fluid.layers.concat(concat_logits, axis=1)
if cfg.MODEL.DEEPLAB.ENCODER.ASPP_WITH_CONCAT_PROJECTION:
data = bn_relu(
conv(
input,
data,
channel,
stride=1,
filter_size=3,
dilation=aspp_ratios[2],
padding=aspp_ratios[2],
1,
1,
groups=1,
padding=0,
param_attr=param_attr))
with scope("concat"):
data = fluid.layers.concat([image_avg, aspp0, aspp1, aspp2, aspp3],
axis=1)
data = bn_relu(
data = fluid.layers.dropout(data, 0.9)
if cfg.MODEL.DEEPLAB.ENCODER.ASPP_WITH_SE:
data = data * image_avg
return data
def _decoder_with_sum_merge(encode_data, decode_shortcut, param_attr):
encode_data = fluid.layers.resize_bilinear(encode_data,
decode_shortcut.shape[2:])
encode_data = conv(
encode_data,
cfg.MODEL.DEEPLAB.DECODER.CONV_FILTERS,
1,
1,
groups=1,
padding=0,
param_attr=param_attr)
with scope('merge'):
decode_shortcut = conv(
decode_shortcut,
cfg.MODEL.DEEPLAB.DECODER.CONV_FILTERS,
1,
1,
groups=1,
padding=0,
param_attr=param_attr)
return encode_data + decode_shortcut
def _decoder_with_concat(encode_data, decode_shortcut, param_attr):
with scope('concat'):
decode_shortcut = bn_relu(
conv(
decode_shortcut,
48,
1,
1,
groups=1,
padding=0,
param_attr=param_attr))
encode_data = fluid.layers.resize_bilinear(encode_data,
decode_shortcut.shape[2:])
encode_data = fluid.layers.concat([encode_data, decode_shortcut],
axis=1)
if cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV:
with scope("separable_conv1"):
encode_data = separate_conv(
encode_data,
cfg.MODEL.DEEPLAB.DECODER.CONV_FILTERS,
1,
3,
dilation=1,
act=relu)
with scope("separable_conv2"):
encode_data = separate_conv(
encode_data,
cfg.MODEL.DEEPLAB.DECODER.CONV_FILTERS,
1,
3,
dilation=1,
act=relu)
else:
with scope("decoder_conv1"):
encode_data = bn_relu(
conv(
data,
channel,
1,
1,
groups=1,
padding=0,
encode_data,
cfg.MODEL.DEEPLAB.DECODER.CONV_FILTERS,
stride=1,
filter_size=3,
dilation=1,
padding=1,
param_attr=param_attr))
data = fluid.layers.dropout(data, 0.9)
return data
with scope("decoder_conv2"):
encode_data = bn_relu(
conv(
encode_data,
cfg.MODEL.DEEPLAB.DECODER.CONV_FILTERS,
stride=1,
filter_size=3,
dilation=1,
padding=1,
param_attr=param_attr))
return encode_data
def decoder(encode_data, decode_shortcut):
......@@ -139,61 +254,49 @@ def decoder(encode_data, decode_shortcut):
regularizer=None,
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.06))
with scope('decoder'):
with scope('concat'):
decode_shortcut = bn_relu(
conv(
decode_shortcut,
48,
1,
1,
groups=1,
padding=0,
param_attr=param_attr))
if cfg.MODEL.DEEPLAB.DECODER.USE_SUM_MERGE:
return _decoder_with_sum_merge(encode_data, decode_shortcut,
param_attr)
encode_data = fluid.layers.resize_bilinear(
encode_data, decode_shortcut.shape[2:])
encode_data = fluid.layers.concat([encode_data, decode_shortcut],
axis=1)
if cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV:
with scope("separable_conv1"):
encode_data = separate_conv(
encode_data, 256, 1, 3, dilation=1, act=relu)
with scope("separable_conv2"):
encode_data = separate_conv(
encode_data, 256, 1, 3, dilation=1, act=relu)
else:
with scope("decoder_conv1"):
encode_data = bn_relu(
conv(
encode_data,
256,
stride=1,
filter_size=3,
dilation=1,
padding=1,
param_attr=param_attr))
with scope("decoder_conv2"):
encode_data = bn_relu(
conv(
encode_data,
256,
stride=1,
filter_size=3,
dilation=1,
padding=1,
param_attr=param_attr))
return encode_data
return _decoder_with_concat(encode_data, decode_shortcut, param_attr)
def mobilenet(input):
if 'v3' in cfg.MODEL.DEEPLAB.BACKBONE:
model_name = 'large' if 'large' in cfg.MODEL.DEEPLAB.BACKBONE else 'small'
return _mobilenetv3(input, model_name)
return _mobilenetv2(input)
def mobilenetv2(input):
def _mobilenetv3(input, model_name='large'):
# Backbone: mobilenetv3结构配置
# DEPTH_MULTIPLIER: mobilenetv3的scale设置,默认1.0
# OUTPUT_STRIDE:下采样倍数
scale = cfg.MODEL.DEEPLAB.DEPTH_MULTIPLIER
output_stride = cfg.MODEL.DEEPLAB.OUTPUT_STRIDE
lr_mult_shortcut = cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST
model = mobilenet_v3_backbone(
scale=scale,
output_stride=output_stride,
model_name=model_name,
lr_mult_list=lr_mult_shortcut)
data, decode_shortcut = model.net(input)
return data, decode_shortcut
def _mobilenetv2(input):
# Backbone: mobilenetv2结构配置
# DEPTH_MULTIPLIER: mobilenetv2的scale设置,默认1.0
# OUTPUT_STRIDE:下采样倍数
# end_points: mobilenetv2的block数
# decode_point: 从mobilenetv2中引出分支所在block数, 作为decoder输入
if cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST is not None:
print(
'mobilenetv2 backbone do not support BACKBONE_LR_MULT_LIST setting')
scale = cfg.MODEL.DEEPLAB.DEPTH_MULTIPLIER
output_stride = cfg.MODEL.DEEPLAB.OUTPUT_STRIDE
model = mobilenet_backbone(scale=scale, output_stride=output_stride)
model = mobilenet_v2_backbone(scale=scale, output_stride=output_stride)
end_points = 18
decode_point = 4
data, decode_shortcuts = model.net(
......@@ -270,11 +373,7 @@ def deeplabv3p(img, num_classes):
'xception backbone do not support BACKBONE_LR_MULT_LIST setting'
)
elif 'mobilenet' in cfg.MODEL.DEEPLAB.BACKBONE:
data, decode_shortcut = mobilenetv2(img)
if cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST is not None:
print(
'mobilenetv2 backbone do not support BACKBONE_LR_MULT_LIST setting'
)
data, decode_shortcut = mobilenet(img)
elif 'resnet' in cfg.MODEL.DEEPLAB.BACKBONE:
data, decode_shortcut = resnet_vd(img)
else:
......@@ -294,16 +393,20 @@ def deeplabv3p(img, num_classes):
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0),
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
with scope('logit'):
with fluid.name_scope('last_conv'):
logit = conv(
data,
num_classes,
1,
stride=1,
padding=0,
bias_attr=True,
param_attr=param_attr)
logit = fluid.layers.resize_bilinear(logit, img.shape[2:])
if not cfg.MODEL.DEEPLAB.DECODER.OUTPUT_IS_LOGITS:
with scope('logit'):
with fluid.name_scope('last_conv'):
logit = conv(
data,
num_classes,
1,
stride=1,
padding=0,
bias_attr=True,
param_attr=param_attr)
else:
logit = data
logit = fluid.layers.resize_bilinear(logit, img.shape[2:])
return logit
......@@ -198,17 +198,28 @@ cfg.MODEL.SCALE_LOSS = "DYNAMIC"
cfg.MODEL.DEEPLAB.BACKBONE = "xception_65"
# DeepLab output stride
cfg.MODEL.DEEPLAB.OUTPUT_STRIDE = 16
# MobileNet v2 backbone scale 设置
# MobileNet v2/v3 backbone scale 设置
cfg.MODEL.DEEPLAB.DEPTH_MULTIPLIER = 1.0
# MobileNet v2 backbone scale 设置
# DeepLab Encoder 设置
cfg.MODEL.DEEPLAB.ENCODER_WITH_ASPP = True
# MobileNet v2 backbone scale 设置
cfg.MODEL.DEEPLAB.ENCODER.POOLING_STRIDE = [1, 1]
cfg.MODEL.DEEPLAB.ENCODER.POOLING_CROP_SIZE = None
cfg.MODEL.DEEPLAB.ENCODER.ASPP_WITH_SE = False
cfg.MODEL.DEEPLAB.ENCODER.SE_USE_QSIGMOID = False
cfg.MODEL.DEEPLAB.ENCODER.ASPP_CONVS_FILTERS = 256
cfg.MODEL.DEEPLAB.ENCODER.ASPP_WITH_CONCAT_PROJECTION = True
cfg.MODEL.DEEPLAB.ENCODER.ADD_IMAGE_LEVEL_FEATURE = True
cfg.MODEL.DEEPLAB.ENCODER.ASPP_RATIOS = None
# DeepLab Decoder 设置
cfg.MODEL.DEEPLAB.ENABLE_DECODER = True
cfg.MODEL.DEEPLAB.DECODER.USE_SUM_MERGE = False
cfg.MODEL.DEEPLAB.DECODER.CONV_FILTERS = 256
cfg.MODEL.DEEPLAB.DECODER.OUTPUT_IS_LOGITS = False
# ASPP是否使用可分离卷积
cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV = True
# 解码器是否使用可分离卷积
cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV = True
# resnet_vd分阶段学习率
# Backbone分阶段学习率
cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST = None
########################## UNET模型配置 #######################################
......
......@@ -34,6 +34,10 @@ model_urls = {
"https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_5_pretrained.tar",
"mobilenetv2-0-25_bn_imagenet":
"https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_25_pretrained.tar",
"mobilenetv3-1-0_large_bn_imagenet":
"https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar",
"mobilenetv3-1-0_small_bn_imagenet":
"https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar",
"xception41_imagenet":
"https://paddleseg.bj.bcebos.com/models/Xception41_pretrained.tgz",
"xception65_imagenet":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册