// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "paddle/fluid/operators/distributed/parameter_recv.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { namespace distributed { using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; template void RecvSelectedRows(const CommContext &rpc_ctx, const framework::Scope &scope) { platform::RecordEvent record_event_grpc("ParameterRecv::RecvSelectedRows", platform::EventRole::kInnerOp); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto cpu_place = platform::CPUPlace(); auto &cpu_ctx = *pool.Get(cpu_place); distributed::RPCClient *rpc_client = distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); std::unique_ptr local_scope = scope.NewTmpScope(); std::vector rets; for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { auto &recv_var_name = rpc_ctx.splited_varnames[i]; local_scope->Var(recv_var_name); VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; // sparse param in recv_scope is LoDTensor rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name, recv_var_name, recv_var_name)); } for (size_t i = 0; i < rets.size(); i++) { PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( "internal error in RPCClient")); } int64_t height = 0; int64_t ids_num = 0; int64_t width = 0; std::vector all_ids; auto pserver_num = rpc_ctx.splited_varnames.size(); for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { auto &recv_var_name = rpc_ctx.splited_varnames[i]; auto *recv_var = local_scope->FindVar(recv_var_name); auto &recv_t = recv_var->Get(); height += recv_t.height(); ids_num += recv_t.rows().size(); width = recv_t.value().dims()[1]; std::transform(recv_t.rows().begin(), recv_t.rows().end(), std::back_inserter(all_ids), [&](int64_t id) { return id * pserver_num + i; }); } auto *var = scope.FindVar(rpc_ctx.var_name); auto *t_ = var->GetMutable(); T *out_data = t_->mutable_value()->mutable_data({ids_num, width}, cpu_place); t_->set_height(height); t_->set_rows(all_ids); int64_t cnt = 0; for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { auto &recv_var_name = rpc_ctx.splited_varnames[i]; auto *recv_var = local_scope->FindVar(recv_var_name); auto &recv_t = recv_var->Get(); auto rows = recv_t.rows().size(); const T *in_data = recv_t.value().data(); std::copy_n(in_data, rows * width, out_data + cnt); cnt += rows * width; } t_->SyncIndex(); } template void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { platform::RecordEvent record_event_grpc("ParameterRecv::RecvLodTensor", platform::EventRole::kInnerOp); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto cpu_place = platform::CPUPlace(); auto &cpu_ctx = *pool.Get(cpu_place); distributed::RPCClient *rpc_client = distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); std::vector rets; // variable do not spilt if (rpc_ctx.origin_varnames.size() == 1 && rpc_ctx.splited_varnames.size() == 1) { auto varname = rpc_ctx.origin_varnames[0]; VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0]; rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx, scope, varname, varname)); for (size_t i = 0; i < rets.size(); i++) { PADDLE_ENFORCE_NE( rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout("internal error in RPCClient")); } VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name; return; } else { PADDLE_ENFORCE(false, platform::errors::Unimplemented( "ParameterRecv can not recv dense with multi " "parts now, add it soon.")); } } template void ParameterRecv::operator()(const CommContext &rpc_ctx, const framework::Scope &scope, bool barrier) { VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name; PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1, platform::errors::InvalidArgument( "origin_varnames.size() >= 1 is permitted")); if (rpc_ctx.is_sparse) { RecvSelectedRows(rpc_ctx, scope); } else { RecvLodTensor(rpc_ctx, scope); } VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name; } template void ParameterRecv::operator()(const CommContext &rpc_ctx, const framework::Scope &scope) { this->operator()(rpc_ctx, scope, true); } template struct ParameterRecv; }; // namespace distributed }; // namespace operators }; // namespace paddle