提交 77cf21e5 编写于 作者: W wanghaoshuang

Change input data type to int64_t

上级 3388e52d
...@@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Hyps", AddInput("Hyps",
"(2-D LoDTensor<int>, 2nd dim. equal to 1) " "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for hypothesis strings."); "The indices for hypothesis strings.");
AddInput("Refs", AddInput("Refs",
"(2-D LoDTensor<int>, 2nd dim. equal to 1) " "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for reference strings."); "The indices for reference strings.");
AddAttr<bool>("normalized", AddAttr<bool>("normalized",
"(bool, default false) Indicated whether to normalize " "(bool, default false) Indicated whether to normalize "
......
...@@ -113,8 +113,8 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -113,8 +113,8 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
dist_t.Resize({m + 1, n + 1}); dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace()); dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>(); auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int>() + hyp_lod[num]; auto x1 = x1_t->data<int64_t>() + hyp_lod[num];
auto x2 = x2_t->data<int>() + ref_lod[num]; auto x2 = x2_t->data<int64_t>() + ref_lod[num];
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS, FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n);
......
...@@ -60,8 +60,8 @@ class EditDistanceKernel : public framework::OpKernel<T> { ...@@ -60,8 +60,8 @@ class EditDistanceKernel : public framework::OpKernel<T> {
dist_t.Resize({m + 1, n + 1}); dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace()); dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>(); auto dist = dist_t.data<T>();
auto x1 = x1_t->data<int>() + hyp_lod[num]; auto x1 = x1_t->data<int64_t>() + hyp_lod[num];
auto x2 = x2_t->data<int>() + ref_lod[num]; auto x2 = x2_t->data<int64_t>() + ref_lod[num];
for (int64_t i = 0; i < m + 1; ++i) { for (int64_t i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i; dist[i * (n + 1)] = i;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册