rec_sar_head.py 14.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from: 
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
"""

A
andyjpaddle 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F


class SAREncoder(nn.Layer):
    """
    Args:
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
        enc_gru (bool): If True, use GRU, else LSTM in encoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        mask (bool): If True, mask padding in RNN sequence.
    """
A
andyjpaddle 已提交
41

A
andyjpaddle 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    def __init__(self,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 d_model=512,
                 d_enc=512,
                 mask=True,
                 **kwargs):
        super().__init__()
        assert isinstance(enc_bi_rnn, bool)
        assert isinstance(enc_drop_rnn, (int, float))
        assert 0 <= enc_drop_rnn < 1.0
        assert isinstance(enc_gru, bool)
        assert isinstance(d_model, int)
        assert isinstance(d_enc, int)
        assert isinstance(mask, bool)

        self.enc_bi_rnn = enc_bi_rnn
        self.enc_drop_rnn = enc_drop_rnn
        self.mask = mask

        # LSTM Encoder
        if enc_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'
        kwargs = dict(
            input_size=d_model,
            hidden_size=d_enc,
            num_layers=2,
            time_major=False,
            dropout=enc_drop_rnn,
A
andyjpaddle 已提交
74
            direction=direction)
A
andyjpaddle 已提交
75 76 77 78
        if enc_gru:
            self.rnn_encoder = nn.GRU(**kwargs)
        else:
            self.rnn_encoder = nn.LSTM(**kwargs)
A
andyjpaddle 已提交
79

A
andyjpaddle 已提交
80 81 82
        # global feature transformation
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
A
andyjpaddle 已提交
83

A
andyjpaddle 已提交
84 85
    def forward(self, feat, img_metas=None):
        if img_metas is not None:
86
            assert len(img_metas[0]) == paddle.shape(feat)[0]
A
andyjpaddle 已提交
87

A
andyjpaddle 已提交
88 89 90
        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
A
andyjpaddle 已提交
91 92

        h_feat = feat.shape[2]  # bsz c h w
A
andyjpaddle 已提交
93
        feat_v = F.max_pool2d(
A
andyjpaddle 已提交
94 95 96 97 98
            feat, kernel_size=(h_feat, 1), stride=1, padding=0)
        feat_v = feat_v.squeeze(2)  # bsz * C * W
        feat_v = paddle.transpose(feat_v, perm=[0, 2, 1])  # bsz * W * C
        holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C

A
andyjpaddle 已提交
99 100
        if valid_ratios is not None:
            valid_hf = []
101 102 103 104
            T = paddle.shape(holistic_feat)[1]
            for i in range(paddle.shape(valid_ratios)[0]):
                valid_step = paddle.minimum(
                    T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1
A
andyjpaddle 已提交
105 106 107
                valid_hf.append(holistic_feat[i, valid_step, :])
            valid_hf = paddle.stack(valid_hf, axis=0)
        else:
A
andyjpaddle 已提交
108 109 110
            valid_hf = holistic_feat[:, -1, :]  # bsz * C
        holistic_feat = self.linear(valid_hf)  # bsz * C

A
andyjpaddle 已提交
111
        return holistic_feat
A
andyjpaddle 已提交
112

A
andyjpaddle 已提交
113 114 115 116 117 118 119 120 121 122 123

class BaseDecoder(nn.Layer):
    def __init__(self, **kwargs):
        super().__init__()

    def forward_train(self, feat, out_enc, targets, img_metas):
        raise NotImplementedError

    def forward_test(self, feat, out_enc, img_metas):
        raise NotImplementedError

A
andyjpaddle 已提交
124
    def forward(self,
A
andyjpaddle 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
                feat,
                out_enc,
                label=None,
                img_metas=None,
                train_mode=True):
        self.train_mode = train_mode

        if train_mode:
            return self.forward_train(feat, out_enc, label, img_metas)
        return self.forward_test(feat, out_enc, img_metas)


class ParallelSARDecoder(BaseDecoder):
    """
    Args:
A
andyjpaddle 已提交
140
        out_channels (int): Output class number.
A
andyjpaddle 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
        dec_drop_rnn (float): Dropout of RNN layer in decoder.
        dec_gru (bool): If True, use GRU, else LSTM in decoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        d_k (int): Dim of channels of attention module.
        pred_dropout (float): Dropout probability of prediction layer.
        max_seq_len (int): Maximum sequence length for decoding.
        mask (bool): If True, mask padding in feature map.
        start_idx (int): Index of start token.
        padding_idx (int): Index of padding token.
        pred_concat (bool): If True, concat glimpse feature from
            attention with holistic feature and hidden state.
    """

A
andyjpaddle 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    def __init__(
            self,
            out_channels,  # 90 + unknown + start + padding
            enc_bi_rnn=False,
            dec_bi_rnn=False,
            dec_drop_rnn=0.0,
            dec_gru=False,
            d_model=512,
            d_enc=512,
            d_k=64,
            pred_dropout=0.1,
            max_text_length=30,
            mask=True,
            pred_concat=True,
            **kwargs):
A
andyjpaddle 已提交
172 173
        super().__init__()

A
andyjpaddle 已提交
174
        self.num_classes = out_channels
A
andyjpaddle 已提交
175 176
        self.enc_bi_rnn = enc_bi_rnn
        self.d_k = d_k
A
andyjpaddle 已提交
177
        self.start_idx = out_channels - 2
A
andyjpaddle 已提交
178
        self.padding_idx = out_channels - 1
A
andyjpaddle 已提交
179 180 181 182 183 184 185 186 187
        self.max_seq_len = max_text_length
        self.mask = mask
        self.pred_concat = pred_concat

        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)

        # 2D attention layer
        self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
A
andyjpaddle 已提交
188 189
        self.conv3x3_1 = nn.Conv2D(
            d_model, d_k, kernel_size=3, stride=1, padding=1)
A
andyjpaddle 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203
        self.conv1x1_2 = nn.Linear(d_k, 1)

        # Decoder RNN layer
        if dec_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'

        kwargs = dict(
            input_size=encoder_rnn_out_size,
            hidden_size=encoder_rnn_out_size,
            num_layers=2,
            time_major=False,
            dropout=dec_drop_rnn,
A
andyjpaddle 已提交
204
            direction=direction)
A
andyjpaddle 已提交
205 206 207 208 209 210 211
        if dec_gru:
            self.rnn_decoder = nn.GRU(**kwargs)
        else:
            self.rnn_decoder = nn.LSTM(**kwargs)

        # Decoder input embedding
        self.embedding = nn.Embedding(
A
andyjpaddle 已提交
212 213 214 215
            self.num_classes,
            encoder_rnn_out_size,
            padding_idx=self.padding_idx)

A
andyjpaddle 已提交
216 217
        # Prediction layer
        self.pred_dropout = nn.Dropout(pred_dropout)
A
andyjpaddle 已提交
218
        pred_num_classes = self.num_classes - 1
A
andyjpaddle 已提交
219
        if pred_concat:
A
andyjpaddle 已提交
220
            fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
A
andyjpaddle 已提交
221 222 223 224 225 226 227 228 229
        else:
            fc_in_channel = d_model
        self.prediction = nn.Linear(fc_in_channel, pred_num_classes)

    def _2d_attention(self,
                      decoder_input,
                      feat,
                      holistic_feat,
                      valid_ratios=None):
A
andyjpaddle 已提交
230

A
andyjpaddle 已提交
231 232
        y = self.rnn_decoder(decoder_input)[0]
        # y: bsz * (seq_len + 1) * hidden_size
A
andyjpaddle 已提交
233 234

        attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size
A
andyjpaddle 已提交
235 236 237 238 239 240 241 242 243 244
        bsz, seq_len, attn_size = attn_query.shape
        attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
        # (bsz, seq_len + 1, attn_size, 1, 1)

        attn_key = self.conv3x3_1(feat)
        # bsz * attn_size * h * w
        attn_key = attn_key.unsqueeze(1)
        # bsz * 1 * attn_size * h * w

        attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
A
andyjpaddle 已提交
245

A
andyjpaddle 已提交
246 247 248 249 250
        # bsz * (seq_len + 1) * attn_size * h * w
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
        # bsz * (seq_len + 1) * h * w * attn_size
        attn_weight = self.conv1x1_2(attn_weight)
        # bsz * (seq_len + 1) * h * w * 1
251
        bsz, T, h, w, c = paddle.shape(attn_weight)
A
andyjpaddle 已提交
252 253 254 255
        assert c == 1

        if valid_ratios is not None:
            # cal mask of attention weight
256 257 258
            for i in range(paddle.shape(valid_ratios)[0]):
                valid_width = paddle.minimum(
                    w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
A
andyjpaddle 已提交
259 260
                if valid_width < w:
                    attn_weight[i, :, :, valid_width:, :] = float('-inf')
A
andyjpaddle 已提交
261 262 263

        attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
        attn_weight = F.softmax(attn_weight, axis=-1)
A
andyjpaddle 已提交
264

A
andyjpaddle 已提交
265 266 267 268
        attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
        # attn_weight: bsz * T * c * h * w
        # feat: bsz * c * h * w
A
andyjpaddle 已提交
269 270 271
        attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
                               (3, 4),
                               keepdim=False)
A
andyjpaddle 已提交
272 273 274 275 276
        # bsz * (seq_len + 1) * C

        # Linear transformation
        if self.pred_concat:
            hf_c = holistic_feat.shape[-1]
A
andyjpaddle 已提交
277 278
            holistic_feat = paddle.expand(
                holistic_feat, shape=[bsz, seq_len, hf_c])
A
andyjpaddle 已提交
279 280 281 282 283 284
            y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
        else:
            y = self.prediction(attn_feat)
        # bsz * (seq_len + 1) * num_classes
        if self.train_mode:
            y = self.pred_dropout(y)
A
andyjpaddle 已提交
285

A
andyjpaddle 已提交
286 287 288 289 290 291 292
        return y

    def forward_train(self, feat, out_enc, label, img_metas):
        '''
        img_metas: [label, valid_ratio]
        '''
        if img_metas is not None:
293
            assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0]
A
andyjpaddle 已提交
294 295 296 297

        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
A
andyjpaddle 已提交
298

A
andyjpaddle 已提交
299 300 301 302 303 304 305
        lab_embedding = self.embedding(label)
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
        # bsz * (seq_len + 1) * C
        out_dec = self._2d_attention(
A
andyjpaddle 已提交
306 307 308
            in_dec, feat, out_enc, valid_ratios=valid_ratios)

        return out_dec[:, 1:, :]  # bsz * seq_len * num_classes
A
andyjpaddle 已提交
309 310 311 312 313 314 315

    def forward_test(self, feat, out_enc, img_metas):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
A
andyjpaddle 已提交
316 317
            valid_ratios = img_metas[-1]

A
andyjpaddle 已提交
318 319
        seq_len = self.max_seq_len
        bsz = feat.shape[0]
A
andyjpaddle 已提交
320 321
        start_token = paddle.full(
            (bsz, ), fill_value=self.start_idx, dtype='int64')
A
andyjpaddle 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
        # bsz
        start_token = self.embedding(start_token)
        # bsz * emb_dim
        emb_dim = start_token.shape[1]
        start_token = start_token.unsqueeze(1)
        start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        decoder_input = paddle.concat((out_enc, start_token), axis=1)
        # bsz * (seq_len + 1) * emb_dim

        outputs = []
        for i in range(1, seq_len + 1):
            decoder_output = self._2d_attention(
A
andyjpaddle 已提交
337 338
                decoder_input, feat, out_enc, valid_ratios=valid_ratios)
            char_output = decoder_output[:, i, :]  # bsz * num_classes
A
andyjpaddle 已提交
339 340 341
            char_output = F.softmax(char_output, -1)
            outputs.append(char_output)
            max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
A
andyjpaddle 已提交
342
            char_embedding = self.embedding(max_idx)  # bsz * emb_dim
A
andyjpaddle 已提交
343 344
            if i < seq_len:
                decoder_input[:, i + 1, :] = char_embedding
A
andyjpaddle 已提交
345 346

        outputs = paddle.stack(outputs, 1)  # bsz * seq_len * num_classes
A
andyjpaddle 已提交
347 348 349 350 351

        return outputs


class SARHead(nn.Layer):
A
andyjpaddle 已提交
352
    def __init__(self,
A
andyjpaddle 已提交
353
                 in_channels,
A
andyjpaddle 已提交
354
                 out_channels,
A
andyjpaddle 已提交
355 356
                 enc_dim=512,
                 max_text_length=30,
A
andyjpaddle 已提交
357 358 359 360 361 362 363 364 365 366
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 dec_bi_rnn=False,
                 dec_drop_rnn=0.0,
                 dec_gru=False,
                 d_k=512,
                 pred_dropout=0.1,
                 pred_concat=True,
                 **kwargs):
A
andyjpaddle 已提交
367 368 369 370
        super(SARHead, self).__init__()

        # encoder module
        self.encoder = SAREncoder(
A
andyjpaddle 已提交
371 372 373 374 375
            enc_bi_rnn=enc_bi_rnn,
            enc_drop_rnn=enc_drop_rnn,
            enc_gru=enc_gru,
            d_model=in_channels,
            d_enc=enc_dim)
A
andyjpaddle 已提交
376 377 378

        # decoder module
        self.decoder = ParallelSARDecoder(
A
andyjpaddle 已提交
379
            out_channels=out_channels,
A
andyjpaddle 已提交
380
            enc_bi_rnn=enc_bi_rnn,
A
andyjpaddle 已提交
381 382 383
            dec_bi_rnn=dec_bi_rnn,
            dec_drop_rnn=dec_drop_rnn,
            dec_gru=dec_gru,
A
andyjpaddle 已提交
384 385
            d_model=in_channels,
            d_enc=enc_dim,
A
andyjpaddle 已提交
386 387 388
            d_k=d_k,
            pred_dropout=pred_dropout,
            max_text_length=max_text_length,
A
andyjpaddle 已提交
389 390
            pred_concat=pred_concat)

A
andyjpaddle 已提交
391 392 393 394
    def forward(self, feat, targets=None):
        '''
        img_metas: [label, valid_ratio]
        '''
A
andyjpaddle 已提交
395 396
        holistic_feat = self.encoder(feat, targets)  # bsz c

A
andyjpaddle 已提交
397
        if self.training:
A
andyjpaddle 已提交
398 399 400
            label = targets[0]  # label
            final_out = self.decoder(
                feat, holistic_feat, label, img_metas=targets)
A
andyjpaddle 已提交
401
        else:
A
andyjpaddle 已提交
402 403 404 405 406 407
            final_out = self.decoder(
                feat,
                holistic_feat,
                label=None,
                img_metas=targets,
                train_mode=False)
A
andyjpaddle 已提交
408
            # (bsz, seq_len, num_classes)
A
andyjpaddle 已提交
409

A
andyjpaddle 已提交
410
        return final_out