diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 4cf1850a3344251b2fe3b22e4b94d156ca2518e6..e4b046be81899363000d7bd9e1a0ef85f5f0dc9a 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -139,7 +139,7 @@ paddle.fluid.layers.sequence_slice (ArgSpec(args=['input', 'offset', 'length', ' paddle.fluid.layers.dropout (ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')), ('document', '558d13133596209190df9a624264f28f')) paddle.fluid.layers.split (ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '78cf3a7323d1a7697658242e13f63759')) paddle.fluid.layers.ctc_greedy_decoder (ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2bc3a59efa9d52b628a6255422d9f0e8')) -paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)), ('document', 'f2c252aa2f83f8e503ffaf79668eaa28')) +paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens', 'input_length', 'label_length'], varargs=None, keywords=None, defaults=(True, None, None, None)), ('document', '77cbfb28cd2fc589f589c7013c5086cd')) paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', 'c1df110ea65998984f564c5c10abc54a')) paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)), ('document', 'fa2081f6e731bb9de7cd535ca07f523a')) paddle.fluid.layers.topk (ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e50940f3ce5a08cc477b72f517491bf3')) diff --git a/paddle/fluid/operators/edit_distance_op.cc b/paddle/fluid/operators/edit_distance_op.cc index de25a3dab53492e38a92fbcf07ccbe43f7546950..a854d470dddab074813d99f8c64d2e68ec291892 100644 --- a/paddle/fluid/operators/edit_distance_op.cc +++ b/paddle/fluid/operators/edit_distance_op.cc @@ -29,12 +29,30 @@ class EditDistanceOp : public framework::OperatorWithKernel { "Output(SequenceNum) shouldn't be null."); auto hyp_dims = ctx->GetInputDim("Hyps"); auto ref_dims = ctx->GetInputDim("Refs"); - PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1, - "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " - "equal to 1."); - PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1, - "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " - "equal to 1."); + + if (ctx->HasInput("HypsLength") && ctx->HasInput("RefsLength")) { + auto hyp_length_dims = ctx->GetInputDim("HypsLength"); + auto ref_length_dims = ctx->GetInputDim("RefsLength"); + + PADDLE_ENFORCE(hyp_dims.size() == 2 && ref_dims.size() == 2 && + hyp_dims[0] == ref_dims[0], + "Input(Hyps) and Input(Refs) must be 2-D Tensors with " + "identical first dimension"); + PADDLE_ENFORCE(hyp_length_dims[0] == ref_length_dims[0] && + hyp_length_dims[0] == hyp_dims[0], + "Input(HypsLength), Input(RefsLength) and Input(Hyps) " + "should have identical first dimension"); + } else { + PADDLE_ENFORCE( + hyp_dims.size() == 2 && hyp_dims[1] == 1, + "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + PADDLE_ENFORCE( + ref_dims.size() == 2 && ref_dims[1] == 1, + "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + } + ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); ctx->SetOutputDim("SequenceNum", {1}); } @@ -51,11 +69,21 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Hyps", - "(2-D LoDTensor, 2nd dim. equal to 1) " + "2-D Tensor, or 2-D LoDTensor with last " + "dimension being 1. " "The indices for hypothesis strings."); AddInput("Refs", - "(2-D LoDTensor, 2nd dim. equal to 1) " + "2-D Tensor, or 2-D LoDTensor with last " + "dimension being 1. " "The indices for reference strings."); + AddInput("HypsLength", + "1-D Tensor. " + "Sequence length for hyps when hyps is a tensor") + .AsDispensable(); + AddInput("RefsLength", + "1-D Tensor. " + "Sequence length for refs when refs is a tensor") + .AsDispensable(); AddOutput("SequenceNum", "The sequence count of current batch"); AddAttr("normalized", "(bool, default false) Indicated whether to normalize " @@ -78,12 +106,11 @@ insertion: "kitten" -> "sitten" -> "sittin" -> "sitting" -Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total -number denoted by `batch_size`, and the separation is specified by the LoD information. +Input(Hyps) is a 2-D Tensor or a 2-D LoDTensor consisting of all the hypothesis strings. And the `batch_size` reference strings are arranged in order in the same way in the -LoDTensor Input(Refs). +Input(Refs). -Output(Out) contains the `batch_size` results and each stands for the edit stance +Output(Out) contains the `batch_size` results and each stands for the edit distance for a pair of strings respectively. If Attr(normalized) is true, the edit distance will be divided by the length of reference string. )DOC"); diff --git a/paddle/fluid/operators/edit_distance_op.cu b/paddle/fluid/operators/edit_distance_op.cu index c25b7d2f9ec32bcef44db239de43feefd855bfe5..c7217b9f750b5a83f95b8df161de23a89241925d 100644 --- a/paddle/fluid/operators/edit_distance_op.cu +++ b/paddle/fluid/operators/edit_distance_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/edit_distance_op.h" #include "paddle/fluid/operators/math/math_function.h" @@ -76,20 +77,43 @@ class EditDistanceGPUKernel : public framework::OpKernel { auto* x2_t = ctx.Input("Refs"); auto* sequence_num = ctx.Output("SequenceNum"); sequence_num->mutable_data(ctx.GetPlace()); + auto batch_size = x1_t->dims()[0]; auto normalized = ctx.Attr("normalized"); auto stream = reinterpret_cast( ctx.device_context()) .stream(); - auto hyp_lod = x1_t->lod()[0]; - auto ref_lod = x2_t->lod()[0]; - PADDLE_ENFORCE( - hyp_lod.size() == ref_lod.size(), - "Input(Hyps) and Input(Refs) must have the same batch size."); - for (size_t i = 1; i < ref_lod.size(); ++i) { - PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], - "Reference string %d is empty.", i); + framework::Vector hyp_lod(batch_size + 1); + framework::Vector ref_lod(batch_size + 1); + + bool use_length = ctx.HasInput("HypsLength"); + + if (use_length) { + // build lod when using padding + auto* hyp_length = ctx.Input("HypsLength"); + auto* ref_length = ctx.Input("RefsLength"); + + framework::Tensor hyp_length_cpu; + framework::Tensor ref_length_cpu; + framework::TensorCopy(*hyp_length, platform::CPUPlace(), &hyp_length_cpu); + framework::TensorCopy(*ref_length, platform::CPUPlace(), &ref_length_cpu); + + for (auto i = 0; i < batch_size; i++) { + hyp_lod[i + 1] = hyp_lod[i] + hyp_length_cpu.data()[i]; + ref_lod[i + 1] = ref_lod[i] + ref_length_cpu.data()[i]; + } + + } else { + hyp_lod = x1_t->lod()[0]; + ref_lod = x2_t->lod()[0]; + } + + if (normalized) { + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } } const size_t num_strs = hyp_lod.size() - 1; @@ -108,10 +132,6 @@ class EditDistanceGPUKernel : public framework::OpKernel { if (m == 0 || n == 0) { distance = std::max(m, n); if (normalized) { - PADDLE_ENFORCE(n > 0, - "The reference string (#%d) cannot be empty " - "when Attr(normalized) is enabled.", - n); distance = distance / n; } memory::Copy(boost::get(ctx.GetPlace()), out + num, @@ -121,14 +141,17 @@ class EditDistanceGPUKernel : public framework::OpKernel { dist_t.Resize({m + 1, n + 1}); dist_t.mutable_data(ctx.GetPlace()); auto dist = dist_t.data(); - auto x1 = x1_t->data() + hyp_lod[num]; - auto x2 = x2_t->data() + ref_lod[num]; + auto hyp_offset = use_length ? num * x1_t->dims()[1] : hyp_lod[num]; + auto ref_offset = use_length ? num * x2_t->dims()[1] : ref_lod[num]; + auto x1 = x1_t->data() + hyp_offset; + auto x2 = x2_t->data() + ref_offset; FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); + // Compute the elements of distance matrix in the anti-diagonal diretion for (int64_t slice = 2; slice < m + n + 1; ++slice) { int z_m = slice < m + 1 ? 0 : slice - m; diff --git a/paddle/fluid/operators/edit_distance_op.h b/paddle/fluid/operators/edit_distance_op.h index 73d0af490b3730c01fcd2842ced388583b7acbe6..3e1aec7ceeec781dbf00ac5a24a8a4e95c999850 100644 --- a/paddle/fluid/operators/edit_distance_op.h +++ b/paddle/fluid/operators/edit_distance_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -29,17 +30,37 @@ class EditDistanceKernel : public framework::OpKernel { auto* x2_t = ctx.Input("Refs"); auto* sequence_num = ctx.Output("SequenceNum"); int64_t* seq_num_data = sequence_num->mutable_data(ctx.GetPlace()); + auto batch_size = x1_t->dims()[0]; auto normalized = ctx.Attr("normalized"); - auto hyp_lod = x1_t->lod()[0]; - auto ref_lod = x2_t->lod()[0]; - PADDLE_ENFORCE( - hyp_lod.size() == ref_lod.size(), - "Input(Hyps) and Input(Refs) must have the same batch size."); - for (size_t i = 1; i < ref_lod.size(); ++i) { - PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], - "Reference string %d is empty.", i); + framework::Vector hyp_lod(batch_size + 1); + framework::Vector ref_lod(batch_size + 1); + + bool use_length = ctx.HasInput("HypsLength"); + + if (use_length) { + // build lod when using padding + auto hyp_length_ptr = + ctx.Input("HypsLength")->data(); + auto ref_length_ptr = + ctx.Input("RefsLength")->data(); + + for (auto i = 0; i < batch_size; i++) { + hyp_lod[i + 1] = hyp_lod[i] + hyp_length_ptr[i]; + ref_lod[i + 1] = ref_lod[i] + ref_length_ptr[i]; + } + + } else { + hyp_lod = x1_t->lod()[0]; + ref_lod = x2_t->lod()[0]; + } + + if (normalized) { + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } } auto num_strs = hyp_lod.size() - 1; *seq_num_data = static_cast(num_strs); @@ -62,8 +83,10 @@ class EditDistanceKernel : public framework::OpKernel { dist_t.Resize({m + 1, n + 1}); dist_t.mutable_data(ctx.GetPlace()); auto dist = dist_t.data(); - auto x1 = x1_t->data() + hyp_lod[num]; - auto x2 = x2_t->data() + ref_lod[num]; + auto hyp_offset = use_length ? num * x1_t->dims()[1] : hyp_lod[num]; + auto ref_offset = use_length ? num * x2_t->dims()[1] : ref_lod[num]; + auto x1 = x1_t->data() + hyp_offset; + auto x2 = x2_t->data() + ref_offset; for (int64_t i = 0; i < m + 1; ++i) { dist[i * (n + 1)] = i; } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 029adda67822b4d3021faae0bb259eb412518e7c..41f485799743cd6e1fc0b6a9fbc6ec6234817943 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5353,7 +5353,12 @@ def topk(input, k, name=None): return values, indices -def edit_distance(input, label, normalized=True, ignored_tokens=None): +def edit_distance(input, + label, + normalized=True, + ignored_tokens=None, + input_length=None, + label_length=None): """ Edit distance operator computes the edit distances between a batch of hypothesis strings and their references. Edit distance, also called @@ -5367,52 +5372,49 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None): "kitten" -> "sitten" -> "sittin" -> "sitting" - The input is a LoDTensor consisting of all the hypothesis strings with + The input is a LoDTensor/Tensor consisting of all the hypothesis strings with the total number denoted by `batch_size`, and the separation is specified - by the LoD information. And the `batch_size` reference strings are arranged - in order in the same way in the input LoDTensor. + by the LoD information or input_length. And the `batch_size` reference strings are arranged + in order in the same way as `input`. The output contains the `batch_size` results and each stands for the edit distance for a pair of strings respectively. If Attr(normalized) is true, the edit distance will be divided by the length of reference string. Args: - input(Variable): The indices for hypothesis strings. - label(Variable): The indices for reference strings. + input(Variable): The indices for hypothesis strings, it should have rank 2 and dtype int64. + label(Variable): The indices for reference strings, it should have rank 2 and dtype int64. normalized(bool, default True): Indicated whether to normalize the edit distance by the length of reference string. ignored_tokens(list, default None): Tokens that should be removed before calculating edit distance. - name (str): The name of this layer. It is optional. + input_length(Variable): The length for each sequence in `input` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64. + label_length(Variable): The length for each sequence in `label` if it's of Tensor type, it should have shape `[batch_size]` and dtype int64. Returns: - Variable: sequence-to-sequence edit distance in shape [batch_size, 1]. + edit_distance_out(Variable): edit distance result in shape [batch_size, 1]. \n + sequence_num(Variable): sequence number in shape []. + Examples: .. code-block:: python - + import paddle.fluid as fluid - x = fluid.layers.data(name='x', shape=[1], dtype='int64') - y = fluid.layers.data(name='y', shape=[1], dtype='int64') - cost, _ = fluid.layers.edit_distance(input=x, label=y) - cpu = fluid.core.CPUPlace() - exe = fluid.Executor(cpu) - exe.run(fluid.default_startup_program()) + # using LoDTensor + x_lod = fluid.layers.data(name='x_lod', shape=[1], dtype='int64', lod_level=1) + y_lod = fluid.layers.data(name='y_lod', shape=[1], dtype='int64', lod_level=1) + distance_lod, seq_num_lod = fluid.layers.edit_distance(input=x_lod, label=y_lod) - import numpy - x_ = numpy.random.randint(5, size=(2, 1)).astype('int64') - y_ = numpy.random.randint(5, size=(2, 1)).astype('int64') - - print(x_) - print(y_) - - x = fluid.create_lod_tensor(x_, [[2]], cpu) - y = fluid.create_lod_tensor(y_, [[2]], cpu) + # using Tensor + x_seq_len = 5 + y_seq_len = 6 + x_pad = fluid.layers.data(name='x_pad', shape=[x_seq_len], dtype='int64') + y_pad = fluid.layers.data(name='y_pad', shape=[y_seq_len], dtype='int64') + x_len = fluid.layers.data(name='x_len', shape=[], dtype='int64') + y_len = fluid.layers.data(name='y_len', shape=[], dtype='int64') + distance_pad, seq_num_pad = fluid.layers.edit_distance(input=x_pad, label=y_pad, input_length=x_len, label_length=y_len) - outs = exe.run(feed={'x':x, 'y':y}, fetch_list=[cost.name]) - - print(outs) """ helper = LayerHelper("edit_distance", **locals()) @@ -5435,13 +5437,17 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None): attrs={"tokens": ignored_tokens}) label = erased_label + this_inputs = {"Hyps": [input], "Refs": [label]} + if input_length and label_length: + this_inputs['HypsLength'] = [input_length] + this_inputs['RefsLength'] = [label_length] + # edit distance op edit_distance_out = helper.create_variable_for_type_inference(dtype="int64") sequence_num = helper.create_variable_for_type_inference(dtype="int64") helper.append_op( type="edit_distance", - inputs={"Hyps": [input], - "Refs": [label]}, + inputs=this_inputs, outputs={"Out": [edit_distance_out], "SequenceNum": [sequence_num]}, attrs={"normalized": normalized}) diff --git a/python/paddle/fluid/tests/unittests/test_edit_distance_op.py b/python/paddle/fluid/tests/unittests/test_edit_distance_op.py index 0a334197ab76fa444fdeb81690b70a35b67219ac..ba48b143a8e43731f633e6b64299225c0fefe0e9 100644 --- a/python/paddle/fluid/tests/unittests/test_edit_distance_op.py +++ b/python/paddle/fluid/tests/unittests/test_edit_distance_op.py @@ -89,27 +89,31 @@ class TestEditDistanceOpNormalizedCase0(OpTest): def reset_config(self): pass + def post_config(self): + pass + def setUp(self): self.op_type = "edit_distance" normalized = True - x1 = np.array([[10, 3, 6, 5, 8, 2]]).astype("int64") - x2 = np.array([[10, 4, 6, 7, 8]]).astype("int64") - x1 = np.transpose(x1) - x2 = np.transpose(x2) + self.x1 = np.array([[10, 3, 6, 5, 8, 2]]).astype("int64") + self.x2 = np.array([[10, 4, 6, 7, 8]]).astype("int64") self.x1_lod = [3, 0, 3] self.x2_lod = [2, 1, 2] + self.x1 = np.transpose(self.x1) + self.x2 = np.transpose(self.x2) + self.reset_config() num_strs = len(self.x1_lod) distance = np.zeros((num_strs, 1)).astype("float32") - sequence_num = np.array(3).astype("int64") + sequence_num = np.array(num_strs).astype("int64") x1_offset = 0 x2_offset = 0 for i in range(0, num_strs): distance[i] = Levenshtein( - hyp=x1[x1_offset:(x1_offset + self.x1_lod[i])], - ref=x2[x2_offset:(x2_offset + self.x2_lod[i])]) + hyp=self.x1[x1_offset:(x1_offset + self.x1_lod[i])], + ref=self.x2[x2_offset:(x2_offset + self.x2_lod[i])]) x1_offset += self.x1_lod[i] x2_offset += self.x2_lod[i] if normalized is True: @@ -117,9 +121,14 @@ class TestEditDistanceOpNormalizedCase0(OpTest): distance[i] = distance[i] / len_ref self.attrs = {'normalized': normalized} - self.inputs = {'Hyps': (x1, [self.x1_lod]), 'Refs': (x2, [self.x2_lod])} + self.inputs = { + 'Hyps': (self.x1, [self.x1_lod]), + 'Refs': (self.x2, [self.x2_lod]) + } self.outputs = {'Out': distance, 'SequenceNum': sequence_num} + self.post_config() + def test_check_output(self): self.check_output() @@ -136,5 +145,43 @@ class TestEditDistanceOpNormalizedCase2(TestEditDistanceOpNormalizedCase0): self.x2_lod = [2, 2, 1] +class TestEditDistanceOpNormalizedTensor(OpTest): + def reset_config(self): + self.x1 = np.array([[10, 3, 0, 0], [6, 5, 8, 2]], dtype=np.int64) + self.x2 = np.array([[10, 4, 0], [6, 7, 8]], dtype=np.int64) + self.x1_lod = np.array([2, 4], dtype=np.int64) + self.x2_lod = np.array([2, 3], dtype=np.int64) + + def setUp(self): + self.op_type = "edit_distance" + normalized = True + + self.reset_config() + + num_strs = len(self.x1_lod) + distance = np.zeros((num_strs, 1)).astype("float32") + sequence_num = np.array(num_strs).astype("int64") + + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=self.x1[i][0:self.x1_lod[i]], + ref=self.x2[i][0:self.x2_lod[i]]) + if normalized is True: + len_ref = self.x2_lod[i] + distance[i] = distance[i] / len_ref + + self.attrs = {'normalized': normalized} + self.inputs = { + 'Hyps': self.x1, + 'Refs': self.x2, + 'HypsLength': self.x1_lod, + 'RefsLength': self.x2_lod + } + self.outputs = {'Out': distance, 'SequenceNum': sequence_num} + + def test_check_output(self): + self.check_output() + + if __name__ == '__main__': unittest.main()