提交 ad0dc8fd 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #3897 from Canpio/fix_warnings_in_lookup_op

Fix compile warnings in lookup_op
...@@ -30,12 +30,12 @@ class LookupTableKernel : public framework::OpKernel { ...@@ -30,12 +30,12 @@ class LookupTableKernel : public framework::OpKernel {
auto ids_t = context.Input<Tensor>("Ids"); // int tensor auto ids_t = context.Input<Tensor>("Ids"); // int tensor
auto output_t = context.Output<Tensor>("Out"); // float tensor auto output_t = context.Output<Tensor>("Out"); // float tensor
size_t N = table_t->dims()[0]; int N = table_t->dims()[0];
size_t D = table_t->dims()[1]; int D = table_t->dims()[1];
auto ids = ids_t->data<int32_t>(); auto ids = ids_t->data<int32_t>();
auto table = table_t->data<T>(); auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto output = output_t->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < product(ids_t->dims()); ++i) { for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
...@@ -51,8 +51,8 @@ class LookupTableGradKernel : public framework::OpKernel { ...@@ -51,8 +51,8 @@ class LookupTableGradKernel : public framework::OpKernel {
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out")); auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W")); auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
size_t N = d_table_t->dims()[0]; int N = d_table_t->dims()[0];
size_t D = d_table_t->dims()[1]; int D = d_table_t->dims()[1];
auto ids = ids_t->data<int32_t>(); auto ids = ids_t->data<int32_t>();
const T* d_output = d_output_t->data<T>(); const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace()); T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
...@@ -61,10 +61,10 @@ class LookupTableGradKernel : public framework::OpKernel { ...@@ -61,10 +61,10 @@ class LookupTableGradKernel : public framework::OpKernel {
t.device(context.GetEigenDevice<platform::CPUPlace>()) = t.device(context.GetEigenDevice<platform::CPUPlace>()) =
t.constant(static_cast<T>(0)); t.constant(static_cast<T>(0));
for (size_t i = 0; i < product(ids_t->dims()); ++i) { for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
for (size_t j = 0; j < D; ++j) { for (int j = 0; j < D; ++j) {
d_table[ids[i] * D + j] += d_output[i * D + j]; d_table[ids[i] * D + j] += d_output[i * D + j];
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册