提交 0dd3919a 编写于 作者: W wanghaoshuang

Add python wrapper for ctc_evaluator

上级 144854d2
...@@ -50,6 +50,7 @@ __all__ = [ ...@@ -50,6 +50,7 @@ __all__ = [
'sequence_last_step', 'sequence_last_step',
'dropout', 'dropout',
'split', 'split',
'greedy_ctc_evaluator',
] ]
...@@ -1597,3 +1598,39 @@ def split(input, num_or_sections, dim=-1): ...@@ -1597,3 +1598,39 @@ def split(input, num_or_sections, dim=-1):
'axis': dim 'axis': dim
}) })
return outs return outs
def greedy_ctc_evaluator(input, label, blank, normalized=False, name=None):
"""
"""
helper = LayerHelper("greedy_ctc_evalutor", **locals())
# top 1 op
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": 1})
# ctc align op
ctc_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="ctc_align",
inputs={"Input": [topk_indices]},
outputs={"Out": [ctc_out]},
attrs={"merge_repeated": True,
"blank": blank})
# edit distance op
edit_distance_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="edit_distance",
inputs={"Hyps": [ctc_out],
"Refs": [label]},
outputs={"Out": [edit_distance_out]},
attrs={"normalized": normalized})
return edit_distance_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册