diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 78aced804c3da7b6a09989d625e9bfcc0b2f1c67..2ea3c10fd87beceb3b1a9ea95effa2d4f46480bd 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -30,11 +30,19 @@ namespace framework { class HeterContext { public: + ~HeterContext() { + for (size_t i = 0; i < mutex_.size(); ++i) { + delete mutex_[i]; + } + mutex_.clear(); + } Scope* scope_{nullptr}; std::vector> feature_keys_; std::vector> value_ptr_; - std::vector> feature_values_; - std::vector mutex_lock_; + std::vector> device_values_; + std::vector> device_keys_; + std::vector mutex_; + uint32_t shard_num_ = 37; uint64_t size() { uint64_t total_size = 0; @@ -45,19 +53,28 @@ class HeterContext { } void SetShardNum(uint32_t shard_num) { shard_num_ = shard_num; } uint32_t ShardNum() { return shard_num_; } - void init() { feature_keys_.resize(shard_num_); } + void init(int shard_num, int device_num) { + shard_num_ = shard_num; + feature_keys_.resize(shard_num_); + value_ptr_.resize(shard_num_); + + device_values_.resize(device_num); + device_keys_.resize(device_num); + mutex_.resize(device_num); + for (size_t i = 0; i < mutex_.size(); ++i) { + mutex_[i] = new std::mutex(); + } + } void batch_add_keys(const std::vector>& thread_keys) { assert(thread_keys.size() == feature_keys_.size()); for (uint32_t i = 0; i < shard_num_; i++) { int idx = 0; - // mutex_lock_[i]->lock(); idx = feature_keys_[i].size(); feature_keys_[i].resize(feature_keys_[i].size() + thread_keys[i].size()); for (uint64_t j = 0; j < thread_keys[i].size(); j++) { feature_keys_[i][idx + j] = thread_keys[i][j]; } - // mutex_lock_[i]->unlock(); } } void UniqueKeys() { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index a3c90fa944fb2e69e62d440c13821f9aa543637e..67b24a3b037665d90dcb5d060f7bce4156bf515b 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -40,16 +40,22 @@ namespace framework { std::shared_ptr PSGPUWrapper::s_instance_ = NULL; bool PSGPUWrapper::is_initialized_ = false; -void PSGPUWrapper::BuildTask(uint64_t table_id, int feature_dim) { +void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, + uint64_t table_id, int feature_dim) { VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin"; platform::Timer timeline; timeline.Start(); + int device_num = heter_devices_.size(); MultiSlotDataset* dataset = dynamic_cast(dataset_); - std::shared_ptr gpu_task = gpu_task_pool_.Get(); + gpu_task->init(thread_keys_shard_num_, device_num); auto input_channel = dataset->GetInputChannel(); auto& local_keys = gpu_task->feature_keys_; - auto& local_values = gpu_task->feature_values_; auto& local_ptr = gpu_task->value_ptr_; + + auto& device_keys = gpu_task->device_keys_; + auto& device_vals = gpu_task->device_values_; + auto& device_mutex = gpu_task->mutex_; + std::vector threads; auto fleet_ptr = FleetWrapper::GetInstance(); @@ -91,12 +97,11 @@ void PSGPUWrapper::BuildTask(uint64_t table_id, int feature_dim) { t.join(); } timeline.Pause(); - VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; + VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; timeline.Start(); // merge thread_keys to shard_keys - gpu_task->init(); for (size_t i = 0; i < thread_keys_.size(); i++) { gpu_task->batch_add_keys(thread_keys_[i]); for (int j = 0; j < thread_keys_thread_num_; j++) { @@ -105,21 +110,20 @@ void PSGPUWrapper::BuildTask(uint64_t table_id, int feature_dim) { } timeline.Pause(); - VLOG(0) << "GpuPs task unique11111 cost " << timeline.ElapsedSec() + VLOG(1) << "GpuPs task unique11111 cost " << timeline.ElapsedSec() << " seconds."; - VLOG(0) << "FK1"; timeline.Start(); gpu_task->UniqueKeys(); timeline.Pause(); - VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; + VLOG(1) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; for (int i = 0; i < thread_keys_shard_num_; i++) { - local_values[i].resize(local_keys[i].size()); + VLOG(3) << "GpuPs shard: " << i << " key len: " << local_keys[i].size(); local_ptr[i].resize(local_keys[i].size()); } - auto ptl_func = [this, &local_keys, &local_values, &local_ptr, &table_id, + auto ptl_func = [this, &local_keys, &local_ptr, &table_id, &fleet_ptr](int i) { size_t key_size = local_keys[i].size(); auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( @@ -136,68 +140,102 @@ void PSGPUWrapper::BuildTask(uint64_t table_id, int feature_dim) { VLOG(3) << "FleetWrapper Pull sparse to local done with table size: " << local_keys[i].size(); } - for (size_t num = 0; num < local_ptr[i].size(); ++num) { - float* ptr_val = local_ptr[i][num]->data(); - FeatureValue& val = local_values[i][num]; - size_t dim = local_ptr[i][num]->size(); - - val.delta_score = ptr_val[1]; - val.show = ptr_val[2]; - val.clk = ptr_val[3]; - val.slot = ptr_val[6]; - val.lr = ptr_val[4]; - val.lr_g2sum = ptr_val[5]; - - if (dim > 7) { - val.mf_size = MF_DIM + 1; - for (int x = 0; x < val.mf_size; x++) { - val.mf[x] = ptr_val[x + 7]; - } - } else { - val.mf_size = 0; - for (int x = 0; x < MF_DIM + 1; x++) { - val.mf[x] = 0; + }; + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(ptl_func, i); + } + for (std::thread& t : threads) { + t.join(); + } + timeline.Pause(); + VLOG(1) << "GpuPs pull sparse cost " << timeline.ElapsedSec() << " seconds."; + + timeline.Start(); + auto build_func = [device_num, &local_keys, &local_ptr, &device_keys, + &device_vals, &device_mutex](int i) { + std::vector> task_keys(device_num); + std::vector> task_ptrs( + device_num); + + for (size_t j = 0; j < local_keys[i].size(); j++) { + int shard = local_keys[i][j] % device_num; + task_keys[shard].push_back(local_keys[i][j]); + task_ptrs[shard].push_back(local_ptr[i][j]); + } + + for (int dev = 0; dev < device_num; dev++) { + device_mutex[dev]->lock(); + + int len = task_keys[dev].size(); + int cur = device_keys[dev].size(); + device_keys[dev].resize(device_keys[dev].size() + len); + device_vals[dev].resize(device_vals[dev].size() + len); + + for (int j = 0; j < len; ++j) { + device_keys[dev][cur + j] = task_keys[dev][j]; + float* ptr_val = task_ptrs[dev][j]->data(); + FeatureValue& val = device_vals[dev][cur + j]; + size_t dim = task_ptrs[dev][j]->size(); + + val.delta_score = ptr_val[1]; + val.show = ptr_val[2]; + val.clk = ptr_val[3]; + val.slot = ptr_val[6]; + val.lr = ptr_val[4]; + val.lr_g2sum = ptr_val[5]; + + if (dim > 7) { + val.mf_size = MF_DIM + 1; + for (int x = 0; x < val.mf_size; x++) { + val.mf[x] = ptr_val[x + 7]; + } + } else { + val.mf_size = 0; + for (int x = 0; x < MF_DIM + 1; x++) { + val.mf[x] = 0; + } } } + + device_mutex[dev]->unlock(); } }; + for (size_t i = 0; i < threads.size(); i++) { - threads[i] = std::thread(ptl_func, i); + threads[i] = std::thread(build_func, i); } for (std::thread& t : threads) { t.join(); } timeline.Pause(); - VLOG(0) << "GpuPs pull sparse cost " << timeline.ElapsedSec() << " seconds."; + VLOG(1) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() + << " seconds."; } void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { - BuildTask(table_id, feature_dim); + int device_num = heter_devices_.size(); + std::shared_ptr gpu_task = gpu_task_pool_.Get(); + BuildTask(gpu_task, table_id, feature_dim); platform::Timer timeline; timeline.Start(); - std::shared_ptr gpu_task = gpu_task_pool_.Get(); - int shard_num = gpu_task->feature_keys_.size(); - if (shard_num == 0) { - return; - } - std::vector feature_keys_count(shard_num); + std::vector feature_keys_count(device_num); size_t size_max = 0; - for (int i = 0; i < shard_num; i++) { - feature_keys_count[i] = gpu_task->feature_keys_[i].size(); + for (int i = 0; i < device_num; i++) { + feature_keys_count[i] = gpu_task->device_keys_[i].size(); size_max = std::max(size_max, feature_keys_count[i]); } if (HeterPs_) { HeterPs_->show_one_table(0); return; } - std::vector threads(shard_num); + std::vector threads(device_num); HeterPs_ = HeterPsBase::get_instance(size_max, resource_); auto build_func = [this, &gpu_task, &feature_keys_count](int i) { std::cout << "building table: " << i << std::endl; - this->HeterPs_->build_ps(i, gpu_task->feature_keys_[i].data(), - gpu_task->feature_values_[i].data(), - feature_keys_count[i], 10000, 2); + this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(), + gpu_task->device_values_[i].data(), + feature_keys_count[i], 500000, 2); HeterPs_->show_one_table(i); }; for (size_t i = 0; i < threads.size(); i++) { @@ -207,7 +245,7 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { t.join(); } timeline.Pause(); - VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec() + VLOG(1) << "GpuPs build table total costs: " << timeline.ElapsedSec() << " s."; } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index ed06000c3076916108fccdd32edf2e172d33e01e..631c8456c562976cdb8e6b8af86ed94d8855d829 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -76,7 +76,8 @@ class PSGPUWrapper { const int batch_size); void BuildGPUPS(const uint64_t table_id, int feature_dim); - void BuildTask(uint64_t table_id, int feature_dim); + void BuildTask(std::shared_ptr gpu_task, uint64_t table_id, + int feature_dim); void InitializeGPU(const std::vector& dev_ids) { if (s_instance_ != NULL) { VLOG(3) << "PSGPUWrapper Begin InitializeGPU"; diff --git a/paddle/fluid/framework/ps_gpu_trainer.cc b/paddle/fluid/framework/ps_gpu_trainer.cc index 4ac98e977d38036338f291aebad33d49e1b61f17..bca1843dd8f2366d8b6115c711bbbbc2e75c6fd9 100644 --- a/paddle/fluid/framework/ps_gpu_trainer.cc +++ b/paddle/fluid/framework/ps_gpu_trainer.cc @@ -74,8 +74,6 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, workers_[i]->Initialize(trainer_desc); workers_[i]->SetWorkerNum(place_num); } - auto gpu_ps_wrapper = PSGPUWrapper::GetInstance(); - gpu_ps_wrapper->InitializeGPU(dev_ids); return; } diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc index d2327495039bcbfb4ef414e0fa63417bcc4a4195..b8ecdfe9a56a3883888cebd5373750d1f60f256c 100644 --- a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -41,6 +41,10 @@ void BindPSGPUWrapper(py::module* m) { py::call_guard()) .def("init_GPU_server", &framework::PSGPUWrapper::InitializeGPUServer, py::call_guard()) + .def("set_dataset", &framework::PSGPUWrapper::SetDataset, + py::call_guard()) + .def("init_gpu_ps", &framework::PSGPUWrapper::InitializeGPU, + py::call_guard()) .def("build_gpu_ps", &framework::PSGPUWrapper::BuildGPUPS, py::call_guard()); } // end PSGPUWrapper