提交 7efdf05a 编写于 作者: F fengjiayi

make look_up_op supporting tensor ids

上级 56b50ee4
...@@ -32,11 +32,16 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -32,11 +32,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W"); auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
int ids_rank = ids_dims.size();
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(table_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1); PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
"The last dimension of the 'Ids' tensor must be 1.");
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); auto output_dims =
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
output_dims.push_back(table_dims[1]);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
if (ctx->GetOutputsVarType("Out")[0] == if (ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR) { framework::proto::VarType::LOD_TENSOR) {
...@@ -61,8 +66,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -61,8 +66,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Ids", AddInput("Ids",
"An input with type int32 or int64 " "An input with type int32 or int64 "
"contains the ids to be looked up in W. " "contains the ids to be looked up in W. "
"Ids must be a column vector with rank = 2. " "The last dimension size must be 1.");
"The 2nd dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W."); AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<bool>("is_sparse", AddAttr<bool>("is_sparse",
"(boolean, default false) " "(boolean, default false) "
......
...@@ -118,23 +118,23 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -118,23 +118,23 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims(); int64_t ids_num = ids->numel();
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
// copy GPU memory to CPU pinned memory // copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows; framework::Vector<int64_t> new_rows;
new_rows.resize(ids_dim[0]); new_rows.resize(ids_num);
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
// TODO(yuyang18): Strange code here. // TODO(yuyang18): Strange code here.
memory::Copy(platform::CPUPlace(), memory::Copy(platform::CPUPlace(),
new_rows.CUDAMutableData(context.GetPlace()), gpu_place, new_rows.CUDAMutableData(context.GetPlace()), gpu_place,
ids_data, ids_dim[0] * sizeof(int64_t), stream); ids_data, ids_num * sizeof(int64_t), stream);
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table->dims()[1]}); d_table_value->Resize({ids_num, table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace()); d_table_value->mutable_data<T>(context.GetPlace());
auto *d_table_data = d_table_value->data<T>(); auto *d_table_data = d_table_value->data<T>();
......
...@@ -109,17 +109,17 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -109,17 +109,17 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims(); int64_t ids_num = ids->numel();
framework::Vector<int64_t> new_rows; framework::Vector<int64_t> new_rows;
new_rows.reserve(ids_dim[0]); new_rows.reserve(ids_num);
for (int64_t i = 0; i < ids_dim[0]; i++) { for (int64_t i = 0; i < ids_num; i++) {
new_rows.push_back(ids_data[i]); new_rows.push_back(ids_data[i]);
} }
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table_dim[1]}); d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->mutable_data<T>(context.GetPlace()); d_table_value->mutable_data<T>(context.GetPlace());
d_table->set_height(table_dim[0]); d_table->set_height(table_dim[0]);
...@@ -135,7 +135,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -135,7 +135,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W")); auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();
int N = table_dim[0]; int N = table_dim[0];
int D = d_output->dims()[1]; int D = d_output->dims()[1];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册