未验证 提交 60960064 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #315 from tianlanshidai/develop

EVAL_CROP_SIZE: (2048, 1024) # (width, height), for unpadding rangescaling and stepscaling
TRAIN_CROP_SIZE: (1024, 512) # (width, height), for unpadding rangescaling and stepscaling
AUG:
# AUG_METHOD: "unpadding" # choice unpadding rangescaling and stepscaling
AUG_METHOD: "stepscaling" # choice unpadding rangescaling and stepscaling
FIX_RESIZE_SIZE: (1024, 512) # (width, height), for unpadding
INF_RESIZE_VALUE: 500 # for rangescaling
MAX_RESIZE_VALUE: 600 # for rangescaling
MIN_RESIZE_VALUE: 400 # for rangescaling
MAX_SCALE_FACTOR: 2.0 # for stepscaling
MIN_SCALE_FACTOR: 0.5 # for stepscaling
SCALE_STEP_SIZE: 0.25 # for stepscaling
MIRROR: True
BATCH_SIZE: 4
#BATCH_SIZE: 4
DATASET:
DATA_DIR: "./dataset/cityscapes/"
IMAGE_TYPE: "rgb" # choice rgb or rgba
NUM_CLASSES: 19
TEST_FILE_LIST: "./dataset/cityscapes/val.list"
TRAIN_FILE_LIST: "./dataset/cityscapes/train.list"
VAL_FILE_LIST: "./dataset/cityscapes/val.list"
VIS_FILE_LIST: "./dataset/cityscapes/val.list"
IGNORE_INDEX: 255
SEPARATOR: " "
FREEZE:
MODEL_FILENAME: "model"
PARAMS_FILENAME: "params"
MODEL:
MODEL_NAME: "ocnet"
DEFAULT_NORM_TYPE: "bn"
HRNET:
STAGE2:
NUM_CHANNELS: [18, 36]
STAGE3:
NUM_CHANNELS: [18, 36, 72]
STAGE4:
NUM_CHANNELS: [18, 36, 72, 144]
OCR:
OCR_MID_CHANNELS: 512
OCR_KEY_CHANNELS: 256
MULTI_LOSS_WEIGHT: [1.0, 1.0]
TRAIN:
PRETRAINED_MODEL_DIR: u"./pretrained_model/ocnet_w18_cityscape/best_model"
MODEL_SAVE_DIR: "output/ocnet_w18_bn_cityscapes"
SNAPSHOT_EPOCH: 1
SYNC_BATCH_NORM: True
TEST:
TEST_MODEL: "output/ocnet_w18_bn_cityscapes/first"
SOLVER:
LR: 0.01
LR_POLICY: "poly"
OPTIMIZER: "sgd"
NUM_EPOCHS: 500
......@@ -66,5 +66,6 @@ train数据集合为Cityscapes训练集合,测试为Cityscapes的验证集合
| PSPNet/bn | Cityscapes |[pspnet101_cityscapes.tgz](https://paddleseg.bj.bcebos.com/models/pspnet101_cityscapes.tgz) |16|false| 0.7734 |
| HRNet_W18/bn | Cityscapes |[hrnet_w18_bn_cityscapes.tgz](https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz) | 4 | false | 0.7936 |
| Fast-SCNN/bn | Cityscapes |[fast_scnn_cityscapes.tar](https://paddleseg.bj.bcebos.com/models/fast_scnn_cityscape.tar) | 32 | false | 0.6964 |
| OCNet/bn | Cityscapes |[ocnet_w18_bn_cityscapes.tar.gz](https://paddleseg.bj.bcebos.com/models/ocnet_w18_bn_cityscapes.tar.gz) | 4 | false | 0.8023 |
测试环境为python 3.7.3,v100,cudnn 7.6.2。
......@@ -26,7 +26,7 @@ from loss import multi_dice_loss
from loss import multi_bce_loss
from lovasz_losses import lovasz_hinge
from lovasz_losses import lovasz_softmax
from models.modeling import deeplab, unet, icnet, pspnet, hrnet, fast_scnn
from models.modeling import deeplab, unet, icnet, pspnet, hrnet, fast_scnn,ocnet
class ModelPhase(object):
......@@ -85,6 +85,8 @@ def seg_model(image, class_num):
logits = hrnet.hrnet(image, class_num)
elif model_name == 'fast_scnn':
logits = fast_scnn.fast_scnn(image, class_num)
elif model_name == 'ocnet':
logits = ocnet.ocnet(image, class_num)
else:
raise Exception(
"unknow model name, only support unet, deeplabv3p, icnet, pspnet, hrnet, fast_scnn"
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from utils.config import cfg
def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_groups=1, if_act=True, name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
act=None,
# param_attr=ParamAttr(initializer=MSRA(), learning_rate=1.0, name=name + '_weights'),
param_attr=ParamAttr(initializer=fluid.initializer.Normal(scale=0.001), learning_rate=1.0, name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(input=conv,
param_attr=ParamAttr(name=bn_name + "_scale",
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
bn = fluid.layers.relu(bn)
return bn
def basic_block(input, num_filters, stride=1, downsample=False, name=None):
residual = input
conv = conv_bn_layer(input=input, filter_size=3, num_filters=num_filters, stride=stride, name=name + '_conv1')
conv = conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, if_act=False, name=name + '_conv2')
if downsample:
residual = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def bottleneck_block(input, num_filters, stride=1, downsample=False, name=None):
residual = input
conv = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, name=name + '_conv1')
conv = conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, stride=stride, name=name + '_conv2')
conv = conv_bn_layer(input=conv, filter_size=1, num_filters=num_filters * 4, if_act=False,
name=name + '_conv3')
if downsample:
residual = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters * 4, if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def fuse_layers(x, channels, multi_scale_output=True, name=None):
out = []
for i in range(len(channels) if multi_scale_output else 1):
residual = x[i]
shape = residual.shape
width = shape[-1]
height = shape[-2]
for j in range(len(channels)):
if j > i:
y = conv_bn_layer(x[j], filter_size=1, num_filters=channels[i], if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_bilinear(input=y, out_shape=[height, width])
residual = fluid.layers.elementwise_add(x=residual, y=y, act=None)
elif j < i:
y = x[j]
for k in range(i - j):
if k == i - j - 1:
y = conv_bn_layer(y, filter_size=3, num_filters=channels[i], stride=2, if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1) + '_' + str(k + 1))
else:
y = conv_bn_layer(y, filter_size=3, num_filters=channels[j], stride=2,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(x=residual, y=y, act=None)
residual = fluid.layers.relu(residual)
out.append(residual)
return out
def branches(x, block_num, channels, name=None):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num):
residual = basic_block(residual, channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' + str(j + 1))
out.append(residual)
return out
def high_resolution_module(x, channels, multi_scale_output=True, name=None):
residual = branches(x, 4, channels, name=name)
out = fuse_layers(residual, channels, multi_scale_output=multi_scale_output, name=name)
return out
def transition_layer(x, in_channels, out_channels, name=None):
num_in = len(in_channels)
num_out = len(out_channels)
out = []
for i in range(num_out):
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = conv_bn_layer(x[i], filter_size=3, num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual)
else:
out.append(x[i])
else:
residual = conv_bn_layer(x[-1], filter_size=3, num_filters=out_channels[i], stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual)
return out
def stage(x, num_modules, channels, multi_scale_output=True, name=None):
out = x
for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False:
out = high_resolution_module(out, channels, multi_scale_output=False, name=name + '_' + str(i + 1))
else:
out = high_resolution_module(out, channels, name=name + '_' + str(i + 1))
return out
def layer1(input, name=None):
conv = input
for i in range(4):
conv = bottleneck_block(conv, num_filters=64, downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv
def aux_head(input, last_inp_channels, num_classes):
x = conv_bn_layer(input=input, filter_size=1, num_filters=last_inp_channels, stride=1, padding=0, name='aux_head_conv1')
x = fluid.layers.conv2d(
input=x,
num_filters=num_classes,
filter_size=1,
stride=1,
padding=0,
act=None,
# param_attr=ParamAttr(initializer=MSRA(), learning_rate=1.0, name='aux_head_conv2_weights'),
param_attr=ParamAttr(initializer=fluid.initializer.Normal(scale=0.001), learning_rate=1.0, name='aux_head_conv2_weights'),
bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0), name="aux_head_conv2_bias")
)
return x
def conv3x3_ocr(input, ocr_mid_channels):
x = conv_bn_layer(input=input, filter_size=3, num_filters=ocr_mid_channels, stride=1, padding=1, name='conv3x3_ocr')
return x
def f_pixel(input, key_channels):
x = conv_bn_layer(input=input, filter_size=1, num_filters=key_channels, stride=1, padding=0, name='f_pixel_conv1')
x = conv_bn_layer(input=x, filter_size=1, num_filters=key_channels, stride=1, padding=0, name='f_pixel_conv2')
return x
def f_object(input, key_channels):
x = conv_bn_layer(input=input, filter_size=1, num_filters=key_channels, stride=1, padding=0, name='f_object_conv1')
x = conv_bn_layer(input=x, filter_size=1, num_filters=key_channels, stride=1, padding=0, name='f_object_conv2')
return x
def f_down(input, key_channels):
x = conv_bn_layer(input=input, filter_size=1, num_filters=key_channels, stride=1, padding=0, name='f_down_conv')
return x
def f_up(input, in_channels):
x = conv_bn_layer(input=input, filter_size=1, num_filters=in_channels, stride=1, padding=0, name='f_up_conv')
return x
def object_context_block(x, proxy, in_channels, key_channels, scale):
batch_size, _, h, w = x.shape
if scale > 1:
x = fluid.layers.pool2d(x, pool_size=[scale, scale], pool_type='max')
query = f_pixel(x, key_channels)
query = fluid.layers.reshape(query, shape=[batch_size, key_channels, query.shape[2]*query.shape[3]])
query = fluid.layers.transpose(query, perm=[0, 2, 1])
key = f_object(proxy, key_channels)
key = fluid.layers.reshape(key, shape=[batch_size, key_channels, key.shape[2]*key.shape[3]])
value = f_down(proxy, key_channels)
value = fluid.layers.reshape(value, shape=[batch_size, key_channels, value.shape[2]*value.shape[3]])
value = fluid.layers.transpose(value, perm=[0, 2, 1])
sim_map = fluid.layers.matmul(query, key)
sim_map = (key_channels**-.5) * sim_map
sim_map = fluid.layers.softmax(sim_map, axis=-1)
context = fluid.layers.matmul(sim_map, value)
context = fluid.layers.transpose(context, perm=[0, 2, 1])
context = fluid.layers.reshape(context, shape=[batch_size, key_channels, x.shape[2], x.shape[3]])
context = f_up(context, in_channels)
if scale > 1:
context = fluid.layers.resize_bilinear(context, out_shape=[h, w])
return context
def ocr_gather_head(feats, probs, scale=1):
feats = fluid.layers.reshape(feats, shape=[feats.shape[0], feats.shape[1], feats.shape[2]*feats.shape[3]])
feats = fluid.layers.transpose(feats, perm=[0, 2, 1])
probs = fluid.layers.reshape(probs, shape=[probs.shape[0], probs.shape[1], probs.shape[2]*probs.shape[3]])
probs = fluid.layers.softmax(scale * probs, axis=2)
ocr_context = fluid.layers.matmul(probs, feats)
ocr_context = fluid.layers.transpose(ocr_context, perm=[0, 2, 1])
ocr_context = fluid.layers.unsqueeze(ocr_context, axes=[3])
return ocr_context
def ocr_distri_head(feats, proxy_feats, ocr_mid_channels, ocr_key_channels, scale=1, dropout=0.05):
context = object_context_block(feats, proxy_feats, ocr_mid_channels, ocr_key_channels, scale)
x = fluid.layers.concat([context, feats], axis=1)
x = conv_bn_layer(input=x, filter_size=1, num_filters=ocr_mid_channels, stride=1, padding=0, name='spatial_ocr_conv')
x = fluid.layers.dropout(x, dropout_prob=dropout)
return x
def cls_head(input, num_classes):
x = fluid.layers.conv2d(
input=input,
num_filters=num_classes,
filter_size=1,
stride=1,
padding=0,
act=None,
# param_attr=ParamAttr(initializer=MSRA(), learning_rate=1.0, name='cls_head_conv_weights'),
param_attr=ParamAttr(initializer=fluid.initializer.Normal(scale=0.001), learning_rate=1.0, name='cls_head_conv_weights'),
bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0), name="cls_head_conv_bias")
)
return x
def ocr_module(input, last_inp_channels, num_classes, ocr_mid_channels, ocr_key_channels):
out_aux = aux_head(input, last_inp_channels, num_classes)
feats = conv3x3_ocr(input, ocr_mid_channels)
context = ocr_gather_head(feats, out_aux)
feats = ocr_distri_head(feats, context, ocr_mid_channels, ocr_key_channels)
out = cls_head(feats, num_classes)
return out, out_aux
def high_resolution_ocr_net(input, num_classes):
channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS
channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS
channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS
num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES
num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES
num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES
ocr_mid_channels = cfg.MODEL.OCR.OCR_MID_CHANNELS
ocr_key_channels = cfg.MODEL.OCR.OCR_KEY_CHANNELS
last_inp_channels = sum(channels_4)
x = conv_bn_layer(input=input, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_1')
x = conv_bn_layer(input=x, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_2')
la1 = layer1(x, name='layer2')
tr1 = transition_layer([la1], [256], channels_2, name='tr1')
st2 = stage(tr1, num_modules_2, channels_2, name='st2')
tr2 = transition_layer(st2, channels_2, channels_3, name='tr2')
st3 = stage(tr2, num_modules_3, channels_3, name='st3')
tr3 = transition_layer(st3, channels_3, channels_4, name='tr3')
st4 = stage(tr3, num_modules_4, channels_4, name='st4')
# upsample
shape = st4[0].shape
height, width = shape[-2], shape[-1]
st4[1] = fluid.layers.resize_bilinear(
st4[1], out_shape=[height, width])
st4[2] = fluid.layers.resize_bilinear(
st4[2], out_shape=[height, width])
st4[3] = fluid.layers.resize_bilinear(
st4[3], out_shape=[height, width])
feats = fluid.layers.concat(st4, axis=1)
out, out_aux = ocr_module(feats, last_inp_channels, num_classes, ocr_mid_channels, ocr_key_channels)
out = fluid.layers.resize_bilinear(out, input.shape[2:])
out_aux = fluid.layers.resize_bilinear(out_aux, input.shape[2:])
return out, out_aux
def ocnet(input, num_classes):
logit = high_resolution_ocr_net(input, num_classes)
return logit
if __name__ == '__main__':
image_shape = [-1, 3, 769, 769]
image = fluid.data(name='image', shape=image_shape, dtype='float32')
logit = ocnet(image, 4)
print("logit:", logit.shape)
......@@ -248,7 +248,10 @@ cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS = [40, 80, 160]
# HRNET STAGE4 设置
cfg.MODEL.HRNET.STAGE4.NUM_MODULES = 3
cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS = [40, 80, 160, 320]
########################## OCNET模型配置 ######################################
cfg.MODEL.OCR.OCR_MID_CHANNELS = 512
cfg.MODEL.OCR.OCR_KEY_CHANNELS = 256
########################## 预测部署模型配置 ###################################
# 预测保存的模型名称
cfg.FREEZE.MODEL_FILENAME = '__model__'
......
......@@ -92,6 +92,8 @@ model_urls = {
"https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz",
"fast_scnn_cityscapes":
"https://paddleseg.bj.bcebos.com/models/fast_scnn_cityscape.tar",
"ocnet_w18_bn_cityscapes":
"https://paddleseg.bj.bcebos.com/models/ocnet_w18_bn_cityscapes.tar.gz",
}
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册