提交 d6115158 编写于 作者: A andyjpaddle

fix code style

上级 ae09ef60
......@@ -19,6 +19,7 @@ class SAREncoder(nn.Layer):
d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
"""
def __init__(self,
enc_bi_rnn=False,
enc_drop_rnn=0.1,
......@@ -51,8 +52,7 @@ class SAREncoder(nn.Layer):
num_layers=2,
time_major=False,
dropout=enc_drop_rnn,
direction=direction
)
direction=direction)
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
else:
......@@ -72,8 +72,7 @@ class SAREncoder(nn.Layer):
h_feat = feat.shape[2] # bsz c h w
feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0
)
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
......@@ -135,7 +134,8 @@ class ParallelSARDecoder(BaseDecoder):
attention with holistic feature and hidden state.
"""
def __init__(self,
def __init__(
self,
out_channels, # 90 + unknown + start + padding
enc_bi_rnn=False,
dec_bi_rnn=False,
......@@ -165,7 +165,8 @@ class ParallelSARDecoder(BaseDecoder):
# 2D attention layer
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
self.conv3x3_1 = nn.Conv2D(d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv3x3_1 = nn.Conv2D(
d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Linear(d_k, 1)
# Decoder RNN layer
......@@ -180,8 +181,7 @@ class ParallelSARDecoder(BaseDecoder):
num_layers=2,
time_major=False,
dropout=dec_drop_rnn,
direction=direction
)
direction=direction)
if dec_gru:
self.rnn_decoder = nn.GRU(**kwargs)
else:
......@@ -189,7 +189,9 @@ class ParallelSARDecoder(BaseDecoder):
# Decoder input embedding
self.embedding = nn.Embedding(
self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx)
self.num_classes,
encoder_rnn_out_size,
padding_idx=self.padding_idx)
# Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout)
......@@ -242,13 +244,16 @@ class ParallelSARDecoder(BaseDecoder):
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
# attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
(3, 4),
keepdim=False)
# bsz * (seq_len + 1) * C
# Linear transformation
if self.pred_concat:
hf_c = holistic_feat.shape[-1]
holistic_feat = paddle.expand(holistic_feat, shape=[bsz, seq_len, hf_c])
holistic_feat = paddle.expand(
holistic_feat, shape=[bsz, seq_len, hf_c])
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
else:
y = self.prediction(attn_feat)
......@@ -277,8 +282,7 @@ class ParallelSARDecoder(BaseDecoder):
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
# bsz * (seq_len + 1) * C
out_dec = self._2d_attention(
in_dec, feat, out_enc, valid_ratios=valid_ratios
)
in_dec, feat, out_enc, valid_ratios=valid_ratios)
# bsz * (seq_len + 1) * num_classes
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
......@@ -293,9 +297,8 @@ class ParallelSARDecoder(BaseDecoder):
seq_len = self.max_seq_len
bsz = feat.shape[0]
start_token = paddle.full((bsz, ),
fill_value=self.start_idx,
dtype='int64')
start_token = paddle.full(
(bsz, ), fill_value=self.start_idx, dtype='int64')
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
......@@ -311,8 +314,7 @@ class ParallelSARDecoder(BaseDecoder):
outputs = []
for i in range(1, seq_len + 1):
decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios
)
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1)
outputs.append(char_output)
......@@ -344,9 +346,7 @@ class SARHead(nn.Layer):
# encoder module
self.encoder = SAREncoder(
enc_bi_rnn=enc_bi_rnn,
enc_drop_rnn=enc_drop_rnn,
enc_gru=enc_gru)
enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
# decoder module
self.decoder = ParallelSARDecoder(
......@@ -369,10 +369,15 @@ class SARHead(nn.Layer):
if self.training:
label = targets[0] # label
label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder(feat, holistic_feat, label, img_metas=targets)
final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets)
if not self.training:
final_out = self.decoder(feat, holistic_feat, label=None, img_metas=targets, train_mode=False)
final_out = self.decoder(
feat,
holistic_feat,
label=None,
img_metas=targets,
train_mode=False)
# (bsz, seq_len, num_classes)
return final_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册