diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index b0517982f00ff7e283b613309b3676d793e8b7ad..c769b7b4a3076645b0fefe27d1271dedd4ad2d19 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -216,6 +216,7 @@ class SRNLabelDecode(BaseRecLabelDecode): character_type='en', use_space_char=False, **kwargs): + self.max_text_length = kwargs['max_text_length'] super(SRNLabelDecode, self).__init__(character_dict_path, character_type, use_space_char) @@ -229,9 +230,9 @@ class SRNLabelDecode(BaseRecLabelDecode): preds_idx = np.argmax(pred, axis=1) preds_prob = np.max(pred, axis=1) - preds_idx = np.reshape(preds_idx, [-1, 25]) + preds_idx = np.reshape(preds_idx, [-1, self.max_text_length]) - preds_prob = np.reshape(preds_prob, [-1, 25]) + preds_prob = np.reshape(preds_prob, [-1, self.max_text_length]) text = self.decode(preds_idx, preds_prob)