From df4a2f6a7ee5efbf48a36b42a70c89511fdb4ac6 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Tue, 7 Sep 2021 03:33:02 +0000 Subject: [PATCH] update rec_sar_head --- ppocr/losses/rec_sar_loss.py | 2 +- ppocr/modeling/heads/rec_sar_head.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ppocr/losses/rec_sar_loss.py b/ppocr/losses/rec_sar_loss.py index 1afb21fe..9e1c6495 100644 --- a/ppocr/losses/rec_sar_loss.py +++ b/ppocr/losses/rec_sar_loss.py @@ -9,7 +9,7 @@ from paddle import nn class SARLoss(nn.Layer): def __init__(self, **kwargs): super(SARLoss, self).__init__() - self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=92) + self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96) def forward(self, predicts, batch): predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py index fb37b8ce..98b00ed0 100644 --- a/ppocr/modeling/heads/rec_sar_head.py +++ b/ppocr/modeling/heads/rec_sar_head.py @@ -118,8 +118,7 @@ class BaseDecoder(nn.Layer): class ParallelSARDecoder(BaseDecoder): """ Args: - num_classes (int): Output class number. - channels (list[int]): Network layer channels. + out_channels (int): Output class number. enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. dec_drop_rnn (float): Dropout of RNN layer in decoder. @@ -137,7 +136,7 @@ class ParallelSARDecoder(BaseDecoder): """ def __init__(self, - num_classes=93, # 90 + unknown + start + padding + out_channels, # 90 + unknown + start + padding enc_bi_rnn=False, dec_bi_rnn=False, dec_drop_rnn=0.0, @@ -148,8 +147,6 @@ class ParallelSARDecoder(BaseDecoder): pred_dropout=0.1, max_text_length=30, mask=True, - start_idx=91, - padding_idx=92, # 92 pred_concat=True, **kwargs): super().__init__() @@ -157,7 +154,8 @@ class ParallelSARDecoder(BaseDecoder): self.num_classes = num_classes self.enc_bi_rnn = enc_bi_rnn self.d_k = d_k - self.start_idx = start_idx + self.start_idx = out_channels - 2 + self.padding_idx = out_channels - 1 self.max_seq_len = max_text_length self.mask = mask self.pred_concat = pred_concat @@ -191,7 +189,7 @@ class ParallelSARDecoder(BaseDecoder): # Decoder input embedding self.embedding = nn.Embedding( - self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx) + self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx) # Prediction layer self.pred_dropout = nn.Dropout(pred_dropout) @@ -330,6 +328,7 @@ class ParallelSARDecoder(BaseDecoder): class SARHead(nn.Layer): def __init__(self, + out_channels, enc_bi_rnn=False, enc_drop_rnn=0.1, enc_gru=False, @@ -351,7 +350,8 @@ class SARHead(nn.Layer): # decoder module self.decoder = ParallelSARDecoder( - enc_bi_rnn=enc_bi_rnn, + out_channels=out_channels, + enc_bi_rnn=enc_bi_rnn, dec_bi_rnn=dec_bi_rnn, dec_drop_rnn=dec_drop_rnn, dec_gru=dec_gru, @@ -375,4 +375,4 @@ class SARHead(nn.Layer): # (bsz, seq_len, num_classes) return final_out - \ No newline at end of file + -- GitLab