提交 d6115158 编写于 作者: A andyjpaddle

fix code style

上级 ae09ef60
...@@ -19,6 +19,7 @@ class SAREncoder(nn.Layer): ...@@ -19,6 +19,7 @@ class SAREncoder(nn.Layer):
d_enc (int): Dim of encoder RNN layer. d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence. mask (bool): If True, mask padding in RNN sequence.
""" """
def __init__(self, def __init__(self,
enc_bi_rnn=False, enc_bi_rnn=False,
enc_drop_rnn=0.1, enc_drop_rnn=0.1,
...@@ -51,8 +52,7 @@ class SAREncoder(nn.Layer): ...@@ -51,8 +52,7 @@ class SAREncoder(nn.Layer):
num_layers=2, num_layers=2,
time_major=False, time_major=False,
dropout=enc_drop_rnn, dropout=enc_drop_rnn,
direction=direction direction=direction)
)
if enc_gru: if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs) self.rnn_encoder = nn.GRU(**kwargs)
else: else:
...@@ -72,8 +72,7 @@ class SAREncoder(nn.Layer): ...@@ -72,8 +72,7 @@ class SAREncoder(nn.Layer):
h_feat = feat.shape[2] # bsz c h w h_feat = feat.shape[2] # bsz c h w
feat_v = F.max_pool2d( feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0 feat, kernel_size=(h_feat, 1), stride=1, padding=0)
)
feat_v = feat_v.squeeze(2) # bsz * C * W feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
...@@ -135,7 +134,8 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -135,7 +134,8 @@ class ParallelSARDecoder(BaseDecoder):
attention with holistic feature and hidden state. attention with holistic feature and hidden state.
""" """
def __init__(self, def __init__(
self,
out_channels, # 90 + unknown + start + padding out_channels, # 90 + unknown + start + padding
enc_bi_rnn=False, enc_bi_rnn=False,
dec_bi_rnn=False, dec_bi_rnn=False,
...@@ -165,7 +165,8 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -165,7 +165,8 @@ class ParallelSARDecoder(BaseDecoder):
# 2D attention layer # 2D attention layer
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
self.conv3x3_1 = nn.Conv2D(d_model, d_k, kernel_size=3, stride=1, padding=1) self.conv3x3_1 = nn.Conv2D(
d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Linear(d_k, 1) self.conv1x1_2 = nn.Linear(d_k, 1)
# Decoder RNN layer # Decoder RNN layer
...@@ -180,8 +181,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -180,8 +181,7 @@ class ParallelSARDecoder(BaseDecoder):
num_layers=2, num_layers=2,
time_major=False, time_major=False,
dropout=dec_drop_rnn, dropout=dec_drop_rnn,
direction=direction direction=direction)
)
if dec_gru: if dec_gru:
self.rnn_decoder = nn.GRU(**kwargs) self.rnn_decoder = nn.GRU(**kwargs)
else: else:
...@@ -189,7 +189,9 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -189,7 +189,9 @@ class ParallelSARDecoder(BaseDecoder):
# Decoder input embedding # Decoder input embedding
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx) self.num_classes,
encoder_rnn_out_size,
padding_idx=self.padding_idx)
# Prediction layer # Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout) self.pred_dropout = nn.Dropout(pred_dropout)
...@@ -242,13 +244,16 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -242,13 +244,16 @@ class ParallelSARDecoder(BaseDecoder):
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3]) attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
# attn_weight: bsz * T * c * h * w # attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w # feat: bsz * c * h * w
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False) attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
(3, 4),
keepdim=False)
# bsz * (seq_len + 1) * C # bsz * (seq_len + 1) * C
# Linear transformation # Linear transformation
if self.pred_concat: if self.pred_concat:
hf_c = holistic_feat.shape[-1] hf_c = holistic_feat.shape[-1]
holistic_feat = paddle.expand(holistic_feat, shape=[bsz, seq_len, hf_c]) holistic_feat = paddle.expand(
holistic_feat, shape=[bsz, seq_len, hf_c])
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2)) y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
else: else:
y = self.prediction(attn_feat) y = self.prediction(attn_feat)
...@@ -277,8 +282,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -277,8 +282,7 @@ class ParallelSARDecoder(BaseDecoder):
in_dec = paddle.concat((out_enc, lab_embedding), axis=1) in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
# bsz * (seq_len + 1) * C # bsz * (seq_len + 1) * C
out_dec = self._2d_attention( out_dec = self._2d_attention(
in_dec, feat, out_enc, valid_ratios=valid_ratios in_dec, feat, out_enc, valid_ratios=valid_ratios)
)
# bsz * (seq_len + 1) * num_classes # bsz * (seq_len + 1) * num_classes
return out_dec[:, 1:, :] # bsz * seq_len * num_classes return out_dec[:, 1:, :] # bsz * seq_len * num_classes
...@@ -293,9 +297,8 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -293,9 +297,8 @@ class ParallelSARDecoder(BaseDecoder):
seq_len = self.max_seq_len seq_len = self.max_seq_len
bsz = feat.shape[0] bsz = feat.shape[0]
start_token = paddle.full((bsz, ), start_token = paddle.full(
fill_value=self.start_idx, (bsz, ), fill_value=self.start_idx, dtype='int64')
dtype='int64')
# bsz # bsz
start_token = self.embedding(start_token) start_token = self.embedding(start_token)
# bsz * emb_dim # bsz * emb_dim
...@@ -311,8 +314,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -311,8 +314,7 @@ class ParallelSARDecoder(BaseDecoder):
outputs = [] outputs = []
for i in range(1, seq_len + 1): for i in range(1, seq_len + 1):
decoder_output = self._2d_attention( decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios decoder_input, feat, out_enc, valid_ratios=valid_ratios)
)
char_output = decoder_output[:, i, :] # bsz * num_classes char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1) char_output = F.softmax(char_output, -1)
outputs.append(char_output) outputs.append(char_output)
...@@ -344,9 +346,7 @@ class SARHead(nn.Layer): ...@@ -344,9 +346,7 @@ class SARHead(nn.Layer):
# encoder module # encoder module
self.encoder = SAREncoder( self.encoder = SAREncoder(
enc_bi_rnn=enc_bi_rnn, enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
enc_drop_rnn=enc_drop_rnn,
enc_gru=enc_gru)
# decoder module # decoder module
self.decoder = ParallelSARDecoder( self.decoder = ParallelSARDecoder(
...@@ -369,10 +369,15 @@ class SARHead(nn.Layer): ...@@ -369,10 +369,15 @@ class SARHead(nn.Layer):
if self.training: if self.training:
label = targets[0] # label label = targets[0] # label
label = paddle.to_tensor(label, dtype='int64') label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder(feat, holistic_feat, label, img_metas=targets) final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets)
if not self.training: if not self.training:
final_out = self.decoder(feat, holistic_feat, label=None, img_metas=targets, train_mode=False) final_out = self.decoder(
feat,
holistic_feat,
label=None,
img_metas=targets,
train_mode=False)
# (bsz, seq_len, num_classes) # (bsz, seq_len, num_classes)
return final_out return final_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册