config.py 8.6 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# -*- coding: utf-8 -*-
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from __future__ import unicode_literals
from utils.collect import SegConfig
import numpy as np

cfg = SegConfig()

########################## 基本配置 ###########################################
# 均值,图像预处理减去的均值
W
wuzewu 已提交
25
cfg.MEAN = [0.5, 0.5, 0.5]
W
wuzewu 已提交
26
# 标准差,图像预处理除以标准差·
W
wuzewu 已提交
27
cfg.STD = [0.5, 0.5, 0.5]
W
wuzewu 已提交
28 29 30 31 32 33
# 批处理大小
cfg.BATCH_SIZE = 1
# 验证时图像裁剪尺寸(宽,高)
cfg.EVAL_CROP_SIZE = tuple()
# 训练时图像裁剪尺寸(宽,高)
cfg.TRAIN_CROP_SIZE = tuple()
34 35 36 37
# 多进程训练总进程数
cfg.NUM_TRAINERS = 1
# 多进程训练进程ID
cfg.TRAINER_ID = 0
W
wuzewu 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
########################## 数据载入配置 #######################################
# 数据载入时的并发数, 建议值8
cfg.DATALOADER.NUM_WORKERS = 8
# 数据载入时缓存队列大小, 建议值256
cfg.DATALOADER.BUF_SIZE = 256

########################## 数据集配置 #########################################
# 数据主目录目录
cfg.DATASET.DATA_DIR = './dataset/cityscapes/'
# 训练集列表
cfg.DATASET.TRAIN_FILE_LIST = './dataset/cityscapes/train.list'
# 训练集数量
cfg.DATASET.TRAIN_TOTAL_IMAGES = 2975
# 验证集列表
cfg.DATASET.VAL_FILE_LIST = './dataset/cityscapes/val.list'
# 验证数据数量
cfg.DATASET.VAL_TOTAL_IMAGES = 500
# 测试数据列表
cfg.DATASET.TEST_FILE_LIST = './dataset/cityscapes/test.list'
# 测试数据数量
cfg.DATASET.TEST_TOTAL_IMAGES = 500
# Tensorboard 可视化的数据集
cfg.DATASET.VIS_FILE_LIST = None
# 类别数(需包括背景类)
cfg.DATASET.NUM_CLASSES = 19
# 输入图像类型, 支持三通道'rgb',四通道'rgba',单通道灰度图'gray'
cfg.DATASET.IMAGE_TYPE = 'rgb'
# 输入图片的通道数
cfg.DATASET.DATA_DIM = 3
# 数据列表分割符, 默认为空格
cfg.DATASET.SEPARATOR = ' '
# 忽略的像素标签值, 默认为255,一般无需改动
cfg.DATASET.IGNORE_INDEX = 255
W
wuzewu 已提交
71 72
# 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5, 127.5, 127.5]
W
wuzewu 已提交
73 74 75 76

########################### 数据增强配置 ######################################
# 图像resize的方式有三种:
# unpadding(固定尺寸),stepscaling(按比例resize),rangescaling(长边对齐)
77 78 79
cfg.AUG.AUG_METHOD = 'unpadding'
# 图像resize的固定尺寸(宽,高),非负
cfg.AUG.FIX_RESIZE_SIZE = (512, 512)
W
wuzewu 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
# 图像resize方式为stepscaling,resize最小尺度,非负
cfg.AUG.MIN_SCALE_FACTOR = 0.5
# 图像resize方式为stepscaling,resize最大尺度,不小于MIN_SCALE_FACTOR
cfg.AUG.MAX_SCALE_FACTOR = 2.0
# 图像resize方式为stepscaling,resize尺度范围间隔,非负
cfg.AUG.SCALE_STEP_SIZE = 0.25
# 图像resize方式为rangescaling,训练时长边resize的范围最小值,非负
cfg.AUG.MIN_RESIZE_VALUE = 400
# 图像resize方式为rangescaling,训练时长边resize的范围最大值,
# 不小于MIN_RESIZE_VALUE
cfg.AUG.MAX_RESIZE_VALUE = 600
# 图像resize方式为rangescaling, 测试验证可视化模式下长边resize的长度,
# 在MIN_RESIZE_VALUE到MAX_RESIZE_VALUE范围内
cfg.AUG.INF_RESIZE_VALUE = 500

95 96 97 98 99 100 101
# 图像镜像左右翻转
cfg.AUG.MIRROR = True
# 图像上下翻转开关,True/False
cfg.AUG.FLIP = False
# 图像启动上下翻转的概率,0-1
cfg.AUG.FLIP_RATIO = 0.5

W
wuzewu 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
# RichCrop数据增广开关,用于提升模型鲁棒性
cfg.AUG.RICH_CROP.ENABLE = False
# 图像旋转最大角度,0-90
cfg.AUG.RICH_CROP.MAX_ROTATION = 15
# 裁取图像与原始图像面积比,0-1
cfg.AUG.RICH_CROP.MIN_AREA_RATIO = 0.5
# 裁取图像宽高比范围,非负
cfg.AUG.RICH_CROP.ASPECT_RATIO = 0.33
# 亮度调节范围,0-1
cfg.AUG.RICH_CROP.BRIGHTNESS_JITTER_RATIO = 0.5
# 饱和度调节范围,0-1
cfg.AUG.RICH_CROP.SATURATION_JITTER_RATIO = 0.5
# 对比度调节范围,0-1
cfg.AUG.RICH_CROP.CONTRAST_JITTER_RATIO = 0.5
# 图像模糊开关,True/False
cfg.AUG.RICH_CROP.BLUR = False
# 图像启动模糊百分比,0-1
cfg.AUG.RICH_CROP.BLUR_RATIO = 0.1

########################### 训练配置 ##########################################
# 模型保存路径
cfg.TRAIN.MODEL_SAVE_DIR = ''
# 预训练模型路径
W
wuzewu 已提交
125
cfg.TRAIN.PRETRAINED_MODEL_DIR = ''
W
wuzewu 已提交
126
# 是否resume,继续训练
W
wuzewu 已提交
127
cfg.TRAIN.RESUME_MODEL_DIR = ''
W
wuzewu 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
# 是否使用多卡间同步BatchNorm均值和方差
cfg.TRAIN.SYNC_BATCH_NORM = False
# 模型参数保存的epoch间隔数,可用来继续训练中断的模型
cfg.TRAIN.SNAPSHOT_EPOCH = 10

########################### 模型优化相关配置 ##################################
# 初始学习率
cfg.SOLVER.LR = 0.1
# 学习率下降方法, 支持poly piecewise cosine 三种
cfg.SOLVER.LR_POLICY = "poly"
# 优化算法, 支持SGD和Adam两种算法
cfg.SOLVER.OPTIMIZER = "sgd"
# 动量参数
cfg.SOLVER.MOMENTUM = 0.9
# 二阶矩估计的指数衰减率
cfg.SOLVER.MOMENTUM2 = 0.999
# 学习率Poly下降指数
cfg.SOLVER.POWER = 0.9
# step下降指数
cfg.SOLVER.GAMMA = 0.1
# step下降间隔
cfg.SOLVER.DECAY_EPOCH = [10, 20]
# 学习率权重衰减,0-1
cfg.SOLVER.WEIGHT_DECAY = 0.00004
# 训练开始epoch数,默认为1
cfg.SOLVER.BEGIN_EPOCH = 1
# 训练epoch数,正整数
cfg.SOLVER.NUM_EPOCHS = 30
W
wuyefeilin 已提交
156 157
# loss的选择,支持softmax_loss, bce_loss, dice_loss
cfg.SOLVER.LOSS = ["softmax_loss"]
F
fuyi02 已提交
158 159 160 161
# 是否开启warmup学习策略 
cfg.SOLVER.LR_WARMUP = False 
# warmup的迭代次数
cfg.SOLVER.LR_WARMUP_STEPS = 2000 
W
wuzewu 已提交
162 163 164 165 166 167

########################## 测试配置 ###########################################
# 测试模型路径
cfg.TEST.TEST_MODEL = ''

########################## 模型通用配置 #######################################
168
# 模型名称, 已支持deeplabv3p, unet, icnet,pspnet,hrnet
W
wuzewu 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181
cfg.MODEL.MODEL_NAME = ''
# BatchNorm类型: bn、gn(group_norm)
cfg.MODEL.DEFAULT_NORM_TYPE = 'bn'
# 多路损失加权值
cfg.MODEL.MULTI_LOSS_WEIGHT = [1.0]
# DEFAULT_NORM_TYPE为gn时group数
cfg.MODEL.DEFAULT_GROUP_NUMBER = 32
# 极小值, 防止分母除0溢出,一般无需改动
cfg.MODEL.DEFAULT_EPSILON = 1e-5
# BatchNorm动量, 一般无需改动
cfg.MODEL.BN_MOMENTUM = 0.99
# 是否使用FP16训练
cfg.MODEL.FP16 = False
182 183
# 混合精度训练需对LOSS进行scale, 默认为动态scale,静态scale可以设置为512.0
cfg.MODEL.SCALE_LOSS = "DYNAMIC"
W
wuzewu 已提交
184 185 186 187 188 189

########################## DeepLab模型配置 ####################################
# DeepLab backbone 配置, 可选项xception_65, mobilenetv2
cfg.MODEL.DEEPLAB.BACKBONE = "xception_65"
# DeepLab output stride
cfg.MODEL.DEEPLAB.OUTPUT_STRIDE = 16
W
wuzewu 已提交
190
# MobileNet v2 backbone scale 设置
W
wuzewu 已提交
191
cfg.MODEL.DEEPLAB.DEPTH_MULTIPLIER = 1.0
W
wuzewu 已提交
192
# MobileNet v2 backbone scale 设置
W
wuzewu 已提交
193
cfg.MODEL.DEEPLAB.ENCODER_WITH_ASPP = True
W
wuzewu 已提交
194
# MobileNet v2 backbone scale 设置
W
wuzewu 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
cfg.MODEL.DEEPLAB.ENABLE_DECODER = True
# ASPP是否使用可分离卷积
cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV = True
# 解码器是否使用可分离卷积
cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV = True

########################## UNET模型配置 #######################################
# 上采样方式, 默认为双线性插值
cfg.MODEL.UNET.UPSAMPLE_MODE = 'bilinear'

########################## ICNET模型配置 ######################################
# RESNET backbone scale 设置
cfg.MODEL.ICNET.DEPTH_MULTIPLIER = 0.5
# RESNET 层数 设置
cfg.MODEL.ICNET.LAYERS = 50

P
pengmian 已提交
211 212 213
########################## PSPNET模型配置 ######################################
# RESNET backbone scale 设置
cfg.MODEL.PSPNET.DEPTH_MULTIPLIER = 1
P
pennypm 已提交
214
# RESNET backbone 层数 设置
P
pengmian 已提交
215 216
cfg.MODEL.PSPNET.LAYERS = 50

W
wuyefeilin 已提交
217 218 219 220 221 222 223 224 225 226 227 228
########################## HRNET模型配置 ######################################
# HRNET STAGE2 设置
cfg.MODEL.HRNET.STAGE2.NUM_MODULES = 1
cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS = [40, 80]
# HRNET STAGE3 设置
cfg.MODEL.HRNET.STAGE3.NUM_MODULES = 4
cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS = [40, 80, 160]
# HRNET STAGE4 设置
cfg.MODEL.HRNET.STAGE4.NUM_MODULES = 3
cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS = [40, 80, 160, 320]


W
wuzewu 已提交
229 230 231 232 233 234 235
########################## 预测部署模型配置 ###################################
# 预测保存的模型名称
cfg.FREEZE.MODEL_FILENAME = '__model__'
# 预测保存的参数名称
cfg.FREEZE.PARAMS_FILENAME = '__params__'
# 预测模型参数保存的路径
cfg.FREEZE.SAVE_DIR = 'freeze_model'