From 0250e54c2dc8ae4687a2ede661cd25dadfb66ce9 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 3 Jan 2018 09:48:29 +0000 Subject: [PATCH] Enable batch input in edit_distance_op --- paddle/operators/edit_distance_op.cc | 49 ++++++---- paddle/operators/edit_distance_op.cu | 98 +++++++++++-------- paddle/operators/edit_distance_op.h | 91 ++++++++++------- .../v2/fluid/tests/test_edit_distance_op.py | 52 ++++++++-- 4 files changed, 189 insertions(+), 101 deletions(-) diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc index 6022a7a4bd..7b92148f0e 100644 --- a/paddle/operators/edit_distance_op.cc +++ b/paddle/operators/edit_distance_op.cc @@ -22,10 +22,18 @@ class EditDistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Hyp"), "Input(Hyp) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput("Ref"), "Input(Ref) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); - ctx->SetOutputDim("Out", {1}); + auto hyp_dims = ctx->GetInputDim("Hyps"); + auto ref_dims = ctx->GetInputDim("Refs"); + PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1, + "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1, + "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); } protected: @@ -40,24 +48,23 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { public: EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Hyp", - "(2-D tensor with shape [M x 1]) The indices for " - "hypothesis string"); - AddInput("Ref", - "(2-D tensor with shape [N x 1]) The indices " - "for reference string."); + AddInput("Hyps", + "(2-D LoDTensor, 2nd dim. equal to 1) " + "The indices for hypothesis strings."); + AddInput("Refs", + "(2-D LoDTensor, 2nd dim. equal to 1) " + "The indices for reference strings."); AddAttr("normalized", - "(bool, default false) Indicated whether " - "normalize the Output(Out) by the length of reference " - "string (Ref).") + "(bool, default false) Indicated whether to normalize " + "the edit distance by the length of reference string.") .SetDefault(false); AddOutput("Out", - "(2-D tensor with shape [1 x 1]) " - "The output distance of EditDistance operator."); + "(2-D Tensor with shape [`batch_size` x 1]) " + "The output edit distances of EditDistance operator."); AddComment(R"DOC( -EditDistance operator computes the edit distance of two sequences, one named -hypothesis with length M and another named reference with length N. +EditDistance operator computes the edit distances between a batch of hypothesis +strings and their references. Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. @@ -68,8 +75,14 @@ insertion: "kitten" -> "sitten" -> "sittin" -> "sitting" -If Attr(normalized) is true, the edit distance will be divided by the length of -reference string N. +Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total +number denoted by `batch_size`, and the separation is specified by the LoD information. +And the `batch_size` reference strings are arranged in order in the same way in the +LoDTensor Input(Refs). + +Output(Out) contains the `batch_size` results and each stands for the edit stance +for a pair of strings respectively. If Attr(normalized) is true, the edit distance +will be divided by the length of reference string. )DOC"); } }; diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu index fed91ffb43..b548345986 100644 --- a/paddle/operators/edit_distance_op.cu +++ b/paddle/operators/edit_distance_op.cu @@ -70,53 +70,71 @@ class EditDistanceGPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); - auto* x1_t = ctx.Input("Hyp"); - auto* x2_t = ctx.Input("Ref"); - - out_t->mutable_data(ctx.GetPlace()); - auto out = out_t->data(); + auto* x1_t = ctx.Input("Hyps"); + auto* x2_t = ctx.Input("Refs"); auto normalized = ctx.Attr("normalized"); auto stream = reinterpret_cast( ctx.device_context()) .stream(); - auto m = x1_t->numel(); - auto n = x2_t->numel(); - T distance = 0.0; - if (m == 0 || n == 0) { - distance = std::max(m, n); - if (normalized) { - distance = distance / n; - } - memory::Copy(boost::get(ctx.GetPlace()), out, platform::CPUPlace(), - &distance, sizeof(T), stream); - } else { - framework::Tensor dist_t; - dist_t.Resize({m + 1, n + 1}); - dist_t.mutable_data(ctx.GetPlace()); - auto dist = dist_t.data(); - auto x1 = x1_t->data(); - auto x2 = x2_t->data(); - - FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); - - FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); - // Compute the elements of distance matrix in the anti-diagonal diretion - for (int64_t slice = 2; slice < m + n + 1; ++slice) { - int z_m = slice < m + 1 ? 0 : slice - m; - int z_n = slice < n + 1 ? 0 : slice - n; - int size = slice - (z_m + z_n) + 1; // number of elments in the same - // anti-diagonal line to update - // the start index at which computes from - int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; - Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, - n, start); + auto hyp_lod = x1_t->lod()[0]; + auto ref_lod = x2_t->lod()[0]; + PADDLE_ENFORCE( + hyp_lod.size() == ref_lod.size(), + "Input(Hyps) and Input(Refs) must have the same batch size."); + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } + + auto num_strs = hyp_lod.size() - 1; + out_t->Resize({static_cast(num_strs), 1}); + out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); + + std::vector distance(num_strs, 0.0); + for (size_t num = 0; num < num_strs; num++) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); + if (m == 0 || n == 0) { + distance[num] = std::max(m, n); + if (normalized) { + PADDLE_ENFORCE(n > 0, + "The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n); + distance[num] = distance[num] / n; + } + memory::Copy(boost::get(ctx.GetPlace()), out + num, + platform::CPUPlace(), &distance[num], sizeof(T), stream); + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data() + hyp_lod[num]; + auto x2 = x2_t->data() + ref_lod[num]; + + FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); + + FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); + // Compute the elements of distance matrix in the anti-diagonal diretion + for (int64_t slice = 2; slice < m + n + 1; ++slice) { + int z_m = slice < m + 1 ? 0 : slice - m; + int z_n = slice < n + 1 ? 0 : slice - n; + int size = slice - (z_m + z_n) + 1; // number of elments in the same + // anti-diagonal line to update + // the start index at which computes from + int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; + Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, + m, n, start); + } + SetOutput<<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized); } - SetOutput<<<1, 1, 0, stream>>>(out, dist, m, n, normalized); } } }; diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h index abde4fe97c..6284f230e5 100644 --- a/paddle/operators/edit_distance_op.h +++ b/paddle/operators/edit_distance_op.h @@ -26,50 +26,69 @@ class EditDistanceKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); - auto* x1_t = ctx.Input("Hyp"); - auto* x2_t = ctx.Input("Ref"); + auto* x1_t = ctx.Input("Hyps"); + auto* x2_t = ctx.Input("Refs"); + auto normalized = ctx.Attr("normalized"); + + auto hyp_lod = x1_t->lod()[0]; + auto ref_lod = x2_t->lod()[0]; + PADDLE_ENFORCE( + hyp_lod.size() == ref_lod.size(), + "Input(Hyps) and Input(Refs) must have the same batch size."); + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } + auto num_strs = hyp_lod.size() - 1; + + out_t->Resize({static_cast(num_strs), 1}); out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); - auto normalized = ctx.Attr("normalized"); + std::vector distance(num_strs, 0.0); + for (size_t num = 0; num < num_strs; ++num) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); - auto m = x1_t->numel(); - auto n = x2_t->numel(); - T distance = 0.0; - if (m == 0) { - distance = n; - } else if (n == 0) { - distance = m; - } else { - framework::Tensor dist_t; - dist_t.Resize({m + 1, n + 1}); - dist_t.mutable_data(ctx.GetPlace()); - auto dist = dist_t.data(); - auto x1 = x1_t->data(); - auto x2 = x2_t->data(); - for (int64_t i = 0; i < m + 1; ++i) { - dist[i * (n + 1)] = i; - } - for (int64_t j = 0; j < n + 1; ++j) { - dist[j] = j; - } - for (int64_t i = 1; i < m + 1; ++i) { - for (int64_t j = 1; j < n + 1; ++j) { - int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; - int dels = dist[(i - 1) * (n + 1) + j] + 1; - int ins = dist[i * (n + 1) + (j - 1)] + 1; - int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; - dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); + if (m == 0) { + distance[num] = n; + } else if (n == 0) { + distance[num] = m; + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data() + hyp_lod[num]; + auto x2 = x2_t->data() + ref_lod[num]; + for (int64_t i = 0; i < m + 1; ++i) { + dist[i * (n + 1)] = i; + } + for (int64_t j = 0; j < n + 1; ++j) { + dist[j] = j; + } + for (int64_t i = 1; i < m + 1; ++i) { + for (int64_t j = 1; j < n + 1; ++j) { + int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; + int dels = dist[(i - 1) * (n + 1) + j] + 1; + int ins = dist[i * (n + 1) + (j - 1)] + 1; + int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; + dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); + } } + distance[num] = dist[m * (n + 1) + n]; } - distance = dist[m * (n + 1) + n]; - } - if (normalized) { - distance = distance / n; + if (normalized) { + PADDLE_ENFORCE(n > 0, + "The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n); + distance[num] = distance[num] / n; + } + out[num] = distance[num]; } - auto out = out_t->data(); - out[0] = distance; } }; diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py index df1ac620e7..24f2f0c5c2 100644 --- a/python/paddle/v2/fluid/tests/test_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -18,7 +18,7 @@ def Levenshtein(hyp, ref): if n == 0: return m - dist = np.zeros((m + 1, n + 1)) + dist = np.zeros((m + 1, n + 1)).astype("float32") for i in range(0, m + 1): dist[i][0] = i for j in range(0, n + 1): @@ -35,17 +35,55 @@ def Levenshtein(hyp, ref): class TestCTCEditDistanceOp(OpTest): + def setUp(self): + self.op_type = "edit_distance" + normalized = False + x1 = np.array([[0, 12, 3, 5, 8, 2]]).astype("int32") + x2 = np.array([[0, 12, 4, 7, 8]]).astype("int32") + x1 = np.transpose(x1) + x2 = np.transpose(x2) + x1_lod = [0, 1, 5] + x2_lod = [0, 3, 4] + + num_strs = len(x1_lod) - 1 + distance = np.zeros((num_strs, 1)).astype("float32") + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=x1[x1_lod[i]:x1_lod[i + 1]], + ref=x2[x2_lod[i]:x2_lod[i + 1]]) + if normalized is True: + len_ref = x2_lod[i + 1] - x2_lod[i] + distance[i] = distance[i] / len_ref + self.attrs = {'normalized': normalized} + self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} + self.outputs = {'Out': distance} + + def test_check_output(self): + self.check_output() + + +class TestCTCEditDistanceOpNormalized(OpTest): def setUp(self): self.op_type = "edit_distance" normalized = True - x1 = np.array([0, 12, 3, 5]).astype("int32") - x2 = np.array([0, 12, 4, 7, 8]).astype("int32") + x1 = np.array([[0, 10, 3, 6, 5, 8, 2]]).astype("int32") + x2 = np.array([[0, 10, 4, 6, 7, 8]]).astype("int32") + x1 = np.transpose(x1) + x2 = np.transpose(x2) + x1_lod = [0, 1, 3, 6] + x2_lod = [0, 2, 3, 5] - distance = Levenshtein(hyp=x1, ref=x2) - if normalized is True: - distance = distance / len(x2) + num_strs = len(x1_lod) - 1 + distance = np.zeros((num_strs, 1)).astype("float32") + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=x1[x1_lod[i]:x1_lod[i + 1]], + ref=x2[x2_lod[i]:x2_lod[i + 1]]) + if normalized is True: + len_ref = x2_lod[i + 1] - x2_lod[i] + distance[i] = distance[i] / len_ref self.attrs = {'normalized': normalized} - self.inputs = {'Hyp': x1, 'Ref': x2} + self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} self.outputs = {'Out': distance} def test_check_output(self): -- GitLab