diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 03f47b594d9bc4a186aacd2eb457335f2a1bf752..231f4b3bc41e1bf3f0a7d52217d3234df0a68967 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(WITH_GRPC) else() set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc) - set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib) @@ -53,6 +53,7 @@ cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) +cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) if(WITH_GPU) cc_test(collective_server_test SRCS collective_server_test.cc DEPS sendrecvop_rpc executor ${RPC_DEPS} diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5b486d1218e4989de3e7f4030a1150e857f9a5d --- /dev/null +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/operators/distributed/parameter_recv.h" + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor.h" + +#include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/rpc_client.h" +#include "paddle/fluid/operators/distributed/variable_response.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" + +namespace paddle { +namespace operators { +namespace distributed { + +using LoDTensor = framework::LoDTensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +template +void ParameterRecv::operator()(const std::string &var_name, + const std::vector &send_varnames, + const std::vector &epmap, + const std::vector &height_sections, + const framework::ExecutionContext &ctx, + const framework::Scope &scope, bool sync) { + framework::Scope *local_scope = scope.NewTmpScope(); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &cpu_ctx = *pool.Get(platform::CPUPlace()); + + distributed::RPCClient *rpc_client = + distributed::RPCClient::GetInstance( + ctx.Attr("trainer_id")); + + auto *send_var = scope.FindVar(var_name); + size_t out_num = send_varnames.size(); + if (send_var->IsType()) { + if (out_num > 1) { + auto &send_tensor = send_var->Get(); + auto &send_tensor_dims = send_tensor.dims(); + std::vector outs_dims; + outs_dims.reserve(out_num); + + // infer output shape + PADDLE_ENFORCE_EQ(height_sections.size(), out_num, + "tensor split sections size" + "should be equal to output size."); + for (size_t i = 0; i < out_num; ++i) { + auto dim = send_tensor_dims; + dim[0] = height_sections[i]; + outs_dims.push_back(dim); + } + + // create output var in local scope + size_t row_offset = 0; + for (auto i = 0; i < out_num; ++i) { + framework::Tensor *out = local_scope->Var(send_varnames[i]) + ->GetMutable(); + *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); + row_offset += outs_dims[i][0]; + } + } + } else if (send_var->IsType()) { + auto &send_slr = send_var->Get(); + auto abs_sections = ToAbsoluteSection(height_sections); + + auto send_rows = send_slr.rows(); + std::vector> outs_rows_idx; + std::vector> outs_dense_idx; + + outs_rows_idx.resize(out_num); + outs_dense_idx.resize(out_num); + + auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0]; + auto src = send_slr.value().data(); + + // create output var in local scope + std::vector outs; + for (auto &name : send_varnames) { + auto *out = local_scope->Var(name)->GetMutable(); + outs.push_back(out); + } + + // split rows index into output sparse vars + for (size_t i = 0; i < send_rows.size(); ++i) { + int out_idx = FindOutIdx(send_rows[i], abs_sections); + outs_rows_idx[out_idx].push_back(send_rows[i]); + outs_dense_idx[out_idx].push_back(i); + } + auto place = ctx.GetPlace(); + + for (size_t i = 0; i < outs_rows_idx.size(); ++i) { + auto rows_idx = outs_rows_idx[i]; + outs[i]->set_height(height_sections[i]); + auto dims = send_slr.GetCompleteDims(); + dims[0] = rows_idx.size(); + outs[i]->mutable_value()->mutable_data(dims, send_slr.place()); + outs[i]->mutable_rows()->clear(); + if (rows_idx.size() > 0) { + for (auto idx : rows_idx) { + outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); + } + auto dst = outs[i]->mutable_value()->mutable_data(ctx.GetPlace()); + for (size_t j = 0; j < rows_idx.size(); j++) { + if (platform::is_cpu_place(place)) { + memory::Copy( + platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(), + src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel); + } else { +#ifdef PADDLE_WITH_CUDA + auto stream = ctx.cuda_device_context().stream(); + memory::Copy(platform::CUDAPlace(), dst + j * row_numel, + platform::CUDAPlace(), + src + outs_dense_idx[i][j] * row_numel, + sizeof(T) * row_numel, stream); +#else + PADDLE_THROW("Paddle is not compiled with GPU"); +#endif + } + } + } + PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(), + "rows should has the same size with tensor dim 0"); + } + + } else { + PADDLE_THROW("unsupported var type to send!"); + } + + std::vector rets; + for (size_t i = 0; i < send_varnames.size(); i++) { + auto &send_var_name = send_varnames[i]; + auto &endpoint = epmap[i]; + if (NeedSend(*local_scope, send_var_name)) { + VLOG(3) << "sending " << send_var_name << " to " << endpoint; + rets.push_back(rpc_client->AsyncSendVar(endpoint, cpu_ctx, *local_scope, + send_var_name)); + } else { + VLOG(3) << "don't send non-initialized variable: " << send_varnames[i]; + } + } + + // note!! only support sync send now + if (true || sync) { + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } + } + + delete local_scope; +} + +template struct ParameterRecv; + +}; // namespace distributed +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_recv.h b/paddle/fluid/operators/distributed/parameter_recv.h new file mode 100644 index 0000000000000000000000000000000000000000..817115e2d1e0adeaf40623a3eb83e900899ec4ea --- /dev/null +++ b/paddle/fluid/operators/distributed/parameter_recv.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +namespace distributed { + +template +struct ParameterRecv { + void operator()(const std::string &var_name, + const std::vector &send_varnames, + const std::vector &epmap, + const std::vector &height_sections, + const framework::ExecutionContext &context, + const framework::Scope &scope, bool sync); +}; + +}; // namespace distributed +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index 0eb30ce695a0364e4e4b759622a6516e5e80b885..3bcfc532e8665cf8226c85c2b10cf049feae748e 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -2,9 +2,9 @@ include(operators) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send brpc leveldb snappystream snappy protobuf ssl crypto zlib node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv brpc leveldb snappystream snappy protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 120c65f29699bf2745b09ea312d1de069c8173c5..5e004a7a3cbc592bf82231f30973e5b2eb71eb21 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -110,6 +110,11 @@ This operator can get variables from server side. "for example: we need var named 'moment_1@127.0.0.1:1001', " "and it real name on parameter server is 'moment_1'. ") .SetDefault({}); + AddAttr>( + "recv_varnames", + "(vector) " + "the splited parameter varnames to be recved from pserver") + .SetDefault(std::vector{}); } };