提交 b1e8b3ef 编写于 作者: C chulutao

Support MULTI_LOSS_WEIGHT for lovasz loss

上级 3f658a36
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import numpy as np
from utils.config import cfg
def _cumsum(x):
......@@ -203,3 +204,43 @@ def flatten_probas(probas, labels, ignore=None):
vprobas = fluid.layers.gather(probas, indexs[:, 0])
vlabels = fluid.layers.gather(labels, indexs[:, 0])
return vprobas, vlabels
def multi_lovasz_softmax_loss(logits, label, ignore_mask=None):
if isinstance(logits, tuple):
avg_loss = 0
for i, logit in enumerate(logits):
if label.shape[2] != logit.shape[2] or label.shape[
3] != logit.shape[3]:
logit_label = fluid.layers.resize_nearest(
label, logit.shape[2:])
else:
logit_label = label
logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32')
probas = fluid.layers.softmax(logit, axis=1)
loss = lovasz_softmax(probas, logit_label, ignore=logit_mask)
avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss
else:
probas = fluid.layers.softmax(logits, axis=1)
avg_loss = lovasz_softmax(probas, label, ignore=ignore_mask)
return avg_loss
def multi_lovasz_hinge_loss(logits, label, ignore_mask=None):
if isinstance(logits, tuple):
avg_loss = 0
for i, logit in enumerate(logits):
if label.shape[2] != logit.shape[2] or label.shape[
3] != logit.shape[3]:
logit_label = fluid.layers.resize_nearest(
label, logit.shape[2:])
else:
logit_label = label
logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32')
loss = lovasz_hinge(logit, logit_label, ignore=logit_mask)
avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss
else:
avg_loss = lovasz_hinge(logits, label, ignore=ignore_mask)
return avg_loss
......@@ -24,8 +24,7 @@ from utils.config import cfg
from loss import multi_softmax_with_loss
from loss import multi_dice_loss
from loss import multi_bce_loss
from lovasz_losses import lovasz_hinge
from lovasz_losses import lovasz_softmax
from lovasz_losses import multi_lovasz_hinge_loss, multi_lovasz_softmax_loss
from models.modeling import deeplab, unet, icnet, pspnet, hrnet, fast_scnn, ocrnet
......@@ -189,13 +188,12 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
valid_loss.append("bce_loss")
if "lovasz_hinge_loss" in loss_type:
avg_loss_list.append(
lovasz_hinge(logits, label, ignore=mask))
multi_lovasz_hinge_loss(logits, label, mask))
loss_valid = True
valid_loss.append("lovasz_hinge_loss")
if "lovasz_softmax_loss" in loss_type:
probas = fluid.layers.softmax(logits, axis=1)
avg_loss_list.append(
lovasz_softmax(probas, label, ignore=mask))
multi_lovasz_softmax_loss(logits, label, mask))
loss_valid = True
valid_loss.append("lovasz_softmax_loss")
if not loss_valid:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册