未验证 提交 b23b1096 编写于 作者: Z Zeyu Chen 提交者: GitHub

Merge pull request #15 from pennypm/master

add pspnet
EVAL_CROP_SIZE: (2049, 1025) # (width, height), for unpadding rangescaling and stepscaling
TRAIN_CROP_SIZE: (713, 713) # (width, height), for unpadding rangescaling and stepscaling
AUG:
AUG_METHOD: "stepscaling" # choice unpadding rangescaling and stepscaling
FIX_RESIZE_SIZE: (640, 640) # (width, height), for unpadding
INF_RESIZE_VALUE: 500 # for rangescaling
MAX_RESIZE_VALUE: 600 # for rangescaling
MIN_RESIZE_VALUE: 400 # for rangescaling
MAX_SCALE_FACTOR: 2.0 # for stepscaling
MIN_SCALE_FACTOR: 0.5 # for stepscaling
SCALE_STEP_SIZE: 0.25 # for stepscaling
MIRROR: True
BATCH_SIZE: 4
DATASET:
DATA_DIR: "./dataset/cityscapes/"
IMAGE_TYPE: "rgb" # choice rgb or rgba
NUM_CLASSES: 19
TEST_FILE_LIST: "dataset/cityscapes/val.list"
TRAIN_FILE_LIST: "dataset/cityscapes/train.list"
VAL_FILE_LIST: "dataset/cityscapes/val.list"
IGNORE_INDEX: 255
FREEZE:
MODEL_FILENAME: "model"
PARAMS_FILENAME: "params"
MODEL:
MODEL_NAME: "pspnet"
DEFAULT_NORM_TYPE: "bn"
TEST:
TEST_MODEL: "pretrained_model/pspnet50_ADE20K/"
TRAIN:
MODEL_SAVE_DIR: "snapshots/cityscape_pspnet50/"
PRETRAINED_MODEL_DIR: u"pretrained_model/pspnet50_ADE20K/"
SNAPSHOT_EPOCH: 10
SOLVER:
LR: 0.001
LR_POLICY: "poly"
OPTIMIZER: "sgd"
NUM_EPOCHS: 700
......@@ -374,7 +374,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN):
Args:
crop_img(numpy.ndarray): 输入图像
crop_seg(numpy.ndarray): 标签图
mode(string): 模式, 默认训练模式,验证或预测模式时crop尺寸需大于原始图片尺寸, 其他模式无限制
mode(string): 模式, 默认训练模式,验证或预测、可视化模式时crop尺寸需大于原始图片尺寸
Returns:
裁剪后的图片和标签图
......@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN):
crop_width = cfg.EVAL_CROP_SIZE[0]
crop_height = cfg.EVAL_CROP_SIZE[1]
if ModelPhase.is_eval(mode) or ModelPhase.is_predict(mode):
if not ModelPhase.is_train(mode):
if (crop_height < img_height or crop_width < img_width):
raise Exception(
"Crop size({},{}) must large than img size({},{}) when in EvalPhase."
......
......@@ -85,7 +85,7 @@ class ResNet():
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
if self.stem == 'icnet':
if self.stem == 'icnet' or self.stem == 'pspnet':
conv = self.conv_bn_layer(
input=input,
num_filters=int(64 * self.scale),
......@@ -139,7 +139,7 @@ class ResNet():
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "conv" + str(block + 2) + '_' + str(1 + i)
conv_name = "res" + str(block + 2) + chr(97 + i)
dilation_rate = get_dilated_rate(dilation_dict, block)
conv = self.bottleneck_block(
......@@ -215,6 +215,12 @@ class ResNet():
groups=1,
act=None,
name=None):
if self.stem == 'pspnet':
bias_attr=ParamAttr(name=name + "_biases")
else:
bias_attr=False
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -224,20 +230,21 @@ class ResNet():
dilation=dilation,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "/weights"),
bias_attr=False,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=bias_attr,
name=name + '.conv2d.output.1')
bn_name = name + '/BatchNorm/'
return fluid.layers.batch_norm(
input=conv,
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + 'gamma'),
bias_attr=ParamAttr(bn_name + 'beta'),
moving_mean_name=bn_name + 'moving_mean',
moving_variance_name=bn_name + 'moving_variance',
)
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
......@@ -247,12 +254,17 @@ class ResNet():
return input
def bottleneck_block(self, input, num_filters, stride, name, dilation=1):
if self.stem == 'pspnet' and self.layers == 101:
strides = [1, stride]
else:
strides = [stride, 1]
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
dilation=1,
stride=stride,
stride=strides[0],
act='relu',
name=name + "_branch2a")
if dilation > 1:
......@@ -262,6 +274,7 @@ class ResNet():
num_filters=num_filters,
filter_size=3,
dilation=dilation,
stride=strides[1],
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
......
......@@ -73,6 +73,7 @@ def map_model_name(model_name):
"unet": "unet.unet",
"deeplabv3p": "deeplab.deeplabv3p",
"icnet": "icnet.icnet",
"pspnet": "pspnet.pspnet",
}
if model_name in name_dict.keys():
return name_dict[model_name]
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from models.libs.model_libs import scope, name_scope
from models.libs.model_libs import avg_pool, conv, bn
from models.backbone.resnet import ResNet as resnet_backbone
from utils.config import cfg
def get_logit_interp(input, num_classes, out_shape, name="logit"):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
param_attr = fluid.ParamAttr(
name=name + 'weights',
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0),
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
with scope(name):
logit = conv(input,
num_classes,
filter_size=1,
param_attr=param_attr,
bias_attr=True,
name=name+'_conv')
logit_interp = fluid.layers.resize_bilinear(
logit,
out_shape=out_shape,
name=name+'_interp')
return logit_interp
def psp_module(input, out_features):
# Pyramid Scene Parsing 金字塔池化模块
# 输入:backbone输出的特征
# 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat
# 最后进行一个卷积及BN操作
cat_layers = []
sizes = (1,2,3,6)
for size in sizes:
psp_name = "psp" + str(size)
with scope(psp_name):
pool = fluid.layers.adaptive_pool2d(input,
pool_size=[size, size],
pool_type='avg',
name=psp_name+'_adapool')
data = conv(pool, out_features,
filter_size=1,
bias_attr=True,
name= psp_name + '_conv')
data_bn = bn(data, act='relu')
interp = fluid.layers.resize_bilinear(data_bn,
out_shape=input.shape[2:],
name=psp_name+'_interp')
cat_layers.append(interp)
cat_layers = [input] + cat_layers[::-1]
cat = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
psp_end_name = "psp_end"
with scope(psp_end_name):
data = conv(cat,
out_features,
filter_size=3,
padding=1,
bias_attr=True,
name=psp_end_name)
out = bn(data, act='relu')
return out
def resnet(input):
# PSPNET backbone: resnet, 默认resnet50
# end_points: resnet终止层数
# dilation_dict: resnet block数及对应的膨胀卷积尺度
scale = cfg.MODEL.PSPNET.DEPTH_MULTIPLIER
layers = cfg.MODEL.PSPNET.LAYERS
end_points = layers - 1
dilation_dict = {2:2, 3:4}
model = resnet_backbone(layers, scale, stem='pspnet')
data, _ = model.net(input,
end_points=end_points,
dilation_dict=dilation_dict)
return data
def pspnet(input, num_classes):
# Backbone: ResNet
res = resnet(input)
# PSP模块
psp = psp_module(res, 512)
dropout = fluid.layers.dropout(psp, dropout_prob=0.1, name="dropout")
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
logit = get_logit_interp(dropout, num_classes, input.shape[2:])
return logit
......@@ -196,6 +196,12 @@ cfg.MODEL.ICNET.DEPTH_MULTIPLIER = 0.5
# RESNET 层数 设置
cfg.MODEL.ICNET.LAYERS = 50
########################## PSPNET模型配置 ######################################
# RESNET backbone scale 设置
cfg.MODEL.PSPNET.DEPTH_MULTIPLIER = 1
# RESNET 层数 设置 50或101
cfg.MODEL.PSPNET.LAYERS = 50
########################## 预测部署模型配置 ###################################
# 预测保存的模型名称
cfg.FREEZE.MODEL_FILENAME = '__model__'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册