提交 39449ba0 编写于 作者: X xujiaqi01 提交者: dongdaxiang

fix bug && add DestroyReaders in trainer

上级 3641a78b
......@@ -314,21 +314,21 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
// todo get ins id
// std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t random_num = fleet_ptr->local_random_engine()();
int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t node_id = random_num % trainer_num_;
std::string str;
SerializeIns((*memory_data_)[i], &str);
send_str_vec[node_id] += str;
if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_str_vec.size(); ++j) {
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]);
fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]);
send_str_vec[j] = "";
}
}
}
for (int j = 0; j < send_str_vec.size(); ++j) {
if (send_str_vec[j].length() != 0) {
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]);
fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]);
}
}
}
......
......@@ -117,8 +117,8 @@ void DatasetImpl<T>::GlobalShuffle() {
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "registe_client2client_msg_handler";
fleet_ptr->registe_client2client_msg_handler(0,
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
......
......@@ -25,6 +25,7 @@ namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
workers_.resize(thread_num_);
dataset->CreateReaders();
......@@ -55,6 +56,7 @@ void DistMultiTrainer::Finalize() {
th.join();
}
pull_dense_worker_->Stop();
dataset_ptr_->DestroyReaders();
}
} // end namespace framework
......
......@@ -292,21 +292,31 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
int FleetWrapper::registe_client2client_msg_handler(
int FleetWrapper::RegisterClientToClientMsgHandler(
int msg_type, MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(
msg_type, handler);
#else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib";
#endif
return 0;
}
int FleetWrapper::send_client2client_msg(
int FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->send_client2client_msg(
msg_type, to_client_id, msg);
#else
VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib";
#endif
return 0;
}
std::default_random_engine& FleetWrapper::local_random_engine() {
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
......@@ -330,7 +340,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
ar << t;
*str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib";
VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib";
#endif
}
......@@ -341,7 +351,7 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) {
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
*t = ar.get<T>();
#else
VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib";
VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib";
#endif
}
......
......@@ -115,11 +115,11 @@ class FleetWrapper {
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler);
int send_client2client_msg(int msg_type,
int to_client_id,
const std::string& msg);
std::default_random_engine& local_random_engine();
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
int SendClientToClientMsg(int msg_type,
int to_client_id,
const std::string& msg);
std::default_random_engine& LocalRandomEngine();
template<typename T>
void Serialize(const T& t, std::string* str);
......
......@@ -24,6 +24,7 @@ namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
// get filelist from trainer_desc here
workers_.resize(thread_num_);
VLOG(3) << "worker thread num: " << thread_num_;
......@@ -65,7 +66,7 @@ void MultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
// todo dataset->DestroyReaders();
dataset_ptr_->DestroyReaders();
}
} // end namespace framework
......
......@@ -41,6 +41,7 @@ class TrainerBase {
// model memory are hosted in root_scope
void SetScope(Scope* root_scope);
void SetDebug(const bool debug) { debug_ = debug; }
void SetDataset(Dataset* dataset_ptr) { dataset_ptr_ = dataset_ptr; }
virtual void Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program,
......@@ -52,6 +53,7 @@ class TrainerBase {
protected:
Scope* root_scope_;
bool debug_;
Dataset* dataset_ptr_;
};
// general trainer for async execution
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册