table_att_loss.py 4.5 KB
Newer Older
M
MissPenguin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle import fluid

class TableAttentionLoss(nn.Layer):
    def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
        super(TableAttentionLoss, self).__init__()
        self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
        self.structure_weight = structure_weight
        self.loc_weight = loc_weight
        self.use_giou = use_giou
        self.giou_weight = giou_weight
        
    def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
        '''
        :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
        :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
        :return: loss
        '''
        ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
        iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
        ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
        iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])

        iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
        ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)

        # overlap
        inters = iw * ih

        # union
        uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
            ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
            bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps

        # ious
        ious = inters / uni

        ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
        ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
        ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
        ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
        ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
        eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)

        # enclose erea
        enclose = ew * eh + eps
        giou = ious - (enclose - uni) / enclose

        loss = 1 - giou

        if reduction == 'mean':
            loss = paddle.mean(loss)
        elif reduction == 'sum':
            loss = paddle.sum(loss)
        else:
            raise NotImplementedError
        return loss

    def forward(self, predicts, batch):
        structure_probs = predicts['structure_probs']
        structure_targets = batch[1].astype("int64")
        structure_targets = structure_targets[:, 1:]
        if len(batch) == 6:
            structure_mask = batch[5].astype("int64")
            structure_mask = structure_mask[:, 1:]
            structure_mask = paddle.reshape(structure_mask, [-1])
        structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
        structure_targets = paddle.reshape(structure_targets, [-1])
        structure_loss = self.loss_func(structure_probs, structure_targets)
        
        if len(batch) == 6:
             structure_loss = structure_loss * structure_mask
            
#         structure_loss = paddle.sum(structure_loss) * self.structure_weight
        structure_loss = paddle.mean(structure_loss) * self.structure_weight
        
        loc_preds = predicts['loc_preds']
        loc_targets = batch[2].astype("float32")
        loc_targets_mask = batch[4].astype("float32")
        loc_targets = loc_targets[:, 1:, :]
        loc_targets_mask = loc_targets_mask[:, 1:, :]
        loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
        if self.use_giou:
            loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
            total_loss = structure_loss + loc_loss + loc_loss_giou
            return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
        else:
            total_loss = structure_loss + loc_loss            
            return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}