未验证 提交 97d30602 编写于 作者: T Thunderbrook 提交者: GitHub

[HeterPs]ps gpu dump (#36157)

* ps gpu dump

* remove log
上级 58c8f6b3
...@@ -454,7 +454,6 @@ class PSGPUWorker : public HogwildWorker { ...@@ -454,7 +454,6 @@ class PSGPUWorker : public HogwildWorker {
virtual void Initialize(const TrainerDesc& desc); virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles(); virtual void TrainFiles();
virtual void TrainFilesWithProfiler(); virtual void TrainFilesWithProfiler();
virtual void SetNeedDump(bool need_dump_field);
virtual void SetChannelWriter(ChannelObject<std::string>* queue); virtual void SetChannelWriter(ChannelObject<std::string>* queue);
virtual void SetWorkerNum(int num) { worker_num_ = num; } virtual void SetWorkerNum(int num) { worker_num_ = num; }
virtual void CacheProgram(const ProgramDesc& main_program) { virtual void CacheProgram(const ProgramDesc& main_program) {
...@@ -467,7 +466,6 @@ class PSGPUWorker : public HogwildWorker { ...@@ -467,7 +466,6 @@ class PSGPUWorker : public HogwildWorker {
protected: protected:
void PushGradients(); void PushGradients();
void DumpParam();
void CopySparseTable(); void CopySparseTable();
void CopyDenseTable(); void CopyDenseTable();
void CopyDenseVars(); void CopyDenseVars();
...@@ -475,18 +473,12 @@ class PSGPUWorker : public HogwildWorker { ...@@ -475,18 +473,12 @@ class PSGPUWorker : public HogwildWorker {
private: private:
int mpi_rank_; int mpi_rank_;
std::mutex mutex_; std::mutex mutex_;
std::vector<std::string> send_var_list_;
int worker_num_; int worker_num_;
ProgramDesc program_; ProgramDesc program_;
HeterObjectPool<HeterTask> object_pool_; HeterObjectPool<HeterTask> object_pool_;
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_; bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_; bool dump_slot_;
bool need_to_push_sparse_; bool need_to_push_sparse_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
DownpourWorkerParameter param_; DownpourWorkerParameter param_;
float scale_datanorm_; float scale_datanorm_;
// just save the value in param_ for easy access // just save the value in param_ for easy access
......
...@@ -29,9 +29,12 @@ namespace framework { ...@@ -29,9 +29,12 @@ namespace framework {
void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
dataset_ = dataset; SetDataset(dataset);
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
param_ = trainer_desc.downpour_param(); param_ = trainer_desc.downpour_param();
ParseDumpConfig(trainer_desc);
mpi_rank_ = trainer_desc.mpi_rank();
mpi_size_ = trainer_desc.mpi_size();
for (int i = 0; i < param_.dense_table_size(); ++i) { for (int i = 0; i < param_.dense_table_size(); ++i) {
uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id()); uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
auto table = param_.dense_table(i); auto table = param_.dense_table(i);
...@@ -44,6 +47,8 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -44,6 +47,8 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
int place_num = trainer_desc.worker_places_size(); int place_num = trainer_desc.worker_places_size();
const std::vector<paddle::framework::DataFeed*> readers = const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders(); dataset->GetReaders();
dump_file_num_ = trainer_desc.dump_file_num();
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
std::vector<int> dev_ids; std::vector<int> dev_ids;
for (int i = 0; i < place_num; ++i) { for (int i = 0; i < place_num; ++i) {
int num = trainer_desc.worker_places(i); int num = trainer_desc.worker_places(i);
...@@ -64,6 +69,11 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -64,6 +69,11 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i); workers_[i]->SetDeviceIndex(i);
workers_[i]->SetNeedDumpField(need_dump_field_);
workers_[i]->SetNeedDumpParam(need_dump_param_);
workers_[i]->SetDumpFieldVector(dump_fields_);
workers_[i]->SetDumpParamVector(dump_param_);
workers_[i]->InitRandomDumpConfig(trainer_desc);
workers_[i]->SetDataFeed(readers[i]); workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc); workers_[i]->Initialize(trainer_desc);
workers_[i]->SetWorkerNum(place_num); workers_[i]->SetWorkerNum(place_num);
...@@ -71,7 +81,14 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -71,7 +81,14 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
return; return;
} }
void PSGPUTrainer::DumpWork(int tid) {} std::string PSGPUTrainer::GetDumpPath(int tid) {
if (user_define_dump_filename_ != "") {
return string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(),
user_define_dump_filename_.c_str(), tid);
}
return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(),
mpi_rank_, tid);
}
void PSGPUTrainer::RegisterHeterCallback() { void PSGPUTrainer::RegisterHeterCallback() {
/* /*
...@@ -124,7 +141,28 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program, ...@@ -124,7 +141,28 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program,
return; return;
} }
void PSGPUTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>();
for (size_t i = 0; i < places_.size(); ++i) {
workers_[i]->SetChannelWriter(queue_.get());
}
dump_thread_num_ = 1;
if (dump_file_num_ > mpi_size_) {
dump_thread_num_ = dump_file_num_ / mpi_size_;
if (dump_file_num_ % mpi_size_ > mpi_rank_) {
dump_thread_num_ += 1;
}
}
for (int i = 0; i < dump_thread_num_; i++) {
dump_thread_.push_back(
std::thread(std::bind(&TrainerBase::DumpWork, this, i)));
}
}
void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) { void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_ || need_dump_param_) {
InitDumpEnv();
}
VLOG(3) << "init other env done."; VLOG(3) << "init other env done.";
} }
...@@ -204,6 +242,9 @@ void PSGPUTrainer::Finalize() { ...@@ -204,6 +242,9 @@ void PSGPUTrainer::Finalize() {
} }
} }
MergeDenseParam(); MergeDenseParam();
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
root_scope_->DropKids(); root_scope_->DropKids();
} }
} // namespace framework } // namespace framework
......
...@@ -34,11 +34,6 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) { ...@@ -34,11 +34,6 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) {
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
mpi_rank_ = desc.mpi_rank(); mpi_rank_ = desc.mpi_rank();
trainer_desc_ = desc; trainer_desc_ = desc;
/*
for (int i = 0; i < trainer_desc_.xpu_recv_list_size(); ++i) {
send_var_list_.push_back(trainer_desc_.xpu_recv_list(i));
}
*/
for (int i = 0; i < param_.sparse_table_size(); ++i) { for (int i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id = uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id()); static_cast<uint64_t>(param_.sparse_table(i).table_id());
...@@ -89,19 +84,7 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) { ...@@ -89,19 +84,7 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) {
no_cvm_ = desc.no_cvm(); no_cvm_ = desc.no_cvm();
scale_datanorm_ = desc.scale_datanorm(); scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot(); dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size());
for (int i = 0; i < desc.dump_fields_size(); ++i) {
dump_fields_[i] = desc.dump_fields(i);
}
adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
dump_param_[i] = desc.dump_param(i);
}
if (desc.dump_param_size() != 0) {
need_dump_param_ = true;
}
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i)); check_nan_var_names_.push_back(desc.check_nan_var_names(i));
} }
...@@ -134,12 +117,6 @@ void PSGPUWorker::SetChannelWriter(ChannelObject<std::string>* queue) { ...@@ -134,12 +117,6 @@ void PSGPUWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
writer_.Reset(queue); writer_.Reset(queue);
} }
void PSGPUWorker::SetNeedDump(bool need_dump_field) {
need_dump_field_ = need_dump_field;
}
void PSGPUWorker::DumpParam() {}
void PSGPUWorker::TrainFiles() { void PSGPUWorker::TrainFiles() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
platform::Timer timeline; platform::Timer timeline;
...@@ -150,6 +127,7 @@ void PSGPUWorker::TrainFiles() { ...@@ -150,6 +127,7 @@ void PSGPUWorker::TrainFiles() {
// how to accumulate fetched values here // how to accumulate fetched values here
device_reader_->Start(); device_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
total_ins_num += cur_batch; total_ins_num += cur_batch;
for (auto& op : ops_) { for (auto& op : ops_) {
...@@ -164,9 +142,19 @@ void PSGPUWorker::TrainFiles() { ...@@ -164,9 +142,19 @@ void PSGPUWorker::TrainFiles() {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
} }
if (need_dump_field_) {
DumpField(*thread_scope_, dump_mode_, dump_interval_);
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(*thread_scope_, batch_cnt);
}
PrintFetchVars(); PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
++batch_cnt;
}
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
} }
timeline.Pause(); timeline.Pause();
VLOG(1) << "GpuPs worker " << thread_id_ << " train cost " VLOG(1) << "GpuPs worker " << thread_id_ << " train cost "
......
...@@ -258,13 +258,12 @@ class PSGPUTrainer : public TrainerBase { ...@@ -258,13 +258,12 @@ class PSGPUTrainer : public TrainerBase {
virtual void Run(); virtual void Run();
virtual void Finalize(); virtual void Finalize();
virtual void RegisterHeterCallback(); virtual void RegisterHeterCallback();
virtual void DumpWork(int tid);
virtual Scope* GetWorkerScope(int thread_id); virtual Scope* GetWorkerScope(int thread_id);
virtual void CacheProgram(const ProgramDesc& main_program) { virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program); new (&program_) ProgramDesc(main_program);
} }
virtual std::string GetDumpPath(int tid) { return ""; } virtual std::string GetDumpPath(int tid);
virtual void InitDumpEnv() {} virtual void InitDumpEnv() override;
virtual void MergeDenseParam(); virtual void MergeDenseParam();
template <typename T> template <typename T>
...@@ -286,6 +285,9 @@ class PSGPUTrainer : public TrainerBase { ...@@ -286,6 +285,9 @@ class PSGPUTrainer : public TrainerBase {
std::vector<std::thread> threads_; std::vector<std::thread> threads_;
int use_ps_gpu_; int use_ps_gpu_;
int thread_num_; int thread_num_;
int mpi_rank_;
int mpi_size_;
int dump_file_num_;
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册