rec_sar_head.py 12.9 KB
Newer Older
A
andyjpaddle 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F


class SAREncoder(nn.Layer):
    """
    Args:
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
        enc_gru (bool): If True, use GRU, else LSTM in encoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        mask (bool): If True, mask padding in RNN sequence.
    """
A
andyjpaddle 已提交
22

A
andyjpaddle 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    def __init__(self,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 d_model=512,
                 d_enc=512,
                 mask=True,
                 **kwargs):
        super().__init__()
        assert isinstance(enc_bi_rnn, bool)
        assert isinstance(enc_drop_rnn, (int, float))
        assert 0 <= enc_drop_rnn < 1.0
        assert isinstance(enc_gru, bool)
        assert isinstance(d_model, int)
        assert isinstance(d_enc, int)
        assert isinstance(mask, bool)

        self.enc_bi_rnn = enc_bi_rnn
        self.enc_drop_rnn = enc_drop_rnn
        self.mask = mask

        # LSTM Encoder
        if enc_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'
        kwargs = dict(
            input_size=d_model,
            hidden_size=d_enc,
            num_layers=2,
            time_major=False,
            dropout=enc_drop_rnn,
A
andyjpaddle 已提交
55
            direction=direction)
A
andyjpaddle 已提交
56 57 58 59
        if enc_gru:
            self.rnn_encoder = nn.GRU(**kwargs)
        else:
            self.rnn_encoder = nn.LSTM(**kwargs)
A
andyjpaddle 已提交
60

A
andyjpaddle 已提交
61 62 63
        # global feature transformation
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
A
andyjpaddle 已提交
64

A
andyjpaddle 已提交
65 66 67
    def forward(self, feat, img_metas=None):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]
A
andyjpaddle 已提交
68

A
andyjpaddle 已提交
69 70 71
        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
A
andyjpaddle 已提交
72 73

        h_feat = feat.shape[2]  # bsz c h w
A
andyjpaddle 已提交
74
        feat_v = F.max_pool2d(
A
andyjpaddle 已提交
75 76 77 78 79
            feat, kernel_size=(h_feat, 1), stride=1, padding=0)
        feat_v = feat_v.squeeze(2)  # bsz * C * W
        feat_v = paddle.transpose(feat_v, perm=[0, 2, 1])  # bsz * W * C
        holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C

A
andyjpaddle 已提交
80 81 82 83 84 85 86 87
        if valid_ratios is not None:
            valid_hf = []
            T = holistic_feat.shape[1]
            for i, valid_ratio in enumerate(valid_ratios):
                valid_step = min(T, math.ceil(T * valid_ratio)) - 1
                valid_hf.append(holistic_feat[i, valid_step, :])
            valid_hf = paddle.stack(valid_hf, axis=0)
        else:
A
andyjpaddle 已提交
88 89 90
            valid_hf = holistic_feat[:, -1, :]  # bsz * C
        holistic_feat = self.linear(valid_hf)  # bsz * C

A
andyjpaddle 已提交
91
        return holistic_feat
A
andyjpaddle 已提交
92

A
andyjpaddle 已提交
93 94 95 96 97 98 99 100 101 102 103

class BaseDecoder(nn.Layer):
    def __init__(self, **kwargs):
        super().__init__()

    def forward_train(self, feat, out_enc, targets, img_metas):
        raise NotImplementedError

    def forward_test(self, feat, out_enc, img_metas):
        raise NotImplementedError

A
andyjpaddle 已提交
104
    def forward(self,
A
andyjpaddle 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
                feat,
                out_enc,
                label=None,
                img_metas=None,
                train_mode=True):
        self.train_mode = train_mode

        if train_mode:
            return self.forward_train(feat, out_enc, label, img_metas)
        return self.forward_test(feat, out_enc, img_metas)


class ParallelSARDecoder(BaseDecoder):
    """
    Args:
A
andyjpaddle 已提交
120
        out_channels (int): Output class number.
A
andyjpaddle 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
        dec_drop_rnn (float): Dropout of RNN layer in decoder.
        dec_gru (bool): If True, use GRU, else LSTM in decoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        d_k (int): Dim of channels of attention module.
        pred_dropout (float): Dropout probability of prediction layer.
        max_seq_len (int): Maximum sequence length for decoding.
        mask (bool): If True, mask padding in feature map.
        start_idx (int): Index of start token.
        padding_idx (int): Index of padding token.
        pred_concat (bool): If True, concat glimpse feature from
            attention with holistic feature and hidden state.
    """

A
andyjpaddle 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    def __init__(
            self,
            out_channels,  # 90 + unknown + start + padding
            enc_bi_rnn=False,
            dec_bi_rnn=False,
            dec_drop_rnn=0.0,
            dec_gru=False,
            d_model=512,
            d_enc=512,
            d_k=64,
            pred_dropout=0.1,
            max_text_length=30,
            mask=True,
            pred_concat=True,
            **kwargs):
A
andyjpaddle 已提交
152 153
        super().__init__()

A
andyjpaddle 已提交
154
        self.num_classes = out_channels
A
andyjpaddle 已提交
155 156
        self.enc_bi_rnn = enc_bi_rnn
        self.d_k = d_k
A
andyjpaddle 已提交
157
        self.start_idx = out_channels - 2
A
andyjpaddle 已提交
158
        self.padding_idx = out_channels - 1
A
andyjpaddle 已提交
159 160 161 162 163 164 165 166 167
        self.max_seq_len = max_text_length
        self.mask = mask
        self.pred_concat = pred_concat

        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)

        # 2D attention layer
        self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
A
andyjpaddle 已提交
168 169
        self.conv3x3_1 = nn.Conv2D(
            d_model, d_k, kernel_size=3, stride=1, padding=1)
A
andyjpaddle 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183
        self.conv1x1_2 = nn.Linear(d_k, 1)

        # Decoder RNN layer
        if dec_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'

        kwargs = dict(
            input_size=encoder_rnn_out_size,
            hidden_size=encoder_rnn_out_size,
            num_layers=2,
            time_major=False,
            dropout=dec_drop_rnn,
A
andyjpaddle 已提交
184
            direction=direction)
A
andyjpaddle 已提交
185 186 187 188 189 190 191
        if dec_gru:
            self.rnn_decoder = nn.GRU(**kwargs)
        else:
            self.rnn_decoder = nn.LSTM(**kwargs)

        # Decoder input embedding
        self.embedding = nn.Embedding(
A
andyjpaddle 已提交
192 193 194 195
            self.num_classes,
            encoder_rnn_out_size,
            padding_idx=self.padding_idx)

A
andyjpaddle 已提交
196 197
        # Prediction layer
        self.pred_dropout = nn.Dropout(pred_dropout)
A
andyjpaddle 已提交
198
        pred_num_classes = self.num_classes - 1
A
andyjpaddle 已提交
199 200 201 202 203 204 205 206 207 208 209
        if pred_concat:
            fc_in_channel = decoder_rnn_out_size + d_model + d_enc
        else:
            fc_in_channel = d_model
        self.prediction = nn.Linear(fc_in_channel, pred_num_classes)

    def _2d_attention(self,
                      decoder_input,
                      feat,
                      holistic_feat,
                      valid_ratios=None):
A
andyjpaddle 已提交
210

A
andyjpaddle 已提交
211 212
        y = self.rnn_decoder(decoder_input)[0]
        # y: bsz * (seq_len + 1) * hidden_size
A
andyjpaddle 已提交
213 214

        attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size
A
andyjpaddle 已提交
215 216 217 218 219 220 221 222 223 224
        bsz, seq_len, attn_size = attn_query.shape
        attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
        # (bsz, seq_len + 1, attn_size, 1, 1)

        attn_key = self.conv3x3_1(feat)
        # bsz * attn_size * h * w
        attn_key = attn_key.unsqueeze(1)
        # bsz * 1 * attn_size * h * w

        attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
A
andyjpaddle 已提交
225

A
andyjpaddle 已提交
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
        # bsz * (seq_len + 1) * attn_size * h * w
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
        # bsz * (seq_len + 1) * h * w * attn_size
        attn_weight = self.conv1x1_2(attn_weight)
        # bsz * (seq_len + 1) * h * w * 1
        bsz, T, h, w, c = attn_weight.shape
        assert c == 1

        if valid_ratios is not None:
            # cal mask of attention weight
            for i, valid_ratio in enumerate(valid_ratios):
                valid_width = min(w, math.ceil(w * valid_ratio))
                attn_weight[i, :, :, valid_width:, :] = float('-inf')

        attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
        attn_weight = F.softmax(attn_weight, axis=-1)
A
andyjpaddle 已提交
242

A
andyjpaddle 已提交
243 244 245 246
        attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
        # attn_weight: bsz * T * c * h * w
        # feat: bsz * c * h * w
A
andyjpaddle 已提交
247 248 249
        attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
                               (3, 4),
                               keepdim=False)
A
andyjpaddle 已提交
250 251 252 253 254
        # bsz * (seq_len + 1) * C

        # Linear transformation
        if self.pred_concat:
            hf_c = holistic_feat.shape[-1]
A
andyjpaddle 已提交
255 256
            holistic_feat = paddle.expand(
                holistic_feat, shape=[bsz, seq_len, hf_c])
A
andyjpaddle 已提交
257 258 259 260 261 262
            y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
        else:
            y = self.prediction(attn_feat)
        # bsz * (seq_len + 1) * num_classes
        if self.train_mode:
            y = self.pred_dropout(y)
A
andyjpaddle 已提交
263

A
andyjpaddle 已提交
264 265 266 267 268 269 270 271 272 273 274 275
        return y

    def forward_train(self, feat, out_enc, label, img_metas):
        '''
        img_metas: [label, valid_ratio]
        '''
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
A
andyjpaddle 已提交
276

A
andyjpaddle 已提交
277 278 279 280 281 282 283 284
        label = label.cuda()
        lab_embedding = self.embedding(label)
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
        # bsz * (seq_len + 1) * C
        out_dec = self._2d_attention(
A
andyjpaddle 已提交
285
            in_dec, feat, out_enc, valid_ratios=valid_ratios)
A
andyjpaddle 已提交
286
        # bsz * (seq_len + 1) * num_classes
A
andyjpaddle 已提交
287 288

        return out_dec[:, 1:, :]  # bsz * seq_len * num_classes
A
andyjpaddle 已提交
289 290 291 292 293 294 295

    def forward_test(self, feat, out_enc, img_metas):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
A
andyjpaddle 已提交
296 297
            valid_ratios = img_metas[-1]

A
andyjpaddle 已提交
298 299
        seq_len = self.max_seq_len
        bsz = feat.shape[0]
A
andyjpaddle 已提交
300 301
        start_token = paddle.full(
            (bsz, ), fill_value=self.start_idx, dtype='int64')
A
andyjpaddle 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        # bsz
        start_token = self.embedding(start_token)
        # bsz * emb_dim
        emb_dim = start_token.shape[1]
        start_token = start_token.unsqueeze(1)
        start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        decoder_input = paddle.concat((out_enc, start_token), axis=1)
        # bsz * (seq_len + 1) * emb_dim

        outputs = []
        for i in range(1, seq_len + 1):
            decoder_output = self._2d_attention(
A
andyjpaddle 已提交
317 318
                decoder_input, feat, out_enc, valid_ratios=valid_ratios)
            char_output = decoder_output[:, i, :]  # bsz * num_classes
A
andyjpaddle 已提交
319 320 321
            char_output = F.softmax(char_output, -1)
            outputs.append(char_output)
            max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
A
andyjpaddle 已提交
322
            char_embedding = self.embedding(max_idx)  # bsz * emb_dim
A
andyjpaddle 已提交
323 324
            if i < seq_len:
                decoder_input[:, i + 1, :] = char_embedding
A
andyjpaddle 已提交
325 326

        outputs = paddle.stack(outputs, 1)  # bsz * seq_len * num_classes
A
andyjpaddle 已提交
327 328 329 330 331

        return outputs


class SARHead(nn.Layer):
A
andyjpaddle 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344
    def __init__(self,
                 out_channels,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 dec_bi_rnn=False,
                 dec_drop_rnn=0.0,
                 dec_gru=False,
                 d_k=512,
                 pred_dropout=0.1,
                 max_text_length=30,
                 pred_concat=True,
                 **kwargs):
A
andyjpaddle 已提交
345 346 347 348
        super(SARHead, self).__init__()

        # encoder module
        self.encoder = SAREncoder(
A
andyjpaddle 已提交
349
            enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
A
andyjpaddle 已提交
350 351 352

        # decoder module
        self.decoder = ParallelSARDecoder(
A
andyjpaddle 已提交
353
            out_channels=out_channels,
A
andyjpaddle 已提交
354
            enc_bi_rnn=enc_bi_rnn,
A
andyjpaddle 已提交
355 356 357 358 359 360
            dec_bi_rnn=dec_bi_rnn,
            dec_drop_rnn=dec_drop_rnn,
            dec_gru=dec_gru,
            d_k=d_k,
            pred_dropout=pred_dropout,
            max_text_length=max_text_length,
A
andyjpaddle 已提交
361 362
            pred_concat=pred_concat)

A
andyjpaddle 已提交
363 364 365 366
    def forward(self, feat, targets=None):
        '''
        img_metas: [label, valid_ratio]
        '''
A
andyjpaddle 已提交
367 368
        holistic_feat = self.encoder(feat, targets)  # bsz c

A
andyjpaddle 已提交
369
        if self.training:
A
andyjpaddle 已提交
370
            label = targets[0]  # label
A
andyjpaddle 已提交
371
            label = paddle.to_tensor(label, dtype='int64')
A
andyjpaddle 已提交
372 373
            final_out = self.decoder(
                feat, holistic_feat, label, img_metas=targets)
A
andyjpaddle 已提交
374
        if not self.training:
A
andyjpaddle 已提交
375 376 377 378 379 380
            final_out = self.decoder(
                feat,
                holistic_feat,
                label=None,
                img_metas=targets,
                train_mode=False)
A
andyjpaddle 已提交
381
            # (bsz, seq_len, num_classes)
A
andyjpaddle 已提交
382

A
andyjpaddle 已提交
383
        return final_out