diff --git a/paddle/fluid/operators/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused_embedding_seq_pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea960782919f51653b9653791475f35827cc1576 --- /dev/null +++ b/paddle/fluid/operators/fused_embedding_seq_pool_op.cc @@ -0,0 +1,158 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused_embedding_seq_pool_op.h" +#include "paddle/fluid/framework/var_type_inference.h" + +namespace paddle { +namespace operators { + +class LookupTableOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input(W) of LookupTableOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input(Ids) of LookupTableOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LookupTableOp should not be null."); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + int ids_rank = ids_dims.size(); + + PADDLE_ENFORCE_EQ(table_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, + "The last dimension of the 'Ids' tensor must be 1."); + + auto output_dims = + framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1)); + output_dims.push_back(table_dims[1]); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + + if (ctx->GetOutputsVarType("Out")[0] == + framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("Ids", /*->*/ "Out"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("W", + "(Tensor) The input represents embedding tensors, " + "which is a learnable parameter."); + AddInput("Ids", + "An input with type int32 or int64 " + "contains the ids to be looked up in W. " + "The last dimension size must be 1."); + AddOutput("Out", "The lookup results, which have the same type as W."); + AddAttr("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); + AddAttr("is_distributed", + "(boolean, default false) distributed lookup table.") + .SetDefault(false); + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") + .SetDefault(kNoPadding); + AddComment(R"DOC( +Lookup Table Operator. + +This operator is used to perform lookups on the parameter W, +then concatenated into a dense tensor. + +The input Ids can carry the LoD (Level of Details) information, +or not. And the output only shares the LoD information with input Ids. + +)DOC"); + } +}; + +class LookupTableOpGradDescMaker + : public framework::DefaultGradOpDescMaker { + using ::paddle::framework::DefaultGradOpDescMaker< + true>::DefaultGradOpDescMaker; + + protected: + virtual std::string GradOpType() const { return "lookup_table_grad"; } +}; + +class LookupTableOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto table_dims = ctx->GetInputDim("W"); + ctx->SetOutputDim(framework::GradVarName("W"), table_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto attr = op_desc.GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") + << " is set to SelectedRows"; + block->Var(out_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + } else { + VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") + << " is set to LoDTensor"; + block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); + } + block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(lookup_table, ops::LookupTableOp, + ops::LookupTableOpGradDescMaker, ops::LookupTableOpMaker); +REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad, + ops::LookupTableOpGradVarTypeInference); + +// REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel, +// ops::LookupTableKernel); +// REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel, +// ops::LookupTableGradKernel); +REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel); +REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel); diff --git a/paddle/fluid/operators/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused_embedding_seq_pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6dcf4f44a7109c181dd9cb964082731a49df1524 --- /dev/null +++ b/paddle/fluid/operators/fused_embedding_seq_pool_op.h @@ -0,0 +1,207 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +constexpr int64_t kNoPadding = -1; + +template +class LookupTableKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *ids_t = context.Input("Ids"); // int tensor + auto *output_t = context.Output("Out"); // float tensor + auto *table_var = context.InputVar("W"); + + int64_t padding_idx = context.Attr("padding_idx"); + int64_t *ids = const_cast(ids_t->data()); + int64_t ids_numel = ids_t->numel(); + + if (table_var->IsType()) { + auto *table_t = context.Input("W"); + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_LT(ids[i], row_number); + PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); + memcpy(output + i * row_width, table + ids[i] * row_width, + row_width * sizeof(T)); + } + } + } else if (table_var->IsType()) { + const auto &table_t = table_var->Get(); + int64_t row_width = table_t.value().dims()[1]; + const auto *table = table_t.value().data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_GE(ids[i], 0); + auto id_index = table_t.Index(ids[i]); + PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); + // memcpy(output + i * row_width, table + id_index * row_width, + // row_width * sizeof(T)); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); + } + } + } + } +}; + +template +class LookupTableGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *table_var = context.InputVar("W"); + DDim table_dim; + if (table_var->IsType()) { + table_dim = context.Input("W")->dims(); + } else if (table_var->IsType()) { + auto *table_t = context.Input("W"); + table_dim = table_t->value().dims(); + } else { + PADDLE_THROW( + "The parameter W of a LookupTable " + "must be either LoDTensor or SelectedRows"); + } + + bool is_sparse = context.Attr("is_sparse"); + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + if (is_sparse) { + // auto start = std::chrono::system_clock::now(); + auto *ids = context.Input("Ids"); + auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_table = context.Output(framework::GradVarName("W")); + + auto *ids_data = ids->data(); + int64_t ids_num = ids->numel(); + // auto end = std::chrono::system_clock::now(); + // std::chrono::duration diff = end - start; + + // auto copy_start = std::chrono::system_clock::now(); + std::vector new_rows; + new_rows.resize(ids_num); + std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); + // for (int64_t i = 0; i < ids_num; i++) { + // new_rows.push_back(ids_data[i]); + // } + // auto copy_end = std::chrono::system_clock::now(); + // std::chrono::duration copy_diff = copy_end - copy_start; + // diff += copy_diff; + // LOG(ERROR) << "run emb_grad copy end, cost: " << copy_diff.count() << " + // " << ids_num; + + // copy_start = std::chrono::system_clock::now(); + d_table->set_rows(new_rows); + + auto *d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + d_table_value->ShareDataWith(*d_output); + // d_table_value->mutable_data(context.GetPlace()); + + // // copy_end = std::chrono::system_clock::now(); + // // copy_diff = copy_end - copy_start; + // // diff += copy_diff; + // // LOG(ERROR) << "run emb_grad resize table end, cost: " << + // // copy_diff.count() << " " << ids_num; + + // // copy_start = std::chrono::system_clock::now(); + // d_table->set_height(table_dim[0]); + + // auto *d_output_data = d_output->data(); + // auto *d_table_data = d_table_value->data(); + + // // copy_end = std::chrono::system_clock::now(); + // // copy_diff = copy_end - copy_start; + // // diff += copy_diff; + // // LOG(ERROR) << "run emb_grad set height end, cost: " << + // // copy_diff.count() << " " << ids_num; + + // auto d_output_dims = d_output->dims(); + // PADDLE_ENFORCE_EQ( + // d_table_value->dims(), + // framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); + // // copy_start = std::chrono::system_clock::now(); + // auto blas = math::GetBlas(context); + // blas.VCOPY(d_output->numel(), d_output_data, d_table_data); + // cblas_scopy(d_output->numel(), d_output_data, 1, d_table_data, 1); + // // for (int i = 0; i != d_output->numel(), ++i) { + // // *(d_table_data++) = *(d_output_data++); + // // } + // // memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + // // copy_end = std::chrono::system_clock::now(); + // // copy_diff = copy_end - copy_start; + // // diff += copy_diff; + // // LOG(ERROR) << "run emb_grad core end, cost: " << copy_diff.count() + // << " + // // " << ids_num << " " << d_output->numel(); + + // // LOG(ERROR) << "run emb_grad end, cost: " << diff.count(); + } else { + auto *ids = context.Input("Ids"); + auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_table = context.Output(framework::GradVarName("W")); + + auto *ids_data = ids->data(); + + int N = table_dim[0]; + int D = table_dim[1]; + + auto *d_output_data = d_output->data(); + auto *d_table_data = d_table->mutable_data(context.GetPlace()); + + memset(d_table_data, 0, d_table->numel() * sizeof(T)); + + for (int64_t i = 0; i < ids->numel(); ++i) { + PADDLE_ENFORCE_LT(ids_data[i], N); + PADDLE_ENFORCE_GE(ids_data[i], 0); + for (int j = 0; j < D; ++j) { + d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; + } + } + } + } +}; + +} // namespace operators +} // namespace paddle