From e760df4b1e59f98ee98b084746e3cf781e25cc9e Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Wed, 13 Oct 2021 11:40:38 +0000 Subject: [PATCH] add sar dict --- ppocr/losses/rec_sar_loss.py | 9 ++-- ppocr/utils/dict90.txt | 90 ++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 ppocr/utils/dict90.txt diff --git a/ppocr/losses/rec_sar_loss.py b/ppocr/losses/rec_sar_loss.py index 9e1c6495..c8bd8bb0 100644 --- a/ppocr/losses/rec_sar_loss.py +++ b/ppocr/losses/rec_sar_loss.py @@ -9,11 +9,14 @@ from paddle import nn class SARLoss(nn.Layer): def __init__(self, **kwargs): super(SARLoss, self).__init__() - self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96) + self.loss_func = paddle.nn.loss.CrossEntropyLoss( + reduction="mean", ignore_index=92) def forward(self, predicts, batch): - predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets - label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation + predict = predicts[:, : + -1, :] # ignore last index of outputs to be in same seq_len with targets + label = batch[1].astype( + "int64")[:, 1:] # ignore first index of target in loss calculation batch_size, num_steps, num_classes = predict.shape[0], predict.shape[ 1], predict.shape[2] assert len(label.shape) == len(list(predict.shape)) - 1, \ diff --git a/ppocr/utils/dict90.txt b/ppocr/utils/dict90.txt new file mode 100644 index 00000000..a945ae9c --- /dev/null +++ b/ppocr/utils/dict90.txt @@ -0,0 +1,90 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +: +; +< += +> +? +@ +[ +\ +] +_ +` +~ \ No newline at end of file -- GitLab