rec_sar_head.py 13.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from: 
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
"""

A
andyjpaddle 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F


class SAREncoder(nn.Layer):
    """
    Args:
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
        enc_gru (bool): If True, use GRU, else LSTM in encoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        mask (bool): If True, mask padding in RNN sequence.
    """
A
andyjpaddle 已提交
41

A
andyjpaddle 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    def __init__(self,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 d_model=512,
                 d_enc=512,
                 mask=True,
                 **kwargs):
        super().__init__()
        assert isinstance(enc_bi_rnn, bool)
        assert isinstance(enc_drop_rnn, (int, float))
        assert 0 <= enc_drop_rnn < 1.0
        assert isinstance(enc_gru, bool)
        assert isinstance(d_model, int)
        assert isinstance(d_enc, int)
        assert isinstance(mask, bool)

        self.enc_bi_rnn = enc_bi_rnn
        self.enc_drop_rnn = enc_drop_rnn
        self.mask = mask

        # LSTM Encoder
        if enc_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'
        kwargs = dict(
            input_size=d_model,
            hidden_size=d_enc,
            num_layers=2,
            time_major=False,
            dropout=enc_drop_rnn,
A
andyjpaddle 已提交
74
            direction=direction)
A
andyjpaddle 已提交
75 76 77 78
        if enc_gru:
            self.rnn_encoder = nn.GRU(**kwargs)
        else:
            self.rnn_encoder = nn.LSTM(**kwargs)
A
andyjpaddle 已提交
79

A
andyjpaddle 已提交
80 81 82
        # global feature transformation
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
A
andyjpaddle 已提交
83

A
andyjpaddle 已提交
84 85 86
    def forward(self, feat, img_metas=None):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]
A
andyjpaddle 已提交
87

A
andyjpaddle 已提交
88 89 90
        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
A
andyjpaddle 已提交
91 92

        h_feat = feat.shape[2]  # bsz c h w
A
andyjpaddle 已提交
93
        feat_v = F.max_pool2d(
A
andyjpaddle 已提交
94 95 96 97 98
            feat, kernel_size=(h_feat, 1), stride=1, padding=0)
        feat_v = feat_v.squeeze(2)  # bsz * C * W
        feat_v = paddle.transpose(feat_v, perm=[0, 2, 1])  # bsz * W * C
        holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C

A
andyjpaddle 已提交
99 100 101 102 103 104 105 106
        if valid_ratios is not None:
            valid_hf = []
            T = holistic_feat.shape[1]
            for i, valid_ratio in enumerate(valid_ratios):
                valid_step = min(T, math.ceil(T * valid_ratio)) - 1
                valid_hf.append(holistic_feat[i, valid_step, :])
            valid_hf = paddle.stack(valid_hf, axis=0)
        else:
A
andyjpaddle 已提交
107 108 109
            valid_hf = holistic_feat[:, -1, :]  # bsz * C
        holistic_feat = self.linear(valid_hf)  # bsz * C

A
andyjpaddle 已提交
110
        return holistic_feat
A
andyjpaddle 已提交
111

A
andyjpaddle 已提交
112 113 114 115 116 117 118 119 120 121 122

class BaseDecoder(nn.Layer):
    def __init__(self, **kwargs):
        super().__init__()

    def forward_train(self, feat, out_enc, targets, img_metas):
        raise NotImplementedError

    def forward_test(self, feat, out_enc, img_metas):
        raise NotImplementedError

A
andyjpaddle 已提交
123
    def forward(self,
A
andyjpaddle 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
                feat,
                out_enc,
                label=None,
                img_metas=None,
                train_mode=True):
        self.train_mode = train_mode

        if train_mode:
            return self.forward_train(feat, out_enc, label, img_metas)
        return self.forward_test(feat, out_enc, img_metas)


class ParallelSARDecoder(BaseDecoder):
    """
    Args:
A
andyjpaddle 已提交
139
        out_channels (int): Output class number.
A
andyjpaddle 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
        dec_drop_rnn (float): Dropout of RNN layer in decoder.
        dec_gru (bool): If True, use GRU, else LSTM in decoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        d_k (int): Dim of channels of attention module.
        pred_dropout (float): Dropout probability of prediction layer.
        max_seq_len (int): Maximum sequence length for decoding.
        mask (bool): If True, mask padding in feature map.
        start_idx (int): Index of start token.
        padding_idx (int): Index of padding token.
        pred_concat (bool): If True, concat glimpse feature from
            attention with holistic feature and hidden state.
    """

A
andyjpaddle 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    def __init__(
            self,
            out_channels,  # 90 + unknown + start + padding
            enc_bi_rnn=False,
            dec_bi_rnn=False,
            dec_drop_rnn=0.0,
            dec_gru=False,
            d_model=512,
            d_enc=512,
            d_k=64,
            pred_dropout=0.1,
            max_text_length=30,
            mask=True,
            pred_concat=True,
            **kwargs):
A
andyjpaddle 已提交
171 172
        super().__init__()

A
andyjpaddle 已提交
173
        self.num_classes = out_channels
A
andyjpaddle 已提交
174 175
        self.enc_bi_rnn = enc_bi_rnn
        self.d_k = d_k
A
andyjpaddle 已提交
176
        self.start_idx = out_channels - 2
A
andyjpaddle 已提交
177
        self.padding_idx = out_channels - 1
A
andyjpaddle 已提交
178 179 180 181 182 183 184 185 186
        self.max_seq_len = max_text_length
        self.mask = mask
        self.pred_concat = pred_concat

        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)

        # 2D attention layer
        self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
A
andyjpaddle 已提交
187 188
        self.conv3x3_1 = nn.Conv2D(
            d_model, d_k, kernel_size=3, stride=1, padding=1)
A
andyjpaddle 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202
        self.conv1x1_2 = nn.Linear(d_k, 1)

        # Decoder RNN layer
        if dec_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'

        kwargs = dict(
            input_size=encoder_rnn_out_size,
            hidden_size=encoder_rnn_out_size,
            num_layers=2,
            time_major=False,
            dropout=dec_drop_rnn,
A
andyjpaddle 已提交
203
            direction=direction)
A
andyjpaddle 已提交
204 205 206 207 208 209 210
        if dec_gru:
            self.rnn_decoder = nn.GRU(**kwargs)
        else:
            self.rnn_decoder = nn.LSTM(**kwargs)

        # Decoder input embedding
        self.embedding = nn.Embedding(
A
andyjpaddle 已提交
211 212 213 214
            self.num_classes,
            encoder_rnn_out_size,
            padding_idx=self.padding_idx)

A
andyjpaddle 已提交
215 216
        # Prediction layer
        self.pred_dropout = nn.Dropout(pred_dropout)
A
andyjpaddle 已提交
217
        pred_num_classes = self.num_classes - 1
A
andyjpaddle 已提交
218
        if pred_concat:
A
andyjpaddle 已提交
219
            fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
A
andyjpaddle 已提交
220 221 222 223 224 225 226 227 228
        else:
            fc_in_channel = d_model
        self.prediction = nn.Linear(fc_in_channel, pred_num_classes)

    def _2d_attention(self,
                      decoder_input,
                      feat,
                      holistic_feat,
                      valid_ratios=None):
A
andyjpaddle 已提交
229

A
andyjpaddle 已提交
230 231
        y = self.rnn_decoder(decoder_input)[0]
        # y: bsz * (seq_len + 1) * hidden_size
A
andyjpaddle 已提交
232 233

        attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size
A
andyjpaddle 已提交
234 235 236 237 238 239 240 241 242 243
        bsz, seq_len, attn_size = attn_query.shape
        attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
        # (bsz, seq_len + 1, attn_size, 1, 1)

        attn_key = self.conv3x3_1(feat)
        # bsz * attn_size * h * w
        attn_key = attn_key.unsqueeze(1)
        # bsz * 1 * attn_size * h * w

        attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
A
andyjpaddle 已提交
244

A
andyjpaddle 已提交
245 246 247 248 249 250 251 252 253 254 255 256
        # bsz * (seq_len + 1) * attn_size * h * w
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
        # bsz * (seq_len + 1) * h * w * attn_size
        attn_weight = self.conv1x1_2(attn_weight)
        # bsz * (seq_len + 1) * h * w * 1
        bsz, T, h, w, c = attn_weight.shape
        assert c == 1

        if valid_ratios is not None:
            # cal mask of attention weight
            for i, valid_ratio in enumerate(valid_ratios):
                valid_width = min(w, math.ceil(w * valid_ratio))
A
andyjpaddle 已提交
257 258
                if valid_width < w:
                    attn_weight[i, :, :, valid_width:, :] = float('-inf')
A
andyjpaddle 已提交
259 260 261

        attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
        attn_weight = F.softmax(attn_weight, axis=-1)
A
andyjpaddle 已提交
262

A
andyjpaddle 已提交
263 264 265 266
        attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
        # attn_weight: bsz * T * c * h * w
        # feat: bsz * c * h * w
A
andyjpaddle 已提交
267 268 269
        attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
                               (3, 4),
                               keepdim=False)
A
andyjpaddle 已提交
270 271 272 273 274
        # bsz * (seq_len + 1) * C

        # Linear transformation
        if self.pred_concat:
            hf_c = holistic_feat.shape[-1]
A
andyjpaddle 已提交
275 276
            holistic_feat = paddle.expand(
                holistic_feat, shape=[bsz, seq_len, hf_c])
A
andyjpaddle 已提交
277 278 279 280 281 282
            y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
        else:
            y = self.prediction(attn_feat)
        # bsz * (seq_len + 1) * num_classes
        if self.train_mode:
            y = self.pred_dropout(y)
A
andyjpaddle 已提交
283

A
andyjpaddle 已提交
284 285 286 287 288 289 290 291 292 293 294 295
        return y

    def forward_train(self, feat, out_enc, label, img_metas):
        '''
        img_metas: [label, valid_ratio]
        '''
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
A
andyjpaddle 已提交
296

A
andyjpaddle 已提交
297 298 299 300 301 302 303
        lab_embedding = self.embedding(label)
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
        # bsz * (seq_len + 1) * C
        out_dec = self._2d_attention(
A
andyjpaddle 已提交
304
            in_dec, feat, out_enc, valid_ratios=valid_ratios)
A
andyjpaddle 已提交
305
        # bsz * (seq_len + 1) * num_classes
A
andyjpaddle 已提交
306 307

        return out_dec[:, 1:, :]  # bsz * seq_len * num_classes
A
andyjpaddle 已提交
308 309 310 311 312 313 314

    def forward_test(self, feat, out_enc, img_metas):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
A
andyjpaddle 已提交
315 316
            valid_ratios = img_metas[-1]

A
andyjpaddle 已提交
317 318
        seq_len = self.max_seq_len
        bsz = feat.shape[0]
A
andyjpaddle 已提交
319 320
        start_token = paddle.full(
            (bsz, ), fill_value=self.start_idx, dtype='int64')
A
andyjpaddle 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
        # bsz
        start_token = self.embedding(start_token)
        # bsz * emb_dim
        emb_dim = start_token.shape[1]
        start_token = start_token.unsqueeze(1)
        start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        decoder_input = paddle.concat((out_enc, start_token), axis=1)
        # bsz * (seq_len + 1) * emb_dim

        outputs = []
        for i in range(1, seq_len + 1):
            decoder_output = self._2d_attention(
A
andyjpaddle 已提交
336 337
                decoder_input, feat, out_enc, valid_ratios=valid_ratios)
            char_output = decoder_output[:, i, :]  # bsz * num_classes
A
andyjpaddle 已提交
338 339 340
            char_output = F.softmax(char_output, -1)
            outputs.append(char_output)
            max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
A
andyjpaddle 已提交
341
            char_embedding = self.embedding(max_idx)  # bsz * emb_dim
A
andyjpaddle 已提交
342 343
            if i < seq_len:
                decoder_input[:, i + 1, :] = char_embedding
A
andyjpaddle 已提交
344 345

        outputs = paddle.stack(outputs, 1)  # bsz * seq_len * num_classes
A
andyjpaddle 已提交
346 347 348 349 350

        return outputs


class SARHead(nn.Layer):
A
andyjpaddle 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363
    def __init__(self,
                 out_channels,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 dec_bi_rnn=False,
                 dec_drop_rnn=0.0,
                 dec_gru=False,
                 d_k=512,
                 pred_dropout=0.1,
                 max_text_length=30,
                 pred_concat=True,
                 **kwargs):
A
andyjpaddle 已提交
364 365 366 367
        super(SARHead, self).__init__()

        # encoder module
        self.encoder = SAREncoder(
A
andyjpaddle 已提交
368
            enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
A
andyjpaddle 已提交
369 370 371

        # decoder module
        self.decoder = ParallelSARDecoder(
A
andyjpaddle 已提交
372
            out_channels=out_channels,
A
andyjpaddle 已提交
373
            enc_bi_rnn=enc_bi_rnn,
A
andyjpaddle 已提交
374 375 376 377 378 379
            dec_bi_rnn=dec_bi_rnn,
            dec_drop_rnn=dec_drop_rnn,
            dec_gru=dec_gru,
            d_k=d_k,
            pred_dropout=pred_dropout,
            max_text_length=max_text_length,
A
andyjpaddle 已提交
380 381
            pred_concat=pred_concat)

A
andyjpaddle 已提交
382 383 384 385
    def forward(self, feat, targets=None):
        '''
        img_metas: [label, valid_ratio]
        '''
A
andyjpaddle 已提交
386 387
        holistic_feat = self.encoder(feat, targets)  # bsz c

A
andyjpaddle 已提交
388
        if self.training:
A
andyjpaddle 已提交
389
            label = targets[0]  # label
A
andyjpaddle 已提交
390
            label = paddle.to_tensor(label, dtype='int64')
A
andyjpaddle 已提交
391 392
            final_out = self.decoder(
                feat, holistic_feat, label, img_metas=targets)
A
andyjpaddle 已提交
393
        if not self.training:
A
andyjpaddle 已提交
394 395 396 397 398 399
            final_out = self.decoder(
                feat,
                holistic_feat,
                label=None,
                img_metas=targets,
                train_mode=False)
A
andyjpaddle 已提交
400
            # (bsz, seq_len, num_classes)
A
andyjpaddle 已提交
401

A
andyjpaddle 已提交
402
        return final_out