table_att_head.py 5.6 KB
Newer Older
M
MissPenguin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np

文幕地方's avatar
文幕地方 已提交
24 25
from .rec_att_head import AttentionGRUCell

M
refine  
MissPenguin 已提交
26

M
MissPenguin 已提交
27
class TableAttentionHead(nn.Layer):
28 29 30 31 32
    def __init__(self,
                 in_channels,
                 hidden_size,
                 loc_type,
                 in_max_len=488,
文幕地方's avatar
文幕地方 已提交
33
                 max_text_length=800,
文幕地方's avatar
fix bug  
文幕地方 已提交
34 35
                 out_channels=30,
                 point_num=2,
36
                 **kwargs):
M
MissPenguin 已提交
37 38 39
        super(TableAttentionHead, self).__init__()
        self.input_size = in_channels[-1]
        self.hidden_size = hidden_size
文幕地方's avatar
fix bug  
文幕地方 已提交
40
        self.out_channels = out_channels
41
        self.max_text_length = max_text_length
M
MissPenguin 已提交
42 43

        self.structure_attention_cell = AttentionGRUCell(
文幕地方's avatar
fix bug  
文幕地方 已提交
44 45
            self.input_size, hidden_size, self.out_channels, use_gru=False)
        self.structure_generator = nn.Linear(hidden_size, self.out_channels)
M
MissPenguin 已提交
46 47
        self.loc_type = loc_type
        self.in_max_len = in_max_len
48

M
MissPenguin 已提交
49 50 51 52
        if self.loc_type == 1:
            self.loc_generator = nn.Linear(hidden_size, 4)
        else:
            if self.in_max_len == 640:
文幕地方's avatar
文幕地方 已提交
53
                self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
M
MissPenguin 已提交
54
            elif self.in_max_len == 800:
文幕地方's avatar
文幕地方 已提交
55
                self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
M
MissPenguin 已提交
56
            else:
文幕地方's avatar
文幕地方 已提交
57
                self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
文幕地方's avatar
fix bug  
文幕地方 已提交
58 59
            self.loc_generator = nn.Linear(self.input_size + hidden_size,
                                           point_num * 2)
60

M
MissPenguin 已提交
61 62 63 64
    def _char_to_onehot(self, input_char, onehot_dim):
        input_ont_hot = F.one_hot(input_char, onehot_dim)
        return input_ont_hot

M
refine  
MissPenguin 已提交
65
    def forward(self, inputs, targets=None):
M
MissPenguin 已提交
66 67 68 69 70 71
        # if and else branch are both needed when you want to assign a variable
        # if you modify the var in just one branch, then the modification will not work.
        fea = inputs[-1]
        if len(fea.shape) == 3:
            pass
        else:
72
            last_shape = int(np.prod(fea.shape[2:]))  # gry added
M
MissPenguin 已提交
73 74 75
            fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
            fea = fea.transpose([0, 2, 1])  # (NTC)(batch, width, channels)
        batch_size = fea.shape[0]
76

M
MissPenguin 已提交
77 78
        hidden = paddle.zeros((batch_size, self.hidden_size))
        output_hiddens = []
M
refine  
MissPenguin 已提交
79
        if self.training and targets is not None:
M
MissPenguin 已提交
80
            structure = targets[0]
文幕地方's avatar
文幕地方 已提交
81
            for i in range(self.max_text_length + 1):
M
MissPenguin 已提交
82
                elem_onehots = self._char_to_onehot(
文幕地方's avatar
fix bug  
文幕地方 已提交
83
                    structure[:, i], onehot_dim=self.out_channels)
M
MissPenguin 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
                (outputs, hidden), alpha = self.structure_attention_cell(
                    hidden, fea, elem_onehots)
                output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
            output = paddle.concat(output_hiddens, axis=1)
            structure_probs = self.structure_generator(output)
            if self.loc_type == 1:
                loc_preds = self.loc_generator(output)
                loc_preds = F.sigmoid(loc_preds)
            else:
                loc_fea = fea.transpose([0, 2, 1])
                loc_fea = self.loc_fea_trans(loc_fea)
                loc_fea = loc_fea.transpose([0, 2, 1])
                loc_concat = paddle.concat([output, loc_fea], axis=2)
                loc_preds = self.loc_generator(loc_concat)
                loc_preds = F.sigmoid(loc_preds)
        else:
            temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
            structure_probs = None
            loc_preds = None
            elem_onehots = None
            outputs = None
            alpha = None
文幕地方's avatar
文幕地方 已提交
106
            max_text_length = paddle.to_tensor(self.max_text_length)
M
MissPenguin 已提交
107
            i = 0
文幕地方's avatar
文幕地方 已提交
108
            while i < max_text_length + 1:
M
MissPenguin 已提交
109
                elem_onehots = self._char_to_onehot(
文幕地方's avatar
fix bug  
文幕地方 已提交
110
                    temp_elem, onehot_dim=self.out_channels)
M
MissPenguin 已提交
111 112 113 114 115 116
                (outputs, hidden), alpha = self.structure_attention_cell(
                    hidden, fea, elem_onehots)
                output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
                structure_probs_step = self.structure_generator(outputs)
                temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
                i += 1
117

M
MissPenguin 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130
            output = paddle.concat(output_hiddens, axis=1)
            structure_probs = self.structure_generator(output)
            structure_probs = F.softmax(structure_probs)
            if self.loc_type == 1:
                loc_preds = self.loc_generator(output)
                loc_preds = F.sigmoid(loc_preds)
            else:
                loc_fea = fea.transpose([0, 2, 1])
                loc_fea = self.loc_fea_trans(loc_fea)
                loc_fea = loc_fea.transpose([0, 2, 1])
                loc_concat = paddle.concat([output, loc_fea], axis=2)
                loc_preds = self.loc_generator(loc_concat)
                loc_preds = F.sigmoid(loc_preds)
131
        return {'structure_probs': structure_probs, 'loc_preds': loc_preds}