From d611515803d71a847f7b262704fe7e92607b0ef6 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Tue, 7 Sep 2021 06:13:56 +0000 Subject: [PATCH] fix code style --- ppocr/modeling/heads/rec_sar_head.py | 179 ++++++++++++++------------- 1 file changed, 92 insertions(+), 87 deletions(-) diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py index ba0aa8eb..3c131c8b 100644 --- a/ppocr/modeling/heads/rec_sar_head.py +++ b/ppocr/modeling/heads/rec_sar_head.py @@ -19,6 +19,7 @@ class SAREncoder(nn.Layer): d_enc (int): Dim of encoder RNN layer. mask (bool): If True, mask padding in RNN sequence. """ + def __init__(self, enc_bi_rnn=False, enc_drop_rnn=0.1, @@ -51,33 +52,31 @@ class SAREncoder(nn.Layer): num_layers=2, time_major=False, dropout=enc_drop_rnn, - direction=direction - ) + direction=direction) if enc_gru: self.rnn_encoder = nn.GRU(**kwargs) else: self.rnn_encoder = nn.LSTM(**kwargs) - + # global feature transformation encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) - + def forward(self, feat, img_metas=None): if img_metas is not None: assert len(img_metas[0]) == feat.shape[0] - + valid_ratios = None if img_metas is not None and self.mask: valid_ratios = img_metas[-1] - - h_feat = feat.shape[2] # bsz c h w + + h_feat = feat.shape[2] # bsz c h w feat_v = F.max_pool2d( - feat, kernel_size=(h_feat, 1), stride=1, padding=0 - ) - feat_v = feat_v.squeeze(2) # bsz * C * W - feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C - holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C - + feat, kernel_size=(h_feat, 1), stride=1, padding=0) + feat_v = feat_v.squeeze(2) # bsz * C * W + feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C + holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C + if valid_ratios is not None: valid_hf = [] T = holistic_feat.shape[1] @@ -86,11 +85,11 @@ class SAREncoder(nn.Layer): valid_hf.append(holistic_feat[i, valid_step, :]) valid_hf = paddle.stack(valid_hf, axis=0) else: - valid_hf = holistic_feat[:, -1, :] # bsz * C - holistic_feat = self.linear(valid_hf) # bsz * C - + valid_hf = holistic_feat[:, -1, :] # bsz * C + holistic_feat = self.linear(valid_hf) # bsz * C + return holistic_feat - + class BaseDecoder(nn.Layer): def __init__(self, **kwargs): @@ -102,7 +101,7 @@ class BaseDecoder(nn.Layer): def forward_test(self, feat, out_enc, img_metas): raise NotImplementedError - def forward(self, + def forward(self, feat, out_enc, label=None, @@ -135,20 +134,21 @@ class ParallelSARDecoder(BaseDecoder): attention with holistic feature and hidden state. """ - def __init__(self, - out_channels, # 90 + unknown + start + padding - enc_bi_rnn=False, - dec_bi_rnn=False, - dec_drop_rnn=0.0, - dec_gru=False, - d_model=512, - d_enc=512, - d_k=64, - pred_dropout=0.1, - max_text_length=30, - mask=True, - pred_concat=True, - **kwargs): + def __init__( + self, + out_channels, # 90 + unknown + start + padding + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_drop_rnn=0.0, + dec_gru=False, + d_model=512, + d_enc=512, + d_k=64, + pred_dropout=0.1, + max_text_length=30, + mask=True, + pred_concat=True, + **kwargs): super().__init__() self.num_classes = out_channels @@ -165,7 +165,8 @@ class ParallelSARDecoder(BaseDecoder): # 2D attention layer self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) - self.conv3x3_1 = nn.Conv2D(d_model, d_k, kernel_size=3, stride=1, padding=1) + self.conv3x3_1 = nn.Conv2D( + d_model, d_k, kernel_size=3, stride=1, padding=1) self.conv1x1_2 = nn.Linear(d_k, 1) # Decoder RNN layer @@ -180,8 +181,7 @@ class ParallelSARDecoder(BaseDecoder): num_layers=2, time_major=False, dropout=dec_drop_rnn, - direction=direction - ) + direction=direction) if dec_gru: self.rnn_decoder = nn.GRU(**kwargs) else: @@ -189,8 +189,10 @@ class ParallelSARDecoder(BaseDecoder): # Decoder input embedding self.embedding = nn.Embedding( - self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx) - + self.num_classes, + encoder_rnn_out_size, + padding_idx=self.padding_idx) + # Prediction layer self.pred_dropout = nn.Dropout(pred_dropout) pred_num_classes = num_classes - 1 @@ -205,11 +207,11 @@ class ParallelSARDecoder(BaseDecoder): feat, holistic_feat, valid_ratios=None): - + y = self.rnn_decoder(decoder_input)[0] # y: bsz * (seq_len + 1) * hidden_size - - attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size + + attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size bsz, seq_len, attn_size = attn_query.shape attn_query = paddle.unsqueeze(attn_query, axis=[3, 4]) # (bsz, seq_len + 1, attn_size, 1, 1) @@ -220,7 +222,7 @@ class ParallelSARDecoder(BaseDecoder): # bsz * 1 * attn_size * h * w attn_weight = paddle.tanh(paddle.add(attn_key, attn_query)) - + # bsz * (seq_len + 1) * attn_size * h * w attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2]) # bsz * (seq_len + 1) * h * w * attn_size @@ -237,25 +239,28 @@ class ParallelSARDecoder(BaseDecoder): attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) attn_weight = F.softmax(attn_weight, axis=-1) - + attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c]) attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3]) # attn_weight: bsz * T * c * h * w # feat: bsz * c * h * w - attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False) + attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), + (3, 4), + keepdim=False) # bsz * (seq_len + 1) * C # Linear transformation if self.pred_concat: hf_c = holistic_feat.shape[-1] - holistic_feat = paddle.expand(holistic_feat, shape=[bsz, seq_len, hf_c]) + holistic_feat = paddle.expand( + holistic_feat, shape=[bsz, seq_len, hf_c]) y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2)) else: y = self.prediction(attn_feat) # bsz * (seq_len + 1) * num_classes if self.train_mode: y = self.pred_dropout(y) - + return y def forward_train(self, feat, out_enc, label, img_metas): @@ -268,7 +273,7 @@ class ParallelSARDecoder(BaseDecoder): valid_ratios = None if img_metas is not None and self.mask: valid_ratios = img_metas[-1] - + label = label.cuda() lab_embedding = self.embedding(label) # bsz * seq_len * emb_dim @@ -277,11 +282,10 @@ class ParallelSARDecoder(BaseDecoder): in_dec = paddle.concat((out_enc, lab_embedding), axis=1) # bsz * (seq_len + 1) * C out_dec = self._2d_attention( - in_dec, feat, out_enc, valid_ratios=valid_ratios - ) + in_dec, feat, out_enc, valid_ratios=valid_ratios) # bsz * (seq_len + 1) * num_classes - - return out_dec[:, 1:, :] # bsz * seq_len * num_classes + + return out_dec[:, 1:, :] # bsz * seq_len * num_classes def forward_test(self, feat, out_enc, img_metas): if img_metas is not None: @@ -289,13 +293,12 @@ class ParallelSARDecoder(BaseDecoder): valid_ratios = None if img_metas is not None and self.mask: - valid_ratios = img_metas[-1] - + valid_ratios = img_metas[-1] + seq_len = self.max_seq_len bsz = feat.shape[0] - start_token = paddle.full((bsz, ), - fill_value=self.start_idx, - dtype='int64') + start_token = paddle.full( + (bsz, ), fill_value=self.start_idx, dtype='int64') # bsz start_token = self.embedding(start_token) # bsz * emb_dim @@ -311,68 +314,70 @@ class ParallelSARDecoder(BaseDecoder): outputs = [] for i in range(1, seq_len + 1): decoder_output = self._2d_attention( - decoder_input, feat, out_enc, valid_ratios=valid_ratios - ) - char_output = decoder_output[:, i, :] # bsz * num_classes + decoder_input, feat, out_enc, valid_ratios=valid_ratios) + char_output = decoder_output[:, i, :] # bsz * num_classes char_output = F.softmax(char_output, -1) outputs.append(char_output) max_idx = paddle.argmax(char_output, axis=1, keepdim=False) - char_embedding = self.embedding(max_idx) # bsz * emb_dim + char_embedding = self.embedding(max_idx) # bsz * emb_dim if i < seq_len: decoder_input[:, i + 1, :] = char_embedding - - outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes + + outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes return outputs class SARHead(nn.Layer): - def __init__(self, - out_channels, - enc_bi_rnn=False, - enc_drop_rnn=0.1, - enc_gru=False, - dec_bi_rnn=False, - dec_drop_rnn=0.0, - dec_gru=False, - d_k=512, - pred_dropout=0.1, - max_text_length=30, - pred_concat=True, - **kwargs): + def __init__(self, + out_channels, + enc_bi_rnn=False, + enc_drop_rnn=0.1, + enc_gru=False, + dec_bi_rnn=False, + dec_drop_rnn=0.0, + dec_gru=False, + d_k=512, + pred_dropout=0.1, + max_text_length=30, + pred_concat=True, + **kwargs): super(SARHead, self).__init__() # encoder module self.encoder = SAREncoder( - enc_bi_rnn=enc_bi_rnn, - enc_drop_rnn=enc_drop_rnn, - enc_gru=enc_gru) + enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru) # decoder module self.decoder = ParallelSARDecoder( out_channels=out_channels, - enc_bi_rnn=enc_bi_rnn, + enc_bi_rnn=enc_bi_rnn, dec_bi_rnn=dec_bi_rnn, dec_drop_rnn=dec_drop_rnn, dec_gru=dec_gru, d_k=d_k, pred_dropout=pred_dropout, max_text_length=max_text_length, - pred_concat=pred_concat) - + pred_concat=pred_concat) + def forward(self, feat, targets=None): ''' img_metas: [label, valid_ratio] ''' - holistic_feat = self.encoder(feat, targets) # bsz c - + holistic_feat = self.encoder(feat, targets) # bsz c + if self.training: - label = targets[0] # label + label = targets[0] # label label = paddle.to_tensor(label, dtype='int64') - final_out = self.decoder(feat, holistic_feat, label, img_metas=targets) + final_out = self.decoder( + feat, holistic_feat, label, img_metas=targets) if not self.training: - final_out = self.decoder(feat, holistic_feat, label=None, img_metas=targets, train_mode=False) + final_out = self.decoder( + feat, + holistic_feat, + label=None, + img_metas=targets, + train_mode=False) # (bsz, seq_len, num_classes) - + return final_out - -- GitLab