未验证 提交 97d30602 编写于 作者: T Thunderbrook 提交者: GitHub

[HeterPs]ps gpu dump (#36157)

* ps gpu dump

* remove log
上级 58c8f6b3
......@@ -454,7 +454,6 @@ class PSGPUWorker : public HogwildWorker {
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void SetNeedDump(bool need_dump_field);
virtual void SetChannelWriter(ChannelObject<std::string>* queue);
virtual void SetWorkerNum(int num) { worker_num_ = num; }
virtual void CacheProgram(const ProgramDesc& main_program) {
......@@ -467,7 +466,6 @@ class PSGPUWorker : public HogwildWorker {
protected:
void PushGradients();
void DumpParam();
void CopySparseTable();
void CopyDenseTable();
void CopyDenseVars();
......@@ -475,18 +473,12 @@ class PSGPUWorker : public HogwildWorker {
private:
int mpi_rank_;
std::mutex mutex_;
std::vector<std::string> send_var_list_;
int worker_num_;
ProgramDesc program_;
HeterObjectPool<HeterTask> object_pool_;
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_;
bool need_to_push_sparse_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
DownpourWorkerParameter param_;
float scale_datanorm_;
// just save the value in param_ for easy access
......
......@@ -29,9 +29,12 @@ namespace framework {
void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
dataset_ = dataset;
SetDataset(dataset);
thread_num_ = trainer_desc.thread_num();
param_ = trainer_desc.downpour_param();
ParseDumpConfig(trainer_desc);
mpi_rank_ = trainer_desc.mpi_rank();
mpi_size_ = trainer_desc.mpi_size();
for (int i = 0; i < param_.dense_table_size(); ++i) {
uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
auto table = param_.dense_table(i);
......@@ -44,6 +47,8 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
int place_num = trainer_desc.worker_places_size();
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
dump_file_num_ = trainer_desc.dump_file_num();
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
std::vector<int> dev_ids;
for (int i = 0; i < place_num; ++i) {
int num = trainer_desc.worker_places(i);
......@@ -64,6 +69,11 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetNeedDumpField(need_dump_field_);
workers_[i]->SetNeedDumpParam(need_dump_param_);
workers_[i]->SetDumpFieldVector(dump_fields_);
workers_[i]->SetDumpParamVector(dump_param_);
workers_[i]->InitRandomDumpConfig(trainer_desc);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetWorkerNum(place_num);
......@@ -71,7 +81,14 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
return;
}
void PSGPUTrainer::DumpWork(int tid) {}
std::string PSGPUTrainer::GetDumpPath(int tid) {
if (user_define_dump_filename_ != "") {
return string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(),
user_define_dump_filename_.c_str(), tid);
}
return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(),
mpi_rank_, tid);
}
void PSGPUTrainer::RegisterHeterCallback() {
/*
......@@ -124,7 +141,28 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program,
return;
}
void PSGPUTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>();
for (size_t i = 0; i < places_.size(); ++i) {
workers_[i]->SetChannelWriter(queue_.get());
}
dump_thread_num_ = 1;
if (dump_file_num_ > mpi_size_) {
dump_thread_num_ = dump_file_num_ / mpi_size_;
if (dump_file_num_ % mpi_size_ > mpi_rank_) {
dump_thread_num_ += 1;
}
}
for (int i = 0; i < dump_thread_num_; i++) {
dump_thread_.push_back(
std::thread(std::bind(&TrainerBase::DumpWork, this, i)));
}
}
void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_ || need_dump_param_) {
InitDumpEnv();
}
VLOG(3) << "init other env done.";
}
......@@ -204,6 +242,9 @@ void PSGPUTrainer::Finalize() {
}
}
MergeDenseParam();
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
root_scope_->DropKids();
}
} // namespace framework
......
......@@ -34,11 +34,6 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) {
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
mpi_rank_ = desc.mpi_rank();
trainer_desc_ = desc;
/*
for (int i = 0; i < trainer_desc_.xpu_recv_list_size(); ++i) {
send_var_list_.push_back(trainer_desc_.xpu_recv_list(i));
}
*/
for (int i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id());
......@@ -89,19 +84,7 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) {
no_cvm_ = desc.no_cvm();
scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size());
for (int i = 0; i < desc.dump_fields_size(); ++i) {
dump_fields_[i] = desc.dump_fields(i);
}
adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
dump_param_[i] = desc.dump_param(i);
}
if (desc.dump_param_size() != 0) {
need_dump_param_ = true;
}
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i));
}
......@@ -134,12 +117,6 @@ void PSGPUWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
writer_.Reset(queue);
}
void PSGPUWorker::SetNeedDump(bool need_dump_field) {
need_dump_field_ = need_dump_field;
}
void PSGPUWorker::DumpParam() {}
void PSGPUWorker::TrainFiles() {
platform::SetNumThreads(1);
platform::Timer timeline;
......@@ -150,6 +127,7 @@ void PSGPUWorker::TrainFiles() {
// how to accumulate fetched values here
device_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
total_ins_num += cur_batch;
for (auto& op : ops_) {
......@@ -164,9 +142,19 @@ void PSGPUWorker::TrainFiles() {
op->Run(*thread_scope_, place_);
}
}
if (need_dump_field_) {
DumpField(*thread_scope_, dump_mode_, dump_interval_);
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(*thread_scope_, batch_cnt);
}
PrintFetchVars();
thread_scope_->DropKids();
++batch_cnt;
}
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}
timeline.Pause();
VLOG(1) << "GpuPs worker " << thread_id_ << " train cost "
......
......@@ -258,13 +258,12 @@ class PSGPUTrainer : public TrainerBase {
virtual void Run();
virtual void Finalize();
virtual void RegisterHeterCallback();
virtual void DumpWork(int tid);
virtual Scope* GetWorkerScope(int thread_id);
virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program);
}
virtual std::string GetDumpPath(int tid) { return ""; }
virtual void InitDumpEnv() {}
virtual std::string GetDumpPath(int tid);
virtual void InitDumpEnv() override;
virtual void MergeDenseParam();
template <typename T>
......@@ -286,6 +285,9 @@ class PSGPUTrainer : public TrainerBase {
std::vector<std::thread> threads_;
int use_ps_gpu_;
int thread_num_;
int mpi_rank_;
int mpi_size_;
int dump_file_num_;
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册