提交 77cf21e5 编写于 作者: W wanghaoshuang

Change input data type to int64_t

上级 3388e52d
...@@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Hyps", AddInput("Hyps",
"(2-D LoDTensor<int>, 2nd dim. equal to 1) " "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for hypothesis strings."); "The indices for hypothesis strings.");
AddInput("Refs", AddInput("Refs",
"(2-D LoDTensor<int>, 2nd dim. equal to 1) " "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for reference strings."); "The indices for reference strings.");
AddAttr<bool>("normalized", AddAttr<bool>("normalized",
"(bool, default false) Indicated whether to normalize " "(bool, default false) Indicated whether to normalize "
...@@ -66,22 +66,22 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,22 +66,22 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
EditDistance operator computes the edit distances between a batch of hypothesis EditDistance operator computes the edit distances between a batch of hypothesis
strings and their references. strings and their references.
Edit distance, also called Levenshtein distance, measures how dissimilar two strings Edit distance, also called Levenshtein distance, measures how dissimilar two strings
are by counting the minimum number of operations to transform one string into anthor. are by counting the minimum number of operations to transform one string into anthor.
Here the operations include insertion, deletion, and substitution. For example, Here the operations include insertion, deletion, and substitution. For example,
given hypothesis string A = "kitten" and reference B = "sitting", the edit distance given hypothesis string A = "kitten" and reference B = "sitting", the edit distance
is 3 for A will be transformed into B at least after two substitutions and one is 3 for A will be transformed into B at least after two substitutions and one
insertion: insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting" "kitten" -> "sitten" -> "sittin" -> "sitting"
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total
number denoted by `batch_size`, and the separation is specified by the LoD information. number denoted by `batch_size`, and the separation is specified by the LoD information.
And the `batch_size` reference strings are arranged in order in the same way in the And the `batch_size` reference strings are arranged in order in the same way in the
LoDTensor Input(Refs). LoDTensor Input(Refs).
Output(Out) contains the `batch_size` results and each stands for the edit stance Output(Out) contains the `batch_size` results and each stands for the edit stance
for a pair of strings respectively. If Attr(normalized) is true, the edit distance for a pair of strings respectively. If Attr(normalized) is true, the edit distance
will be divided by the length of reference string. will be divided by the length of reference string.
)DOC"); )DOC");
} }
......
...@@ -113,8 +113,8 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -113,8 +113,8 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
dist_t.Resize({m + 1, n + 1}); dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace()); dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>(); auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int>() + hyp_lod[num]; auto x1 = x1_t->data<int64_t>() + hyp_lod[num];
auto x2 = x2_t->data<int>() + ref_lod[num]; auto x2 = x2_t->data<int64_t>() + ref_lod[num];
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS, FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n);
......
...@@ -60,8 +60,8 @@ class EditDistanceKernel : public framework::OpKernel<T> { ...@@ -60,8 +60,8 @@ class EditDistanceKernel : public framework::OpKernel<T> {
dist_t.Resize({m + 1, n + 1}); dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace()); dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>(); auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int>() + hyp_lod[num]; auto x1 = x1_t->data<int64_t>() + hyp_lod[num];
auto x2 = x2_t->data<int>() + ref_lod[num]; auto x2 = x2_t->data<int64_t>() + ref_lod[num];
for (int64_t i = 0; i < m + 1; ++i) { for (int64_t i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i; dist[i * (n + 1)] = i;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册