提交 a88ce7a5 编写于 作者: W WenmuZhou

修正对label decode时重复字符会消失的bug

上级 ca9ea622
...@@ -70,6 +70,7 @@ class BaseRecLabelDecode(object): ...@@ -70,6 +70,7 @@ class BaseRecLabelDecode(object):
if text_index[batch_idx][idx] in ignored_tokens: if text_index[batch_idx][idx] in ignored_tokens:
continue continue
if is_remove_duplicate: if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]: batch_idx][idx]:
continue continue
...@@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
text = self.decode(preds_idx, preds_prob) text = self.decode(preds_idx, preds_prob)
if label is None: if label is None:
return text return text
label = self.decode(label) label = self.decode(label, is_remove_duplicate=False)
return text, label return text, label
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册