diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 3e8f3ec5c5cd683343bcbdfc2388bd37c25e00f9..d77b095c5d783a2a9fab87eb8b458117a6a3d225 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -32,11 +32,16 @@ class LookupTableOp : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); auto ids_dims = ctx->GetInputDim("Ids"); + int ids_rank = ids_dims.size(); - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + PADDLE_ENFORCE_EQ(table_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, + "The last dimension of the 'Ids' tensor must be 1."); - ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); + auto output_dims = + framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1)); + output_dims.push_back(table_dims[1]); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); if (ctx->GetOutputsVarType("Out")[0] == framework::proto::VarType::LOD_TENSOR) { @@ -61,8 +66,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Ids", "An input with type int32 or int64 " "contains the ids to be looked up in W. " - "Ids must be a column vector with rank = 2. " - "The 2nd dimension size must be 1."); + "The last dimension size must be 1."); AddOutput("Out", "The lookup results, which have the same type as W."); AddAttr("is_sparse", "(boolean, default false) " diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 27483372b93a850d313445386c7973838c4a0710..e8462129fd3b08683b44f09590d284f45b48d0fa 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -118,23 +118,23 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { auto *d_table = context.Output(framework::GradVarName("W")); auto *ids_data = ids->data(); - auto ids_dim = ids->dims(); + int64_t ids_num = ids->numel(); auto stream = dev_ctx.stream(); // copy GPU memory to CPU pinned memory framework::Vector new_rows; - new_rows.resize(ids_dim[0]); + new_rows.resize(ids_num); auto gpu_place = boost::get(context.GetPlace()); // TODO(yuyang18): Strange code here. memory::Copy(platform::CPUPlace(), new_rows.CUDAMutableData(context.GetPlace()), gpu_place, - ids_data, ids_dim[0] * sizeof(int64_t), stream); + ids_data, ids_num * sizeof(int64_t), stream); d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_dim[0], table->dims()[1]}); + d_table_value->Resize({ids_num, table->dims()[1]}); d_table_value->mutable_data(context.GetPlace()); auto *d_table_data = d_table_value->data(); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index c9f074ca0e8dafb374dc9368165df5af5053a6b8..ccf77af6a67e16f3249c0965602cde247f8bbe2a 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -109,17 +109,17 @@ class LookupTableGradKernel : public framework::OpKernel { auto *d_table = context.Output(framework::GradVarName("W")); auto *ids_data = ids->data(); - auto ids_dim = ids->dims(); + int64_t ids_num = ids->numel(); framework::Vector new_rows; - new_rows.reserve(ids_dim[0]); - for (int64_t i = 0; i < ids_dim[0]; i++) { + new_rows.reserve(ids_num); + for (int64_t i = 0; i < ids_num; i++) { new_rows.push_back(ids_data[i]); } d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_dim[0], table_dim[1]}); + d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->mutable_data(context.GetPlace()); d_table->set_height(table_dim[0]); @@ -135,7 +135,6 @@ class LookupTableGradKernel : public framework::OpKernel { auto *d_table = context.Output(framework::GradVarName("W")); auto *ids_data = ids->data(); - auto ids_dim = ids->dims(); int N = table_dim[0]; int D = d_output->dims()[1];