未验证 提交 5fd2de2e 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

improve post process (#5758)

* improve post process

* rm unused code
上级 c2e45f2d
...@@ -54,22 +54,24 @@ class BaseRecLabelDecode(object): ...@@ -54,22 +54,24 @@ class BaseRecLabelDecode(object):
ignored_tokens = self.get_ignored_tokens() ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index) batch_size = len(text_index)
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
char_list = [] selection = np.ones(len(text_index[batch_idx]), dtype=bool)
conf_list = [] if is_remove_duplicate:
for idx in range(len(text_index[batch_idx])): selection[1:] = text_index[batch_idx][1:] != text_index[
if text_index[batch_idx][idx] in ignored_tokens: batch_idx][:-1]
continue for ignored_token in ignored_tokens:
if is_remove_duplicate: selection &= text_index[batch_idx] != ignored_token
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ char_list = [
batch_idx][idx]: self.character[text_id]
continue for text_id in text_index[batch_idx][selection]
char_list.append(self.character[int(text_index[batch_idx][ ]
idx])]) if text_prob is not None:
if text_prob is not None: conf_list = text_prob[batch_idx][selection]
conf_list.append(text_prob[batch_idx][idx]) else:
else: conf_list = [1] * len(selection)
conf_list.append(1) if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list) text = ''.join(char_list)
result_list.append((text, np.mean(conf_list))) result_list.append((text, np.mean(conf_list)))
return result_list return result_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册