提交 39449ba0 编写于 作者: X xujiaqi01 提交者: dongdaxiang

fix bug && add DestroyReaders in trainer

上级 3641a78b
...@@ -314,21 +314,21 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -314,21 +314,21 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
// todo get ins id // todo get ins id
// std::string ins_id = memory_data_[i].ins_id; // std::string ins_id = memory_data_[i].ins_id;
// todo hash // todo hash
int64_t random_num = fleet_ptr->local_random_engine()(); int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t node_id = random_num % trainer_num_; int64_t node_id = random_num % trainer_num_;
std::string str; std::string str;
SerializeIns((*memory_data_)[i], &str); SerializeIns((*memory_data_)[i], &str);
send_str_vec[node_id] += str; send_str_vec[node_id] += str;
if (i % fleet_send_batch_size_ == 0 && i != 0) { if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_str_vec.size(); ++j) { for (int j = 0; j < send_str_vec.size(); ++j) {
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]); fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]);
send_str_vec[j] = ""; send_str_vec[j] = "";
} }
} }
} }
for (int j = 0; j < send_str_vec.size(); ++j) { for (int j = 0; j < send_str_vec.size(); ++j) {
if (send_str_vec[j].length() != 0) { if (send_str_vec[j].length() != 0) {
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]); fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]);
} }
} }
} }
......
...@@ -117,8 +117,8 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -117,8 +117,8 @@ void DatasetImpl<T>::GlobalShuffle() {
// if it is not InMemory, memory_data_ is empty // if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end()); std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "registe_client2client_msg_handler"; VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->registe_client2client_msg_handler(0, fleet_ptr->RegisterClientToClientMsgHandler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int { [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg); return this->ReceiveFromClient(msg_type, client_id, msg);
}); });
......
...@@ -25,6 +25,7 @@ namespace framework { ...@@ -25,6 +25,7 @@ namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
workers_.resize(thread_num_); workers_.resize(thread_num_);
dataset->CreateReaders(); dataset->CreateReaders();
...@@ -55,6 +56,7 @@ void DistMultiTrainer::Finalize() { ...@@ -55,6 +56,7 @@ void DistMultiTrainer::Finalize() {
th.join(); th.join();
} }
pull_dense_worker_->Stop(); pull_dense_worker_->Stop();
dataset_ptr_->DestroyReaders();
} }
} // end namespace framework } // end namespace framework
......
...@@ -292,21 +292,31 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -292,21 +292,31 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
int FleetWrapper::registe_client2client_msg_handler( int FleetWrapper::RegisterClientToClientMsgHandler(
int msg_type, MsgHandlerFunc handler) { int msg_type, MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler( pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(
msg_type, handler); msg_type, handler);
#else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib";
#endif
return 0; return 0;
} }
int FleetWrapper::send_client2client_msg( int FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) { int msg_type, int to_client_id, const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->send_client2client_msg( pslib_ptr_->_worker_ptr->send_client2client_msg(
msg_type, to_client_id, msg); msg_type, to_client_id, msg);
#else
VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib";
#endif
return 0; return 0;
} }
std::default_random_engine& FleetWrapper::local_random_engine() { std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t { struct engine_wrapper_t {
std::default_random_engine engine; std::default_random_engine engine;
engine_wrapper_t() { engine_wrapper_t() {
...@@ -330,7 +340,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) { ...@@ -330,7 +340,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
ar << t; ar << t;
*str = std::string(ar.buffer(), ar.length()); *str = std::string(ar.buffer(), ar.length());
#else #else
VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib"; VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib";
#endif #endif
} }
...@@ -341,7 +351,7 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) { ...@@ -341,7 +351,7 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) {
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr); ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
*t = ar.get<T>(); *t = ar.get<T>();
#else #else
VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib"; VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib";
#endif #endif
} }
......
...@@ -115,11 +115,11 @@ class FleetWrapper { ...@@ -115,11 +115,11 @@ class FleetWrapper {
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc; typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler); int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
int send_client2client_msg(int msg_type, int SendClientToClientMsg(int msg_type,
int to_client_id, int to_client_id,
const std::string& msg); const std::string& msg);
std::default_random_engine& local_random_engine(); std::default_random_engine& LocalRandomEngine();
template<typename T> template<typename T>
void Serialize(const T& t, std::string* str); void Serialize(const T& t, std::string* str);
......
...@@ -24,6 +24,7 @@ namespace framework { ...@@ -24,6 +24,7 @@ namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
// get filelist from trainer_desc here // get filelist from trainer_desc here
workers_.resize(thread_num_); workers_.resize(thread_num_);
VLOG(3) << "worker thread num: " << thread_num_; VLOG(3) << "worker thread num: " << thread_num_;
...@@ -65,7 +66,7 @@ void MultiTrainer::Finalize() { ...@@ -65,7 +66,7 @@ void MultiTrainer::Finalize() {
for (auto& th : threads_) { for (auto& th : threads_) {
th.join(); th.join();
} }
// todo dataset->DestroyReaders(); dataset_ptr_->DestroyReaders();
} }
} // end namespace framework } // end namespace framework
......
...@@ -41,6 +41,7 @@ class TrainerBase { ...@@ -41,6 +41,7 @@ class TrainerBase {
// model memory are hosted in root_scope // model memory are hosted in root_scope
void SetScope(Scope* root_scope); void SetScope(Scope* root_scope);
void SetDebug(const bool debug) { debug_ = debug; } void SetDebug(const bool debug) { debug_ = debug; }
void SetDataset(Dataset* dataset_ptr) { dataset_ptr_ = dataset_ptr; }
virtual void Initialize(const TrainerDesc& trainer_desc, virtual void Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set) = 0; Dataset* data_set) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program, virtual void InitTrainerEnv(const ProgramDesc& main_program,
...@@ -52,6 +53,7 @@ class TrainerBase { ...@@ -52,6 +53,7 @@ class TrainerBase {
protected: protected:
Scope* root_scope_; Scope* root_scope_;
bool debug_; bool debug_;
Dataset* dataset_ptr_;
}; };
// general trainer for async execution // general trainer for async execution
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册