diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 5d05046bbac825e09aeb3dab18b9812edbabd732..c57811df1da14aacaeac85d661009176e7da0ccf 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1864,7 +1864,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): return out -def edit_distance(input, label, normalized=False, name=None): +def edit_distance(input, label, normalized=False, tokens=None, name=None): """ EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion: @@ -1882,6 +1882,8 @@ def edit_distance(input, label, normalized=False, name=None): normalized(bool): Indicated whether to normalize the edit distance by the length of reference string. + tokens(list): Tokens that should be removed before calculating edit distance. + Returns: Variable: sequence-to-sequence edit distance loss in shape [batch_size, 1]. @@ -1895,6 +1897,25 @@ def edit_distance(input, label, normalized=False, name=None): """ helper = LayerHelper("edit_distance", **locals()) + # remove some tokens from input and labels + if tokens is not None and len(tokens) > 0: + erased_input = helper.create_tmp_variable(dtype="int64") + erased_label = helper.create_tmp_variable(dtype="int64") + + helper.append_op( + type="sequence_erase", + inputs={"X": [input]}, + outputs={"Out": [erased_input]}, + attrs={"tokens": tokens}) + input = erased_input + + helper.append_op( + type="sequence_erase", + inputs={"X": [label]}, + outputs={"Out": [erase_label]}, + attrs={"tokens": tokens}) + label = erased_label + # edit distance op edit_distance_out = helper.create_tmp_variable(dtype="int64") helper.append_op(