rec_sar_head.py 12.9 KB
Newer Older
A
andyjpaddle 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F


class SAREncoder(nn.Layer):
    """
    Args:
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
        enc_gru (bool): If True, use GRU, else LSTM in encoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        mask (bool): If True, mask padding in RNN sequence.
    """
A
andyjpaddle 已提交
22

A
andyjpaddle 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    def __init__(self,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 d_model=512,
                 d_enc=512,
                 mask=True,
                 **kwargs):
        super().__init__()
        assert isinstance(enc_bi_rnn, bool)
        assert isinstance(enc_drop_rnn, (int, float))
        assert 0 <= enc_drop_rnn < 1.0
        assert isinstance(enc_gru, bool)
        assert isinstance(d_model, int)
        assert isinstance(d_enc, int)
        assert isinstance(mask, bool)

        self.enc_bi_rnn = enc_bi_rnn
        self.enc_drop_rnn = enc_drop_rnn
        self.mask = mask

        # LSTM Encoder
        if enc_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'
        kwargs = dict(
            input_size=d_model,
            hidden_size=d_enc,
            num_layers=2,
            time_major=False,
            dropout=enc_drop_rnn,
A
andyjpaddle 已提交
55
            direction=direction)
A
andyjpaddle 已提交
56 57 58 59
        if enc_gru:
            self.rnn_encoder = nn.GRU(**kwargs)
        else:
            self.rnn_encoder = nn.LSTM(**kwargs)
A
andyjpaddle 已提交
60

A
andyjpaddle 已提交
61 62 63
        # global feature transformation
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
A
andyjpaddle 已提交
64

A
andyjpaddle 已提交
65 66 67
    def forward(self, feat, img_metas=None):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]
A
andyjpaddle 已提交
68

A
andyjpaddle 已提交
69 70 71
        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
A
andyjpaddle 已提交
72 73

        h_feat = feat.shape[2]  # bsz c h w
A
andyjpaddle 已提交
74
        feat_v = F.max_pool2d(
A
andyjpaddle 已提交
75 76 77 78 79
            feat, kernel_size=(h_feat, 1), stride=1, padding=0)
        feat_v = feat_v.squeeze(2)  # bsz * C * W
        feat_v = paddle.transpose(feat_v, perm=[0, 2, 1])  # bsz * W * C
        holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C

A
andyjpaddle 已提交
80 81 82 83 84 85 86 87
        if valid_ratios is not None:
            valid_hf = []
            T = holistic_feat.shape[1]
            for i, valid_ratio in enumerate(valid_ratios):
                valid_step = min(T, math.ceil(T * valid_ratio)) - 1
                valid_hf.append(holistic_feat[i, valid_step, :])
            valid_hf = paddle.stack(valid_hf, axis=0)
        else:
A
andyjpaddle 已提交
88 89 90
            valid_hf = holistic_feat[:, -1, :]  # bsz * C
        holistic_feat = self.linear(valid_hf)  # bsz * C

A
andyjpaddle 已提交
91
        return holistic_feat
A
andyjpaddle 已提交
92

A
andyjpaddle 已提交
93 94 95 96 97 98 99 100 101 102 103

class BaseDecoder(nn.Layer):
    def __init__(self, **kwargs):
        super().__init__()

    def forward_train(self, feat, out_enc, targets, img_metas):
        raise NotImplementedError

    def forward_test(self, feat, out_enc, img_metas):
        raise NotImplementedError

A
andyjpaddle 已提交
104
    def forward(self,
A
andyjpaddle 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
                feat,
                out_enc,
                label=None,
                img_metas=None,
                train_mode=True):
        self.train_mode = train_mode

        if train_mode:
            return self.forward_train(feat, out_enc, label, img_metas)
        return self.forward_test(feat, out_enc, img_metas)


class ParallelSARDecoder(BaseDecoder):
    """
    Args:
A
andyjpaddle 已提交
120
        out_channels (int): Output class number.
A
andyjpaddle 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
        dec_drop_rnn (float): Dropout of RNN layer in decoder.
        dec_gru (bool): If True, use GRU, else LSTM in decoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        d_k (int): Dim of channels of attention module.
        pred_dropout (float): Dropout probability of prediction layer.
        max_seq_len (int): Maximum sequence length for decoding.
        mask (bool): If True, mask padding in feature map.
        start_idx (int): Index of start token.
        padding_idx (int): Index of padding token.
        pred_concat (bool): If True, concat glimpse feature from
            attention with holistic feature and hidden state.
    """

A
andyjpaddle 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    def __init__(
            self,
            out_channels,  # 90 + unknown + start + padding
            enc_bi_rnn=False,
            dec_bi_rnn=False,
            dec_drop_rnn=0.0,
            dec_gru=False,
            d_model=512,
            d_enc=512,
            d_k=64,
            pred_dropout=0.1,
            max_text_length=30,
            mask=True,
            pred_concat=True,
            **kwargs):
A
andyjpaddle 已提交
152 153
        super().__init__()

A
andyjpaddle 已提交
154
        self.num_classes = out_channels
A
andyjpaddle 已提交
155 156
        self.enc_bi_rnn = enc_bi_rnn
        self.d_k = d_k
A
andyjpaddle 已提交
157
        self.start_idx = out_channels - 2
A
andyjpaddle 已提交
158
        self.padding_idx = out_channels - 1
A
andyjpaddle 已提交
159 160 161 162 163 164 165 166 167
        self.max_seq_len = max_text_length
        self.mask = mask
        self.pred_concat = pred_concat

        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)

        # 2D attention layer
        self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
A
andyjpaddle 已提交
168 169
        self.conv3x3_1 = nn.Conv2D(
            d_model, d_k, kernel_size=3, stride=1, padding=1)
A
andyjpaddle 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183
        self.conv1x1_2 = nn.Linear(d_k, 1)

        # Decoder RNN layer
        if dec_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'

        kwargs = dict(
            input_size=encoder_rnn_out_size,
            hidden_size=encoder_rnn_out_size,
            num_layers=2,
            time_major=False,
            dropout=dec_drop_rnn,
A
andyjpaddle 已提交
184
            direction=direction)
A
andyjpaddle 已提交
185 186 187 188 189 190 191
        if dec_gru:
            self.rnn_decoder = nn.GRU(**kwargs)
        else:
            self.rnn_decoder = nn.LSTM(**kwargs)

        # Decoder input embedding
        self.embedding = nn.Embedding(
A
andyjpaddle 已提交
192 193 194 195
            self.num_classes,
            encoder_rnn_out_size,
            padding_idx=self.padding_idx)

A
andyjpaddle 已提交
196 197
        # Prediction layer
        self.pred_dropout = nn.Dropout(pred_dropout)
A
andyjpaddle 已提交
198
        pred_num_classes = self.num_classes - 1
A
andyjpaddle 已提交
199 200 201 202 203 204 205 206 207 208 209
        if pred_concat:
            fc_in_channel = decoder_rnn_out_size + d_model + d_enc
        else:
            fc_in_channel = d_model
        self.prediction = nn.Linear(fc_in_channel, pred_num_classes)

    def _2d_attention(self,
                      decoder_input,
                      feat,
                      holistic_feat,
                      valid_ratios=None):
A
andyjpaddle 已提交
210

A
andyjpaddle 已提交
211 212
        y = self.rnn_decoder(decoder_input)[0]
        # y: bsz * (seq_len + 1) * hidden_size
A
andyjpaddle 已提交
213 214

        attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size
A
andyjpaddle 已提交
215 216 217 218 219 220 221 222 223 224
        bsz, seq_len, attn_size = attn_query.shape
        attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
        # (bsz, seq_len + 1, attn_size, 1, 1)

        attn_key = self.conv3x3_1(feat)
        # bsz * attn_size * h * w
        attn_key = attn_key.unsqueeze(1)
        # bsz * 1 * attn_size * h * w

        attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
A
andyjpaddle 已提交
225

A
andyjpaddle 已提交
226 227 228 229 230 231 232 233 234 235 236 237
        # bsz * (seq_len + 1) * attn_size * h * w
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
        # bsz * (seq_len + 1) * h * w * attn_size
        attn_weight = self.conv1x1_2(attn_weight)
        # bsz * (seq_len + 1) * h * w * 1
        bsz, T, h, w, c = attn_weight.shape
        assert c == 1

        if valid_ratios is not None:
            # cal mask of attention weight
            for i, valid_ratio in enumerate(valid_ratios):
                valid_width = min(w, math.ceil(w * valid_ratio))
A
andyjpaddle 已提交
238 239
                if valid_width < w:
                    attn_weight[i, :, :, valid_width:, :] = float('-inf')
A
andyjpaddle 已提交
240 241 242

        attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
        attn_weight = F.softmax(attn_weight, axis=-1)
A
andyjpaddle 已提交
243

A
andyjpaddle 已提交
244 245 246 247
        attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
        # attn_weight: bsz * T * c * h * w
        # feat: bsz * c * h * w
A
andyjpaddle 已提交
248 249 250
        attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
                               (3, 4),
                               keepdim=False)
A
andyjpaddle 已提交
251 252 253 254 255
        # bsz * (seq_len + 1) * C

        # Linear transformation
        if self.pred_concat:
            hf_c = holistic_feat.shape[-1]
A
andyjpaddle 已提交
256 257
            holistic_feat = paddle.expand(
                holistic_feat, shape=[bsz, seq_len, hf_c])
A
andyjpaddle 已提交
258 259 260 261 262 263
            y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
        else:
            y = self.prediction(attn_feat)
        # bsz * (seq_len + 1) * num_classes
        if self.train_mode:
            y = self.pred_dropout(y)
A
andyjpaddle 已提交
264

A
andyjpaddle 已提交
265 266 267 268 269 270 271 272 273 274 275 276
        return y

    def forward_train(self, feat, out_enc, label, img_metas):
        '''
        img_metas: [label, valid_ratio]
        '''
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
A
andyjpaddle 已提交
277

A
andyjpaddle 已提交
278 279 280 281 282 283 284 285
        label = label.cuda()
        lab_embedding = self.embedding(label)
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
        # bsz * (seq_len + 1) * C
        out_dec = self._2d_attention(
A
andyjpaddle 已提交
286
            in_dec, feat, out_enc, valid_ratios=valid_ratios)
A
andyjpaddle 已提交
287
        # bsz * (seq_len + 1) * num_classes
A
andyjpaddle 已提交
288 289

        return out_dec[:, 1:, :]  # bsz * seq_len * num_classes
A
andyjpaddle 已提交
290 291 292 293 294 295 296

    def forward_test(self, feat, out_enc, img_metas):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
A
andyjpaddle 已提交
297 298
            valid_ratios = img_metas[-1]

A
andyjpaddle 已提交
299 300
        seq_len = self.max_seq_len
        bsz = feat.shape[0]
A
andyjpaddle 已提交
301 302
        start_token = paddle.full(
            (bsz, ), fill_value=self.start_idx, dtype='int64')
A
andyjpaddle 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
        # bsz
        start_token = self.embedding(start_token)
        # bsz * emb_dim
        emb_dim = start_token.shape[1]
        start_token = start_token.unsqueeze(1)
        start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        decoder_input = paddle.concat((out_enc, start_token), axis=1)
        # bsz * (seq_len + 1) * emb_dim

        outputs = []
        for i in range(1, seq_len + 1):
            decoder_output = self._2d_attention(
A
andyjpaddle 已提交
318 319
                decoder_input, feat, out_enc, valid_ratios=valid_ratios)
            char_output = decoder_output[:, i, :]  # bsz * num_classes
A
andyjpaddle 已提交
320 321 322
            char_output = F.softmax(char_output, -1)
            outputs.append(char_output)
            max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
A
andyjpaddle 已提交
323
            char_embedding = self.embedding(max_idx)  # bsz * emb_dim
A
andyjpaddle 已提交
324 325
            if i < seq_len:
                decoder_input[:, i + 1, :] = char_embedding
A
andyjpaddle 已提交
326 327

        outputs = paddle.stack(outputs, 1)  # bsz * seq_len * num_classes
A
andyjpaddle 已提交
328 329 330 331 332

        return outputs


class SARHead(nn.Layer):
A
andyjpaddle 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345
    def __init__(self,
                 out_channels,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 dec_bi_rnn=False,
                 dec_drop_rnn=0.0,
                 dec_gru=False,
                 d_k=512,
                 pred_dropout=0.1,
                 max_text_length=30,
                 pred_concat=True,
                 **kwargs):
A
andyjpaddle 已提交
346 347 348 349
        super(SARHead, self).__init__()

        # encoder module
        self.encoder = SAREncoder(
A
andyjpaddle 已提交
350
            enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
A
andyjpaddle 已提交
351 352 353

        # decoder module
        self.decoder = ParallelSARDecoder(
A
andyjpaddle 已提交
354
            out_channels=out_channels,
A
andyjpaddle 已提交
355
            enc_bi_rnn=enc_bi_rnn,
A
andyjpaddle 已提交
356 357 358 359 360 361
            dec_bi_rnn=dec_bi_rnn,
            dec_drop_rnn=dec_drop_rnn,
            dec_gru=dec_gru,
            d_k=d_k,
            pred_dropout=pred_dropout,
            max_text_length=max_text_length,
A
andyjpaddle 已提交
362 363
            pred_concat=pred_concat)

A
andyjpaddle 已提交
364 365 366 367
    def forward(self, feat, targets=None):
        '''
        img_metas: [label, valid_ratio]
        '''
A
andyjpaddle 已提交
368 369
        holistic_feat = self.encoder(feat, targets)  # bsz c

A
andyjpaddle 已提交
370
        if self.training:
A
andyjpaddle 已提交
371
            label = targets[0]  # label
A
andyjpaddle 已提交
372
            label = paddle.to_tensor(label, dtype='int64')
A
andyjpaddle 已提交
373 374
            final_out = self.decoder(
                feat, holistic_feat, label, img_metas=targets)
A
andyjpaddle 已提交
375
        if not self.training:
A
andyjpaddle 已提交
376 377 378 379 380 381
            final_out = self.decoder(
                feat,
                holistic_feat,
                label=None,
                img_metas=targets,
                train_mode=False)
A
andyjpaddle 已提交
382
            # (bsz, seq_len, num_classes)
A
andyjpaddle 已提交
383

A
andyjpaddle 已提交
384
        return final_out