未验证 提交 fcb854f7 编写于 作者: L LutaoChu 提交者: GitHub

Optimize lovasz loss implementation, support multiple loss weighted combination (#215)

* modify label tool

* add lovasz loss

* lovasz

* model builder

* test

* lovasz

* update lovasz

* update lovasz

* add sh yaml

* add sh yaml

* add sh yaml

* no per_image and device_guard

* no per_image and device_guard

* Update lovasz_loss.md

* Update lovasz_loss.md

* Update lovasz_loss.md

* Update lovasz_loss.md

* Update lovasz_loss.md

* elementwise_mul replace matmul
上级 1dbbce6c
EVAL_CROP_SIZE: (1025, 1025) # (width, height), for unpadding rangescaling and stepscaling
TRAIN_CROP_SIZE: (769, 769) # (width, height), for unpadding rangescaling and stepscaling
AUG:
AUG_METHOD: u"stepscaling" # choice unpadding rangescaling and stepscaling
FIX_RESIZE_SIZE: (640, 640) # (width, height), for unpadding
INF_RESIZE_VALUE: 500 # for rangescaling
MAX_RESIZE_VALUE: 600 # for rangescaling
MIN_RESIZE_VALUE: 400 # for rangescaling
MAX_SCALE_FACTOR: 2.0 # for stepscaling
MIN_SCALE_FACTOR: 0.5 # for stepscaling
SCALE_STEP_SIZE: 0.25 # for stepscaling
FLIP: True
BATCH_SIZE: 24
DATASET:
DATA_DIR: "./dataset/MiniDeepGlobeRoadExtraction/"
IMAGE_TYPE: "rgb" # choice rgb or rgba
NUM_CLASSES: 2
TEST_FILE_LIST: "dataset/MiniDeepGlobeRoadExtraction/val.txt"
TRAIN_FILE_LIST: "dataset/MiniDeepGlobeRoadExtraction/train.txt"
VAL_FILE_LIST: "dataset/MiniDeepGlobeRoadExtraction/val.txt"
IGNORE_INDEX: 255
SEPARATOR: '|'
FREEZE:
MODEL_FILENAME: "model"
PARAMS_FILENAME: "params"
SAVE_DIR: "freeze_model"
MODEL:
DEFAULT_NORM_TYPE: "bn"
MODEL_NAME: "deeplabv3p"
DEEPLAB:
BACKBONE: "mobilenetv2"
DEPTH_MULTIPLIER: 1.0
ENCODER_WITH_ASPP: False
ENABLE_DECODER: False
TEST:
TEST_MODEL: "./saved_model/lovasz_hinge_deeplabv3p_mobilenet_road/final"
TRAIN:
MODEL_SAVE_DIR: "./saved_model/lovasz_hinge_deeplabv3p_mobilenet_road/"
PRETRAINED_MODEL_DIR: "./pretrained_model/deeplabv3p_mobilenetv2-1-0_bn_coco/"
SNAPSHOT_EPOCH: 10
SOLVER:
LR: 0.1
LR_POLICY: "poly"
OPTIMIZER: "sgd"
NUM_EPOCHS: 300
LOSS: ["lovasz_hinge_loss","bce_loss"]
LOSS_WEIGHT:
LOVASZ_HINGE_LOSS: 0.5
BCE_LOSS: 0.5
TRAIN_CROP_SIZE: (500, 500) # (width, height), for unpadding rangescaling and stepscaling #训练时图像裁剪尺寸(宽,高)
EVAL_CROP_SIZE: (500, 500) # (width, height), for unpadding rangescaling and stepscaling #验证时图像裁剪尺寸(宽,高)
AUG:
AUG_METHOD: "stepscaling" # choice unpadding rangescaling and stepscaling
FIX_RESIZE_SIZE: (500, 500) # (width, height), for unpadding
INF_RESIZE_VALUE: 500 # for rangescaling
MAX_RESIZE_VALUE: 600 # for rangescaling
MIN_RESIZE_VALUE: 400 # for rangescaling
MAX_SCALE_FACTOR: 1.25 # for stepscaling
MIN_SCALE_FACTOR: 0.75 # for stepscaling
SCALE_STEP_SIZE: 0.05 # for stepscaling
MIRROR: True
FLIP: True
BATCH_SIZE: 16 #批处理大小
DATASET:
DATA_DIR: "./dataset/VOCtrainval_11-May-2012/VOC2012/" #图片路径
IMAGE_TYPE: "rgb" # choice rgb or rgba #图片类别“RGB”
NUM_CLASSES: 21 #类别数(包括背景类别)
TEST_FILE_LIST: "dataset/VOCtrainval_11-May-2012/VOC2012/ImageSets/Segmentation/val.list"
TRAIN_FILE_LIST: "dataset/VOCtrainval_11-May-2012/VOC2012/ImageSets/Segmentation/train.list"
VAL_FILE_LIST: "dataset/VOCtrainval_11-May-2012/VOC2012/ImageSets/Segmentation/val.list"
IGNORE_INDEX: 255
SEPARATOR: " "
MODEL:
MODEL_NAME: "deeplabv3p"
DEFAULT_NORM_TYPE: "bn" #指定norm的类型,此处提供bn和gn(默认)两种选择,分别指batch norm和group norm。
DEEPLAB:
BACKBONE: "mobilenetv2"
DEPTH_MULTIPLIER: 1.0
ENCODER_WITH_ASPP: False
ENABLE_DECODER: False
TRAIN:
PRETRAINED_MODEL_DIR: "./pretrained_model/deeplabv3p_mobilenetv2-1-0_bn_coco/"
MODEL_SAVE_DIR: "./saved_model/lovasz-softmax-voc" #模型保存路径
SNAPSHOT_EPOCH: 10
TEST:
TEST_MODEL: "./saved_model/lovasz-softmax-voc/final" #为测试模型路径
SOLVER:
NUM_EPOCHS: 100 #训练epoch数,正整数
LR: 0.0001 #初始学习率
LR_POLICY: "poly" #学习率下降方法, 选项为poly、piecewise和cosine
OPTIMIZER: "sgd" #优化算法, 选项为sgd和adam
LOSS: ["lovasz_softmax_loss","softmax_loss"]
LOSS_WEIGHT:
LOVASZ_SOFTMAX_LOSS: 0.2
SOFTMAX_LOSS: 0.8
# Lovasz loss
对于图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。
我们可使用lovasz loss解决这个问题。Lovasz loss根据分割目标的类别数量可分为两种:lovasz hinge loss适用于二分类问题,lovasz softmax loss适用于多分类问题。
## Lovasz hinge loss
### 使用方式
PaddleSeg通过`cfg.SOLVER.LOSS`参数可以选择训练时的损失函数,
`cfg.SOLVER.LOSS=['lovasz_hinge_loss','bce_loss']`将指定训练loss为`lovasz hinge loss``bce loss`的组合。
Lovasz hinge loss有3种使用方式:(1)直接训练使用。(2)bce loss结合使用。(3)先使用bec loss进行训练,再使用lovasz hinge loss进行finetuning. 第1种方式不一定达到理想效果,推荐使用后两种方式。本文以第2种方式为例。
### 使用示例
我们以道路提取任务为例应用lovasz hinge loss.
在DeepGlobe比赛的Road Extraction中,训练数据道路占比为:4.5%. 如下为其图片样例:
<p align="center">
<img src="./imgs/deepglobe.png" hspace='10'/> <br />
</p>
可以看出道路在整张图片中的比例很小。
#### 实验对比
在MiniDeepGlobeRoadExtraction数据集进行了实验对比。
* 数据集下载
我们从DeepGlobe比赛的Road Extraction的训练集中随机抽取了800张图片作为训练集,200张图片作为验证集,
制作了一个小型的道路提取数据集[MiniDeepGlobeRoadExtraction](https://paddleseg.bj.bcebos.com/dataset/MiniDeepGlobeRoadExtraction.zip)
```shell
python dataset/download_mini_deepglobe_road_extraction.py
```
* 预训练模型下载
```shell
python pretrained_model/download_model.py deeplabv3p_mobilenetv2-1-0_bn_coco
```
* 配置/数据校验
```shell
python pdseg/check.py --cfg ./configs/lovasz_hinge_deeplabv3p_mobilenet_road.yaml
```
* 训练
```shell
python pdseg/train.py --cfg ./configs/lovasz_hinge_deeplabv3p_mobilenet_road.yaml --use_gpu --use_mpio SOLVER.LOSS "['lovasz_hinge_loss','bce_loss']"
```
* 评估
```shell
python pdseg/eval.py --cfg ./configs/lovasz_hinge_deeplabv3p_mobilenet_road.yaml --use_gpu --use_mpio SOLVER.LOSS "['lovasz_hinge_loss','bce_loss']"
```
* 结果比较
lovasz hinge loss + bce loss和softmax loss的对比结果如下图所示。
<p align="center">
<img src="./imgs/lovasz-hinge.png" hspace='10'/> <br />
</p>
图中蓝色曲线为lovasz hinge loss + bce loss,最高mIoU为76.2%,橙色曲线为softmax loss, 最高mIoU为73.44%,相比提升2.76个百分点。
## Lovasz softmax loss
### 使用方式
PaddleSeg通过`cfg.SOLVER.LOSS`参数可以选择训练时的损失函数,
`cfg.SOLVER.LOSS=['lovasz_softmax_loss','softmax_loss']`将指定训练loss为`lovasz softmax loss``softmax loss`的组合。
Lovasz softmax loss有3种使用方式:(1)直接训练使用。(2)softmax loss结合使用。(3)先使用softmax loss进行训练,再使用lovasz softmax loss进行finetuning. 第1种方式不一定达到理想效果,推荐使用后两种方式。本文以第2种方式为例。
### 使用示例
我们以Pascal voc为例应用lovasz softmax loss.
#### 实验对比
在Pascal voc数据集上与softmax loss进行了实验对比。
* 数据集下载
```shell
python dataset/download_and_convert_voc2012.py
```
* 预训练模型下载
```shell
python pretrained_model/download_model.py deeplabv3p_mobilenetv2-1-0_bn_coco
```
* 配置/数据校验
```shell
python pdseg/check.py --cfg ./configs/lovasz_softmax_deeplabv3p_mobilenet_pascal.yaml
```
* 训练
```shell
python pdseg/train.py --cfg ./configs/lovasz_softmax_deeplabv3p_mobilenet_pascal.yaml --use_gpu --use_mpio SOLVER.LOSS "['lovasz_softmax_loss','softmax_loss']"
```
* 评估
```shell
python pdseg/eval.py --cfg ./configs/lovasz_softmax_deeplabv3p_mobilenet_pascal.yaml --use_gpu --use_mpio SOLVER.LOSS "['lovasz_softmax_loss','softmax_loss']"
```
* 结果比较
lovasz softmax loss + softmax loss和softmax loss的对比结果如下图所示。
<p align="center">
<img src="./imgs/lovasz-softmax.png" hspace='10' /> <br />
</p>
图中橙色曲线代表lovasz softmax loss + softmax loss,最高mIoU为64.63%,蓝色曲线代表softmax loss, 最高mIoU为63.55%,相比提升1.08个百分点。
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Lovasz-Softmax and Jaccard hinge loss in PaddlePaddle"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import numpy as np
def _cumsum(x):
y = np.array(x)
return np.cumsum(y, axis=0)
def create_tmp_var(name, dtype, shape):
return fluid.default_main_program().current_block().create_var(
name=name, dtype=dtype, shape=shape)
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
gt_sorted = fluid.layers.squeeze(gt_sorted, axes=[1])
gts = fluid.layers.reduce_sum(gt_sorted)
len_gt = fluid.layers.shape(gt_sorted)
# Acceleration is achieved by reducing the number of calls to cumsum.
# This calculation method is equivalent to that of the original paper.
var_one = fluid.layers.fill_constant(shape=[1], value=1, dtype='int32')
range_ = fluid.layers.range(1, len_gt + var_one, 1, 'int32')
tmp_var = create_tmp_var(
name='tmp_var', dtype=gt_sorted.dtype, shape=gt_sorted.shape)
cumsum_ = fluid.layers.py_func(func=_cumsum, x=gt_sorted, out=tmp_var)
intersection = gts - cumsum_
union = intersection + range_
jaccard = 1.0 - intersection / union
jaccard0 = fluid.layers.slice(jaccard, axes=[0], starts=[0], ends=[1])
jaccard1 = fluid.layers.slice(jaccard, axes=[0], starts=[1], ends=[len_gt])
jaccard2 = fluid.layers.slice(jaccard, axes=[0], starts=[0], ends=[-1])
jaccard = fluid.layers.concat([jaccard0, jaccard1 - jaccard2], axis=0)
jaccard = fluid.layers.unsqueeze(jaccard, axes=[1])
return jaccard
def lovasz_hinge(logits, labels, ignore=None):
"""
Binary Lovasz hinge loss
logits: [N, C, H, W] Tensor, logits at each pixel (between -\infty and +\infty)
labels: [N, 1, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: [N, 1, H, W] Tensor. Void class labels, ignore pixels which value=0
"""
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Tensor, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
"""
shape = fluid.layers.shape(logits)
y = fluid.layers.zeros_like(shape[0])
out_var = fluid.layers.create_tensor("float32")
with fluid.layers.control_flow.Switch() as switch:
with switch.case(fluid.layers.equal(shape[0], y)):
loss = fluid.layers.reduce_sum(logits) * 0.
fluid.layers.assign(input=loss, output=out_var)
with switch.case(fluid.layers.greater_than(shape[0], y)):
labelsf = fluid.layers.cast(labels, logits.dtype)
signs = labelsf * 2 - 1.
signs.stop_gradient = True
errors = 1.0 - fluid.layers.elementwise_mul(logits, signs)
errors_sorted, perm = fluid.layers.argsort(
errors, axis=0, descending=True)
errors_sorted.stop_gradient = False
gt_sorted = fluid.layers.gather(labelsf, perm)
grad = lovasz_grad(gt_sorted)
grad.stop_gradient = True
loss = fluid.layers.reduce_sum(
fluid.layers.relu(errors_sorted) * grad)
fluid.layers.assign(input=loss, output=out_var)
return out_var
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels according to 'ignore'
"""
scores = fluid.layers.reshape(scores, [-1, 1])
labels = fluid.layers.reshape(labels, [-1, 1])
labels.stop_gradient = True
if ignore is None:
return scores, labels
ignore = fluid.layers.cast(ignore, 'int32')
ignore_mask = fluid.layers.reshape(ignore, (-1, 1))
indexs = fluid.layers.where(ignore_mask == 1)
indexs.stop_gradient = True
vscores = fluid.layers.gather(scores, indexs[:, 0])
vlabels = fluid.layers.gather(labels, indexs[:, 0])
return vscores, vlabels
def lovasz_softmax(probas, labels, classes='present', ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [N, C, H, W] Tensor, class probabilities at each prediction (between 0 and 1).
labels: [N, 1, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
ignore: [N, 1, H, W] Tensor. Void class labels, ignore pixels which value=0
"""
vprobas, vlabels = flatten_probas(probas, labels, ignore)
loss = lovasz_softmax_flat(vprobas, vlabels, classes=classes)
return loss
def lovasz_softmax_flat(probas, labels, classes='present'):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Tensor, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
C = probas.shape[1]
losses = []
present = []
classes_to_sum = list(range(C)) if classes in ['all', 'present'
] else classes
for c in classes_to_sum:
fg = fluid.layers.cast(labels == c, probas.dtype)
fg.stop_gradient = True
if classes == 'present':
present.append(
fluid.layers.cast(fluid.layers.reduce_sum(fg) > 0, "int64"))
if C == 1:
if len(classes_to_sum) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = fluid.layers.abs(fg - class_pred)
errors_sorted, perm = fluid.layers.argsort(
errors, axis=0, descending=True)
errors_sorted.stop_gradient = False
fg_sorted = fluid.layers.gather(fg, perm)
fg_sorted.stop_gradient = True
grad = lovasz_grad(fg_sorted)
grad.stop_gradient = True
loss = fluid.layers.reduce_sum(errors_sorted * grad)
losses.append(loss)
if len(classes_to_sum) == 1:
return losses[0]
losses_tensor = fluid.layers.stack(losses)
if classes == 'present':
present_tensor = fluid.layers.stack(present)
index = fluid.layers.where(present_tensor == 1)
index.stop_gradient = True
losses_tensor = fluid.layers.gather(losses_tensor, index[:, 0])
loss = fluid.layers.mean(losses_tensor)
return loss
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch
"""
if len(probas.shape) == 3:
probas = fluid.layers.unsqueeze(probas, axis=[1])
C = probas.shape[1]
probas = fluid.layers.transpose(probas, [0, 2, 3, 1])
probas = fluid.layers.reshape(probas, [-1, C])
labels = fluid.layers.reshape(labels, [-1, 1])
if ignore is None:
return probas, labels
ignore = fluid.layers.cast(ignore, 'int32')
ignore_mask = fluid.layers.reshape(ignore, [-1, 1])
indexs = fluid.layers.where(ignore_mask == 1)
indexs.stop_gradient = True
vprobas = fluid.layers.gather(probas, indexs[:, 0])
vlabels = fluid.layers.gather(labels, indexs[:, 0])
return vprobas, vlabels
......@@ -24,6 +24,8 @@ from utils.config import cfg
from loss import multi_softmax_with_loss
from loss import multi_dice_loss
from loss import multi_bce_loss
from lovasz_losses import lovasz_hinge
from lovasz_losses import lovasz_softmax
from models.modeling import deeplab, unet, icnet, pspnet, hrnet, fast_scnn
......@@ -204,19 +206,22 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
if not isinstance(loss_type, list):
loss_type = list(loss_type)
# dice_loss或bce_loss只适用两类分割中
if class_num > 2 and (("dice_loss" in loss_type) or
# lovasz_hinge_loss或dice_loss或bce_loss只适用两类分割中
if class_num > 2 and (("lovasz_hinge_loss" in loss_type) or
("dice_loss" in loss_type) or
("bce_loss" in loss_type)):
raise Exception(
"dice loss and bce loss is only applicable to binary classfication"
"lovasz hinge loss, dice loss and bce loss are only applicable to binary classfication."
)
# 在两类分割情况下,当loss函数选择dice_loss或bce_loss的时候,最后logit输出通道数设置为1
if ("dice_loss" in loss_type) or ("bce_loss" in loss_type):
# 在两类分割情况下,当loss函数选择lovasz_hinge_loss或dice_loss或bce_loss的时候,最后logit输出通道数设置为1
if ("dice_loss" in loss_type) or ("bce_loss" in loss_type) or (
"lovasz_hinge_loss" in loss_type):
class_num = 1
if "softmax_loss" in loss_type:
if ("softmax_loss" in loss_type) or (
"lovasz_softmax_loss" in loss_type):
raise Exception(
"softmax loss can not combine with dice loss or bce loss"
"softmax loss or lovasz softmax loss can not combine with bce loss or dice loss or lovasz hinge loss."
)
logits = seg_model(image, class_num)
......@@ -240,11 +245,22 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
avg_loss_list.append(multi_bce_loss(logits, label, mask))
loss_valid = True
valid_loss.append("bce_loss")
if "lovasz_hinge_loss" in loss_type:
avg_loss_list.append(
lovasz_hinge(logits, label, ignore=mask))
loss_valid = True
valid_loss.append("lovasz_hinge_loss")
if "lovasz_softmax_loss" in loss_type:
probas = fluid.layers.softmax(logits, axis=1)
avg_loss_list.append(
lovasz_softmax(probas, label, ignore=mask))
loss_valid = True
valid_loss.append("lovasz_softmax_loss")
if not loss_valid:
raise Exception(
"SOLVER.LOSS: {} is set wrong. it should "
"include one of (softmax_loss, bce_loss, dice_loss) at least"
" example: ['softmax_loss'], ['dice_loss'], ['bce_loss', 'dice_loss']"
"include one of (softmax_loss, bce_loss, dice_loss, lovasz_hinge_loss, lovasz_softmax_loss) at least"
" example: ['softmax_loss'], ['dice_loss'], ['bce_loss', 'dice_loss'], ['lovasz_hinge_loss','bce_loss'], ['lovasz_softmax_loss','softmax_loss']"
.format(cfg.SOLVER.LOSS))
invalid_loss = [x for x in loss_type if x not in valid_loss]
......@@ -255,7 +271,9 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
avg_loss = 0
for i in range(0, len(avg_loss_list)):
avg_loss += avg_loss_list[i]
loss_name = valid_loss[i].upper()
loss_weight = eval('cfg.SOLVER.LOSS_WEIGHT.' + loss_name)
avg_loss += loss_weight * avg_loss_list[i]
#get pred result in original size
if isinstance(logits, tuple):
......@@ -268,7 +286,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
# return image input and logit output for inference graph prune
if ModelPhase.is_predict(phase):
# 两类分割中,使用dice_loss或bce_loss返回的logit为单通道,进行到两通道的变换
# 两类分割中,使用lovasz_hinge_loss或dice_loss或bce_loss返回的logit为单通道,进行到两通道的变换
if class_num == 1:
logit = sigmoid_to_softmax(logit)
else:
......
......@@ -155,6 +155,12 @@ cfg.SOLVER.BEGIN_EPOCH = 1
cfg.SOLVER.NUM_EPOCHS = 30
# loss的选择,支持softmax_loss, bce_loss, dice_loss
cfg.SOLVER.LOSS = ["softmax_loss"]
# loss的权重,用于多loss组合加权使用,仅对SOLVER.LOSS内包含的loss生效
cfg.SOLVER.LOSS_WEIGHT.SOFTMAX_LOSS = 1
cfg.SOLVER.LOSS_WEIGHT.DICE_LOSS = 1
cfg.SOLVER.LOSS_WEIGHT.BCE_LOSS = 1
cfg.SOLVER.LOSS_WEIGHT.LOVASZ_HINGE_LOSS = 1
cfg.SOLVER.LOSS_WEIGHT.LOVASZ_SOFTMAX_LOSS = 1
# 是否开启warmup学习策略
cfg.SOLVER.LR_WARMUP = False
# warmup的迭代次数
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册