提交 fca76c3a 编写于 作者: W wuyefeilin 提交者: wuzewu

add dice loss and bce loss (#52)

* update solver_group.md
上级 3c65454c
...@@ -122,3 +122,40 @@ L2正则化系数 ...@@ -122,3 +122,40 @@ L2正则化系数
<br/> <br/>
<br/> <br/>
## `loss`
训练时选择的损失函数, 支持`softmax_loss(sotfmax with cross entroy loss)`,
`dice_loss(dice coefficient loss)`, `bce_loss(binary cross entroy loss)`三种损失函数。
其中`dice_loss``bce_loss`仅在两类分割问题中适用,`softmax_loss`不能与`dice_loss`
`bce_loss`组合,`dice_loss`可以和`bce_loss`组合使用。使用示例如下:
`['softmax_loss']``['dice_loss','bce_loss']`
* softmax_loss
![equation](http://latex.codecogs.com/gif.latex?softmax\\_loss=\sum_{i=1}^Ny_i{log(p_i)})
<br/>
* dice_loss
![equation](http://latex.codecogs.com/gif.latex?dice\\_loss=1-\frac{2|Y\bigcap{P}|}{|Y|\bigcup|P|})
[dice系数](https://zh.wikipedia.org/wiki/Dice%E7%B3%BB%E6%95%B0)
<br/>
* bce_loss
![equation](http://latex.codecogs.com/gif.latex?bce\\_loss=y_i{log(p_i)}+(1-y_i)log(1-p_i))
其中![equation](http://latex.codecogs.com/gif.latex?y_i)*Y*为标签,
![equation](http://latex.codecogs.com/gif.latex?p_i)*P*为预测结果
### 默认值
['softmax_loss']
<br/>
<br/>
...@@ -48,6 +48,42 @@ def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2): ...@@ -48,6 +48,42 @@ def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2):
ignore_mask.stop_gradient = True ignore_mask.stop_gradient = True
return avg_loss return avg_loss
# to change, how to appicate ignore index and ignore mask
def dice_loss(logit, label, ignore_mask=None, epsilon=0.00001):
if logit.shape[1] != 1 or label.shape[1] != 1 or ignore_mask.shape[1] != 1:
raise Exception("dice loss is only applicable to one channel classfication")
ignore_mask = fluid.layers.cast(ignore_mask, 'float32')
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
label = fluid.layers.cast(label, 'int64')
ignore_mask = fluid.layers.transpose(ignore_mask, [0, 2, 3, 1])
logit = fluid.layers.sigmoid(logit)
logit = logit * ignore_mask
label = label * ignore_mask
reduce_dim = list(range(1, len(logit.shape)))
inse = fluid.layers.reduce_sum(logit * label, dim=reduce_dim)
dice_denominator = fluid.layers.reduce_sum(
logit, dim=reduce_dim) + fluid.layers.reduce_sum(
label, dim=reduce_dim)
dice_score = 1 - inse * 2 / (dice_denominator + epsilon)
label.stop_gradient = True
ignore_mask.stop_gradient = True
return fluid.layers.reduce_mean(dice_score)
def bce_loss(logit, label, ignore_mask=None):
if logit.shape[1] != 1 or label.shape[1] != 1 or ignore_mask.shape[1] != 1:
raise Exception("bce loss is only applicable to binary classfication")
label = fluid.layers.cast(label, 'float32')
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logit,
label=label,
ignore_index=cfg.DATASET.IGNORE_INDEX,
normalize=True) # or False
loss = fluid.layers.reduce_sum(loss)
label.stop_gradient = True
ignore_mask.stop_gradient = True
return loss
def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2): def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2):
if isinstance(logits, tuple): if isinstance(logits, tuple):
...@@ -63,19 +99,28 @@ def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2): ...@@ -63,19 +99,28 @@ def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2):
avg_loss = softmax_with_loss(logits, label, ignore_mask, num_classes) avg_loss = softmax_with_loss(logits, label, ignore_mask, num_classes)
return avg_loss return avg_loss
def multi_dice_loss(logits, label, ignore_mask=None):
if isinstance(logits, tuple):
avg_loss = 0
for i, logit in enumerate(logits):
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32')
loss = dice_loss(logit, logit_label, logit_mask)
avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss
else:
avg_loss = dice_loss(logits, label, ignore_mask)
return avg_loss
# to change, how to appicate ignore index and ignore mask def multi_bce_loss(logits, label, ignore_mask=None):
def dice_loss(logit, label, ignore_mask=None, num_classes=2): if isinstance(logits, tuple):
if num_classes != 2: avg_loss = 0
raise Exception("dice loss is only applicable to binary classfication") for i, logit in enumerate(logits):
ignore_mask = fluid.layers.cast(ignore_mask, 'float32') logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
label = fluid.layers.elementwise_min( logit_mask = (logit_label.astype('int32') !=
label, fluid.layers.assign(np.array([num_classes - 1], dtype=np.int32))) cfg.DATASET.IGNORE_INDEX).astype('int32')
logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) loss = bce_loss(logit, logit_label, logit_mask)
logit = fluid.layers.reshape(logit, [-1, num_classes]) avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss
logit = fluid.layers.softmax(logit) else:
label = fluid.layers.reshape(label, [-1, 1]) avg_loss = bce_loss(logits, label, ignore_mask)
label = fluid.layers.cast(label, 'int64') return avg_loss
ignore_mask = fluid.layers.reshape(ignore_mask, [-1, 1])
loss = fluid.layers.dice_loss(logit, label)
return loss
...@@ -24,6 +24,8 @@ from paddle.fluid.proto.framework_pb2 import VarType ...@@ -24,6 +24,8 @@ from paddle.fluid.proto.framework_pb2 import VarType
import solver import solver
from utils.config import cfg from utils.config import cfg
from loss import multi_softmax_with_loss from loss import multi_softmax_with_loss
from loss import multi_dice_loss
from loss import multi_bce_loss
class ModelPhase(object): class ModelPhase(object):
...@@ -109,6 +111,17 @@ def softmax(logit): ...@@ -109,6 +111,17 @@ def softmax(logit):
logit = fluid.layers.transpose(logit, [0, 3, 1, 2]) logit = fluid.layers.transpose(logit, [0, 3, 1, 2])
return logit return logit
def sigmoid_to_softmax(logit):
"""
one channel to two channel
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
logit = fluid.layers.sigmoid(logit)
logit_back = 1 - logit
logit = fluid.layers.concat([logit_back, logit], axis=-1)
logit = fluid.layers.transpose(logit, [0, 3, 1, 2])
return logit
def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
if not ModelPhase.is_valid_phase(phase): if not ModelPhase.is_valid_phase(phase):
...@@ -144,11 +157,35 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): ...@@ -144,11 +157,35 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
image = fluid.layers.cast(image, "float16") image = fluid.layers.cast(image, "float16")
model_name = map_model_name(cfg.MODEL.MODEL_NAME) model_name = map_model_name(cfg.MODEL.MODEL_NAME)
model_func = get_func("modeling." + model_name) model_func = get_func("modeling." + model_name)
loss_type = cfg.SOLVER.LOSS
if ("dice_loss" in loss_type) or ("bce_loss" in loss_type):
class_num = 1
if "softmax_loss" in loss_type:
raise Exception("softmax loss can not combine with dice loss or bce loss")
logits = model_func(image, class_num) logits = model_func(image, class_num)
if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase): if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase):
avg_loss = multi_softmax_with_loss(logits, label, mask, loss_valid = False
class_num) avg_loss_list = []
if "softmax_loss" in loss_type:
avg_loss_list.append(multi_softmax_with_loss(logits,
label, mask,class_num))
loss_valid = True
if "dice_loss" in loss_type:
avg_loss_list.append(multi_dice_loss(logits, label, mask))
loss_valid = True
if "bce_loss" in loss_type:
avg_loss_list.append(multi_bce_loss(logits, label, mask))
loss_valid = True
if not loss_valid:
raise Exception("SOLVER.LOSS: {} is set wrong. it should "
"include one of (softmax_loss, bce_loss, dice_loss) at least"
" example: ['softmax_loss'], ['dice_loss'], ['bce_loss', 'dice_loss']".format(cfg.SOLVER.LOSS))
avg_loss = 0
for i in range(0, len(avg_loss_list)):
avg_loss += avg_loss_list[i]
#get pred result in original size #get pred result in original size
if isinstance(logits, tuple): if isinstance(logits, tuple):
...@@ -161,16 +198,24 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): ...@@ -161,16 +198,24 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
# return image input and logit output for inference graph prune # return image input and logit output for inference graph prune
if ModelPhase.is_predict(phase): if ModelPhase.is_predict(phase):
if class_num == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = softmax(logit) logit = softmax(logit)
return image, logit return image, logit
if class_num == 1:
out = fluid.layers.transpose(x=logit, perm=[0, 2, 3, 1]) out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
if cfg.MODEL.FP16: if cfg.MODEL.FP16:
out = fluid.layers.cast(out, 'float32') out = fluid.layers.cast(out, 'float32')
pred = fluid.layers.argmax(out, axis=3) pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3]) pred = fluid.layers.unsqueeze(pred, axes=[3])
if ModelPhase.is_visual(phase): if ModelPhase.is_visual(phase):
if class_num == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = softmax(logit) logit = softmax(logit)
return pred, logit return pred, logit
......
...@@ -149,6 +149,8 @@ cfg.SOLVER.WEIGHT_DECAY = 0.00004 ...@@ -149,6 +149,8 @@ cfg.SOLVER.WEIGHT_DECAY = 0.00004
cfg.SOLVER.BEGIN_EPOCH = 1 cfg.SOLVER.BEGIN_EPOCH = 1
# 训练epoch数,正整数 # 训练epoch数,正整数
cfg.SOLVER.NUM_EPOCHS = 30 cfg.SOLVER.NUM_EPOCHS = 30
# loss的选择,支持softmax_loss, bce_loss, dice_loss
cfg.SOLVER.LOSS = ["softmax_loss"]
########################## 测试配置 ########################################### ########################## 测试配置 ###########################################
# 测试模型路径 # 测试模型路径
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册