From 39449ba0b9bd8d235f6e353f924d50acebc00faf Mon Sep 17 00:00:00 2001 From: xujiaqi01 Date: Wed, 13 Mar 2019 18:28:26 +0800 Subject: [PATCH] fix bug && add DestroyReaders in trainer --- paddle/fluid/framework/data_feed.cc | 6 +++--- paddle/fluid/framework/data_set.cc | 4 ++-- paddle/fluid/framework/dist_multi_trainer.cc | 2 ++ paddle/fluid/framework/fleet/fleet_wrapper.cc | 20 ++++++++++++++----- paddle/fluid/framework/fleet/fleet_wrapper.h | 10 +++++----- paddle/fluid/framework/multi_trainer.cc | 3 ++- paddle/fluid/framework/trainer.h | 2 ++ 7 files changed, 31 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 5cc1b8a6e..14daf9448 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -314,21 +314,21 @@ void InMemoryDataFeed::GlobalShuffle() { // todo get ins id // std::string ins_id = memory_data_[i].ins_id; // todo hash - int64_t random_num = fleet_ptr->local_random_engine()(); + int64_t random_num = fleet_ptr->LocalRandomEngine()(); int64_t node_id = random_num % trainer_num_; std::string str; SerializeIns((*memory_data_)[i], &str); send_str_vec[node_id] += str; if (i % fleet_send_batch_size_ == 0 && i != 0) { for (int j = 0; j < send_str_vec.size(); ++j) { - fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]); + fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]); send_str_vec[j] = ""; } } } for (int j = 0; j < send_str_vec.size(); ++j) { if (send_str_vec[j].length() != 0) { - fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]); + fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]); } } } diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index adeadf0ce..28cfbed4f 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -117,8 +117,8 @@ void DatasetImpl::GlobalShuffle() { // if it is not InMemory, memory_data_ is empty std::random_shuffle(memory_data_.begin(), memory_data_.end()); auto fleet_ptr = FleetWrapper::GetInstance(); - VLOG(3) << "registe_client2client_msg_handler"; - fleet_ptr->registe_client2client_msg_handler(0, + VLOG(3) << "RegisterClientToClientMsgHandler"; + fleet_ptr->RegisterClientToClientMsgHandler(0, [this](int msg_type, int client_id, const std::string& msg) -> int { return this->ReceiveFromClient(msg_type, client_id, msg); }); diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 1bc6dd08d..4f177574b 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -25,6 +25,7 @@ namespace framework { void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); + SetDataset(dataset); workers_.resize(thread_num_); dataset->CreateReaders(); @@ -55,6 +56,7 @@ void DistMultiTrainer::Finalize() { th.join(); } pull_dense_worker_->Stop(); + dataset_ptr_->DestroyReaders(); } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 2696259f5..ac6ee6c02 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -292,21 +292,31 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( #endif } -int FleetWrapper::registe_client2client_msg_handler( +int FleetWrapper::RegisterClientToClientMsgHandler( int msg_type, MsgHandlerFunc handler) { +#ifdef PADDLE_WITH_PSLIB pslib_ptr_->_worker_ptr->registe_client2client_msg_handler( msg_type, handler); +#else + VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler" + << " does nothing when no pslib"; +#endif return 0; } -int FleetWrapper::send_client2client_msg( +int FleetWrapper::SendClientToClientMsg( int msg_type, int to_client_id, const std::string& msg) { +#ifdef PADDLE_WITH_PSLIB pslib_ptr_->_worker_ptr->send_client2client_msg( msg_type, to_client_id, msg); +#else + VLOG(0) << "FleetWrapper::SendClientToClientMsg" + << " does nothing when no pslib"; +#endif return 0; } -std::default_random_engine& FleetWrapper::local_random_engine() { +std::default_random_engine& FleetWrapper::LocalRandomEngine() { struct engine_wrapper_t { std::default_random_engine engine; engine_wrapper_t() { @@ -330,7 +340,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) { ar << t; *str = std::string(ar.buffer(), ar.length()); #else - VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib"; + VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib"; #endif } @@ -341,7 +351,7 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) { ar.set_read_buffer(const_cast(str.c_str()), str.length(), nullptr); *t = ar.get(); #else - VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib"; + VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib"; #endif } diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 0e2027fcf..a649679b0 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -115,11 +115,11 @@ class FleetWrapper { void GatherServers(const std::vector& host_sign_list, int node_num); typedef std::function MsgHandlerFunc; - int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler); - int send_client2client_msg(int msg_type, - int to_client_id, - const std::string& msg); - std::default_random_engine& local_random_engine(); + int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); + int SendClientToClientMsg(int msg_type, + int to_client_id, + const std::string& msg); + std::default_random_engine& LocalRandomEngine(); template void Serialize(const T& t, std::string* str); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index c3b38fade..a5edbe5fb 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -24,6 +24,7 @@ namespace framework { void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); + SetDataset(dataset); // get filelist from trainer_desc here workers_.resize(thread_num_); VLOG(3) << "worker thread num: " << thread_num_; @@ -65,7 +66,7 @@ void MultiTrainer::Finalize() { for (auto& th : threads_) { th.join(); } - // todo dataset->DestroyReaders(); + dataset_ptr_->DestroyReaders(); } } // end namespace framework diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 1cdc207c3..e57e04068 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -41,6 +41,7 @@ class TrainerBase { // model memory are hosted in root_scope void SetScope(Scope* root_scope); void SetDebug(const bool debug) { debug_ = debug; } + void SetDataset(Dataset* dataset_ptr) { dataset_ptr_ = dataset_ptr; } virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) = 0; virtual void InitTrainerEnv(const ProgramDesc& main_program, @@ -52,6 +53,7 @@ class TrainerBase { protected: Scope* root_scope_; bool debug_; + Dataset* dataset_ptr_; }; // general trainer for async execution -- GitLab