提交 f594ca43 编写于 作者: Y Yibing Liu

Reuse the usable variable in edit_distance_op

上级 0250e54c
......@@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Hyps",
"(2-D LoDTensor, 2nd dim. equal to 1) "
"(2-D LoDTensor<int>, 2nd dim. equal to 1) "
"The indices for hypothesis strings.");
AddInput("Refs",
"(2-D LoDTensor, 2nd dim. equal to 1) "
"(2-D LoDTensor<int>, 2nd dim. equal to 1) "
"The indices for reference strings.");
AddAttr<bool>("normalized",
"(bool, default false) Indicated whether to normalize "
......
......@@ -93,21 +93,21 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
out_t->mutable_data<T>(ctx.GetPlace());
auto out = out_t->data<T>();
std::vector<T> distance(num_strs, 0.0);
T distance = 0.0;
for (size_t num = 0; num < num_strs; num++) {
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
if (m == 0 || n == 0) {
distance[num] = std::max(m, n);
distance = std::max(m, n);
if (normalized) {
PADDLE_ENFORCE(n > 0,
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled.",
n);
distance[num] = distance[num] / n;
distance = distance / n;
}
memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num,
platform::CPUPlace(), &distance[num], sizeof(T), stream);
platform::CPUPlace(), &distance, sizeof(T), stream);
} else {
framework::Tensor dist_t;
dist_t.Resize({m + 1, n + 1});
......
......@@ -46,15 +46,15 @@ class EditDistanceKernel : public framework::OpKernel<T> {
out_t->mutable_data<float>(ctx.GetPlace());
auto out = out_t->data<T>();
std::vector<T> distance(num_strs, 0.0);
T distance = 0.0;
for (size_t num = 0; num < num_strs; ++num) {
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
if (m == 0) {
distance[num] = n;
distance = n;
} else if (n == 0) {
distance[num] = m;
distance = m;
} else {
framework::Tensor dist_t;
dist_t.Resize({m + 1, n + 1});
......@@ -77,7 +77,7 @@ class EditDistanceKernel : public framework::OpKernel<T> {
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
}
}
distance[num] = dist[m * (n + 1) + n];
distance = dist[m * (n + 1) + n];
}
if (normalized) {
......@@ -85,9 +85,9 @@ class EditDistanceKernel : public framework::OpKernel<T> {
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled.",
n);
distance[num] = distance[num] / n;
distance = distance / n;
}
out[num] = distance[num];
out[num] = distance;
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册