DB_loss.py 2.0 KB
Newer Older
Z
zhoujun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
import paddle
from models.losses.basic_loss import BalanceCrossEntropyLoss, MaskL1Loss, DiceLoss


class DBLoss(paddle.nn.Layer):
    def __init__(self,
                 alpha=1.0,
                 beta=10,
                 ohem_ratio=3,
                 reduction='mean',
                 eps=1e-06):
        """
        Implement PSE Loss.
        :param alpha: binary_map loss 前面的系数
        :param beta: threshold_map loss 前面的系数
        :param ohem_ratio: OHEM的比例
        :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
        """
        super().__init__()
        assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
        self.alpha = alpha
        self.beta = beta
        self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
        self.dice_loss = DiceLoss(eps=eps)
        self.l1_loss = MaskL1Loss(eps=eps)
        self.ohem_ratio = ohem_ratio
        self.reduction = reduction

    def forward(self, pred, batch):
        shrink_maps = pred[:, 0, :, :]
        threshold_maps = pred[:, 1, :, :]
        binary_maps = pred[:, 2, :, :]
        loss_shrink_maps = self.bce_loss(shrink_maps, batch['shrink_map'],
                                         batch['shrink_mask'])
        loss_threshold_maps = self.l1_loss(
            threshold_maps, batch['threshold_map'], batch['threshold_mask'])
        metrics = dict(
            loss_shrink_maps=loss_shrink_maps,
            loss_threshold_maps=loss_threshold_maps)
        if pred.shape[1] > 2:
            loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'],
                                              batch['shrink_mask'])
            metrics['loss_binary_maps'] = loss_binary_maps
            loss_all = (self.alpha * loss_shrink_maps + self.beta *
                        loss_threshold_maps + loss_binary_maps)
            metrics['loss'] = loss_all
        else:
            metrics['loss'] = loss_shrink_maps
        return metrics