From ca8c4f3e9b48e8196c5b9d69db79d8bcd55f10f8 Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:08:56 +0800 Subject: [PATCH] update dataset (#37194) --- paddle/fluid/distributed/fleet.cc | 11 ++++- .../fluid/distributed/service/communicator.cc | 13 ------ paddle/fluid/framework/data_set.cc | 41 +++++++++++++------ paddle/fluid/framework/multi_trainer.cc | 27 +++++++++++- .../distributed/fleet/runtime/the_one_ps.py | 4 ++ 5 files changed, 67 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index 4a3dfc3e485..871e503ca42 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -710,8 +710,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler) { VLOG(1) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; auto* communicator = Communicator::GetInstance(); - return communicator->_worker_ptr->registe_client2client_msg_handler(msg_type, - handler); + // for unittest which does not call fleet.init_worker() first + if (communicator == nullptr) { + VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler communicator is " + "null"; + return -1; + } else { + return communicator->_worker_ptr->registe_client2client_msg_handler( + msg_type, handler); + } } std::future FleetWrapper::SendClientToClientMsg( diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc index f51ffbcf811..b4d2fa3e82e 100644 --- a/paddle/fluid/distributed/service/communicator.cc +++ b/paddle/fluid/distributed/service/communicator.cc @@ -368,20 +368,7 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { VLOG(1) << "push dense param to table " << table_id << " from 0' trainer done"; } - BarrierWithTable(1); - } else { - BarrierWithTable(1); - for (auto &iter : recv_varname_to_ctx) { - auto &table_id = iter.first; - auto &varnames = iter.second; - RpcRecvDense(varnames, table_id, recv_scope_); - VLOG(1) << "pull dense param to table " << table_id - << " from 0' trainer done"; - } } - std::this_thread::sleep_for( - std::chrono::milliseconds(100 + trainer_id_ * 10)); - BarrierWithTable(1); return; } diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 2a071665b26..ca5e27dac3a 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -19,6 +19,10 @@ #include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/timer.h" +#ifdef PADDLE_WITH_PSCORE +#include "paddle/fluid/distributed/fleet.h" +#endif + #if defined _WIN32 || defined __APPLE__ #else #define _LINUX @@ -208,13 +212,17 @@ void DatasetImpl::CreateChannel() { // if sent message between workers, should first call this function template void DatasetImpl::RegisterClientToClientMsgHandler() { - auto fleet_ptr = FleetWrapper::GetInstance(); - VLOG(3) << "RegisterClientToClientMsgHandler"; +#ifdef PADDLE_WITH_PSCORE + auto fleet_ptr = distributed::FleetWrapper::GetInstance(); +#else + auto fleet_ptr = framework::FleetWrapper::GetInstance(); +#endif + VLOG(1) << "RegisterClientToClientMsgHandler"; fleet_ptr->RegisterClientToClientMsgHandler( 0, [this](int msg_type, int client_id, const std::string& msg) -> int { return this->ReceiveFromClient(msg_type, client_id, msg); }); - VLOG(3) << "RegisterClientToClientMsgHandler done"; + VLOG(1) << "RegisterClientToClientMsgHandler done"; } static void compute_left_batch_num(const int ins_num, const int thread_num, std::vector>* offset, @@ -523,7 +531,7 @@ void DatasetImpl::LocalShuffle() { VLOG(3) << "DatasetImpl::LocalShuffle() end, no data to shuffle"; return; } - auto fleet_ptr = FleetWrapper::GetInstance(); + auto fleet_ptr = framework::FleetWrapper::GetInstance(); input_channel_->Close(); std::vector data; input_channel_->ReadAll(data); @@ -540,11 +548,14 @@ void DatasetImpl::LocalShuffle() { } void MultiSlotDataset::GlobalShuffle(int thread_num) { -#ifdef PADDLE_WITH_PSLIB VLOG(3) << "MultiSlotDataset::GlobalShuffle() begin"; platform::Timer timeline; timeline.Start(); - auto fleet_ptr = FleetWrapper::GetInstance(); +#ifdef PADDLE_WITH_PSCORE + auto fleet_ptr = distributed::FleetWrapper::GetInstance(); +#else + auto fleet_ptr = framework::FleetWrapper::GetInstance(); +#endif if (!input_channel_ || input_channel_->Size() == 0) { VLOG(3) << "MultiSlotDataset::GlobalShuffle() end, no data to shuffle"; @@ -576,7 +587,12 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) { }; auto global_shuffle_func = [this, get_client_id]() { - auto fleet_ptr = FleetWrapper::GetInstance(); +#ifdef PADDLE_WITH_PSCORE + auto fleet_ptr = distributed::FleetWrapper::GetInstance(); +#else + auto fleet_ptr = framework::FleetWrapper::GetInstance(); +#endif + // auto fleet_ptr = framework::FleetWrapper::GetInstance(); std::vector data; while (this->input_channel_->Read(data)) { std::vector ars(this->trainer_num_); @@ -633,7 +649,6 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) { timeline.Pause(); VLOG(3) << "DatasetImpl::GlobalShuffle() end, cost time=" << timeline.ElapsedSec() << " seconds"; -#endif } template @@ -936,7 +951,7 @@ int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id, } CHECK(ar.Cursor() == ar.Finish()); - auto fleet_ptr = FleetWrapper::GetInstance(); + auto fleet_ptr = framework::FleetWrapper::GetInstance(); // not use random because it doesn't perform well here. // to make sure each channel get data equally, we just put data to // channel one by one. @@ -976,7 +991,7 @@ void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) { void MultiSlotDataset::PostprocessInstance() { // divide pv instance, and merge to input_channel_ if (enable_pv_merge_) { - auto fleet_ptr = FleetWrapper::GetInstance(); + auto fleet_ptr = framework::FleetWrapper::GetInstance(); std::shuffle(input_records_.begin(), input_records_.end(), fleet_ptr->LocalRandomEngine()); input_channel_->Open(); @@ -1014,7 +1029,7 @@ void MultiSlotDataset::PreprocessInstance() { if (!enable_pv_merge_) { // means to use Record this->LocalShuffle(); } else { // means to use Pv - auto fleet_ptr = FleetWrapper::GetInstance(); + auto fleet_ptr = framework::FleetWrapper::GetInstance(); input_channel_->Close(); std::vector pv_data; input_channel_->ReadAll(input_records_); @@ -1073,7 +1088,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim, } CHECK(multi_output_channel_.size() != 0); // NOLINT - auto fleet_ptr_ = FleetWrapper::GetInstance(); + auto fleet_ptr_ = framework::FleetWrapper::GetInstance(); std::vector>>& local_map_tables = fleet_ptr_->GetLocalTable(); local_map_tables.resize(shard_num); @@ -1315,7 +1330,7 @@ void MultiSlotDataset::MergeByInsId() { LOG(WARNING) << "total drop ins num: " << drop_ins_num; results.shrink_to_fit(); - auto fleet_ptr = FleetWrapper::GetInstance(); + auto fleet_ptr = framework::FleetWrapper::GetInstance(); std::shuffle(results.begin(), results.end(), fleet_ptr->LocalRandomEngine()); channel_data->Open(); channel_data->Write(std::move(results)); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 2a022ea4bb9..45087036b5d 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" +#include "paddle/fluid/platform/lodtensor_printer.h" #if defined PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/service/communicator.h" @@ -153,7 +154,20 @@ void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { if (need_dump_field_ || need_dump_param_) { InitDumpEnv(); } - VLOG(3) << "init other env done."; + +#ifdef PADDLE_WITH_PSCORE + // pull dense param first + auto communicator = paddle::distributed::Communicator::GetInstance(); + // for unittest which call train_from_dataset but does not call + // fleet.init_worker() first + if (communicator == nullptr) { + VLOG(0) << "MultiTrainer::InitOtherEnv Communicator is null!"; + } else { + auto& recv_ctx = communicator->GetRecvCtxMap(); + communicator->PullDense(recv_ctx); + VLOG(3) << "init other env done."; + } +#endif } Scope* MultiTrainer::GetWorkerScope(int thread_id) { @@ -253,6 +267,17 @@ void MultiTrainer::Finalize() { #ifdef PADDLE_WITH_HETERPS MergeDenseParam(); #endif + +#if defined PADDLE_WITH_PSCORE + auto communicator = paddle::distributed::Communicator::GetInstance(); + // for unittest which does not call fleet.init_worker() first + if (communicator == nullptr) { + VLOG(0) << "MultiTrainer::Finalize communicator is null!"; + } else { + communicator->_worker_ptr->flush(); + VLOG(1) << "MultiTrainer::Finalize ps client flush done"; + } +#endif root_scope_->DropKids(); } diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index a5f124b80b6..c229d82dd04 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -577,8 +577,12 @@ class TheOnePSRuntime(RuntimeBase): else: init_params = dense_map + import paddle.distributed.fleet as fleet if not is_test: self._communicator.init_params(init_params) + fleet.util.barrier() + self._communicator.pull_dense(init_params) + fleet.util.barrier() if not self._communicator.is_running(): self._communicator.start() -- GitLab