提交 d6115158 编写于 作者: A andyjpaddle

fix code style

上级 ae09ef60
...@@ -19,6 +19,7 @@ class SAREncoder(nn.Layer): ...@@ -19,6 +19,7 @@ class SAREncoder(nn.Layer):
d_enc (int): Dim of encoder RNN layer. d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence. mask (bool): If True, mask padding in RNN sequence.
""" """
def __init__(self, def __init__(self,
enc_bi_rnn=False, enc_bi_rnn=False,
enc_drop_rnn=0.1, enc_drop_rnn=0.1,
...@@ -51,33 +52,31 @@ class SAREncoder(nn.Layer): ...@@ -51,33 +52,31 @@ class SAREncoder(nn.Layer):
num_layers=2, num_layers=2,
time_major=False, time_major=False,
dropout=enc_drop_rnn, dropout=enc_drop_rnn,
direction=direction direction=direction)
)
if enc_gru: if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs) self.rnn_encoder = nn.GRU(**kwargs)
else: else:
self.rnn_encoder = nn.LSTM(**kwargs) self.rnn_encoder = nn.LSTM(**kwargs)
# global feature transformation # global feature transformation
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
def forward(self, feat, img_metas=None): def forward(self, feat, img_metas=None):
if img_metas is not None: if img_metas is not None:
assert len(img_metas[0]) == feat.shape[0] assert len(img_metas[0]) == feat.shape[0]
valid_ratios = None valid_ratios = None
if img_metas is not None and self.mask: if img_metas is not None and self.mask:
valid_ratios = img_metas[-1] valid_ratios = img_metas[-1]
h_feat = feat.shape[2] # bsz c h w h_feat = feat.shape[2] # bsz c h w
feat_v = F.max_pool2d( feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0 feat, kernel_size=(h_feat, 1), stride=1, padding=0)
) feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = feat_v.squeeze(2) # bsz * C * W feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
if valid_ratios is not None: if valid_ratios is not None:
valid_hf = [] valid_hf = []
T = holistic_feat.shape[1] T = holistic_feat.shape[1]
...@@ -86,11 +85,11 @@ class SAREncoder(nn.Layer): ...@@ -86,11 +85,11 @@ class SAREncoder(nn.Layer):
valid_hf.append(holistic_feat[i, valid_step, :]) valid_hf.append(holistic_feat[i, valid_step, :])
valid_hf = paddle.stack(valid_hf, axis=0) valid_hf = paddle.stack(valid_hf, axis=0)
else: else:
valid_hf = holistic_feat[:, -1, :] # bsz * C valid_hf = holistic_feat[:, -1, :] # bsz * C
holistic_feat = self.linear(valid_hf) # bsz * C holistic_feat = self.linear(valid_hf) # bsz * C
return holistic_feat return holistic_feat
class BaseDecoder(nn.Layer): class BaseDecoder(nn.Layer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -102,7 +101,7 @@ class BaseDecoder(nn.Layer): ...@@ -102,7 +101,7 @@ class BaseDecoder(nn.Layer):
def forward_test(self, feat, out_enc, img_metas): def forward_test(self, feat, out_enc, img_metas):
raise NotImplementedError raise NotImplementedError
def forward(self, def forward(self,
feat, feat,
out_enc, out_enc,
label=None, label=None,
...@@ -135,20 +134,21 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -135,20 +134,21 @@ class ParallelSARDecoder(BaseDecoder):
attention with holistic feature and hidden state. attention with holistic feature and hidden state.
""" """
def __init__(self, def __init__(
out_channels, # 90 + unknown + start + padding self,
enc_bi_rnn=False, out_channels, # 90 + unknown + start + padding
dec_bi_rnn=False, enc_bi_rnn=False,
dec_drop_rnn=0.0, dec_bi_rnn=False,
dec_gru=False, dec_drop_rnn=0.0,
d_model=512, dec_gru=False,
d_enc=512, d_model=512,
d_k=64, d_enc=512,
pred_dropout=0.1, d_k=64,
max_text_length=30, pred_dropout=0.1,
mask=True, max_text_length=30,
pred_concat=True, mask=True,
**kwargs): pred_concat=True,
**kwargs):
super().__init__() super().__init__()
self.num_classes = out_channels self.num_classes = out_channels
...@@ -165,7 +165,8 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -165,7 +165,8 @@ class ParallelSARDecoder(BaseDecoder):
# 2D attention layer # 2D attention layer
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
self.conv3x3_1 = nn.Conv2D(d_model, d_k, kernel_size=3, stride=1, padding=1) self.conv3x3_1 = nn.Conv2D(
d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Linear(d_k, 1) self.conv1x1_2 = nn.Linear(d_k, 1)
# Decoder RNN layer # Decoder RNN layer
...@@ -180,8 +181,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -180,8 +181,7 @@ class ParallelSARDecoder(BaseDecoder):
num_layers=2, num_layers=2,
time_major=False, time_major=False,
dropout=dec_drop_rnn, dropout=dec_drop_rnn,
direction=direction direction=direction)
)
if dec_gru: if dec_gru:
self.rnn_decoder = nn.GRU(**kwargs) self.rnn_decoder = nn.GRU(**kwargs)
else: else:
...@@ -189,8 +189,10 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -189,8 +189,10 @@ class ParallelSARDecoder(BaseDecoder):
# Decoder input embedding # Decoder input embedding
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx) self.num_classes,
encoder_rnn_out_size,
padding_idx=self.padding_idx)
# Prediction layer # Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout) self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_classes = num_classes - 1 pred_num_classes = num_classes - 1
...@@ -205,11 +207,11 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -205,11 +207,11 @@ class ParallelSARDecoder(BaseDecoder):
feat, feat,
holistic_feat, holistic_feat,
valid_ratios=None): valid_ratios=None):
y = self.rnn_decoder(decoder_input)[0] y = self.rnn_decoder(decoder_input)[0]
# y: bsz * (seq_len + 1) * hidden_size # y: bsz * (seq_len + 1) * hidden_size
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
bsz, seq_len, attn_size = attn_query.shape bsz, seq_len, attn_size = attn_query.shape
attn_query = paddle.unsqueeze(attn_query, axis=[3, 4]) attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
# (bsz, seq_len + 1, attn_size, 1, 1) # (bsz, seq_len + 1, attn_size, 1, 1)
...@@ -220,7 +222,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -220,7 +222,7 @@ class ParallelSARDecoder(BaseDecoder):
# bsz * 1 * attn_size * h * w # bsz * 1 * attn_size * h * w
attn_weight = paddle.tanh(paddle.add(attn_key, attn_query)) attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
# bsz * (seq_len + 1) * attn_size * h * w # bsz * (seq_len + 1) * attn_size * h * w
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2]) attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
# bsz * (seq_len + 1) * h * w * attn_size # bsz * (seq_len + 1) * h * w * attn_size
...@@ -237,25 +239,28 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -237,25 +239,28 @@ class ParallelSARDecoder(BaseDecoder):
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
attn_weight = F.softmax(attn_weight, axis=-1) attn_weight = F.softmax(attn_weight, axis=-1)
attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c]) attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3]) attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
# attn_weight: bsz * T * c * h * w # attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w # feat: bsz * c * h * w
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False) attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
(3, 4),
keepdim=False)
# bsz * (seq_len + 1) * C # bsz * (seq_len + 1) * C
# Linear transformation # Linear transformation
if self.pred_concat: if self.pred_concat:
hf_c = holistic_feat.shape[-1] hf_c = holistic_feat.shape[-1]
holistic_feat = paddle.expand(holistic_feat, shape=[bsz, seq_len, hf_c]) holistic_feat = paddle.expand(
holistic_feat, shape=[bsz, seq_len, hf_c])
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2)) y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
else: else:
y = self.prediction(attn_feat) y = self.prediction(attn_feat)
# bsz * (seq_len + 1) * num_classes # bsz * (seq_len + 1) * num_classes
if self.train_mode: if self.train_mode:
y = self.pred_dropout(y) y = self.pred_dropout(y)
return y return y
def forward_train(self, feat, out_enc, label, img_metas): def forward_train(self, feat, out_enc, label, img_metas):
...@@ -268,7 +273,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -268,7 +273,7 @@ class ParallelSARDecoder(BaseDecoder):
valid_ratios = None valid_ratios = None
if img_metas is not None and self.mask: if img_metas is not None and self.mask:
valid_ratios = img_metas[-1] valid_ratios = img_metas[-1]
label = label.cuda() label = label.cuda()
lab_embedding = self.embedding(label) lab_embedding = self.embedding(label)
# bsz * seq_len * emb_dim # bsz * seq_len * emb_dim
...@@ -277,11 +282,10 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -277,11 +282,10 @@ class ParallelSARDecoder(BaseDecoder):
in_dec = paddle.concat((out_enc, lab_embedding), axis=1) in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
# bsz * (seq_len + 1) * C # bsz * (seq_len + 1) * C
out_dec = self._2d_attention( out_dec = self._2d_attention(
in_dec, feat, out_enc, valid_ratios=valid_ratios in_dec, feat, out_enc, valid_ratios=valid_ratios)
)
# bsz * (seq_len + 1) * num_classes # bsz * (seq_len + 1) * num_classes
return out_dec[:, 1:, :] # bsz * seq_len * num_classes return out_dec[:, 1:, :] # bsz * seq_len * num_classes
def forward_test(self, feat, out_enc, img_metas): def forward_test(self, feat, out_enc, img_metas):
if img_metas is not None: if img_metas is not None:
...@@ -289,13 +293,12 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -289,13 +293,12 @@ class ParallelSARDecoder(BaseDecoder):
valid_ratios = None valid_ratios = None
if img_metas is not None and self.mask: if img_metas is not None and self.mask:
valid_ratios = img_metas[-1] valid_ratios = img_metas[-1]
seq_len = self.max_seq_len seq_len = self.max_seq_len
bsz = feat.shape[0] bsz = feat.shape[0]
start_token = paddle.full((bsz, ), start_token = paddle.full(
fill_value=self.start_idx, (bsz, ), fill_value=self.start_idx, dtype='int64')
dtype='int64')
# bsz # bsz
start_token = self.embedding(start_token) start_token = self.embedding(start_token)
# bsz * emb_dim # bsz * emb_dim
...@@ -311,68 +314,70 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -311,68 +314,70 @@ class ParallelSARDecoder(BaseDecoder):
outputs = [] outputs = []
for i in range(1, seq_len + 1): for i in range(1, seq_len + 1):
decoder_output = self._2d_attention( decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios decoder_input, feat, out_enc, valid_ratios=valid_ratios)
) char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1) char_output = F.softmax(char_output, -1)
outputs.append(char_output) outputs.append(char_output)
max_idx = paddle.argmax(char_output, axis=1, keepdim=False) max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
char_embedding = self.embedding(max_idx) # bsz * emb_dim char_embedding = self.embedding(max_idx) # bsz * emb_dim
if i < seq_len: if i < seq_len:
decoder_input[:, i + 1, :] = char_embedding decoder_input[:, i + 1, :] = char_embedding
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
return outputs return outputs
class SARHead(nn.Layer): class SARHead(nn.Layer):
def __init__(self, def __init__(self,
out_channels, out_channels,
enc_bi_rnn=False, enc_bi_rnn=False,
enc_drop_rnn=0.1, enc_drop_rnn=0.1,
enc_gru=False, enc_gru=False,
dec_bi_rnn=False, dec_bi_rnn=False,
dec_drop_rnn=0.0, dec_drop_rnn=0.0,
dec_gru=False, dec_gru=False,
d_k=512, d_k=512,
pred_dropout=0.1, pred_dropout=0.1,
max_text_length=30, max_text_length=30,
pred_concat=True, pred_concat=True,
**kwargs): **kwargs):
super(SARHead, self).__init__() super(SARHead, self).__init__()
# encoder module # encoder module
self.encoder = SAREncoder( self.encoder = SAREncoder(
enc_bi_rnn=enc_bi_rnn, enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
enc_drop_rnn=enc_drop_rnn,
enc_gru=enc_gru)
# decoder module # decoder module
self.decoder = ParallelSARDecoder( self.decoder = ParallelSARDecoder(
out_channels=out_channels, out_channels=out_channels,
enc_bi_rnn=enc_bi_rnn, enc_bi_rnn=enc_bi_rnn,
dec_bi_rnn=dec_bi_rnn, dec_bi_rnn=dec_bi_rnn,
dec_drop_rnn=dec_drop_rnn, dec_drop_rnn=dec_drop_rnn,
dec_gru=dec_gru, dec_gru=dec_gru,
d_k=d_k, d_k=d_k,
pred_dropout=pred_dropout, pred_dropout=pred_dropout,
max_text_length=max_text_length, max_text_length=max_text_length,
pred_concat=pred_concat) pred_concat=pred_concat)
def forward(self, feat, targets=None): def forward(self, feat, targets=None):
''' '''
img_metas: [label, valid_ratio] img_metas: [label, valid_ratio]
''' '''
holistic_feat = self.encoder(feat, targets) # bsz c holistic_feat = self.encoder(feat, targets) # bsz c
if self.training: if self.training:
label = targets[0] # label label = targets[0] # label
label = paddle.to_tensor(label, dtype='int64') label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder(feat, holistic_feat, label, img_metas=targets) final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets)
if not self.training: if not self.training:
final_out = self.decoder(feat, holistic_feat, label=None, img_metas=targets, train_mode=False) final_out = self.decoder(
feat,
holistic_feat,
label=None,
img_metas=targets,
train_mode=False)
# (bsz, seq_len, num_classes) # (bsz, seq_len, num_classes)
return final_out return final_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册