From 17c8014fcd2071920a605f12951d4f6ae1ddcab9 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Tue, 6 Nov 2018 17:42:43 +0800 Subject: [PATCH] Complete implementation test=develop --- .../operators/fused_embedding_seq_pool_op.cc | 6 + .../operators/fused_embedding_seq_pool_op.h | 182 ++++++------------ 2 files changed, 63 insertions(+), 125 deletions(-) diff --git a/paddle/fluid/operators/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused_embedding_seq_pool_op.cc index 5ebaf865fcd..e8627690517 100644 --- a/paddle/fluid/operators/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused_embedding_seq_pool_op.cc @@ -93,6 +93,12 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { "are supported, sum computes the weighted sum of the " "embedding results for each row.") .SetDefault("sum"); + // NOTE(minqiyang): grad_inplace is an temporal attribute, + // please do NOT set this attribute in python layer. + AddAttr("grad_inplace", + "(boolean, default false) " + "If the grad op reuse the input's variable.") + .SetDefault(false); AddAttr("is_sparse", "(boolean, default false) " "Sparse update.") diff --git a/paddle/fluid/operators/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused_embedding_seq_pool_op.h index 24cffc60a80..5af234b9375 100644 --- a/paddle/fluid/operators/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused_embedding_seq_pool_op.h @@ -31,62 +31,54 @@ using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; +template +struct EmbeddingVSumFunctor { + void operator()(const DeviceContext &context, LoDTensor *table_t, + LoDTensor *ids_t, LoDTensor *output_t) { + auto *table = table_t->data(); + int64_t row_number = table->dims()[0]; + int64_t row_width = table->dims()[1]; + int64_t *ids = const_cast(ids_t->data()); + auto ids_lod = ids_t->LoD()[0]; + auto *output = output_t->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + for (int64_t i = 0; i != ids_lod.size() - 1; ++i) { + size_t begin = ids_lod[i]; + + PADDLE_ENFORCE_LT(ids[begin], row_number); + PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i); + blas.VCOPY(row_width, table + ids[begin] * row_width, + output + i * row_width); + + for (int64_t r = ids_lod[i] + 1; r < ids_lod[i + 1]; ++r) { + PADDLE_ENFORCE_LT(ids[r], row_number); + PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i); + blas.AXPY(row_width, 1., table + ids[r] * row_width, + output + i * row_width); + } + } + } +}; + template -class LookupTableKernel : public framework::OpKernel { +class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto *ids_t = context.Input("Ids"); // int tensor - auto *output_t = context.Output("Out"); // float tensor - auto *table_var = context.InputVar("W"); - - int64_t padding_idx = context.Attr("padding_idx"); - int64_t *ids = const_cast(ids_t->data()); - int64_t ids_numel = ids_t->numel(); - - if (table_var->IsType()) { - auto *table_t = context.Input("W"); - int64_t row_number = table_t->dims()[0]; - int64_t row_width = table_t->dims()[1]; - - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_LT(ids[i], row_number); - PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); - memcpy(output + i * row_width, table + ids[i] * row_width, - row_width * sizeof(T)); - } - } - } else if (table_var->IsType()) { - const auto &table_t = table_var->Get(); - int64_t row_width = table_t.value().dims()[1]; - const auto *table = table_t.value().data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - auto blas = math::GetBlas(context); - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_GE(ids[i], 0); - auto id_index = table_t.Index(ids[i]); - PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); - // memcpy(output + i * row_width, table + id_index * row_width, - // row_width * sizeof(T)); - blas.VCOPY(row_width, table + id_index * row_width, - output + i * row_width); - } - } + LoDTensor *ids_t = context.Input("Ids"); // int tensor + LoDTensor *output_t = context.Output("Out"); // float tensor + LoDTensor *table_var = context.Input("W"); + const std::string &combiner_type = context.Attr("combiner"); + + if (combiner_type == "sum") { + EmbeddingVSumFunctor functor; + functor(context.template device_context(), ids_t, output_t, table_var); } } }; template -class LookupTableGradKernel : public framework::OpKernel { +class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *table_var = context.InputVar("W"); @@ -106,97 +98,37 @@ class LookupTableGradKernel : public framework::OpKernel { // Since paddings are not trainable and fixed in forward, the gradient of // paddings makes no sense and we don't deal with it in backward. if (is_sparse) { - // auto start = std::chrono::system_clock::now(); auto *ids = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); auto *d_table = context.Output(framework::GradVarName("W")); auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); - // auto end = std::chrono::system_clock::now(); - // std::chrono::duration diff = end - start; + auto lod = ids->lod()[0]; + int64_t row_width = table_dim[1]; - // auto copy_start = std::chrono::system_clock::now(); - std::vector new_rows; + framework::Vector new_rows; new_rows.resize(ids_num); std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); - // for (int64_t i = 0; i < ids_num; i++) { - // new_rows.push_back(ids_data[i]); - // } - // auto copy_end = std::chrono::system_clock::now(); - // std::chrono::duration copy_diff = copy_end - copy_start; - // diff += copy_diff; - // LOG(ERROR) << "run emb_grad copy end, cost: " << copy_diff.count() << " - // " << ids_num; - - // copy_start = std::chrono::system_clock::now(); d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_num, table_dim[1]}); - d_table_value->ShareDataWith(*d_output); - // d_table_value->mutable_data(context.GetPlace()); - - // // copy_end = std::chrono::system_clock::now(); - // // copy_diff = copy_end - copy_start; - // // diff += copy_diff; - // // LOG(ERROR) << "run emb_grad resize table end, cost: " << - // // copy_diff.count() << " " << ids_num; - - // // copy_start = std::chrono::system_clock::now(); - // d_table->set_height(table_dim[0]); - - // auto *d_output_data = d_output->data(); - // auto *d_table_data = d_table_value->data(); - - // // copy_end = std::chrono::system_clock::now(); - // // copy_diff = copy_end - copy_start; - // // diff += copy_diff; - // // LOG(ERROR) << "run emb_grad set height end, cost: " << - // // copy_diff.count() << " " << ids_num; - - // auto d_output_dims = d_output->dims(); - // PADDLE_ENFORCE_EQ( - // d_table_value->dims(), - // framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); - // // copy_start = std::chrono::system_clock::now(); - // auto blas = math::GetBlas(context); - // blas.VCOPY(d_output->numel(), d_output_data, d_table_data); - // cblas_scopy(d_output->numel(), d_output_data, 1, d_table_data, 1); - // // for (int i = 0; i != d_output->numel(), ++i) { - // // *(d_table_data++) = *(d_output_data++); - // // } - // // memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); - // // copy_end = std::chrono::system_clock::now(); - // // copy_diff = copy_end - copy_start; - // // diff += copy_diff; - // // LOG(ERROR) << "run emb_grad core end, cost: " << copy_diff.count() - // << " - // // " << ids_num << " " << d_output->numel(); - - // // LOG(ERROR) << "run emb_grad end, cost: " << diff.count(); - } else { - auto *ids = context.Input("Ids"); - auto *d_output = context.Input(framework::GradVarName("Out")); - auto *d_table = context.Output(framework::GradVarName("W")); - - auto *ids_data = ids->data(); - - int N = table_dim[0]; - int D = table_dim[1]; - - auto *d_output_data = d_output->data(); - auto *d_table_data = d_table->mutable_data(context.GetPlace()); - - memset(d_table_data, 0, d_table->numel() * sizeof(T)); - - for (int64_t i = 0; i < ids->numel(); ++i) { - PADDLE_ENFORCE_LT(ids_data[i], N); - PADDLE_ENFORCE_GE(ids_data[i], 0); - for (int j = 0; j < D; ++j) { - d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; + d_table_value->Resize({ids_num, row_width}); + T *d_table_data = d_table_value->mutable_data(context.GetPlace()); + const T *d_output_data = d_output->data(); + + auto blas = math::GetBlas(context); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t in_offset = lod[i] * row_width; + const T *out_pos = d_output_data + i * row_width; + T *in_pos = d_table_data + in_offset; + for (int r = 0; r != h; ++r) { + blas.VCOPY(row_width, out_pos, in_pos + r * row_width); } } + } else { + LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; } } }; -- GitLab