未验证 提交 90a441e2 编写于 作者: F FlyingQianMM 提交者: GitHub

add epsilon to denominator when computing loss (#292)

* add 0.00001 to denominator when computing loss

* use DEFAULT_EPSILON
上级 b6fbe07f
...@@ -77,7 +77,7 @@ def softmax_with_loss(logit, ...@@ -77,7 +77,7 @@ def softmax_with_loss(logit,
weighted_label_one_hot.stop_gradient = True weighted_label_one_hot.stop_gradient = True
loss = loss * ignore_mask loss = loss * ignore_mask
avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask) avg_loss = fluid.layers.mean(loss) / (fluid.layers.mean(ignore_mask) + cfg.MODEL.DEFAULT_EPSILON)
label.stop_gradient = True label.stop_gradient = True
ignore_mask.stop_gradient = True ignore_mask.stop_gradient = True
...@@ -133,10 +133,12 @@ def multi_softmax_with_loss(logits, ...@@ -133,10 +133,12 @@ def multi_softmax_with_loss(logits,
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
if label.shape[2] != logit.shape[2] or label.shape[ if label.shape[2] != logit.shape[2] or label.shape[
3] != logit.shape[3]: 3] != logit.shape[3]:
label = fluid.layers.resize_nearest(label, logit.shape[2:]) logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
logit_mask = (label.astype('int32') != else:
logit_label = label
logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32') cfg.DATASET.IGNORE_INDEX).astype('int32')
loss = softmax_with_loss(logit, label, logit_mask, num_classes) loss = softmax_with_loss(logit, logit_label, logit_mask, num_classes, weight=weight)
avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss
else: else:
avg_loss = softmax_with_loss( avg_loss = softmax_with_loss(
...@@ -148,7 +150,11 @@ def multi_dice_loss(logits, label, ignore_mask=None): ...@@ -148,7 +150,11 @@ def multi_dice_loss(logits, label, ignore_mask=None):
if isinstance(logits, tuple): if isinstance(logits, tuple):
avg_loss = 0 avg_loss = 0
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:]) if label.shape[2] != logit.shape[2] or label.shape[
3] != logit.shape[3]:
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
else:
logit_label = label
logit_mask = (logit_label.astype('int32') != logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32') cfg.DATASET.IGNORE_INDEX).astype('int32')
loss = dice_loss(logit, logit_label, logit_mask) loss = dice_loss(logit, logit_label, logit_mask)
...@@ -162,7 +168,11 @@ def multi_bce_loss(logits, label, ignore_mask=None): ...@@ -162,7 +168,11 @@ def multi_bce_loss(logits, label, ignore_mask=None):
if isinstance(logits, tuple): if isinstance(logits, tuple):
avg_loss = 0 avg_loss = 0
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:]) if label.shape[2] != logit.shape[2] or label.shape[
3] != logit.shape[3]:
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
else:
logit_label = label
logit_mask = (logit_label.astype('int32') != logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32') cfg.DATASET.IGNORE_INDEX).astype('int32')
loss = bce_loss(logit, logit_label, logit_mask) loss = bce_loss(logit, logit_label, logit_mask)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册