diff --git a/paddle/fluid/operators/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused_embedding_seq_pool_op.cc index e862769051719d386cd22de08a85fa5abb82e3b0..6b6b898d4c720b0e1d90d77ae58aeea7b0d8ba07 100644 --- a/paddle/fluid/operators/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused_embedding_seq_pool_op.cc @@ -42,8 +42,14 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { // we only support sum now PADDLE_ENFORCE_EQ(combiner, "sum"); + int64_t last_dim = table_dims[1]; + for (int i = 1; i != ids_dims.size(); ++i) { + last_dim *= ids_dims[i]; + } + if (ctx->IsRuntime()) { - Variable* ids_var = boost::get(ctx->GetInputVarPtrs("Ids")[0]); + framework::Variable* ids_var = + boost::get(ctx->GetInputVarPtrs("Ids")[0]); const auto& ids_lod = ids_var->Get().lod(); // in run time, the LoD of ids must be 1 @@ -51,20 +57,20 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { "The LoD level of Input(Ids) must be 1"); PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); - size_t batch_size = ids_lod[0].size() - 1; + int64_t batch_size = ids_lod[0].size() - 1; // in run time, the shape from Ids -> output // should be [seq_length, 1] -> [batch_size, embedding_size] - ctx->SetOutputDim("Out", - framework::make_ddim({batch_size, table_dims[1]})); + ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim})); } else { // in compile time, the lod level of ids must be 1 - VarDesc* ids_desc = boost::get(ctx->GetInputVarPtrs("Ids")[0]); + framework::VarDesc* ids_desc = + boost::get(ctx->GetInputVarPtrs("Ids")[0]); PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1); // in compile time, the shape from Ids -> output // should be [-1, 1] -> [-1, embedding_size] - ctx->SetOutputDim("Out", framework::make_ddim({-1, table_dims[1]})); + ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim})); } } diff --git a/paddle/fluid/operators/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused_embedding_seq_pool_op.h index 5af234b9375dfaa89c7bb61e6fa41bd9ca356057..7385c8da334c3ee31230efa0da7ba238f7df8f3c 100644 --- a/paddle/fluid/operators/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused_embedding_seq_pool_op.h @@ -31,31 +31,38 @@ using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; -template +template struct EmbeddingVSumFunctor { - void operator()(const DeviceContext &context, LoDTensor *table_t, - LoDTensor *ids_t, LoDTensor *output_t) { + void operator()(const framework::ExecutionContext &context, + const LoDTensor *table_t, const LoDTensor *ids_t, + LoDTensor *output_t) { auto *table = table_t->data(); - int64_t row_number = table->dims()[0]; - int64_t row_width = table->dims()[1]; + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + int64_t last_dim = output_t->dims()[1]; int64_t *ids = const_cast(ids_t->data()); - auto ids_lod = ids_t->LoD()[0]; + auto ids_lod = ids_t->lod()[0]; + int64_t ids_count = ids_t->numel() / ids_lod.back(); + auto *output = output_t->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); + auto blas = math::GetBlas(context); for (int64_t i = 0; i != ids_lod.size() - 1; ++i) { - size_t begin = ids_lod[i]; + for (int64_t j = 0; j != ids_count; ++j) { + size_t begin = ids_lod[i] * ids_count; - PADDLE_ENFORCE_LT(ids[begin], row_number); - PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i); - blas.VCOPY(row_width, table + ids[begin] * row_width, - output + i * row_width); + PADDLE_ENFORCE_LT(ids[begin], row_number); + PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i); + blas.VCOPY(row_width, table + ids[begin] * row_width, + output + i * last_dim + j * row_width); + } - for (int64_t r = ids_lod[i] + 1; r < ids_lod[i + 1]; ++r) { + for (int64_t r = (ids_lod[i] + 1) * ids_count; + r < ids_lod[i + 1] * ids_count; ++r) { PADDLE_ENFORCE_LT(ids[r], row_number); PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i); blas.AXPY(row_width, 1., table + ids[r] * row_width, - output + i * row_width); + output + i * row_width + (r % ids_count) * row_width); } } } @@ -65,14 +72,14 @@ template class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - LoDTensor *ids_t = context.Input("Ids"); // int tensor - LoDTensor *output_t = context.Output("Out"); // float tensor - LoDTensor *table_var = context.Input("W"); + const LoDTensor *ids_t = context.Input("Ids"); // int tensor + LoDTensor *output_t = context.Output("Out"); // float tensor + const LoDTensor *table_var = context.Input("W"); const std::string &combiner_type = context.Attr("combiner"); if (combiner_type == "sum") { EmbeddingVSumFunctor functor; - functor(context.template device_context(), ids_t, output_t, table_var); + functor(context, table_var, ids_t, output_t); } } }; @@ -105,7 +112,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); auto lod = ids->lod()[0]; - int64_t row_width = table_dim[1]; + int64_t row_width = d_output->dims()[1]; framework::Vector new_rows; new_rows.resize(ids_num); @@ -113,11 +120,11 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_num, row_width}); + d_table_value->Resize({ids_num, table_dim[1]}); T *d_table_data = d_table_value->mutable_data(context.GetPlace()); const T *d_output_data = d_output->data(); - auto blas = math::GetBlas(context); + auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { int64_t h = static_cast(lod[i + 1] - lod[i]); int64_t in_offset = lod[i] * row_width;