未验证 提交 23bbd912 编写于 作者: Z zmxdream 提交者: GitHub

config fleet optimize. test=develop (#39849)

上级 2ec943a7
...@@ -46,6 +46,48 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -46,6 +46,48 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
dense_grad_names_[table_id][j] = table.dense_grad_name(j); dense_grad_names_[table_id][j] = table.dense_grad_name(j);
} }
} }
InitializeGPUServer(trainer_desc);
scale_datanorm_ = trainer_desc.scale_datanorm();
int place_num = trainer_desc.worker_places_size();
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
dump_file_num_ = trainer_desc.dump_file_num();
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
std::vector<int> dev_ids;
for (int i = 0; i < place_num; ++i) {
int num = trainer_desc.worker_places(i);
platform::CUDAPlace place = platform::CUDAPlace(num);
places_.push_back(place);
dev_ids.push_back(num);
}
for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
i++) {
need_merge_var_names_.push_back(
trainer_desc.downpour_param().stat_var_names(i));
}
VLOG(3) << "going to initialize pull dense worker";
SetDebug(trainer_desc.debug());
trainer_desc_ = trainer_desc;
workers_.resize(place_num);
for (int i = 0; i < place_num; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetNeedDumpField(need_dump_field_);
workers_[i]->SetNeedDumpParam(need_dump_param_);
workers_[i]->SetDumpFieldVector(dump_fields_);
workers_[i]->SetDumpParamVector(dump_param_);
workers_[i]->InitRandomDumpConfig(trainer_desc);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->SetPlace(places_[i]);
workers_[i]->SetReaderPlace(places_[i]);
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetWorkerNum(place_num);
}
return;
}
void PSGPUTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) {
// add for hbmps optimizer config // add for hbmps optimizer config
auto fleet_desc_str = trainer_desc.fleet_desc(); auto fleet_desc_str = trainer_desc.fleet_desc();
google::protobuf::TextFormat::ParseFromString(fleet_desc_str, &_ps_param); google::protobuf::TextFormat::ParseFromString(fleet_desc_str, &_ps_param);
...@@ -203,45 +245,6 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -203,45 +245,6 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance(); auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance();
ps_gpu_wrapper->InitializeGPUServer(config); ps_gpu_wrapper->InitializeGPUServer(config);
scale_datanorm_ = trainer_desc.scale_datanorm();
int place_num = trainer_desc.worker_places_size();
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
dump_file_num_ = trainer_desc.dump_file_num();
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
std::vector<int> dev_ids;
for (int i = 0; i < place_num; ++i) {
int num = trainer_desc.worker_places(i);
platform::CUDAPlace place = platform::CUDAPlace(num);
places_.push_back(place);
dev_ids.push_back(num);
}
for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
i++) {
need_merge_var_names_.push_back(
trainer_desc.downpour_param().stat_var_names(i));
}
VLOG(3) << "going to initialize pull dense worker";
SetDebug(trainer_desc.debug());
trainer_desc_ = trainer_desc;
workers_.resize(place_num);
for (int i = 0; i < place_num; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetNeedDumpField(need_dump_field_);
workers_[i]->SetNeedDumpParam(need_dump_param_);
workers_[i]->SetDumpFieldVector(dump_fields_);
workers_[i]->SetDumpParamVector(dump_param_);
workers_[i]->InitRandomDumpConfig(trainer_desc);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->SetPlace(places_[i]);
workers_[i]->SetReaderPlace(places_[i]);
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetWorkerNum(place_num);
}
return;
} }
std::string PSGPUTrainer::GetDumpPath(int tid) { std::string PSGPUTrainer::GetDumpPath(int tid) {
......
...@@ -271,6 +271,7 @@ class PSGPUTrainer : public TrainerBase { ...@@ -271,6 +271,7 @@ class PSGPUTrainer : public TrainerBase {
template <typename T> template <typename T>
void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor);
void InitializeGPUServer(const TrainerDesc& trainer_desc);
protected: protected:
Dataset* dataset_; Dataset* dataset_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册