diff --git a/pdseg/loss.py b/pdseg/loss.py index bd8406102a0b942bf4b3d2bd48caba3f57b153bf..c5ea306d4d5b709e1e5b2ffd68c5b8aa287f8bbd 100644 --- a/pdseg/loss.py +++ b/pdseg/loss.py @@ -77,7 +77,7 @@ def softmax_with_loss(logit, weighted_label_one_hot.stop_gradient = True loss = loss * ignore_mask - avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask) + avg_loss = fluid.layers.mean(loss) / (fluid.layers.mean(ignore_mask) + cfg.MODEL.DEFAULT_EPSILON) label.stop_gradient = True ignore_mask.stop_gradient = True @@ -133,10 +133,12 @@ def multi_softmax_with_loss(logits, for i, logit in enumerate(logits): if label.shape[2] != logit.shape[2] or label.shape[ 3] != logit.shape[3]: - label = fluid.layers.resize_nearest(label, logit.shape[2:]) - logit_mask = (label.astype('int32') != + logit_label = fluid.layers.resize_nearest(label, logit.shape[2:]) + else: + logit_label = label + logit_mask = (logit_label.astype('int32') != cfg.DATASET.IGNORE_INDEX).astype('int32') - loss = softmax_with_loss(logit, label, logit_mask, num_classes) + loss = softmax_with_loss(logit, logit_label, logit_mask, num_classes, weight=weight) avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss else: avg_loss = softmax_with_loss( @@ -148,7 +150,11 @@ def multi_dice_loss(logits, label, ignore_mask=None): if isinstance(logits, tuple): avg_loss = 0 for i, logit in enumerate(logits): - logit_label = fluid.layers.resize_nearest(label, logit.shape[2:]) + if label.shape[2] != logit.shape[2] or label.shape[ + 3] != logit.shape[3]: + logit_label = fluid.layers.resize_nearest(label, logit.shape[2:]) + else: + logit_label = label logit_mask = (logit_label.astype('int32') != cfg.DATASET.IGNORE_INDEX).astype('int32') loss = dice_loss(logit, logit_label, logit_mask) @@ -162,7 +168,11 @@ def multi_bce_loss(logits, label, ignore_mask=None): if isinstance(logits, tuple): avg_loss = 0 for i, logit in enumerate(logits): - logit_label = fluid.layers.resize_nearest(label, logit.shape[2:]) + if label.shape[2] != logit.shape[2] or label.shape[ + 3] != logit.shape[3]: + logit_label = fluid.layers.resize_nearest(label, logit.shape[2:]) + else: + logit_label = label logit_mask = (logit_label.astype('int32') != cfg.DATASET.IGNORE_INDEX).astype('int32') loss = bce_loss(logit, logit_label, logit_mask)