From 361cb0e078d1942e06ffcb3586e68be11c465d29 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 23 Nov 2018 10:53:35 +0800 Subject: [PATCH] lookup remote table can compile --- .../distributed_ops/lookup_remote_table_op.cc | 12 +- .../distributed_ops/lookup_remote_table_op.h | 220 ++++++++++-------- 2 files changed, 133 insertions(+), 99 deletions(-) diff --git a/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.cc b/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.cc index 06e96a7f9..5d3a50a44 100644 --- a/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.cc +++ b/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.cc @@ -68,6 +68,15 @@ class LookupRemoteTableOpMaker : public framework::OpProtoAndCheckerMaker { "contains the ids to be looked up in W. " "The last dimension size must be 1."); AddOutput("Out", "The lookup results, which have the same type as W."); + AddAttr>("height_sections", + "Height for each output SelectedRows.") + .SetDefault(std::vector({})); + AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); + AddAttr>( + "epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input variables for mapping") + .SetDefault({"127.0.0.1:6164"}); AddAttr("padding_idx", "(int64, default -1) " "If the value is -1, it makes no effect to lookup. " @@ -98,7 +107,8 @@ or not. And the output only shares the LoD information with input Ids. namespace ops = paddle::operators; REGISTER_OPERATOR(lookup_remote_table, ops::LookupRemoteTableOp, - ops::EmptyGradOpMaker, ops::LookupRemoteTableOpMaker); + paddle::framework::EmptyGradOpMaker, + ops::LookupRemoteTableOpMaker); REGISTER_OP_CPU_KERNEL(lookup_remote_table, ops::LookupRemoteTableKernel, ops::LookupRemoteTableKernel); diff --git a/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.h b/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.h index 1a383f6d3..ddf57016d 100644 --- a/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.h +++ b/paddle/fluid/operators/distributed_ops/lookup_remote_table_op.h @@ -12,26 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once + #include // NOLINT #include #include +#include #include #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { -namespace distributed { inline size_t GetSectionIndex(int64_t id, const std::vector& abs_sections) { for (size_t i = 1; i < abs_sections.size(); ++i) { - if (row < abs_sections[i]) { + if (id < abs_sections[i]) { return i - 1; } } @@ -62,9 +68,10 @@ inline std::vector> SplitIds( std::vector> splited_ids; splited_ids.resize(height_section.size() + 1); for (auto& id : all_ids) { - auto section_index = GetSectionIndex(id); + auto section_index = GetSectionIndex(id, abs_sections); splited_ids[section_index].push_back(id - abs_sections[section_index]); } + return splited_ids; } inline void SplitIdsIntoMultipleVarsBySection( @@ -82,7 +89,7 @@ inline void SplitIdsIntoMultipleVarsBySection( auto& ids = splited_ids[i]; if (!ids.empty()) { auto* id_tensor_data = id_tensor->mutable_data( - framework::make_ddim({ids.size(), 1}), place); + framework::make_ddim({static_cast(ids.size()), 1}), place); memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size()); } } @@ -93,8 +100,8 @@ inline void MergeMultipleVarsIntoOnBySection( const std::vector& out_var_names, const std::vector& height_section, const std::vector>& splited_ids, - framework::Scope* scope) { - PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, ""); + const framework::ExecutionContext& context, framework::Scope* scope) { + PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size() + 1, ""); auto cpu_place = platform::CPUPlace(); @@ -106,15 +113,15 @@ inline void MergeMultipleVarsIntoOnBySection( id_to_offset[id_data[i]].push_back(i); } - auto& out_tensor = scope->Var(out_name)->Get(); - auto* out_tensor_data = out_tensor.mutable_data(); + auto* out_tensor = scope->Var(out_name)->GetMutable(); + auto* out_tensor_data = out_tensor->mutable_data(context.GetPlace()); for (size_t section_idx = 0; section_idx < out_var_names.size(); ++section_idx) { auto& ids_in_this_section = splited_ids[section_idx]; auto& prefetch_out_var = scope->Var(out_var_names[section_idx])->Get(); - const auto* out_var_data = prefetch_out_var.mutable_data(); + const auto* out_var_data = prefetch_out_var.data(); auto& dims = prefetch_out_var.dims(); PADDLE_ENFORCE_EQ(dims.size(), 2, ""); @@ -129,63 +136,64 @@ inline void MergeMultipleVarsIntoOnBySection( for (auto& offset : offsets) { // should support GPU tensor memory::Copy(cpu_place, out_tensor_data + offset * row_numel, cpu_place, - out_var_data + i * grad_row_numel, - sizeof(T) * grad_row_numel); + out_var_data + i * row_numel, sizeof(float) * row_numel); } } } } -inline void prefetch(const std::string& table_name, const std::string& id_name, - const std::string& out_name, - const std::vector& epmap, - const std::vector& height_section, - const framework::Scope& scope, - const platform::Place& place) const { - auto local_scope = scope.NewScope(); - - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - - distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance(Attr("trainer_id")); - - std::vector in_var_names; - std::vector out_var_names; - for (size_t i = 0; i < epmap.size(); ++i) { - in_var_names.push_back(id_name + "@" + epmap[i]); - out_var_names.push_back(out_name + "@" + epmap[i]); - } - - auto splited_ids = SplitIds(id_name, height_section, local_scope); - SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_section, - splited_ids, local_scope); - - // create output var in local scope - for (auto& name : out_var_names) { - local_scope.Var(name)->GetMutable(); - } - - std::vector rets; - for (size_t i = 0; i < ins.size(); i++) { - if (NeedSend(local_scope, ins[i])) { - VLOG(30) << "sending " << ins[i] << " to " << epmap[i] << " to get " - << outs[i] << " back"; - rets.push_back(rpc_client->AsyncPrefetchVar( - epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i])); - } else { - VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; - } - } - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); - } - - MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names, - height_section, plited_ids, scope) - - scope.DeleteScope(local_scope); -} +// inline void prefetch(const std::string& table_name, const std::string& +// id_name, +// const std::string& out_name, +// const std::vector& epmap, +// const std::vector& height_section, +// const framework::Scope& scope, +// const platform::Place& place) { +// auto& local_scope = scope.NewScope(); +// +// platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); +// auto& ctx = *pool.Get(place); +// +// distributed::RPCClient* rpc_client = +// distributed::RPCClient::GetInstance(Attr("trainer_id")); +// +// std::vector in_var_names; +// std::vector out_var_names; +// for (size_t i = 0; i < epmap.size(); ++i) { +// in_var_names.push_back(id_name + "@" + epmap[i]); +// out_var_names.push_back(out_name + "@" + epmap[i]); +// } +// +// auto splited_ids = SplitIds(id_name, height_section, &local_scope); +// SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_section, +// splited_ids, &local_scope); +// +// // create output var in local scope +// for (auto& name : out_var_names) { +// local_scope.Var(name)->GetMutable(); +// } +// +// std::vector rets; +// for (size_t i = 0; i < in_var_names.size(); i++) { +// if (NeedSend(local_scope, in_var_names[i])) { +// VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i] << " to +// get " +// << out_var_names[i] << " back"; +// rets.push_back(rpc_client->AsyncPrefetchVar( +// epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i])); +// } else { +// VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; +// } +// } +// for (size_t i = 0; i < rets.size(); i++) { +// PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); +// } +// +// MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names, +// height_section, splited_ids, &local_scope); +// +// scope.DeleteScope(&local_scope); +//} using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; @@ -198,54 +206,70 @@ template class LookupRemoteTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* ids_t = context.Input("Ids"); // int tensor + std::string id_name = context.Inputs("Ids").front(); + auto* ids_t = context.Input("Ids"); // int tensor + + std::string out_name = context.Outputs("Out").front(); auto* output_t = context.Output("Out"); // float tensor + + std::string table_name = context.Inputs("W").front(); auto* table_var = context.InputVar("W"); int64_t padding_idx = context.Attr("padding_idx"); int64_t* ids = const_cast(ids_t->data()); int64_t ids_numel = ids_t->numel(); - if (table_var->IsType()) { - auto* table_t = context.Input("W"); - int64_t row_number = table_t->dims()[0]; - int64_t row_width = table_t->dims()[1]; - - auto* table = table_t->data(); - auto* output = output_t->mutable_data(context.GetPlace()); - - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_LT(ids[i], row_number); - PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); - memcpy(output + i * row_width, table + ids[i] * row_width, - row_width * sizeof(T)); - } - } - } else if (table_var->IsType()) { - const auto& table_t = table_var->Get(); - int64_t row_width = table_t.value().dims()[1]; - const auto* table = table_t.value().data(); - auto* output = output_t->mutable_data(context.GetPlace()); - - auto blas = math::GetBlas(context); - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_GE(ids[i], 0); - auto id_index = table_t.Index(ids[i]); - PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); - blas.VCOPY(row_width, table + id_index * row_width, - output + i * row_width); - } + auto epmap = context.Attr>("epmap"); + auto height_sections = + context.Attr>("height_sections"); + + auto& local_scope = context.scope().NewScope(); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(context.GetPlace()); + + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance( + context.Attr("trainer_id")); + + std::vector in_var_names; + std::vector out_var_names; + for (size_t i = 0; i < epmap.size(); ++i) { + in_var_names.push_back(id_name + "@" + epmap[i]); + out_var_names.push_back(out_name + "@" + epmap[i]); + } + + auto splited_ids = SplitIds(id_name, height_sections, &local_scope); + SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_sections, + splited_ids, &local_scope); + + // create output var in local scope + for (auto& name : out_var_names) { + local_scope.Var(name)->GetMutable(); + } + + std::vector rets; + for (size_t i = 0; i < in_var_names.size(); i++) { + if (NeedSend(local_scope, in_var_names[i])) { + VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i] + << " to get " << out_var_names[i] << " back"; + rets.push_back(rpc_client->AsyncPrefetchVar( + epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i])); + } else { + VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; } } + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } + + MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names, + height_sections, splited_ids, context, + &local_scope); + + context.scope().DeleteScope(&local_scope); } }; -} // namespace distributed } // namespace operators } // namespace paddle -- GitLab