diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 93d385544e40af59a871d09ee6181888ce84691d..de771acca86a8956b06b366b840aac7e21f835a4 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -54,22 +54,24 @@ class BaseRecLabelDecode(object): ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] in ignored_tokens: - continue - if is_remove_duplicate: - # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[ + batch_idx][:-1] + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + + char_list = [ + self.character[text_id] + for text_id in text_index[batch_idx][selection] + ] + if text_prob is not None: + conf_list = text_prob[batch_idx][selection] + else: + conf_list = [1] * len(selection) + if len(conf_list) == 0: + conf_list = [0] + text = ''.join(char_list) result_list.append((text, np.mean(conf_list))) return result_list