diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 5cc1b8a6e3f09798a7e41326680375675936d66e..14daf9448b3cb25f5d91b68616fa9f375b89d5c9 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -314,21 +314,21 @@ void InMemoryDataFeed::GlobalShuffle() { // todo get ins id // std::string ins_id = memory_data_[i].ins_id; // todo hash - int64_t random_num = fleet_ptr->local_random_engine()(); + int64_t random_num = fleet_ptr->LocalRandomEngine()(); int64_t node_id = random_num % trainer_num_; std::string str; SerializeIns((*memory_data_)[i], &str); send_str_vec[node_id] += str; if (i % fleet_send_batch_size_ == 0 && i != 0) { for (int j = 0; j < send_str_vec.size(); ++j) { - fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]); + fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]); send_str_vec[j] = ""; } } } for (int j = 0; j < send_str_vec.size(); ++j) { if (send_str_vec[j].length() != 0) { - fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]); + fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]); } } } diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index adeadf0cecf89f92c7434b89c140350d2341ae8c..28cfbed4f419c1bb2ef3ca49a987abf8f7e260db 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -117,8 +117,8 @@ void DatasetImpl::GlobalShuffle() { // if it is not InMemory, memory_data_ is empty std::random_shuffle(memory_data_.begin(), memory_data_.end()); auto fleet_ptr = FleetWrapper::GetInstance(); - VLOG(3) << "registe_client2client_msg_handler"; - fleet_ptr->registe_client2client_msg_handler(0, + VLOG(3) << "RegisterClientToClientMsgHandler"; + fleet_ptr->RegisterClientToClientMsgHandler(0, [this](int msg_type, int client_id, const std::string& msg) -> int { return this->ReceiveFromClient(msg_type, client_id, msg); }); diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 1bc6dd08d76d74df5dbc262625d8e3455fd1025b..4f177574b63ce9515b9d065f85fe33b1b7b43386 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -25,6 +25,7 @@ namespace framework { void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); + SetDataset(dataset); workers_.resize(thread_num_); dataset->CreateReaders(); @@ -55,6 +56,7 @@ void DistMultiTrainer::Finalize() { th.join(); } pull_dense_worker_->Stop(); + dataset_ptr_->DestroyReaders(); } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 2696259f5501d70598fb16f39e5fa55ea984ccb2..ac6ee6c0246afebaa746a21c2c4e09e4fd092736 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -292,21 +292,31 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( #endif } -int FleetWrapper::registe_client2client_msg_handler( +int FleetWrapper::RegisterClientToClientMsgHandler( int msg_type, MsgHandlerFunc handler) { +#ifdef PADDLE_WITH_PSLIB pslib_ptr_->_worker_ptr->registe_client2client_msg_handler( msg_type, handler); +#else + VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler" + << " does nothing when no pslib"; +#endif return 0; } -int FleetWrapper::send_client2client_msg( +int FleetWrapper::SendClientToClientMsg( int msg_type, int to_client_id, const std::string& msg) { +#ifdef PADDLE_WITH_PSLIB pslib_ptr_->_worker_ptr->send_client2client_msg( msg_type, to_client_id, msg); +#else + VLOG(0) << "FleetWrapper::SendClientToClientMsg" + << " does nothing when no pslib"; +#endif return 0; } -std::default_random_engine& FleetWrapper::local_random_engine() { +std::default_random_engine& FleetWrapper::LocalRandomEngine() { struct engine_wrapper_t { std::default_random_engine engine; engine_wrapper_t() { @@ -330,7 +340,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) { ar << t; *str = std::string(ar.buffer(), ar.length()); #else - VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib"; + VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib"; #endif } @@ -341,7 +351,7 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) { ar.set_read_buffer(const_cast(str.c_str()), str.length(), nullptr); *t = ar.get(); #else - VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib"; + VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib"; #endif } diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 0e2027fcf893d38603fcb4fc460395c149303007..a649679b0d829ed9692bf9b4b9fadec474a1b49c 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -115,11 +115,11 @@ class FleetWrapper { void GatherServers(const std::vector& host_sign_list, int node_num); typedef std::function MsgHandlerFunc; - int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler); - int send_client2client_msg(int msg_type, - int to_client_id, - const std::string& msg); - std::default_random_engine& local_random_engine(); + int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); + int SendClientToClientMsg(int msg_type, + int to_client_id, + const std::string& msg); + std::default_random_engine& LocalRandomEngine(); template void Serialize(const T& t, std::string* str); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index c3b38fadedc08bd8e916a1cebe08fecee4409d28..a5edbe5fb3bc7519de3f85986f4825af5eed1418 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -24,6 +24,7 @@ namespace framework { void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); + SetDataset(dataset); // get filelist from trainer_desc here workers_.resize(thread_num_); VLOG(3) << "worker thread num: " << thread_num_; @@ -65,7 +66,7 @@ void MultiTrainer::Finalize() { for (auto& th : threads_) { th.join(); } - // todo dataset->DestroyReaders(); + dataset_ptr_->DestroyReaders(); } } // end namespace framework diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 1cdc207c389c932bd23bb7a1c6ce90c6cbabec22..e57e04068b387262e479ac2328f81969f4e6f7d9 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -41,6 +41,7 @@ class TrainerBase { // model memory are hosted in root_scope void SetScope(Scope* root_scope); void SetDebug(const bool debug) { debug_ = debug; } + void SetDataset(Dataset* dataset_ptr) { dataset_ptr_ = dataset_ptr; } virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) = 0; virtual void InitTrainerEnv(const ProgramDesc& main_program, @@ -52,6 +53,7 @@ class TrainerBase { protected: Scope* root_scope_; bool debug_; + Dataset* dataset_ptr_; }; // general trainer for async execution