From f594ca436939b1ef0133727eadf0d5470ff74f67 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 10 Jan 2018 09:17:03 +0000 Subject: [PATCH] Reuse the usable variable in edit_distance_op --- paddle/operators/edit_distance_op.cc | 4 ++-- paddle/operators/edit_distance_op.cu | 8 ++++---- paddle/operators/edit_distance_op.h | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc index 7b92148f0..441ae2aa0 100644 --- a/paddle/operators/edit_distance_op.cc +++ b/paddle/operators/edit_distance_op.cc @@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Hyps", - "(2-D LoDTensor, 2nd dim. equal to 1) " + "(2-D LoDTensor, 2nd dim. equal to 1) " "The indices for hypothesis strings."); AddInput("Refs", - "(2-D LoDTensor, 2nd dim. equal to 1) " + "(2-D LoDTensor, 2nd dim. equal to 1) " "The indices for reference strings."); AddAttr("normalized", "(bool, default false) Indicated whether to normalize " diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu index b54834598..cf5ebc5c3 100644 --- a/paddle/operators/edit_distance_op.cu +++ b/paddle/operators/edit_distance_op.cu @@ -93,21 +93,21 @@ class EditDistanceGPUKernel : public framework::OpKernel { out_t->mutable_data(ctx.GetPlace()); auto out = out_t->data(); - std::vector distance(num_strs, 0.0); + T distance = 0.0; for (size_t num = 0; num < num_strs; num++) { auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); if (m == 0 || n == 0) { - distance[num] = std::max(m, n); + distance = std::max(m, n); if (normalized) { PADDLE_ENFORCE(n > 0, "The reference string (#%d) cannot be empty " "when Attr(normalized) is enabled.", n); - distance[num] = distance[num] / n; + distance = distance / n; } memory::Copy(boost::get(ctx.GetPlace()), out + num, - platform::CPUPlace(), &distance[num], sizeof(T), stream); + platform::CPUPlace(), &distance, sizeof(T), stream); } else { framework::Tensor dist_t; dist_t.Resize({m + 1, n + 1}); diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h index 6284f230e..537e70281 100644 --- a/paddle/operators/edit_distance_op.h +++ b/paddle/operators/edit_distance_op.h @@ -46,15 +46,15 @@ class EditDistanceKernel : public framework::OpKernel { out_t->mutable_data(ctx.GetPlace()); auto out = out_t->data(); - std::vector distance(num_strs, 0.0); + T distance = 0.0; for (size_t num = 0; num < num_strs; ++num) { auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); if (m == 0) { - distance[num] = n; + distance = n; } else if (n == 0) { - distance[num] = m; + distance = m; } else { framework::Tensor dist_t; dist_t.Resize({m + 1, n + 1}); @@ -77,7 +77,7 @@ class EditDistanceKernel : public framework::OpKernel { dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); } } - distance[num] = dist[m * (n + 1) + n]; + distance = dist[m * (n + 1) + n]; } if (normalized) { @@ -85,9 +85,9 @@ class EditDistanceKernel : public framework::OpKernel { "The reference string (#%d) cannot be empty " "when Attr(normalized) is enabled.", n); - distance[num] = distance[num] / n; + distance = distance / n; } - out[num] = distance[num]; + out[num] = distance; } } }; -- GitLab