det_db_loss.py 3.2 KB
Newer Older
W
WenmuZhou 已提交
1
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
L
LDOUBLEV 已提交
2
#
W
WenmuZhou 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
L
LDOUBLEV 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
W
WenmuZhou 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WenmuZhou 已提交
14 15 16 17
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
"""
L
LDOUBLEV 已提交
18 19 20 21 22

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

H
huangjun12 已提交
23
import paddle
W
WenmuZhou 已提交
24 25
from paddle import nn

L
LDOUBLEV 已提交
26 27 28
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss


W
WenmuZhou 已提交
29
class DBLoss(nn.Layer):
L
LDOUBLEV 已提交
30 31 32 33 34 35
    """
    Differentiable Binarization (DB) Loss Function
    args:
        param (dict): the super paramter for DB Loss
    """

W
WenmuZhou 已提交
36 37 38 39 40 41 42 43
    def __init__(self,
                 balance_loss=True,
                 main_loss_type='DiceLoss',
                 alpha=5,
                 beta=10,
                 ohem_ratio=3,
                 eps=1e-6,
                 **kwargs):
L
LDOUBLEV 已提交
44
        super(DBLoss, self).__init__()
W
WenmuZhou 已提交
45 46 47 48 49 50 51 52
        self.alpha = alpha
        self.beta = beta
        self.dice_loss = DiceLoss(eps=eps)
        self.l1_loss = MaskL1Loss(eps=eps)
        self.bce_loss = BalanceLoss(
            balance_loss=balance_loss,
            main_loss_type=main_loss_type,
            negative_ratio=ohem_ratio)
L
LDOUBLEV 已提交
53

W
WenmuZhou 已提交
54
    def forward(self, predicts, labels):
W
update  
WenmuZhou 已提交
55
        predict_maps = predicts['maps']
W
WenmuZhou 已提交
56 57
        label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
            1:]
W
update  
WenmuZhou 已提交
58 59 60
        shrink_maps = predict_maps[:, 0, :, :]
        threshold_maps = predict_maps[:, 1, :, :]
        binary_maps = predict_maps[:, 2, :, :]
L
LDOUBLEV 已提交
61

W
WenmuZhou 已提交
62 63 64 65 66 67
        loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
                                         label_shrink_mask)
        loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
                                           label_threshold_mask)
        loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
                                          label_shrink_mask)
L
LDOUBLEV 已提交
68 69
        loss_shrink_maps = self.alpha * loss_shrink_maps
        loss_threshold_maps = self.beta * loss_threshold_maps
H
huangjun12 已提交
70 71 72 73 74 75 76 77 78
        # CBN loss
        if 'distance_maps' in predicts.keys():
            distance_maps = predicts['distance_maps']
            cbn_maps = predicts['cbn_maps']
            cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map,
                                     label_shrink_mask)
        else:
            dis_loss = paddle.to_tensor([0.])
            cbn_loss = paddle.to_tensor([0.])
L
LDOUBLEV 已提交
79

W
WenmuZhou 已提交
80 81
        loss_all = loss_shrink_maps + loss_threshold_maps \
                   + loss_binary_maps
H
huangjun12 已提交
82
        losses = {'loss': loss_all+ cbn_loss, \
W
WenmuZhou 已提交
83 84
                  "loss_shrink_maps": loss_shrink_maps, \
                  "loss_threshold_maps": loss_threshold_maps, \
H
huangjun12 已提交
85 86
                  "loss_binary_maps": loss_binary_maps, \
                  "loss_cbn": cbn_loss}
L
LDOUBLEV 已提交
87
        return losses