diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index 4da8079b91624c3510cae89fd599a7035a4c7477..877b36cef4ea9cdaaaf37c97d5e5bfce55b91436 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -30,12 +30,12 @@ class LookupTableKernel : public framework::OpKernel { auto ids_t = context.Input("Ids"); // int tensor auto output_t = context.Output("Out"); // float tensor - size_t N = table_t->dims()[0]; - size_t D = table_t->dims()[1]; + int N = table_t->dims()[0]; + int D = table_t->dims()[1]; auto ids = ids_t->data(); auto table = table_t->data(); auto output = output_t->mutable_data(context.GetPlace()); - for (size_t i = 0; i < product(ids_t->dims()); ++i) { + for (ssize_t i = 0; i < product(ids_t->dims()); ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); @@ -51,8 +51,8 @@ class LookupTableGradKernel : public framework::OpKernel { auto d_output_t = context.Input(framework::GradVarName("Out")); auto d_table_t = context.Output(framework::GradVarName("W")); - size_t N = d_table_t->dims()[0]; - size_t D = d_table_t->dims()[1]; + int N = d_table_t->dims()[0]; + int D = d_table_t->dims()[1]; auto ids = ids_t->data(); const T* d_output = d_output_t->data(); T* d_table = d_table_t->mutable_data(context.GetPlace()); @@ -61,10 +61,10 @@ class LookupTableGradKernel : public framework::OpKernel { t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - for (size_t i = 0; i < product(ids_t->dims()); ++i) { + for (ssize_t i = 0; i < product(ids_t->dims()); ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); - for (size_t j = 0; j < D; ++j) { + for (int j = 0; j < D; ++j) { d_table[ids[i] * D + j] += d_output[i * D + j]; } }