From 0b854bdb8b0aad6360cf2c15b1ca40b52a94d40c Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 22 Jan 2018 09:37:23 +0800 Subject: [PATCH] Add sequence_erase option into edit distance python API --- python/paddle/v2/fluid/layers/nn.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 5d05046bbac..c57811df1da 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1864,7 +1864,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): return out -def edit_distance(input, label, normalized=False, name=None): +def edit_distance(input, label, normalized=False, tokens=None, name=None): """ EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion: @@ -1882,6 +1882,8 @@ def edit_distance(input, label, normalized=False, name=None): normalized(bool): Indicated whether to normalize the edit distance by the length of reference string. + tokens(list): Tokens that should be removed before calculating edit distance. + Returns: Variable: sequence-to-sequence edit distance loss in shape [batch_size, 1]. @@ -1895,6 +1897,25 @@ def edit_distance(input, label, normalized=False, name=None): """ helper = LayerHelper("edit_distance", **locals()) + # remove some tokens from input and labels + if tokens is not None and len(tokens) > 0: + erased_input = helper.create_tmp_variable(dtype="int64") + erased_label = helper.create_tmp_variable(dtype="int64") + + helper.append_op( + type="sequence_erase", + inputs={"X": [input]}, + outputs={"Out": [erased_input]}, + attrs={"tokens": tokens}) + input = erased_input + + helper.append_op( + type="sequence_erase", + inputs={"X": [label]}, + outputs={"Out": [erase_label]}, + attrs={"tokens": tokens}) + label = erased_label + # edit distance op edit_distance_out = helper.create_tmp_variable(dtype="int64") helper.append_op( -- GitLab