table_att_loss.py 4.4 KB
Newer Older
M
MissPenguin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
from paddle import nn
from paddle.nn import functional as F

文幕地方's avatar
fix bug  
文幕地方 已提交
23

M
MissPenguin 已提交
24
class TableAttentionLoss(nn.Layer):
文幕地方's avatar
fix bug  
文幕地方 已提交
25 26 27 28 29 30
    def __init__(self,
                 structure_weight,
                 loc_weight,
                 use_giou=False,
                 giou_weight=1.0,
                 **kwargs):
M
MissPenguin 已提交
31 32 33 34 35 36
        super(TableAttentionLoss, self).__init__()
        self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
        self.structure_weight = structure_weight
        self.loc_weight = loc_weight
        self.use_giou = use_giou
        self.giou_weight = giou_weight
文幕地方's avatar
fix bug  
文幕地方 已提交
37

M
MissPenguin 已提交
38 39 40 41 42 43
    def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
        '''
        :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
        :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
        :return: loss
        '''
T
tink2123 已提交
44 45 46 47
        ix1 = paddle.maximum(preds[:, 0], bbox[:, 0])
        iy1 = paddle.maximum(preds[:, 1], bbox[:, 1])
        ix2 = paddle.minimum(preds[:, 2], bbox[:, 2])
        iy2 = paddle.minimum(preds[:, 3], bbox[:, 3])
M
MissPenguin 已提交
48

T
tink2123 已提交
49 50
        iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10)
        ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10)
M
MissPenguin 已提交
51 52 53 54 55

        # overlap
        inters = iw * ih

        # union
文幕地方's avatar
fix bug  
文幕地方 已提交
56 57 58 59
        uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
            preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
                                                 ) * (bbox[:, 3] - bbox[:, 1] +
                                                      1e-3) - inters + eps
M
MissPenguin 已提交
60 61 62 63

        # ious
        ious = inters / uni

T
tink2123 已提交
64 65 66 67 68 69
        ex1 = paddle.minimum(preds[:, 0], bbox[:, 0])
        ey1 = paddle.minimum(preds[:, 1], bbox[:, 1])
        ex2 = paddle.maximum(preds[:, 2], bbox[:, 2])
        ey2 = paddle.maximum(preds[:, 3], bbox[:, 3])
        ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10)
        eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10)
M
MissPenguin 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

        # enclose erea
        enclose = ew * eh + eps
        giou = ious - (enclose - uni) / enclose

        loss = 1 - giou

        if reduction == 'mean':
            loss = paddle.mean(loss)
        elif reduction == 'sum':
            loss = paddle.sum(loss)
        else:
            raise NotImplementedError
        return loss

    def forward(self, predicts, batch):
        structure_probs = predicts['structure_probs']
        structure_targets = batch[1].astype("int64")
        structure_targets = structure_targets[:, 1:]
文幕地方's avatar
fix bug  
文幕地方 已提交
89 90
        structure_probs = paddle.reshape(structure_probs,
                                         [-1, structure_probs.shape[-1]])
M
MissPenguin 已提交
91 92
        structure_targets = paddle.reshape(structure_targets, [-1])
        structure_loss = self.loss_func(structure_probs, structure_targets)
文幕地方's avatar
fix bug  
文幕地方 已提交
93

M
MissPenguin 已提交
94
        structure_loss = paddle.mean(structure_loss) * self.structure_weight
文幕地方's avatar
fix bug  
文幕地方 已提交
95

M
MissPenguin 已提交
96 97
        loc_preds = predicts['loc_preds']
        loc_targets = batch[2].astype("float32")
文幕地方's avatar
fix bug  
文幕地方 已提交
98
        loc_targets_mask = batch[3].astype("float32")
M
MissPenguin 已提交
99 100
        loc_targets = loc_targets[:, 1:, :]
        loc_targets_mask = loc_targets_mask[:, 1:, :]
文幕地方's avatar
fix bug  
文幕地方 已提交
101 102
        loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
                              loc_targets) * self.loc_weight
M
MissPenguin 已提交
103
        if self.use_giou:
文幕地方's avatar
fix bug  
文幕地方 已提交
104 105
            loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
                                           loc_targets) * self.giou_weight
M
MissPenguin 已提交
106
            total_loss = structure_loss + loc_loss + loc_loss_giou
文幕地方's avatar
fix bug  
文幕地方 已提交
107 108 109 110 111 112
            return {
                'loss': total_loss,
                "structure_loss": structure_loss,
                "loc_loss": loc_loss,
                "loc_loss_giou": loc_loss_giou
            }
M
MissPenguin 已提交
113
        else:
文幕地方's avatar
fix bug  
文幕地方 已提交
114 115 116 117 118 119
            total_loss = structure_loss + loc_loss
            return {
                'loss': total_loss,
                "structure_loss": structure_loss,
                "loc_loss": loc_loss
            }