提交 f594ca43 编写于 作者: Y Yibing Liu

Reuse the usable variable in edit_distance_op

上级 0250e54c
...@@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Hyps", AddInput("Hyps",
"(2-D LoDTensor, 2nd dim. equal to 1) " "(2-D LoDTensor<int>, 2nd dim. equal to 1) "
"The indices for hypothesis strings."); "The indices for hypothesis strings.");
AddInput("Refs", AddInput("Refs",
"(2-D LoDTensor, 2nd dim. equal to 1) " "(2-D LoDTensor<int>, 2nd dim. equal to 1) "
"The indices for reference strings."); "The indices for reference strings.");
AddAttr<bool>("normalized", AddAttr<bool>("normalized",
"(bool, default false) Indicated whether to normalize " "(bool, default false) Indicated whether to normalize "
......
...@@ -93,21 +93,21 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -93,21 +93,21 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
out_t->mutable_data<T>(ctx.GetPlace()); out_t->mutable_data<T>(ctx.GetPlace());
auto out = out_t->data<T>(); auto out = out_t->data<T>();
std::vector<T> distance(num_strs, 0.0); T distance = 0.0;
for (size_t num = 0; num < num_strs; num++) { for (size_t num = 0; num < num_strs; num++) {
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]); auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]); auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
if (m == 0 || n == 0) { if (m == 0 || n == 0) {
distance[num] = std::max(m, n); distance = std::max(m, n);
if (normalized) { if (normalized) {
PADDLE_ENFORCE(n > 0, PADDLE_ENFORCE(n > 0,
"The reference string (#%d) cannot be empty " "The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled.", "when Attr(normalized) is enabled.",
n); n);
distance[num] = distance[num] / n; distance = distance / n;
} }
memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num, memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num,
platform::CPUPlace(), &distance[num], sizeof(T), stream); platform::CPUPlace(), &distance, sizeof(T), stream);
} else { } else {
framework::Tensor dist_t; framework::Tensor dist_t;
dist_t.Resize({m + 1, n + 1}); dist_t.Resize({m + 1, n + 1});
......
...@@ -46,15 +46,15 @@ class EditDistanceKernel : public framework::OpKernel<T> { ...@@ -46,15 +46,15 @@ class EditDistanceKernel : public framework::OpKernel<T> {
out_t->mutable_data<float>(ctx.GetPlace()); out_t->mutable_data<float>(ctx.GetPlace());
auto out = out_t->data<T>(); auto out = out_t->data<T>();
std::vector<T> distance(num_strs, 0.0); T distance = 0.0;
for (size_t num = 0; num < num_strs; ++num) { for (size_t num = 0; num < num_strs; ++num) {
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]); auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]); auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
if (m == 0) { if (m == 0) {
distance[num] = n; distance = n;
} else if (n == 0) { } else if (n == 0) {
distance[num] = m; distance = m;
} else { } else {
framework::Tensor dist_t; framework::Tensor dist_t;
dist_t.Resize({m + 1, n + 1}); dist_t.Resize({m + 1, n + 1});
...@@ -77,7 +77,7 @@ class EditDistanceKernel : public framework::OpKernel<T> { ...@@ -77,7 +77,7 @@ class EditDistanceKernel : public framework::OpKernel<T> {
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
} }
} }
distance[num] = dist[m * (n + 1) + n]; distance = dist[m * (n + 1) + n];
} }
if (normalized) { if (normalized) {
...@@ -85,9 +85,9 @@ class EditDistanceKernel : public framework::OpKernel<T> { ...@@ -85,9 +85,9 @@ class EditDistanceKernel : public framework::OpKernel<T> {
"The reference string (#%d) cannot be empty " "The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled.", "when Attr(normalized) is enabled.",
n); n);
distance[num] = distance[num] / n; distance = distance / n;
} }
out[num] = distance[num]; out[num] = distance;
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册