提交 0dd3919a 编写于 作者: W wanghaoshuang

Add python wrapper for ctc_evaluator

上级 144854d2
......@@ -50,6 +50,7 @@ __all__ = [
'sequence_last_step',
'dropout',
'split',
'greedy_ctc_evaluator',
]
......@@ -1547,13 +1548,13 @@ def split(input, num_or_sections, dim=-1):
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
num_or_sections (int|list): If :attr:`num_or_sections` is an integer,
then the integer indicates the number of equal sized sub-tensors
that the tensor will be divided into. If :attr:`num_or_sections`
is a list of integers, the length of list indicates the number of
sub-tensors and the integers indicate the sizes of sub-tensors'
num_or_sections (int|list): If :attr:`num_or_sections` is an integer,
then the integer indicates the number of equal sized sub-tensors
that the tensor will be divided into. If :attr:`num_or_sections`
is a list of integers, the length of list indicates the number of
sub-tensors and the integers indicate the sizes of sub-tensors'
:attr:`dim` dimension orderly.
dim (int): The dimension along which to split. If :math:`dim < 0`, the
dim (int): The dimension along which to split. If :math:`dim < 0`, the
dimension to split along is :math:`rank(input) + dim`.
Returns:
......@@ -1597,3 +1598,39 @@ def split(input, num_or_sections, dim=-1):
'axis': dim
})
return outs
def greedy_ctc_evaluator(input, label, blank, normalized=False, name=None):
"""
"""
helper = LayerHelper("greedy_ctc_evalutor", **locals())
# top 1 op
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": 1})
# ctc align op
ctc_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="ctc_align",
inputs={"Input": [topk_indices]},
outputs={"Out": [ctc_out]},
attrs={"merge_repeated": True,
"blank": blank})
# edit distance op
edit_distance_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="edit_distance",
inputs={"Hyps": [ctc_out],
"Refs": [label]},
outputs={"Out": [edit_distance_out]},
attrs={"normalized": normalized})
return edit_distance_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册