From 6e0da01c61050bf9220de2feb4eb296ce1b9f3ba Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Thu, 14 Jan 2021 14:18:21 +0800 Subject: [PATCH] Heter ps new (#30198) --- paddle/fluid/framework/fleet/heter_context.h | 36 +++ .../framework/fleet/heter_ps/optimizer.cuh | 6 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 160 +++++++++++++- .../fluid/framework/fleet/ps_gpu_wrapper.cu | 37 ++++ paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 75 ++++++- paddle/fluid/framework/ps_gpu_trainer.cc | 208 ------------------ paddle/fluid/framework/trainer.h | 8 +- paddle/fluid/pybind/ps_gpu_wrapper_py.cc | 5 + 8 files changed, 305 insertions(+), 230 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 3fad689c17d..78aced804c3 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -16,6 +16,7 @@ limitations under the License. */ #if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +#include #include #include #include @@ -33,6 +34,8 @@ class HeterContext { std::vector> feature_keys_; std::vector> value_ptr_; std::vector> feature_values_; + std::vector mutex_lock_; + uint32_t shard_num_ = 37; uint64_t size() { uint64_t total_size = 0; for (auto& keys : feature_keys_) { @@ -40,6 +43,39 @@ class HeterContext { } return total_size; } + void SetShardNum(uint32_t shard_num) { shard_num_ = shard_num; } + uint32_t ShardNum() { return shard_num_; } + void init() { feature_keys_.resize(shard_num_); } + void batch_add_keys(const std::vector>& thread_keys) { + assert(thread_keys.size() == feature_keys_.size()); + + for (uint32_t i = 0; i < shard_num_; i++) { + int idx = 0; + // mutex_lock_[i]->lock(); + idx = feature_keys_[i].size(); + feature_keys_[i].resize(feature_keys_[i].size() + thread_keys[i].size()); + for (uint64_t j = 0; j < thread_keys[i].size(); j++) { + feature_keys_[i][idx + j] = thread_keys[i][j]; + } + // mutex_lock_[i]->unlock(); + } + } + void UniqueKeys() { + std::vector threads; + auto unique_func = [this](int i) { + auto& cur_keys = feature_keys_[i]; + std::sort(cur_keys.begin(), cur_keys.end()); + std::vector::iterator it; + it = std::unique(cur_keys.begin(), cur_keys.end()); + cur_keys.resize(std::distance(cur_keys.begin(), it)); + }; + for (uint32_t i = 0; i < shard_num_; i++) { + threads.push_back(std::thread(unique_func, i)); + } + for (std::thread& t : threads) { + t.join(); + } + } }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh index 7263f610fcb..e8e027f383f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "optimizer_conf.h" #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" @@ -106,8 +107,11 @@ class Optimizer { optimizer_config::clk_coeff * val.clk) { val.mf_size = MF_DIM + 1; val.mf[0] = 0; + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + curandState state; + curand_init(clock64(), tid_x, 0, &state); for (int i = 0; i < MF_DIM; ++i) { - val.mf[i + 1] = (cuda_normal_random((int)grad.show) * 2 - 1) * + val.mf[i + 1] = (curand_uniform(&state)) * optimizer_config::mf_initial_range; } } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index e70b1ca84f9..a3c90fa944f 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -27,13 +27,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) -/* + #include -#include -#include "paddle/fluid/framework/io/fs.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/scope.h" -*/ +#include + #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/platform/timer.h" @@ -43,10 +40,142 @@ namespace framework { std::shared_ptr PSGPUWrapper::s_instance_ = NULL; bool PSGPUWrapper::is_initialized_ = false; -void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim, - std::shared_ptr gpu_task) { +void PSGPUWrapper::BuildTask(uint64_t table_id, int feature_dim) { + VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin"; + platform::Timer timeline; + timeline.Start(); + MultiSlotDataset* dataset = dynamic_cast(dataset_); + std::shared_ptr gpu_task = gpu_task_pool_.Get(); + auto input_channel = dataset->GetInputChannel(); + auto& local_keys = gpu_task->feature_keys_; + auto& local_values = gpu_task->feature_values_; + auto& local_ptr = gpu_task->value_ptr_; + std::vector threads; + auto fleet_ptr = FleetWrapper::GetInstance(); + + // data should be in input channel + thread_keys_.resize(thread_keys_thread_num_); + for (int i = 0; i < thread_keys_thread_num_; i++) { + thread_keys_[i].resize(thread_keys_shard_num_); + for (int j = 0; j < thread_keys_shard_num_; j++) { + thread_keys_[i][j].reserve(2 * max_fea_num_per_pass_ / + thread_keys_shard_num_ / + thread_keys_thread_num_); + } + } + const std::deque& vec_data = input_channel->GetData(); + size_t total_len = vec_data.size(); + size_t len_per_thread = total_len / thread_keys_thread_num_; + int remain = total_len % thread_keys_thread_num_; + size_t begin = 0; + auto gen_func = [this](const std::deque& total_data, int begin_index, + int end_index, int i) { + for (auto iter = total_data.begin() + begin_index; + iter != total_data.begin() + end_index; iter++) { + const auto& ins = *iter; + const auto& feasign_v = ins.uint64_feasigns_; + for (const auto feasign : feasign_v) { + uint64_t cur_key = feasign.sign().uint64_feasign_; + int shard_id = cur_key % thread_keys_shard_num_; + this->thread_keys_[i][shard_id].push_back(cur_key); + } + } + }; + for (int i = 0; i < thread_keys_thread_num_; i++) { + threads.push_back(std::thread(gen_func, std::ref(vec_data), begin, + begin + len_per_thread + (i < remain ? 1 : 0), + i)); + begin += len_per_thread + (i < remain ? 1 : 0); + } + for (std::thread& t : threads) { + t.join(); + } + timeline.Pause(); + VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; + + timeline.Start(); + + // merge thread_keys to shard_keys + gpu_task->init(); + for (size_t i = 0; i < thread_keys_.size(); i++) { + gpu_task->batch_add_keys(thread_keys_[i]); + for (int j = 0; j < thread_keys_thread_num_; j++) { + thread_keys_[i][j].clear(); + } + } + timeline.Pause(); + + VLOG(0) << "GpuPs task unique11111 cost " << timeline.ElapsedSec() + << " seconds."; + VLOG(0) << "FK1"; + timeline.Start(); + gpu_task->UniqueKeys(); + timeline.Pause(); + + VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; + + for (int i = 0; i < thread_keys_shard_num_; i++) { + local_values[i].resize(local_keys[i].size()); + local_ptr[i].resize(local_keys[i].size()); + } + + auto ptl_func = [this, &local_keys, &local_values, &local_ptr, &table_id, + &fleet_ptr](int i) { + size_t key_size = local_keys[i].size(); + auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( + reinterpret_cast(local_ptr[i].data()), table_id, + local_keys[i].data(), key_size); + tt.wait(); + auto status = tt.get(); + // auto status = 0; + if (status != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; + sleep(300); + exit(-1); + } else { + VLOG(3) << "FleetWrapper Pull sparse to local done with table size: " + << local_keys[i].size(); + } + for (size_t num = 0; num < local_ptr[i].size(); ++num) { + float* ptr_val = local_ptr[i][num]->data(); + FeatureValue& val = local_values[i][num]; + size_t dim = local_ptr[i][num]->size(); + + val.delta_score = ptr_val[1]; + val.show = ptr_val[2]; + val.clk = ptr_val[3]; + val.slot = ptr_val[6]; + val.lr = ptr_val[4]; + val.lr_g2sum = ptr_val[5]; + + if (dim > 7) { + val.mf_size = MF_DIM + 1; + for (int x = 0; x < val.mf_size; x++) { + val.mf[x] = ptr_val[x + 7]; + } + } else { + val.mf_size = 0; + for (int x = 0; x < MF_DIM + 1; x++) { + val.mf[x] = 0; + } + } + } + }; + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(ptl_func, i); + } + for (std::thread& t : threads) { + t.join(); + } + timeline.Pause(); + VLOG(0) << "GpuPs pull sparse cost " << timeline.ElapsedSec() << " seconds."; +} + +void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { + BuildTask(table_id, feature_dim); platform::Timer timeline; timeline.Start(); + std::shared_ptr gpu_task = gpu_task_pool_.Get(); int shard_num = gpu_task->feature_keys_.size(); if (shard_num == 0) { return; @@ -62,13 +191,20 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim, HeterPs_->show_one_table(0); return; } + std::vector threads(shard_num); HeterPs_ = HeterPsBase::get_instance(size_max, resource_); - for (int i = 0; i < shard_num; ++i) { + auto build_func = [this, &gpu_task, &feature_keys_count](int i) { std::cout << "building table: " << i << std::endl; - HeterPs_->build_ps(i, gpu_task->feature_keys_[i].data(), - gpu_task->feature_values_[i].data(), - feature_keys_count[i], 10000, 2); + this->HeterPs_->build_ps(i, gpu_task->feature_keys_[i].data(), + gpu_task->feature_values_[i].data(), + feature_keys_count[i], 10000, 2); HeterPs_->show_one_table(i); + }; + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(build_func, i); + } + for (std::thread& t : threads) { + t.join(); } timeline.Pause(); VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec() diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 9b7920acef3..2eedcd5f1c7 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/gpu_info.h" @@ -177,6 +178,42 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, slot_lengths.size(), total_length, batch_size, d_slot_vector); cudaStreamSynchronize(stream); } + +void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, + float min_bound, float max_bound, + float learning_rate, float initial_g2sum, + float initial_range) { + cudaMemcpyToSymbol(optimizer_config::nonclk_coeff, &nonclk_coeff, + sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::clk_coeff, &clk_coeff, sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::min_bound, &min_bound, sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::max_bound, &max_bound, sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::learning_rate, &learning_rate, + sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::initial_g2sum, &initial_g2sum, + sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::initial_range, &initial_range, + sizeof(float)); +} + +void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, + float mf_learning_rate, float mf_initial_g2sum, + float mf_initial_range, float mf_min_bound, + float mf_max_bound) { + cudaMemcpyToSymbol(optimizer_config::mf_create_thresholds, + &mf_create_thresholds, sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::mf_learning_rate, &mf_learning_rate, + sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::mf_initial_g2sum, &mf_initial_g2sum, + sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::mf_initial_range, &mf_initial_range, + sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::mf_min_bound, &mf_min_bound, + sizeof(float)); + cudaMemcpyToSymbol(optimizer_config::mf_max_bound, &mf_max_bound, + sizeof(float)); +} + } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index df6af23d701..ed06000c307 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -23,8 +23,10 @@ limitations under the License. */ #include #include #include +#include #include +#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/fleet/heter_context.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" @@ -73,16 +75,77 @@ class PSGPUWrapper { const int hidden_size, const int64_t total_length, const int batch_size); - void BuildGPUPS(const uint64_t table_id, int feature_dim, - std::shared_ptr context); + void BuildGPUPS(const uint64_t table_id, int feature_dim); + void BuildTask(uint64_t table_id, int feature_dim); void InitializeGPU(const std::vector& dev_ids) { if (s_instance_ != NULL) { VLOG(3) << "PSGPUWrapper Begin InitializeGPU"; resource_ = std::make_shared(dev_ids); resource_->enable_p2p(); keys_tensor.resize(resource_->total_gpu()); + heter_devices_ = dev_ids; } } + + void SetSparseSGD(float nonclk_coeff, float clk_coeff, float min_bound, + float max_bound, float learning_rate, float initial_g2sum, + float initial_range); + void SetEmbedxSGD(float mf_create_thresholds, float mf_learning_rate, + float mf_initial_g2sum, float mf_initial_range, + float mf_min_bound, float mf_max_bound); + void InitializeGPUServer(std::unordered_map config) { + float nonclk_coeff = (config.find("nonclk_coeff") == config.end()) + ? 1.0 + : config["nonclk_coeff"]; + float clk_coeff = + (config.find("clk_coeff") == config.end()) ? 1.0 : config["clk_coeff"]; + float min_bound = (config.find("min_bound") == config.end()) + ? -10000.0 + : config["min_bound"]; + float max_bound = (config.find("max_bound") == config.end()) + ? 10000.0 + : config["max_bound"]; + float learning_rate = (config.find("learning_rate") == config.end()) + ? 1.0 + : config["learning_rate"]; + float initial_g2sum = (config.find("initial_g2sum") == config.end()) + ? 1.0 + : config["initial_g2sum"]; + float initial_range = (config.find("initial_range") == config.end()) + ? 1.0 + : config["initial_range"]; + + // mf config settings + float mf_create_thresholds = + (config.find("mf_create_thresholds") == config.end()) + ? static_cast(1.0) + : config["mf_create_thresholds"]; + float mf_learning_rate = (config.find("mf_learning_rate") == config.end()) + ? 1.0 + : config["mf_learning_rate"]; + float mf_initial_g2sum = (config.find("mf_initial_g2sum") == config.end()) + ? 1.0 + : config["mf_initial_g2sum"]; + float mf_initial_range = (config.find("mf_initial_range") == config.end()) + ? 1.0 + : config["mf_initial_range"]; + float mf_min_bound = (config.find("mf_min_bound") == config.end()) + ? 1.0 + : config["mf_min_bound"]; + float mf_max_bound = (config.find("mf_max_bound") == config.end()) + ? 1.0 + : config["mf_max_bound"]; + for (size_t i = 0; i < heter_devices_.size(); i++) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(heter_devices_[i])); + this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, max_bound, + learning_rate, initial_g2sum, initial_range); + this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate, + mf_initial_g2sum, mf_initial_range, mf_min_bound, + mf_max_bound); + } + } + void SetDataset(Dataset* dataset) { dataset_ = dataset; } + // PSGPUWrapper singleton static std::shared_ptr GetInstance() { if (NULL == s_instance_) { @@ -100,6 +163,7 @@ class PSGPUWrapper { private: static std::shared_ptr s_instance_; + Dataset* dataset_; std::unordered_map< uint64_t, std::vector>>> local_tables_; @@ -108,6 +172,13 @@ class PSGPUWrapper { std::shared_ptr resource_; int32_t sleep_seconds_before_fail_exit_; std::vector slot_vector_; + std::vector heter_devices_; + std::unordered_set gpu_ps_config_keys_; + HeterObjectPool gpu_task_pool_; + std::vector>> thread_keys_; + int thread_keys_thread_num_ = 37; + int thread_keys_shard_num_ = 37; + uint64_t max_fea_num_per_pass_ = 5000000000; protected: static bool is_initialized_; diff --git a/paddle/fluid/framework/ps_gpu_trainer.cc b/paddle/fluid/framework/ps_gpu_trainer.cc index 530750d98ac..4ac98e977d3 100644 --- a/paddle/fluid/framework/ps_gpu_trainer.cc +++ b/paddle/fluid/framework/ps_gpu_trainer.cc @@ -131,219 +131,11 @@ void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) { } void PSGPUTrainer::Run() { - BuildGPUPSTask(0, 8); for (size_t thidx = 0; thidx < places_.size(); ++thidx) { threads_.push_back( std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get())); } } -void PSGPUTrainer::BuildGPUPSTask(int table_id, int feadim) { - VLOG(3) << "PSGPUTrainer::BuildGPUPSTask begin"; - platform::Timer timeline; - timeline.Start(); - MultiSlotDataset* dataset = dynamic_cast(dataset_); - auto fleet_ptr = FleetWrapper::GetInstance(); - std::shared_ptr heter_context = - std::make_shared(); - auto& multi_output_channel = dataset->GetCurOutputChannel(); - auto& input_channel = dataset->GetInputChannelRef(); - int gen_shard_num = multi_output_channel.size(); - int device_num = places_.size(); - auto gpu_ps_wrapper = PSGPUWrapper::GetInstance(); - auto& local_keys = heter_context->feature_keys_; - local_keys.resize(device_num); - auto& local_values = heter_context->feature_values_; - local_values.resize(device_num); - auto& local_ptr = heter_context->value_ptr_; - local_ptr.resize(device_num); - for (auto& ks : local_keys) { - ks.reserve(100000); - } - // read thread - std::vector threads(gen_shard_num); - std::vector> consume_task_pool(device_num); - for (size_t i = 0; i < consume_task_pool.size(); i++) { - consume_task_pool[i].reset(new ::ThreadPool(1)); - } - auto consume_func = [&local_keys](int shard_id, int feadim, - std::vector& keys) { - local_keys[shard_id].insert(local_keys[shard_id].end(), keys.begin(), - keys.end()); - }; - - if (input_channel->Size() == 0) { - // output_channel_ should hold one pass instances now - uint64_t output_channels_data_size = 0; - for (size_t i = 0; i < multi_output_channel.size(); i++) { - int cur_channel_size = multi_output_channel[i]->Size(); - output_channels_data_size += cur_channel_size; - } - CHECK(output_channels_data_size > 0); - for (auto& ks : local_keys) { - ks.reserve(output_channels_data_size * 10); // magic number - } - auto gen_func = [&dataset, &device_num, &feadim, &consume_task_pool, - &multi_output_channel, &consume_func](int i) { - const std::deque& vec_data = multi_output_channel[i]->GetData(); - std::vector> task_keys(device_num); - std::vector> task_futures; - for (size_t j = 0; j < vec_data.size(); j++) { - for (auto& feature : vec_data[j].uint64_feasigns_) { - int shard = feature.sign().uint64_feasign_ % device_num; - task_keys[shard].push_back(feature.sign().uint64_feasign_); - } - } - - for (int shard_id = 0; shard_id < device_num; shard_id++) { - task_futures.emplace_back(consume_task_pool[shard_id]->enqueue( - consume_func, shard_id, feadim, task_keys[shard_id])); - } - - for (auto& tf : task_futures) { - tf.wait(); - } - for (auto& tk : task_keys) { - tk.clear(); - std::vector().swap(tk); - } - task_keys.clear(); - std::vector>().swap(task_keys); - }; - for (size_t i = 0; i < threads.size(); i++) { - threads[i] = std::thread(gen_func, i); - } - for (std::thread& t : threads) { - t.join(); - } - } else { - int input_channel_size = input_channel->Size(); - CHECK(input_channel_size > 0); - CHECK(gen_shard_num > 0); - for (auto& ks : local_keys) { - ks.reserve(input_channel_size * 10); // magic number - } - const std::deque& vec_data = input_channel->GetData(); - auto gen_func = [&dataset, &vec_data, &device_num, &gen_shard_num, - &input_channel_size, &feadim, &consume_task_pool, - multi_output_channel, &consume_func](int i) { - std::vector> task_keys(device_num); - std::vector> task_futures; - size_t per_shard_num = input_channel_size / gen_shard_num + 1; - size_t total_size = vec_data.size(); - size_t start_index = i * per_shard_num; - size_t end_index = - std::min(start_index + per_shard_num - 1, total_size - 1); - for (size_t j = start_index; j <= end_index; j++) { - for (auto& feature : vec_data[j].uint64_feasigns_) { - int shard = feature.sign().uint64_feasign_ % device_num; - task_keys[shard].push_back(feature.sign().uint64_feasign_); - } - } - - for (int shard_id = 0; shard_id < device_num; shard_id++) { - task_futures.emplace_back(consume_task_pool[shard_id]->enqueue( - consume_func, shard_id, feadim, task_keys[shard_id])); - } - - for (auto& tf : task_futures) { - tf.wait(); - } - for (auto& tk : task_keys) { - tk.clear(); - std::vector().swap(tk); - } - task_keys.clear(); - std::vector>().swap(task_keys); - }; - for (size_t i = 0; i < threads.size(); i++) { - threads[i] = std::thread(gen_func, i); - } - for (std::thread& t : threads) { - t.join(); - } - } - timeline.Pause(); - VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; - timeline.Start(); - auto unique_func = [&local_keys](int i) { - auto& cur_keys = local_keys[i]; - std::sort(cur_keys.begin(), cur_keys.end()); - cur_keys.erase(std::unique(cur_keys.begin(), cur_keys.end()), - cur_keys.end()); - }; - for (size_t i = 0; i < threads.size(); i++) { - threads[i] = std::thread(unique_func, i); - } - for (std::thread& t : threads) { - t.join(); - } - timeline.Pause(); - - VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; - - timeline.Start(); - for (size_t i = 0; i < consume_task_pool.size(); i++) { - consume_task_pool[i].reset(); - } - consume_task_pool.clear(); - - for (int i = 0; i < device_num; i++) { - local_values[i].resize(local_keys[i].size()); - local_ptr[i].resize(local_keys[i].size()); - } - - auto ptl_func = [this, &local_keys, &local_values, &local_ptr, &table_id, - &fleet_ptr](int i) { - size_t key_size = local_keys[i].size(); - auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( - (char**)(local_ptr[i].data()), table_id, local_keys[i].data(), - key_size); - tt.wait(); - auto status = tt.get(); - // auto status = 0; - if (status != 0) { - LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; - sleep(300); - exit(-1); - } else { - VLOG(3) << "FleetWrapper Pull sparse to local done with table size: " - << local_keys[i].size(); - } - for (size_t num = 0; num < local_ptr[i].size(); ++num) { - float* ptr_val = local_ptr[i][num]->data(); - FeatureValue& val = local_values[i][num]; - size_t dim = local_ptr[i][num]->size(); - - val.delta_score = ptr_val[1]; - val.show = ptr_val[2]; - val.clk = ptr_val[3]; - val.slot = ptr_val[6]; - val.lr = ptr_val[4]; - val.lr_g2sum = ptr_val[5]; - - if (dim > 7) { - val.mf_size = MF_DIM + 1; - for (int x = 0; x < val.mf_size; x++) { - val.mf[x] = ptr_val[x + 7]; - } - } else { - val.mf_size = 0; - for (int x = 0; x < MF_DIM + 1; x++) { - val.mf[x] = 0; - } - } - } - }; - for (size_t i = 0; i < threads.size(); i++) { - threads[i] = std::thread(ptl_func, i); - } - for (std::thread& t : threads) { - t.join(); - } - timeline.Pause(); - VLOG(0) << "GpuPs pull sparse cost " << timeline.ElapsedSec() << " seconds."; - gpu_ps_wrapper->BuildGPUPS(table_id, feadim, heter_context); -} Scope* PSGPUTrainer::GetWorkerScope(int thread_id) { return nullptr; } diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 25b215df3e4..ca57a89ca98 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/fleet/heter_context.h" #include "paddle/fluid/framework/fleet/heter_wrapper.h" #include "paddle/fluid/framework/heter_service.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -296,13 +297,6 @@ class PSGPUTrainer : public TrainerBase { } virtual std::string GetDumpPath(int tid) { return ""; } virtual void InitDumpEnv() {} - void BuildGPUPSTask(int table_id, int feadim); - /* - template - void HeterMemCpy(LoDTensor* tensor, LoDTensor* root_tensor, - const paddle::platform::Place& thread_place, - cudaStream_t stream); - */ template void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc index 0bbe8091975..d2327495039 100644 --- a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -21,6 +21,7 @@ limitations under the License. */ #undef _XOPEN_SOURCE #endif +#include #include #include @@ -37,6 +38,10 @@ void BindPSGPUWrapper(py::module* m) { *m, "PSGPU") .def(py::init([]() { return framework::PSGPUWrapper::GetInstance(); })) .def("set_slot_vector", &framework::PSGPUWrapper::SetSlotVector, + py::call_guard()) + .def("init_GPU_server", &framework::PSGPUWrapper::InitializeGPUServer, + py::call_guard()) + .def("build_gpu_ps", &framework::PSGPUWrapper::BuildGPUPS, py::call_guard()); } // end PSGPUWrapper #endif -- GitLab