diff --git a/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.cc b/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..06e96a7f98303af06db8f93405ecac9dcd513c93 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.cc @@ -0,0 +1,104 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/distributed_ops/lookup_remote_table_op.h" +#include "paddle/fluid/framework/var_type_inference.h" + +namespace paddle { +namespace operators { + +class LookupRemoteTableOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input(W) of LookupRemoteTableOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input(Ids) of LookupRemoteTableOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LookupRemoteTableOp should not be null."); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + int ids_rank = ids_dims.size(); + + PADDLE_ENFORCE_EQ(table_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, + "The last dimension of the 'Ids' tensor must be 1."); + + auto output_dims = + framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1)); + output_dims.push_back(table_dims[1]); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + + if (ctx->GetOutputsVarType("Out")[0] == + framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("Ids", /*->*/ "Out"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class LookupRemoteTableOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("W", + "(Tensor) The input represents embedding tensors, " + "which is a learnable parameter."); + AddInput("Ids", + "An input with type int32 or int64 " + "contains the ids to be looked up in W. " + "The last dimension size must be 1."); + AddOutput("Out", "The lookup results, which have the same type as W."); + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") + .SetDefault(kNoPadding); + // NOTE(minqiyang): grad_inplace is an temporal attribute, + // please do NOT set this attribute in python layer. + AddAttr("grad_inplace", + "(boolean, default false) " + "If the grad op reuse the input's variable.") + .SetDefault(false); + AddComment(R"DOC( +Lookup Remote Table Operator. + +This operator is used to perform lookups on the parameter W, +then concatenated into a dense tensor. + +The input Ids can carry the LoD (Level of Details) information, +or not. And the output only shares the LoD information with input Ids. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(lookup_remote_table, ops::LookupRemoteTableOp, + ops::EmptyGradOpMaker, ops::LookupRemoteTableOpMaker); + +REGISTER_OP_CPU_KERNEL(lookup_remote_table, ops::LookupRemoteTableKernel, + ops::LookupRemoteTableKernel); diff --git a/paddle/fluid/operators/distributed_ops/lookup_remote_table.h b/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.h similarity index 54% rename from paddle/fluid/operators/distributed_ops/lookup_remote_table.h rename to paddle/fluid/operators/distributed_ops/lookup_remote_table_op.h index 5b066c8196108bb482da7d372bbe2e39541a38b7..1a383f6d3e6f0123c0c83b8a1aa51d66bb6bad17 100644 --- a/paddle/fluid/operators/distributed_ops/lookup_remote_table.h +++ b/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.h @@ -14,21 +14,22 @@ limitations under the License. */ #include // NOLINT #include -#include #include #include +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h" namespace paddle { namespace operators { namespace distributed { -inline size_t GetSectionIndex(int64_t id, const std::vector& abs_sections) { +inline size_t GetSectionIndex(int64_t id, + const std::vector& abs_sections) { for (size_t i = 1; i < abs_sections.size(); ++i) { if (row < abs_sections[i]) { return i - 1; @@ -38,7 +39,7 @@ inline size_t GetSectionIndex(int64_t id, const std::vector& abs_sectio } inline std::vector ToAbsoluteSection( - const std::vector& height_sections) { + const std::vector& height_sections) { std::vector abs_sections; abs_sections.resize(height_sections.size()); abs_sections[0] = 0; @@ -49,9 +50,8 @@ inline std::vector ToAbsoluteSection( } inline std::vector> SplitIds( - const std::string& id_name, - const std::vector& height_section, - framework::Scope* scope) { + const std::string& id_name, const std::vector& height_section, + framework::Scope* scope) { auto& id_tensor = scope->Var(id_name)->Get(); auto* id_data = id_tensor.data(); std::set all_ids; @@ -68,32 +68,32 @@ inline std::vector> SplitIds( } inline void SplitIdsIntoMultipleVarsBySection( - const std::string& id_name, - const std::vector& in_var_names, - const std::vector& height_section, - const std::vector>& splited_ids, - framework::Scope* scope) { + const std::string& id_name, const std::vector& in_var_names, + const std::vector& height_section, + const std::vector>& splited_ids, + framework::Scope* scope) { PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, ""); auto place = platform::CPUPlace(); for (size_t i = 0; i < in_var_names.size(); ++i) { - auto* id_tensor = scope->Var(in_var_names[i])->GetMutable(); + auto* id_tensor = + scope->Var(in_var_names[i])->GetMutable(); auto& ids = splited_ids[i]; if (!ids.empty()) { - auto* id_tensor_data = id_tensor->mutable_data(framework::make_ddim({ids.size(), 1}), place); + auto* id_tensor_data = id_tensor->mutable_data( + framework::make_ddim({ids.size(), 1}), place); memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size()); } } } inline void MergeMultipleVarsIntoOnBySection( - const std::string& id_name, - const std::string& out_name, - const std::vector& out_var_names, - const std::vector& height_section, - const std::vector>& splited_ids, - framework::Scope* scope) { + const std::string& id_name, const std::string& out_name, + const std::vector& out_var_names, + const std::vector& height_section, + const std::vector>& splited_ids, + framework::Scope* scope) { PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, ""); auto cpu_place = platform::CPUPlace(); @@ -109,9 +109,11 @@ inline void MergeMultipleVarsIntoOnBySection( auto& out_tensor = scope->Var(out_name)->Get(); auto* out_tensor_data = out_tensor.mutable_data(); - for (size_t section_idx = 0; section_idx < out_var_names.size(); ++section_idx) { + for (size_t section_idx = 0; section_idx < out_var_names.size(); + ++section_idx) { auto& ids_in_this_section = splited_ids[section_idx]; - auto& prefetch_out_var = scope->Var(out_var_names[section_idx])->Get(); + auto& prefetch_out_var = + scope->Var(out_var_names[section_idx])->Get(); const auto* out_var_data = prefetch_out_var.mutable_data(); auto& dims = prefetch_out_var.dims(); @@ -126,31 +128,27 @@ inline void MergeMultipleVarsIntoOnBySection( auto& offsets = id_to_offset[origin_id]; for (auto& offset : offsets) { // should support GPU tensor - memory::Copy(cpu_place, out_tensor_data + offset * row_numel, - cpu_place, out_var_data + i * grad_row_numel, + memory::Copy(cpu_place, out_tensor_data + offset * row_numel, cpu_place, + out_var_data + i * grad_row_numel, sizeof(T) * grad_row_numel); } } } } -inline void prefetch( - const std::string& table_name, - const std::string& id_name, - const std::string& out_name, - const std::vector& epmap, - const std::vector& height_section, - const framework::Scope& scope, - const platform::Place& place) const { - +inline void prefetch(const std::string& table_name, const std::string& id_name, + const std::string& out_name, + const std::vector& epmap, + const std::vector& height_section, + const framework::Scope& scope, + const platform::Place& place) const { auto local_scope = scope.NewScope(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(Attr("trainer_id")); std::vector in_var_names; std::vector out_var_names; @@ -160,7 +158,8 @@ inline void prefetch( } auto splited_ids = SplitIds(id_name, height_section, local_scope); - SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_section, splited_ids, local_scope); + SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_section, + splited_ids, local_scope); // create output var in local scope for (auto& name : out_var_names) { @@ -171,9 +170,9 @@ inline void prefetch( for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(local_scope, ins[i])) { VLOG(30) << "sending " << ins[i] << " to " << epmap[i] << " to get " - << outs[i] << " back"; - rets.push_back(rpc_client->AsyncPrefetchVar(epmap[i], ctx, local_scope, - in_var_names[i], out_var_names[i])); + << outs[i] << " back"; + rets.push_back(rpc_client->AsyncPrefetchVar( + epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i])); } else { VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; } @@ -182,11 +181,71 @@ inline void prefetch( PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); } - MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names, height_section, plited_ids, scope) + MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names, + height_section, plited_ids, scope) - scope.DeleteScope(local_scope); + scope.DeleteScope(local_scope); } +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +constexpr int64_t kNoPadding = -1; + +template +class LookupRemoteTableKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* ids_t = context.Input("Ids"); // int tensor + auto* output_t = context.Output("Out"); // float tensor + auto* table_var = context.InputVar("W"); + + int64_t padding_idx = context.Attr("padding_idx"); + int64_t* ids = const_cast(ids_t->data()); + int64_t ids_numel = ids_t->numel(); + + if (table_var->IsType()) { + auto* table_t = context.Input("W"); + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + + auto* table = table_t->data(); + auto* output = output_t->mutable_data(context.GetPlace()); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_LT(ids[i], row_number); + PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); + memcpy(output + i * row_width, table + ids[i] * row_width, + row_width * sizeof(T)); + } + } + } else if (table_var->IsType()) { + const auto& table_t = table_var->Get(); + int64_t row_width = table_t.value().dims()[1]; + const auto* table = table_t.value().data(); + auto* output = output_t->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_GE(ids[i], 0); + auto id_index = table_t.Index(ids[i]); + PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); + } + } + } + } +}; + } // namespace distributed } // namespace operators } // namespace paddle