提交 8143a426 编写于 作者: W wanghaoshuang

1. Add more comments

上级 1bc8de32
...@@ -208,20 +208,46 @@ class ChunkEvaluator(Evaluator): ...@@ -208,20 +208,46 @@ class ChunkEvaluator(Evaluator):
class EditDistance(Evaluator): class EditDistance(Evaluator):
""" """
Average edit distance error for multiple mini-batches. Accumulate edit distance sum and sequence number from mini-batches and
compute the average edit_distance of all batches.
Args:
input: the sequences predicted by network
label: the target sequences which must has same sequence count
with input.
ignored_tokens(list of int): Tokens that should be removed before
calculating edit distance.
Example:
exe = fluid.executor(place)
distance_evaluator = fluid.Evaluator.EditDistance(input, label)
for epoch in PASS_NUM:
distance_evaluator.reset(exe)
for data in batches:
loss, sum_distance = exe.run(fetch_list=[cost] + distance_evaluator.metrics)
avg_distance = distance_evaluator.eval(exe)
pass_distance = distance_evaluator.eval(exe)
In the above example:
'sum_distance' is the sum of the batch's edit distance.
'avg_distance' is the average of edit distance from the firt batch to the current batch.
'pass_distance' is the average of edit distance from all the pass.
""" """
def __init__(self, input, label, k=1, **kwargs): def __init__(self, input, label, ignored_tokens=None, **kwargs):
super(EditDistance, self).__init__("edit_distance", **kwargs) super(EditDistance, self).__init__("edit_distance", **kwargs)
main_program = self.helper.main_program main_program = self.helper.main_program
if main_program.current_block().idx != 0: if main_program.current_block().idx != 0:
raise ValueError("You can only invoke Evaluator in root block") raise ValueError("You can only invoke Evaluator in root block")
self.total_error = self.create_state( self.total_error = self.create_state(
dtype='float32', shape=[1], suffix='total') dtype='float32', shape=[1], suffix='total_error')
self.seq_num = self.create_state( self.seq_num = self.create_state(
dtype='int64', shape=[1], suffix='total') dtype='int64', shape=[1], suffix='seq_num')
error, seq_num = layers.edit_distance(input=input, label=label) error, seq_num = layers.edit_distance(
input=input, label=label, ignored_tokens=ignored_tokens)
#error = layers.cast(x=error, dtype='float32') #error = layers.cast(x=error, dtype='float32')
sum_error = layers.reduce_sum(error) sum_error = layers.reduce_sum(error)
layers.sums(input=[self.total_error, sum_error], out=self.total_error) layers.sums(input=[self.total_error, sum_error], out=self.total_error)
......
...@@ -1864,7 +1864,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): ...@@ -1864,7 +1864,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
return out return out
def edit_distance(input, label, normalized=False, tokens=None, name=None): def edit_distance(input,
label,
normalized=False,
ignored_tokens=None,
name=None):
""" """
EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion: EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
...@@ -1882,10 +1886,10 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None): ...@@ -1882,10 +1886,10 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None):
normalized(bool): Indicated whether to normalize the edit distance by the length of reference string. normalized(bool): Indicated whether to normalize the edit distance by the length of reference string.
tokens(list): Tokens that should be removed before calculating edit distance. ignored_tokens(list of int): Tokens that should be removed before calculating edit distance.
Returns: Returns:
Variable: sequence-to-sequence edit distance loss in shape [batch_size, 1]. Variable: sequence-to-sequence edit distance in shape [batch_size, 1].
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1898,7 +1902,7 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None): ...@@ -1898,7 +1902,7 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None):
helper = LayerHelper("edit_distance", **locals()) helper = LayerHelper("edit_distance", **locals())
# remove some tokens from input and labels # remove some tokens from input and labels
if tokens is not None and len(tokens) > 0: if ignored_tokens is not None and len(ignored_tokens) > 0:
erased_input = helper.create_tmp_variable(dtype="int64") erased_input = helper.create_tmp_variable(dtype="int64")
erased_label = helper.create_tmp_variable(dtype="int64") erased_label = helper.create_tmp_variable(dtype="int64")
...@@ -1906,14 +1910,14 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None): ...@@ -1906,14 +1910,14 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None):
type="sequence_erase", type="sequence_erase",
inputs={"X": [input]}, inputs={"X": [input]},
outputs={"Out": [erased_input]}, outputs={"Out": [erased_input]},
attrs={"tokens": tokens}) attrs={"tokens": ignored_tokens})
input = erased_input input = erased_input
helper.append_op( helper.append_op(
type="sequence_erase", type="sequence_erase",
inputs={"X": [label]}, inputs={"X": [label]},
outputs={"Out": [erase_label]}, outputs={"Out": [erase_label]},
attrs={"tokens": tokens}) attrs={"tokens": ignored_tokens})
label = erased_label label = erased_label
# edit distance op # edit distance op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册