From f4c990e7b8493304b61249417aaaca45d95e5174 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 7 Jan 2019 12:54:37 +0800 Subject: [PATCH] Add fused embedding ops --- .../fused/fused_embedding_seq_pool_op.cc | 194 ++++++++++++++++++ .../fused/fused_embedding_seq_pool_op.h | 142 +++++++++++++ 2 files changed, 336 insertions(+) create mode 100644 paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc create mode 100644 paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc new file mode 100644 index 0000000000..fe4c73f472 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -0,0 +1,194 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h" +#include "paddle/fluid/framework/var_type_inference.h" + +namespace paddle { +namespace operators { + +class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input W of FusedEmbeddingSeqPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input Ids of FusedEmbeddingSeqPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output of FusedEmbeddingSeqPoolOp should not be null."); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + const std::string& combiner = ctx->Attrs().Get("combiner"); + + PADDLE_ENFORCE_EQ(table_dims.size(), 2); + PADDLE_ENFORCE_GE(ids_dims.size(), 1, + "The dim size of the 'Ids' tensor must greater than 1."); + PADDLE_ENFORCE_EQ(ids_dims[ids_dims.size() - 1], 1, + "The last dimension of the 'Ids' tensor must be 1."); + // we only support sum now + PADDLE_ENFORCE_EQ(combiner, "sum"); + + int64_t last_dim = table_dims[1]; + for (int i = 1; i != ids_dims.size(); ++i) { + last_dim *= ids_dims[i]; + } + + if (ctx->IsRuntime()) { + framework::Variable* ids_var = + boost::get(ctx->GetInputVarPtrs("Ids")[0]); + const auto& ids_lod = ids_var->Get().lod(); + + // in run time, the LoD of ids must be 1 + PADDLE_ENFORCE(ids_lod.size(), 1u, + "The LoD level of Input(Ids) must be 1"); + PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); + + int64_t batch_size = ids_lod[0].size() - 1; + + // in run time, the shape from Ids -> output + // should be [seq_length, 1] -> [batch_size, embedding_size] + ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim})); + } else { + // in compile time, the lod level of ids must be 1 + framework::VarDesc* ids_desc = + boost::get(ctx->GetInputVarPtrs("Ids")[0]); + PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1); + + // in compile time, the shape from Ids -> output + // should be [-1, 1] -> [-1, embedding_size] + ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim})); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("W", + "(Tensor) The input represents embedding tensors, " + "which is a learnable parameter."); + AddInput("Ids", + "An input with type int32 or int64 " + "contains the ids to be looked up in W. " + "The last dimension size must be 1."); + AddOutput("Out", "The lookup results, which have the same type as W."); + AddAttr("combiner", + "(string, default sum) " + "A string specifying the reduction op. Currently sum " + "are supported, sum computes the weighted sum of the " + "embedding results for each row.") + .SetDefault("sum"); + // NOTE(minqiyang): grad_inplace is an temporal attribute, + // please do NOT set this attribute in python layer. + AddAttr("grad_inplace", + "(boolean, default false) " + "If the grad op reuse the input's variable.") + .SetDefault(false); + AddAttr("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); + AddComment(R"DOC( +FusedEmbeddingSeqPool Operator. + +Computes embeddings for the given ids and weights. + +This operator is used to perform lookups on the parameter W, +then computes the weighted sum of the lookups results for each row +and concatenated into a dense tensor. + +The input Ids should carry the LoD (Level of Details) information. +And the output will change the LoD information with input Ids. + +)DOC"); + } +}; + +class FusedEmbeddingSeqPoolOpGradDescMaker + : public framework::DefaultGradOpDescMaker { + using ::paddle::framework::DefaultGradOpDescMaker< + true>::DefaultGradOpDescMaker; + + protected: + virtual std::string GradOpType() const { + return "fused_embedding_seq_pool_grad"; + } +}; + +class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto table_dims = ctx->GetInputDim("W"); + ctx->SetOutputDim(framework::GradVarName("W"), table_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class FusedEmbeddingSeqPoolOpGradVarTypeInference + : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto attr = op_desc.GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(3) << "fused_embedding_seq_pool_grad op " + << framework::GradVarName("W") << " is set to SelectedRows"; + block->Var(out_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + } else { + VLOG(3) << "fused_embedding_seq_pool_grad op " + << framework::GradVarName("W") << " is set to LoDTensor"; + block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); + } + block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_embedding_seq_pool, ops::FusedEmbeddingSeqPoolOp, + ops::FusedEmbeddingSeqPoolOpGradDescMaker, + ops::FusedEmbeddingSeqPoolOpMaker); +REGISTER_OPERATOR(fused_embedding_seq_pool_grad, + ops::FusedEmbeddingSeqPoolOpGrad, + ops::FusedEmbeddingSeqPoolOpGradVarTypeInference); + +REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool, + ops::FusedEmbeddingSeqPoolKernel, + ops::FusedEmbeddingSeqPoolKernel); +REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool_grad, + ops::FusedEmbeddingSeqPoolGradKernel, + ops::FusedEmbeddingSeqPoolGradKernel); diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h new file mode 100644 index 0000000000..38dfae8ad6 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +template +struct EmbeddingVSumFunctor { + void operator()(const framework::ExecutionContext &context, + const LoDTensor *table_t, const LoDTensor *ids_t, + LoDTensor *output_t) { + auto *table = table_t->data(); + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + int64_t last_dim = output_t->dims()[1]; + int64_t *ids = const_cast(ids_t->data()); + auto ids_lod = ids_t->lod()[0]; + int64_t ids_count = ids_t->numel() / ids_lod.back(); + + auto *output = output_t->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + for (int64_t i = 0; i != ids_lod.size() - 1; ++i) { + size_t begin = ids_lod[i] * ids_count; + for (int64_t j = 0; j != ids_count; ++j) { + PADDLE_ENFORCE_LT(ids[begin], row_number); + PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i); + blas.VCOPY(row_width, table + ids[begin + j] * row_width, + output + i * last_dim + j * row_width); + } + + for (int64_t r = (ids_lod[i] + 1) * ids_count; + r < ids_lod[i + 1] * ids_count; ++r) { + PADDLE_ENFORCE_LT(ids[r], row_number); + PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i); + blas.AXPY(row_width, 1., table + ids[r] * row_width, + output + i * last_dim + (r % ids_count) * row_width); + } + } + } +}; + +template +class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const LoDTensor *ids_t = context.Input("Ids"); // int tensor + LoDTensor *output_t = context.Output("Out"); // float tensor + const LoDTensor *table_var = context.Input("W"); + const std::string &combiner_type = context.Attr("combiner"); + + if (combiner_type == "sum") { + EmbeddingVSumFunctor functor; + functor(context, table_var, ids_t, output_t); + } + } +}; + +template +class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *table_var = context.InputVar("W"); + DDim table_dim; + if (table_var->IsType()) { + table_dim = context.Input("W")->dims(); + } else if (table_var->IsType()) { + auto *table_t = context.Input("W"); + table_dim = table_t->value().dims(); + } else { + PADDLE_THROW( + "The parameter W of a LookupTable " + "must be either LoDTensor or SelectedRows"); + } + + bool is_sparse = context.Attr("is_sparse"); + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + if (is_sparse) { + auto *ids = context.Input("Ids"); + auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_table = context.Output(framework::GradVarName("W")); + + auto *ids_data = ids->data(); + int64_t ids_num = ids->numel(); + auto lod = ids->lod()[0]; + int64_t row_width = d_output->dims()[1]; + + framework::Vector *new_rows = d_table->mutable_rows(); + new_rows->resize(ids_num); + std::memcpy(&(*new_rows)[0], ids_data, ids_num * sizeof(int64_t)); + + auto *d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + T *d_table_data = d_table_value->mutable_data(context.GetPlace()); + const T *d_output_data = d_output->data(); + + auto blas = math::GetBlas(context); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t in_offset = lod[i] * row_width; + const T *out_pos = d_output_data + i * row_width; + T *in_pos = d_table_data + in_offset; + for (int r = 0; r != h; ++r) { + blas.VCOPY(row_width, out_pos, in_pos + r * row_width); + } + } + } else { + LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; + } + } +}; + +} // namespace operators +} // namespace paddle -- GitLab