未验证 提交 5fd2de2e 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

improve post process (#5758)

* improve post process

* rm unused code
上级 c2e45f2d
...@@ -54,22 +54,24 @@ class BaseRecLabelDecode(object): ...@@ -54,22 +54,24 @@ class BaseRecLabelDecode(object):
ignored_tokens = self.get_ignored_tokens() ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index) batch_size = len(text_index)
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
char_list = [] selection = np.ones(len(text_index[batch_idx]), dtype=bool)
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate: if is_remove_duplicate:
# only for predict selection[1:] = text_index[batch_idx][1:] != text_index[
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ batch_idx][:-1]
batch_idx][idx]: for ignored_token in ignored_tokens:
continue selection &= text_index[batch_idx] != ignored_token
char_list.append(self.character[int(text_index[batch_idx][
idx])]) char_list = [
self.character[text_id]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None: if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx]) conf_list = text_prob[batch_idx][selection]
else: else:
conf_list.append(1) conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list) text = ''.join(char_list)
result_list.append((text, np.mean(conf_list))) result_list.append((text, np.mean(conf_list)))
return result_list return result_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册