diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 3fdcf2379cb54af955f0a45f60b18bbca33820a1..11217b6c485fc28ac6eb9f0e771c3bd4fc89585f 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -129,11 +129,6 @@ class HeterContext { for (size_t i = 0; i < feature_dim_keys_.size(); i++) { feature_dim_keys_[i].resize(dim_num); value_dim_ptr_[i].resize(dim_num); - if (i == 0) { - for (int j = 0; j < dim_num; j++) { - feature_dim_keys_[i][j].push_back(0); - } - } } device_values_.resize(device_num); device_dim_values_.resize(device_num); diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index b633394e7a81179ba8edf74950014951ffda2ee3..682c4568cb7e119d7a704de7d3fb5279a2206b3b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -32,17 +32,33 @@ struct FeatureValue { float lr; float lr_g2sum; int mf_size; - float mf[MF_DIM + 1]; + int mf_dim; uint64_t cpu_ptr; + float mf[0]; friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) { out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot - << " lr: " << val.lr << " mf_size: " << val.mf_size << " mf:"; - for (int i = 0; i < val.mf_size; ++i) { + << " lr: " << val.lr << " mf_dim: " << val.mf_dim + << "cpuptr: " << val.cpu_ptr << " mf_size: " << val.mf_size << " mf:"; + for (int i = 0; i < val.mf_dim + 1; ++i) { out << " " << val.mf[i]; } return out; } + __device__ __forceinline__ void operator=(const FeatureValue& in) { + delta_score = in.delta_score; + show = in.show; + clk = in.clk; + slot = in.slot; + lr = in.lr; + lr_g2sum = in.lr_g2sum; + mf_size = in.mf_size; + mf_dim = in.mf_dim; + cpu_ptr = in.cpu_ptr; + for (int i = 0; i < mf_dim + 1; i++) { + mf[i] = in.mf[i]; + } + } }; struct FeaturePushValue { @@ -50,20 +66,19 @@ struct FeaturePushValue { float clk; int slot; float lr_g; - float mf_g[MF_DIM]; + int mf_dim; + float mf_g[0]; - // __device__ __forceinline__ FeaturePushValue - // operator+(const FeaturePushValue& a) const { - // FeaturePushValue out; - // out.slot = a.slot; - // out.show = a.show + show; - // out.clk = a.clk + clk; - // out.lr_g = a.lr_g + lr_g; - // for (int i = 0; i < MF_DIM; ++i) { - // out.mf_g[i] = a.mf_g[i] + mf_g[i]; - // } - // return out; - // } + __device__ __forceinline__ void operator=(const FeaturePushValue& in) { + show = in.show; + clk = in.clk; + slot = in.slot; + lr_g = in.lr_g; + mf_dim = in.mf_dim; + for (int i = 0; i < mf_dim; i++) { + mf_g[i] = in.mf_g[i]; + } + } }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index e2f362d40745897b3b0aa44477bdf1559966bc1b..234aa15ebf74d1da276cf1e2664017ca7893f66f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -118,8 +118,8 @@ class HashTable { StreamType stream); template - void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index, - StreamType stream); + void insert(const KeyType* d_keys, size_t len, char* pool, + size_t feature_value_size, size_t start_index, StreamType stream); template void get(const KeyType* d_keys, ValType* d_vals, size_t len, diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 5edc218796ef8a3c3052d3aec9cad1c101f67191..32dbd98992b5d8a63f035a30b034d314725260b2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -50,7 +50,8 @@ __global__ void insert_kernel(Table* table, template __global__ void insert_kernel(Table* table, const typename Table::key_type* const keys, - size_t len, char* pool, int start_index) { + size_t len, char* pool, size_t feature_value_size, + int start_index) { ReplaceOp op; thrust::pair kv; @@ -58,7 +59,8 @@ __global__ void insert_kernel(Table* table, if (i < len) { kv.first = keys[i]; - kv.second = (Table::mapped_type)(pool + (start_index + i) * 80); + uint64_t offset = uint64_t(start_index + i) * feature_value_size; + kv.second = (Table::mapped_type)(pool + offset); auto it = table->insert(kv, op); assert(it != table->end() && "error: insert fails: table is full"); } @@ -81,14 +83,16 @@ __global__ void search_kernel(Table* table, template __global__ void dy_mf_search_kernel(Table* table, const typename Table::key_type* const keys, - char* const vals, size_t len, + char* vals, size_t len, size_t pull_feature_value_size) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { - *(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second); + uint64_t offset = i * pull_feature_value_size; + FeatureValue& cur = *(FeatureValue*)(vals + offset); + FeatureValue& input = *(FeatureValue*)(it->second); } } } @@ -121,7 +125,7 @@ __global__ void dy_mf_update_kernel(Table* table, FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size); sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur); } else { - printf("yxf::push miss key: %d", keys[i]); + printf("warning: push miss key: %d", keys[i]); } } } @@ -201,7 +205,8 @@ void HashTable::insert(const KeyType* d_keys, template template void HashTable::insert(const KeyType* d_keys, size_t len, - char* pool, size_t start_index, + char* pool, size_t feature_value_size, + size_t start_index, StreamType stream) { if (len == 0) { return; @@ -210,8 +215,8 @@ void HashTable::insert(const KeyType* d_keys, size_t len, return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; - insert_kernel<<>>(container_, d_keys, len, - pool, start_index); + insert_kernel<<>>( + container_, d_keys, len, pool, feature_value_size, start_index); } template @@ -319,6 +324,7 @@ void HashTable::update(const KeyType* d_keys, } template class HashTable; +template class HashTable; template class HashTable; template class HashTable; template class HashTable; @@ -331,6 +337,10 @@ template void HashTable::get< paddle::framework::FeatureValue* d_vals, size_t len, cudaStream_t stream); +template void +HashTable::get( + const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream); + template void HashTable::get(const long* d_keys, int* d_vals, size_t len, cudaStream_t stream); @@ -354,6 +364,11 @@ template void HashTable::insert< const paddle::framework::FeatureValue* d_vals, size_t len, cudaStream_t stream); +template void HashTable:: + insert(const unsigned long* d_keys, size_t len, char* pool, + size_t feature_value_size, size_t start_index, + cudaStream_t stream); + template void HashTable::insert(const long* d_keys, const int* d_vals, size_t len, @@ -393,6 +408,16 @@ template void HashTable::update< sgd, cudaStream_t stream); +template void +HashTable::update< + Optimizer, + cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t len, + Optimizer + sgd, + cudaStream_t stream); + // template void HashTable::update< // Optimizer #include +#include "cub/cub.cuh" +#include "cub/util_allocator.cuh" #if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/dynload/nccl.h" +#include "paddle/fluid/platform/timer.h" #include "thrust/pair.h" #elif defined(PADDLE_WITH_XPU_KP) // #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" @@ -38,6 +41,9 @@ limitations under the License. */ namespace paddle { namespace framework { +#define TYPEALIGN(ALIGNVAL, LEN) \ + (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1))) + template class HeterComm { public: @@ -50,9 +56,13 @@ class HeterComm { int* left, int* right, int gpu_num); void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, int& uniq_len); // NOLINT + void dynamic_merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, + size_t len, int& uniq_len); void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len); void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len, size_t chunk_size, int stream_num); + void build_ps(int num, KeyType* h_keys, char* pool, size_t len, + size_t feature_value_size, size_t chunk_size, int stream_num); void dump(); void show_one_table(int gpu_num); int get_index_by_devid(int devid); @@ -96,6 +106,11 @@ class HeterComm { nccl_inter_comms_ = inter_comms; node_size_ = comm_size; } + + void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) { + multi_mf_dim_ = multi_mf_dim; + max_mf_dim_ = max_mf_dim; + } #endif bool need_transfer(int send_id, int receive_id) { @@ -114,8 +129,8 @@ class HeterComm { char* key_storage; char* val_storage; int sync; - int key_bytes_len; - int val_bytes_len; + size_t key_bytes_len; + size_t val_bytes_len; int dev_num; }; @@ -206,12 +221,18 @@ class HeterComm { void destroy_storage(int start_index, int end_index); void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, GradType* src_val); + void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, + KeyType* src_key, char* src_val, size_t val_size); void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right, ValType* src_val); + void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right, + char* src_val, size_t val_size); protected: using Table = HashTable; + using PtrTable = HashTable; std::vector tables_; + std::vector ptr_tables_; std::shared_ptr resource_; std::vector> path_; float load_factor_{0.75}; @@ -221,6 +242,7 @@ class HeterComm { private: int topo_aware_{0}; std::vector storage_; + DynamicGradMerger merger_; int feanum_{1800 * 2048}; int multi_node_{0}; int node_size_; @@ -228,6 +250,8 @@ class HeterComm { #if defined(PADDLE_WITH_CUDA) std::vector nccl_inner_comms_; std::vector nccl_inter_comms_; + int multi_mf_dim_{8}; + int max_mf_dim_ = 8; std::vector> allocators_; #endif }; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index d23719ea9eb774843f894a6bed5db9b1206b4ee0..506a0c0b1863f7f103e923e0fb931958975a42f5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #ifdef PADDLE_WITH_HETERPS #include +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" #include "paddle/fluid/platform/device_context.h" #ifdef PADDLE_WITH_XPU_KP @@ -22,20 +23,31 @@ limitations under the License. */ namespace paddle { namespace framework { - template HeterComm::HeterComm( size_t capacity, std::shared_ptr resource) { resource_ = resource; storage_.resize(resource_->total_device()); + multi_mf_dim_ = resource->multi_mf(); for (int i = 0; i < resource_->total_device(); ++i) { #if defined(PADDLE_WITH_CUDA) platform::CUDADeviceGuard guard(resource_->dev_id(i)); allocators_.push_back(std::make_shared( 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT #endif - auto table = new Table(capacity / load_factor_); - tables_.push_back(table); + if (!multi_mf_dim_) { + auto table = new Table(capacity / load_factor_); + tables_.push_back(table); + } else { + max_mf_dim_ = resource_->max_mf_dim(); + size_t val_type_size = TYPEALIGN( + 8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1)); + size_t grad_type_size = TYPEALIGN( + 8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + auto ptr_table = new PtrTable(capacity / load_factor_); + ptr_table->set_feature_value_size(val_type_size, grad_type_size); + ptr_tables_.push_back(ptr_table); + } if (multi_node_) { storage_[i].init(feanum_, resource_->dev_id(i)); } @@ -238,95 +250,128 @@ void HeterComm::walk_to_dest(int start_index, } template -void HeterComm::walk_to_src(int start_index, - int num, int* h_left, - int* h_right, - ValType* src_val) { +void HeterComm::walk_to_dest( + int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, + char* src_val, size_t val_size) { + int need_copy_val = 0; + if (src_val) { + need_copy_val = 1; + } std::queue que; + for (int i = 0; i < gpu_num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + int size = path_[start_index][i].nodes_.size(); + auto& node = path_[start_index][i].nodes_[0]; + CopyTask t(&path_[start_index][i], 0); + que.push(t); + cudaMemcpyAsync(node.key_storage, + reinterpret_cast(src_key + h_left[i]), + node.key_bytes_len, cudaMemcpyDefault, node.in_stream); + if (need_copy_val) { + cudaMemcpyAsync(node.val_storage, + src_val + uint64_t(h_left[i]) * uint64_t(val_size), + node.val_bytes_len, cudaMemcpyDefault, node.in_stream); + } + } + while (!que.empty()) { + CopyTask& cur_task = que.front(); + que.pop(); + if (cur_task.path->nodes_[cur_task.step].sync) { + cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream); + } + if (cur_task.step != cur_task.path->nodes_.size() - 1) { + int cur_step = cur_task.step; + CopyTask c(cur_task.path, cur_step + 1); + que.push(c); + cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage, + cur_task.path->nodes_[cur_step].key_storage, + cur_task.path->nodes_[cur_step + 1].key_bytes_len, + cudaMemcpyDefault, + cur_task.path->nodes_[cur_step + 1].in_stream); + if (need_copy_val) { + cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage, + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step + 1].val_bytes_len, + cudaMemcpyDefault, + cur_task.path->nodes_[cur_step + 1].in_stream); + } + } + } +} - for (int i = 0; i < num; i++) { +template +void HeterComm::walk_to_src( + int start_index, int gpu_num, int* h_left, int* h_right, char* src_val, + size_t val_size) { + std::queue que; + for (int i = 0; i < gpu_num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { continue; } int cur_step = path_[start_index][i].nodes_.size() - 1; auto& node = path_[start_index][i].nodes_[cur_step]; - - auto src_dev_id = resource_->dev_id(i); - auto src_place = DevPlace(src_dev_id); - if (cur_step == 0) { - auto dst_dev_id = resource_->dev_id(start_index); - auto dst_place = DevPlace(dst_dev_id); - memory_copy(dst_place, reinterpret_cast(src_val + h_left[i]), - src_place, node.val_storage, node.val_bytes_len, - node.out_stream); + cudaMemcpyAsync(src_val + uint64_t(h_left[i]) * val_size, + node.val_storage, node.val_bytes_len, cudaMemcpyDefault, + node.out_stream); } else { CopyTask t(&path_[start_index][i], cur_step - 1); que.push(t); - - auto dst_dev_id = - resource_->dev_id(path_[start_index][i].nodes_[cur_step - 1].dev_num); - auto dst_place = DevPlace(dst_dev_id); - - memory_copy(dst_place, - path_[start_index][i].nodes_[cur_step - 1].val_storage, - src_place, node.val_storage, - path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, - path_[start_index][i].nodes_[cur_step - 1].out_stream); + cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage, + node.val_storage, + path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, + cudaMemcpyDefault, + path_[start_index][i].nodes_[cur_step - 1].out_stream); } } - while (!que.empty()) { CopyTask& cur_task = que.front(); que.pop(); int cur_step = cur_task.step; if (cur_task.path->nodes_[cur_step].sync) { - sync_stream(cur_task.path->nodes_[cur_step].out_stream); + cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream); } - - auto src_dev_id = - resource_->dev_id(cur_task.path->nodes_[cur_step].dev_num); - auto src_place = DevPlace(src_dev_id); - if (cur_step > 0) { CopyTask c(cur_task.path, cur_step - 1); que.push(c); - - auto dst_dev_id = - resource_->dev_id(cur_task.path->nodes_[cur_step - 1].dev_num); - auto dst_place = DevPlace(dst_dev_id); - - memory_copy(dst_place, cur_task.path->nodes_[cur_step - 1].val_storage, - src_place, cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step - 1].val_bytes_len, - cur_task.path->nodes_[cur_step - 1].out_stream); - + cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage, + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step - 1].val_bytes_len, + cudaMemcpyDefault, + cur_task.path->nodes_[cur_step - 1].out_stream); } else if (cur_step == 0) { int end_index = cur_task.path->nodes_.back().dev_num; - - auto dst_dev_id = resource_->dev_id(end_index); - auto dst_place = DevPlace(dst_dev_id); - - memory_copy(dst_place, - reinterpret_cast(src_val + h_left[end_index]), - src_place, cur_task.path->nodes_[cur_step].val_storage, - cur_task.path->nodes_[cur_step].val_bytes_len, - cur_task.path->nodes_[cur_step].out_stream); + cudaMemcpyAsync(src_val + uint64_t(h_left[end_index]) * val_size, + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step].val_bytes_len, + cudaMemcpyDefault, + cur_task.path->nodes_[cur_step].out_stream); } } } template HeterComm::~HeterComm() { - for (auto& table : tables_) { - delete table; - table = nullptr; + if (!multi_mf_dim_) { + for (auto& table : tables_) { + delete table; + table = nullptr; + } + } else { + for (auto& table : ptr_tables_) { + delete table; + table = nullptr; + } } } template -void HeterComm::show_one_table(int num) { - tables_[num]->show(); +void HeterComm::show_one_table(int gpu_num) { + if (!multi_mf_dim_) { + tables_[gpu_num]->show(); + } } template @@ -418,59 +463,165 @@ void HeterComm::build_ps( } } +template +void HeterComm::build_ps(int num, KeyType* h_keys, + char* pool, size_t len, + size_t feature_value_size, + size_t chunk_size, + int stream_num) { + if (len <= 0) { + return; + } + int dev_id = resource_->dev_id(num); + DevPlace place = DevPlace(dev_id); + AnyDeviceGuard guard(dev_id); + std::vector d_key_bufs; + ppStream streams[stream_num]; // NOLINT + for (int i = 0; i < stream_num; ++i) { + create_stream(&(streams[i])); + auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType)); + d_key_bufs.push_back(std::move(d_k_buf)); + } + int cur_len = 0; + int cur_stream = 0; + while (cur_len < len) { + cur_stream = cur_stream % stream_num; + auto cur_use_stream = streams[cur_stream]; +#if defined(PADDLE_WITH_XPU_KP) + cur_use_stream = 0; +#endif + int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size; + auto dst_place = place; + auto src_place = platform::CPUPlace(); + memory_copy( + dst_place, reinterpret_cast(d_key_bufs[cur_stream]->ptr()), + src_place, h_keys + cur_len, sizeof(KeyType) * tmp_len, cur_use_stream); + ptr_tables_[num]->insert( + reinterpret_cast(d_key_bufs[cur_stream]->ptr()), tmp_len, + pool, feature_value_size, cur_len, cur_use_stream); + cur_stream += 1; + cur_len += tmp_len; + } + for (int i = 0; i < stream_num; ++i) { + sync_stream(streams[i]); + destroy_stream(streams[i]); + } +} + template void HeterComm::merge_grad( int dev_num, KeyType* d_keys, GradType* d_grads, size_t len, int& uniq_len) { // NOLINT - int dev_id = resource_->dev_id(dev_num); DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(dev_num, 0); - size_t temp_storage_bytes; - auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); auto d_merge_grads = memory::Alloc(place, len * sizeof(GradType)); GradType* d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); - heter_comm_kernel_->sort_pairs(NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false); - auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - heter_comm_kernel_->sort_pairs( d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false); temp_storage_bytes = 0; - auto d_num_runs_out_mem = memory::Alloc(place, sizeof(int)); int* d_num_runs_out = reinterpret_cast(d_num_runs_out_mem->ptr()); - heter_comm_kernel_->reduce_by_key(NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_merge_grads_ptr, d_grads, d_num_runs_out, len, stream, false); - if (d_temp_storage->size() < temp_storage_bytes) { d_temp_storage = NULL; d_temp_storage = memory::Alloc(place, temp_storage_bytes); } - heter_comm_kernel_->reduce_by_key( d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys, d_merge_grads_ptr, d_grads, d_num_runs_out, len, stream, false); - auto dst_place = platform::CPUPlace(); auto src_place = place; memory_copy(dst_place, &uniq_len, src_place, d_num_runs_out, sizeof(int), stream); - sync_stream(stream); } +template +void HeterComm::dynamic_merge_grad( + int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, + int& uniq_len) { + int dev_id = resource_->dev_id(gpu_num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->local_stream(gpu_num, 0); + size_t temp_storage_bytes; + size_t grad_value_size = + TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); + KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); + auto d_merge_grads = memory::Alloc(place, len * grad_value_size); + GradType* d_merge_grads_ptr = + reinterpret_cast(d_merge_grads->ptr()); + auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); + uint32_t* d_fea_num_info_ptr = + reinterpret_cast(d_fea_num_info->ptr()); + uint32_t* d_index = (uint32_t*)&d_fea_num_info_ptr[len]; + uint32_t* d_idx = (uint32_t*)&d_index[len]; + int* d_merged_size = (int*)&d_idx[len]; + int grid_size = (len - 1) / block_size_ + 1; + heter_comm_kernel_->fill_idx(d_idx, len, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len, + 0, 8 * sizeof(KeyType), stream)); + void* d_buff = NULL; + auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, + d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( + NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_fea_num_info_ptr, + d_merged_size, len, stream)); + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + } + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( + d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys, + d_fea_num_info_ptr, d_merged_size, len, stream)); + + cudaMemcpyAsync((void*)&uniq_len, d_merged_size, sizeof(int), + cudaMemcpyDeviceToHost, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + assert(d_merged_size > 0); + uint32_t* d_offset = (uint32_t*)&d_index[len]; + temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + NULL, temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len, + stream)); + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + } + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + d_temp_storage->ptr(), temp_storage_bytes, d_fea_num_info_ptr, d_offset, + uniq_len, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + heter_comm_kernel_->merge_gradient( + d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, + (char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, + grad_value_size * uniq_len, + cudaMemcpyDeviceToDevice, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); +} + template void HeterComm::split_input_to_shard( KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right, @@ -529,8 +680,6 @@ void HeterComm::pull_sparse(int num, AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(num, 0); - // int grid_size = (len - 1) / block_size_ + 1; - int h_left[total_device]; // NOLINT int h_right[total_device]; // NOLINT @@ -562,10 +711,11 @@ void HeterComm::pull_sparse(int num, auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - + size_t val_type_size = + TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1)); auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); - auto d_shard_vals = memory::Alloc(place, len * sizeof(ValType)); + auto d_shard_vals = memory::Alloc(place, len * val_type_size); ValType* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); @@ -589,9 +739,8 @@ void HeterComm::pull_sparse(int num, continue; } create_storage(num, i, shard_len * sizeof(KeyType), - shard_len * sizeof(ValType)); + shard_len * val_type_size); } - walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL); for (int i = 0; i < total_device; ++i) { @@ -600,14 +749,11 @@ void HeterComm::pull_sparse(int num, } auto& node = path_[num][i].nodes_.back(); sync_stream(node.in_stream); - AnyDeviceGuard guard(resource_->dev_id(i)); - - tables_[i]->rwlock_->RDLock(); - tables_[i]->get(reinterpret_cast(node.key_storage), - reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, - resource_->remote_stream(i, num)); + ptr_tables_[i]->rwlock_->RDLock(); + ptr_tables_[i]->get(reinterpret_cast(node.key_storage), + node.val_storage, h_right[i] - h_left[i] + 1, + resource_->remote_stream(i, num)); } for (int i = 0; i < total_device; ++i) { @@ -615,21 +761,18 @@ void HeterComm::pull_sparse(int num, if (h_left[i] == -1) { continue; } - tables_[i]->rwlock_->UNLock(); + ptr_tables_[i]->rwlock_->UNLock(); } - - walk_to_src(num, total_device, h_left, h_right, d_shard_vals_ptr); - + walk_to_src(num, total_device, h_left, h_right, + reinterpret_cast(d_shard_vals_ptr), val_type_size); for (int i = 0; i < total_device; ++i) { auto& node = path_[num][i].nodes_.front(); sync_stream(node.out_stream); } - - heter_comm_kernel_->fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, - stream); + heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, + val_type_size, stream); sync_stream(stream); - for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; @@ -653,6 +796,8 @@ void HeterComm::push_sparse(int dev_num, int total_device = resource_->total_device(); int dev_id = resource_->dev_id(dev_num); + size_t grad_value_size = + TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(dev_num, 0); @@ -691,21 +836,19 @@ void HeterComm::push_sparse(int dev_num, auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); - auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType)); - GradType* d_shard_grads_ptr = - reinterpret_cast(d_shard_grads->ptr()); - + GradType* d_shard_grads_ptr; + auto d_shard_grads = memory::Alloc(place, len * grad_value_size); + d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); int uniq_len = len; - merge_grad(dev_num, d_keys, d_grads, len, uniq_len); + dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len); - // int grid_size = (uniq_len - 1) / block_size_ + 1; + int grid_size = (uniq_len - 1) / block_size_ + 1; split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, dev_num); - - heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys, - d_shard_grads_ptr, d_grads, d_idx_ptr, - uniq_len, stream); + heter_comm_kernel_->dy_mf_fill_shard_grads( + d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, uniq_len, + grad_value_size, stream); sync_stream(stream); @@ -721,12 +864,22 @@ void HeterComm::push_sparse(int dev_num, if (h_left[i] == -1 || h_right[i] == -1) { continue; } - create_storage(dev_num, i, shard_len * sizeof(KeyType), - shard_len * sizeof(GradType)); + if (!multi_mf_dim_) { + create_storage(dev_num, i, shard_len * sizeof(KeyType), + shard_len * sizeof(GradType)); + } else { + create_storage(dev_num, i, shard_len * sizeof(KeyType), + shard_len * grad_value_size); + } } - walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr, - d_shard_grads_ptr); + if (!multi_mf_dim_) { + walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr, + d_shard_grads_ptr); + } else { + walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr, + reinterpret_cast(d_shard_grads_ptr), grad_value_size); + } for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { @@ -736,17 +889,28 @@ void HeterComm::push_sparse(int dev_num, sync_stream(node.in_stream); AnyDeviceGuard guard(resource_->dev_id(i)); - tables_[i]->rwlock_->WRLock(); - tables_[i]->update(reinterpret_cast(node.key_storage), - reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, sgd, - resource_->remote_stream(i, dev_num)); + if (!multi_mf_dim_) { + tables_[i]->rwlock_->WRLock(); + tables_[i]->update(reinterpret_cast(node.key_storage), + reinterpret_cast(node.val_storage), + h_right[i] - h_left[i] + 1, sgd, + resource_->remote_stream(i, dev_num)); + } else { + ptr_tables_[i]->rwlock_->WRLock(); + ptr_tables_[i]->update(reinterpret_cast(node.key_storage), + node.val_storage, h_right[i] - h_left[i] + 1, sgd, + resource_->remote_stream(i, dev_num)); + } } for (int i = 0; i < total_device; ++i) { sync_stream(resource_->remote_stream(i, dev_num)); if (h_left[i] != -1) { - tables_[i]->rwlock_->UNLock(); + if (!multi_mf_dim_) { + tables_[i]->rwlock_->UNLock(); + } else { + ptr_tables_[i]->rwlock_->UNLock(); + } } } @@ -1078,11 +1242,13 @@ void HeterComm::end_pass() { tables_[index]->dump_to_cpu(dev_id, stream); }; - for (int i = 0; i < total_device; ++i) { - threads.push_back(std::thread(dump_to_cpu_func, i)); - } - for (auto& t : threads) { - t.join(); + if (!multi_mf_dim_) { + for (int i = 0; i < total_device; ++i) { + threads.push_back(std::thread(dump_to_cpu_func, i)); + } + for (auto& t : threads) { + t.join(); + } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index bdeb696a92bcef6592d43d4d3050f6838f6760a6..f44803982a55a137f365223e4a52fc6a3d00b380 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -117,6 +117,52 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, } } +template +__global__ void dy_mf_fill_shard_grads_kernel( + KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads, + GradType* d_grads, T* idx, size_t len, size_t grad_value_size) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + d_shard_keys[i] = d_keys[idx[i]]; + *(GradType*)((char*)d_shard_grads + i * grad_value_size) = + *(GradType*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); + } +} + +__global__ void merge_gradients_kernel(const uint32_t* offset, + const uint32_t* fea_num, + const uint32_t* index, const char* input, + char* output, int n, + size_t grad_value_size, + DynamicGradMerger& merger_) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + uint32_t start = offset[i]; + uint32_t num = fea_num[i]; + int ori_index = index[start]; + FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size); + FeaturePushValue& in = + *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); + merger_.update_one(out, in); + for (int j = 1; j < num; ++j) { + ori_index = index[start + j]; + in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); + merger_.merge_one(out, in); + } + } +} + +template +__global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, + T* idx, size_t len, size_t val_size) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + uint64_t new_offset = uint64_t(idx[i]) * val_size; + *(ValType*)((char*)d_vals + new_offset) = + *(ValType*)((char*)d_shard_vals + i * val_size); + } +} + // cuda implemention of heter_comm_kernel.h template void HeterCommKernel::fill_idx(T* idx, long long len, @@ -207,8 +253,42 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage, debug_synchronous)); } +template +void HeterCommKernel::dy_mf_fill_shard_grads( + KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads, + GradType* d_grads, T* idx, long long len, size_t grad_value_size, + const StreamType& stream) { + int grid_size = (len - 1) / block_size_ + 1; + size_t c_len = (size_t)len; + dy_mf_fill_shard_grads_kernel<<>>( + d_shard_keys, d_keys, d_shard_grads, d_grads, idx, c_len, + grad_value_size); +} + +template +void HeterCommKernel::merge_gradient( + const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, + const char* input, char* output, int n, size_t grad_value_size, + DynamicGradMerger& merger_, const StreamType& stream) { + int grid_size = (n - 1) / block_size_ + 1; + merge_gradients_kernel<<>>( + offset, fea_num, index, input, output, n, grad_value_size, merger_); +} + +template +void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals, ValType* d_vals, + T* idx, long long len, size_t val_size, + const StreamType& stream) { + int grid_size = (len - 1) / block_size_ + 1; + size_t c_len = (size_t)len; + dy_mf_fill_dvals_kernel<<>>( + d_shard_vals, d_vals, idx, c_len, val_size); +} + template void HeterCommKernel::fill_idx( int* idx, long long len, const cudaStream_t& stream); +template void HeterCommKernel::fill_idx( + uint32_t* idx, long long len, const cudaStream_t& stream); template void HeterCommKernel::calc_shard_offset( int* idx, int* left, int* right, long long len, int total_devs, @@ -270,6 +350,23 @@ template void HeterCommKernel::reduce_by_key< paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out, int num_items, cudaStream_t stream, bool debug_synchronous); +template void HeterCommKernel::dy_mf_fill_shard_grads< + unsigned long, paddle::framework::FeaturePushValue, int, cudaStream_t>( + unsigned long* d_shard_keys, unsigned long* d_keys, + paddle::framework::FeaturePushValue* d_shard_grads, + paddle::framework::FeaturePushValue* d_grads, int* idx, long long len, + size_t grad_value_size, const cudaStream_t& stream); + +template void HeterCommKernel::merge_gradient( + const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, + const char* input, char* output, int n, size_t grad_value_size, + DynamicGradMerger& merger_, const cudaStream_t& stream); + +template void HeterCommKernel::dy_mf_fill_dvals( + paddle::framework::FeatureValue* d_shard_vals, + paddle::framework::FeatureValue* d_vals, int* idx, long long len, + size_t val_size, const cudaStream_t& stream); #endif } // namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 9d2ee5d272c722c6668c09495f305263b31eb62e..4f866ccda820179769636347ff15d1b22f4d6648 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -27,6 +27,42 @@ limitations under the License. */ namespace paddle { namespace framework { +struct DynamicGradMerger { + template + CUB_RUNTIME_FUNCTION __forceinline__ __device__ T + operator()(const T& a, const T& b) const { + T out; + out.slot = a.slot; + out.mf_dim = a.mf_dim; + out.show = a.show + b.show; + out.clk = a.clk + b.clk; + out.lr_g = a.lr_g + b.lr_g; + + return out; + } + + template + __device__ __forceinline__ void update_one(T& output, const T& input) { + output.slot = input.slot; + output.show = input.show; + output.clk = input.clk; + output.mf_dim = input.mf_dim; + output.lr_g = input.lr_g; + for (int i = 0; i < output.mf_dim; ++i) { + output.mf_g[i] = input.mf_g[i]; + } + } + template + __device__ __forceinline__ void merge_one(T& output, const T& input) { + output.show += input.show; + output.clk += input.clk; + output.lr_g += input.lr_g; + for (int i = 0; i < input.mf_dim; ++i) { + output.mf_g[i] += input.mf_g[i]; + } + } +}; + class HeterCommKernel { public: HeterCommKernel() {} @@ -80,6 +116,24 @@ class HeterCommKernel { StreamType stream = NULL, bool debug_synchronous = false); + template + void dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, + GradType* d_shard_grads, GradType* d_grads, + T* idx, long long len, size_t grad_value_size, + const StreamType& stream); + + template + void merge_gradient(const uint32_t* offset, const uint32_t* fea_num, + const uint32_t* index, const char* input, char* output, + int n, size_t grad_value_size, DynamicGradMerger& merger_, + const StreamType& stream); + + template + void dy_mf_fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, + long long len, size_t val_size, + const StreamType& stream); + private: int block_size_{256}; }; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index 66e06b13b046f401ddab0de822e94465924c903c..43b84ee5d26fbe98487dca95b328f7cc395a46ff 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -44,6 +44,13 @@ void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num); } +void HeterPs::build_ps(int num, FeatureKey* h_keys, char* pool, size_t len, + size_t feature_value_size, size_t chunk_size, + int stream_num) { + comm_->build_ps(num, h_keys, pool, len, feature_value_size, chunk_size, + stream_num); +} + int HeterPs::get_index_by_devid(int devid) { return comm_->get_index_by_devid(devid); } @@ -72,6 +79,10 @@ void HeterPs::set_nccl_comm_and_size(const std::vector& inner_comms, comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size); } +void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) { + comm_->set_multi_mf_dim(multi_mf_dim, max_mf_dim); +} + } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 70b88350f2720a85851db0a92a3f99f88cf3afd4..8449a4048b72f9493feffdc29969eaf87f572938 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -37,11 +37,14 @@ class HeterPs : public HeterPsBase { size_t len) override; void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) override; - + void build_ps(int num, FeatureKey* h_keys, char* pool, size_t len, + size_t feature_value_size, size_t chunk_size, + int stream_num) override; #if defined(PADDLE_WITH_CUDA) void set_nccl_comm_and_size(const std::vector& inner_comms, const std::vector& inter_comms, int comm_size) override; + void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) override; #endif void set_sparse_sgd(const OptimizerConfig& optimizer_config) override; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index 0727e2c2dbce1cb9c359ab02fdfa732aba2ede78..2c312e9d4d60aa7494573138c89848dd0b765474 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -35,11 +35,15 @@ class HeterPsBase { size_t len) = 0; virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) = 0; + virtual void build_ps(int num, FeatureKey* h_keys, char* pool, size_t len, + size_t feature_value_size, size_t chunk_size, + int stream_num) = 0; virtual int get_index_by_devid(int devid) = 0; #if defined(PADDLE_WITH_CUDA) virtual void set_nccl_comm_and_size( const std::vector& inner_comms, const std::vector& inter_comms, int comm_size) = 0; + virtual void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) = 0; #endif virtual void end_pass() = 0; virtual void show_one_table(int gpu_num) = 0; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h index 17bc12a5af1a7305f9202e02548b10bd0a9a9860..5717f44d400a55ae21cf2ef5293c522c986b657d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h @@ -107,6 +107,8 @@ class HeterPsResource { int get_index_by_devid(int devid); int dev_id(int num); void set_multi_mf(int multi_mf_dim, int max_mf_dim); + int multi_mf() { return multi_mf_dim_; } + int max_mf_dim() { return max_mf_dim_; } ppStream local_stream(int dev_num, int stream_num); ppStream remote_stream(int dev_num, int stream_num); diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index 065d5e6d527fc0614905d9afaaa0d7f57037b713..4684b4a0bc155c76286f9731dab63cf7c6606b3d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -125,20 +125,21 @@ class Optimizer { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) + optimizer_config.clk_coeff * ptr->clk) { - // ptr->mf_size = ptr->mf_dim + 1; + ptr->mf_size = ptr->mf_dim + 1; - ptr->mf_size = MF_DIM + 1; + // ptr->mf_size = MF_DIM + 1; ptr->mf[0] = 0; int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); - for (int i = 0; i < MF_DIM; ++i) { + for (int i = 0; i < ptr->mf_dim; ++i) { ptr->mf[i + 1] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } } } else { - update_mf(optimizer_config, MF_DIM, &(ptr->mf[1]), ptr->mf[0], grad.mf_g, + update_mf(optimizer_config, ptr->mf_dim, &(ptr->mf[1]), ptr->mf[0], + grad.mf_g, grad.show); // for local test } } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index f512fcc7b9fdbee8baad3d9241cc48305580c83b..6e4ddc2f020fe7560451528f910f46415369d369 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -31,7 +31,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/platform/timer.h" @@ -112,12 +111,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { } else { gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_); } - auto& local_keys = gpu_task->feature_keys_; - auto& local_ptr = gpu_task->value_ptr_; std::vector threads; - - // data should be in input channel if (!multi_mf_dim_) { thread_keys_.resize(thread_keys_thread_num_); for (int i = 0; i < thread_keys_thread_num_; i++) { @@ -141,11 +136,9 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { std::string data_set_name = std::string(typeid(*dataset_).name()); if (data_set_name.find("SlotRecordDataset") != std::string::npos) { - VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset"; SlotRecordDataset* dataset = dynamic_cast(dataset_); auto input_channel = dataset->GetInputChannel(); - VLOG(0) << "yxf::buildtask::inputslotchannle size: " - << input_channel->Size(); + VLOG(0) << "psgpu wrapperinputslotchannle size: " << input_channel->Size(); const std::deque& vec_data = input_channel->GetData(); total_len = vec_data.size(); len_per_thread = total_len / thread_keys_thread_num_; @@ -176,21 +169,12 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { j < slot_offset[slot_offset_vector_[slot_idx] + 1]; j++) { int shard_id = feasign_v[j] % thread_keys_shard_num_; int dim_id = slot_index_vec_[slot_idx]; - this->thread_dim_keys_[i][shard_id][dim_id].insert(feasign_v[j]); + if (feasign_v[j] != 0) { + this->thread_dim_keys_[i][shard_id][dim_id].insert(feasign_v[j]); + } } } } - /* - for (auto iter = total_data.begin() + begin_index; - iter != total_data.begin() + end_index; iter++) { - const auto& ins = *iter; - const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values; - for (const auto feasign : feasign_v) { - int shard_id = feasign % thread_keys_shard_num_; - this->thread_dim_keys_[i][shard_id][0].insert(feasign); - } - } - */ }; for (int i = 0; i < thread_keys_thread_num_; i++) { if (!multi_mf_dim_) { @@ -264,12 +248,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { thread_dim_keys_[i][shard_num][dim_id].clear(); } }; - // for (size_t i = 0; i < thread_keys_.size(); i++) { - // gpu_task->batch_add_keys(thread_keys_[i]); - // for (int j = 0; j < thread_keys_thread_num_; j++) { - // thread_keys_[i][j].clear(); - // } - //} for (int i = 0; i < thread_keys_shard_num_; ++i) { if (!multi_mf_dim_) { threads.push_back(std::thread(merge_ins_func, i)); @@ -291,20 +269,15 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { timeline.Pause(); VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; - - if (!multi_mf_dim_) { - for (int i = 0; i < thread_keys_shard_num_; i++) { - VLOG(0) << "GpuPs shard: " << i << " key len: " << local_keys[i].size(); - local_ptr[i].resize(local_keys[i].size()); - } - } else { - for (int i = 0; i < thread_keys_shard_num_; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - VLOG(0) << "GpuPs shard: " << i << "mf dim: " << index_dim_vec_[j] - << " key len: " << gpu_task->feature_dim_keys_[i][j].size(); - gpu_task->value_dim_ptr_[i][j].resize( - gpu_task->feature_dim_keys_[i][j].size()); + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + if (i == 0 && j == multi_mf_dim_ - 1) { + gpu_task->feature_dim_keys_[i][j].push_back(0); } + VLOG(0) << "GpuPs shard: " << i << "mf dim: " << index_dim_vec_[j] + << " key len: " << gpu_task->feature_dim_keys_[i][j].size(); + gpu_task->value_dim_ptr_[i][j].resize( + gpu_task->feature_dim_keys_[i][j].size()); } } } @@ -353,85 +326,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { #endif timeline.Start(); - auto ptl_func = [this, &local_keys, &local_ptr, &fleet_ptr](int i) { - size_t key_size = local_keys[i].size(); - int32_t status = -1; -#ifdef PADDLE_WITH_PSLIB - // auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( - // reinterpret_cast(local_ptr[i].data()), this->table_id_, - // local_keys[i].data(), key_size); - int32_t cnt = 0; - while (true) { - auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( - i, reinterpret_cast(local_ptr[i].data()), this->table_id_, - local_keys[i].data(), key_size); - bool flag = true; - - tt.wait(); - - try { - status = tt.get(); - } catch (const std::future_error& e) { - VLOG(0) << "Caught a future_error with code" << e.code() - << ", Message:" << e.what(); - } - if (status != 0) { - VLOG(0) << "fleet pull sparse failed, status[" << status << "]"; - sleep(sleep_seconds_before_fail_exit_); - flag = false; - cnt++; - } - if (cnt > 3) { - VLOG(0) << "fleet pull sparse failed, retry 3 times"; - exit(-1); - } - - if (flag) { - break; - } - } -#endif -#ifdef PADDLE_WITH_PSCORE - int32_t cnt = 0; - while (true) { - auto tt = fleet_ptr->worker_ptr_->PullSparsePtr( - reinterpret_cast(local_ptr[i].data()), this->table_id_, - local_keys[i].data(), key_size); - bool flag = true; - - tt.wait(); - - try { - status = tt.get(); - } catch (const std::future_error& e) { - VLOG(0) << "Caught a future_error with code" << e.code() - << ", Message:" << e.what(); - } - if (status != 0) { - VLOG(0) << "fleet pull sparse failed, status[" << status << "]"; - sleep(sleep_seconds_before_fail_exit_); - flag = false; - cnt++; - } - if (cnt > 3) { - VLOG(0) << "fleet pull sparse failed, retry 3 times"; - exit(-1); - } - - if (flag) { - break; - } - } -#endif - if (status != 0) { - LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; - sleep(300); - exit(-1); - } else { - VLOG(3) << "FleetWrapper Pull sparse to local done with table size: " - << local_keys[i].size(); - } - }; auto ptl_dynamic_mf_func = [this, &local_dim_keys, &local_dim_ptr, &fleet_ptr](int i, int j) { @@ -478,21 +372,18 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { } #endif }; - if (!multi_mf_dim_) { - for (size_t i = 0; i < threads.size(); i++) { - threads[i] = std::thread(ptl_func, i); - } - } else { - threads.resize(thread_keys_shard_num_ * multi_mf_dim_); - for (int i = 0; i < thread_keys_shard_num_; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - threads[i * multi_mf_dim_ + j] = std::thread(ptl_dynamic_mf_func, i, j); - } + + threads.resize(thread_keys_shard_num_ * multi_mf_dim_); + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + task_futures.emplace_back( + pull_thread_pool_[i]->enqueue(ptl_dynamic_mf_func, i, j)); } } - for (std::thread& t : threads) { - t.join(); + for (auto& f : task_futures) { + f.wait(); } + task_futures.clear(); timeline.Pause(); VLOG(0) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec() << " seconds."; @@ -509,19 +400,12 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { std::vector>> pass_values; bool record_status = false; -#ifdef PADDLE_WITH_PSLIB - uint16_t pass_id = 0; - if (multi_node_) { - record_status = fleet_ptr->pslib_ptr_->_worker_ptr->take_sparse_record( - table_id_, pass_id, pass_values); - } -#endif auto& device_task_keys = gpu_task->device_task_keys_; auto& device_task_ptrs = gpu_task->device_task_ptr_; - auto build_dynamic_mf_func = [this, device_num, &local_dim_keys, - &local_dim_ptr, &device_dim_keys, - &device_dim_ptr, - &device_dim_mutex](int i, int j) { + auto build_pull_dynamic_mf_func = [this, device_num, &local_dim_keys, + &local_dim_ptr, &device_dim_keys, + &device_dim_ptr, + &device_dim_mutex](int i, int j) { #ifdef PADDLE_WITH_PSLIB std::vector> task_keys(device_num); std::vector> task_ptrs( @@ -532,20 +416,16 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { task_ptrs[shard].push_back(local_dim_ptr[i][j][k]); } for (int dev = 0; dev < device_num; dev++) { - for (int dim = 0; dim < multi_mf_dim_; dim++) { - device_dim_mutex[dev][dim]->lock(); - - int len = task_keys[dev].size(); - int cur = device_dim_keys[dev][dim].size(); - device_dim_keys[dev][dim].resize(device_dim_keys[dev][dim].size() + - len); - device_dim_ptr[dev][dim].resize(device_dim_ptr[dev][dim].size() + len); - for (int k = 0; k < len; ++k) { - device_dim_keys[dev][dim][cur + k] = task_keys[dev][k]; - device_dim_ptr[dev][dim][cur + k] = task_ptrs[dev][k]; - } - device_dim_mutex[dev][dim]->unlock(); + device_dim_mutex[dev][j]->lock(); + int len = task_keys[dev].size(); + int cur = device_dim_keys[dev][j].size(); + device_dim_keys[dev][j].resize(device_dim_keys[dev][j].size() + len); + device_dim_ptr[dev][j].resize(device_dim_ptr[dev][j].size() + len); + for (int k = 0; k < len; ++k) { + device_dim_keys[dev][j][cur + k] = task_keys[dev][k]; + device_dim_ptr[dev][j][cur + k] = task_ptrs[dev][k]; } + device_dim_mutex[dev][j]->unlock(); } #endif }; @@ -697,7 +577,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { for (int i = 0; i < thread_keys_shard_num_; i++) { for (int j = 0; j < multi_mf_dim_; j++) { threads[i * multi_mf_dim_ + j] = - std::thread(build_dynamic_mf_func, i, j); + std::thread(build_pull_dynamic_mf_func, i, j); } } for (std::thread& t : threads) { @@ -727,21 +607,17 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { std::vector feature_keys_count(device_num); size_t size_max = 0; - if (!multi_mf_dim_) { - for (int i = 0; i < device_num; i++) { - feature_keys_count[i] = gpu_task->device_keys_[i].size(); - VLOG(0) << i << " card contains feasign nums: " << feature_keys_count[i]; - size_max = std::max(size_max, feature_keys_count[i]); - } - } else { - for (int i = 0; i < device_num; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - feature_keys_count[i] += gpu_task->device_dim_ptr_[i][j].size(); - } - VLOG(0) << i << " card with dynamic mf contains feasign nums: " - << feature_keys_count[i]; - size_max = std::max(size_max, feature_keys_count[i]); + + for (int i = 0; i < device_num; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + feature_keys_count[i] += gpu_task->device_dim_ptr_[i][j].size(); + VLOG(1) << i << " card with dynamic mf dim: " << index_dim_vec_[j] + << " dim index: " << j << " contains feasign nums: " + << gpu_task->device_dim_ptr_[i][j].size(); } + VLOG(1) << i << " card with dynamic mf contains feasign nums total: " + << feature_keys_count[i]; + size_max = std::max(size_max, feature_keys_count[i]); } if (HeterPs_) { delete HeterPs_; @@ -756,17 +632,73 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { #ifdef PADDLE_WITH_CUDA HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_); #endif - auto build_func = [this, &gpu_task, &feature_keys_count](int i) { - VLOG(3) << "building table: " << i; - this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(), - gpu_task->device_values_[i].data(), - feature_keys_count[i], 500000, 2); - // if (feature_keys_count[i] > 0) { - // HeterPs_->show_one_table(i); - // } + auto build_dynamic_mf_func = [this, &gpu_task](int i, int j) { + this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); + int mf_dim = this->index_dim_vec_[j]; + VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim; + size_t feature_value_size = + TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float))); + auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; + auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; + size_t len = device_dim_keys.size(); + CHECK(len == device_dim_ptrs.size()); + this->mem_pools_[i * this->multi_mf_dim_ + j] = + new MemoryPool(len, feature_value_size); + auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; + for (size_t k = 0; k < len; k++) { + FeatureValue* val = (FeatureValue*)(mem_pool->mem_address(k)); + float* ptr_val = device_dim_ptrs[k]->data(); + size_t dim = device_dim_ptrs[k]->size(); +#ifdef PADDLE_WITH_PSLIB + val->delta_score = + ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::delta_score_index()]; + val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::show_index()]; + val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::click_index()]; + val->slot = int(ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::slot_index()]); + val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_w_index()]; + val->lr_g2sum = + ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_g2sum_index()]; + val->cpu_ptr = (uint64_t)(device_dim_ptrs[k]); + ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + mf_dim_index()] = float(mf_dim); + val->mf_dim = mf_dim; +#endif + if (dim > 8) { // CpuPS alreay expand as mf_dim + val->mf_size = mf_dim + 1; + for (int x = 0; x < val->mf_dim + 1; x++) { + val->mf[x] = ptr_val[x + 8]; + } + } else { + val->mf_size = 0; + for (int x = 0; x < val->mf_dim + 1; x++) { + val->mf[x] = 0; + } + } + } + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool); + auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; + this->HeterPs_->build_ps(i, device_dim_keys.data(), cur_pool->mem(), len, + feature_value_size, 500000, 2); + if (device_dim_keys.size() > 0) { + VLOG(0) << "show ptr table: " << i + << " table kv size: " << device_dim_keys.size() + << "dim: " << mf_dim << " len: " << len; + this->HeterPs_->show_one_table(i); + } + delete mem_pool; }; - for (size_t i = 0; i < threads.size(); i++) { - threads[i] = std::thread(build_func, i); + threads.resize(device_num * multi_mf_dim_); + for (int i = 0; i < device_num; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + threads[i + j * device_num] = std::thread(build_dynamic_mf_func, i, j); + } } for (std::thread& t : threads) { t.join(); @@ -788,7 +720,6 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { if (is_shuffle) { dataset_->LocalShuffle(); } - std::shared_ptr gpu_task = gpu_task_pool_.Get(); gpu_task->Reset(); data_ready_channel_->Put(gpu_task); @@ -874,17 +805,86 @@ void PSGPUWrapper::EndPass() { size_t keysize_max = 0; // in case of feasign_num = 0, skip dump_to_cpu for (size_t i = 0; i < heter_devices_.size(); i++) { - keysize_max = std::max(keysize_max, current_task_->device_keys_[i].size()); + for (int j = 0; j < multi_mf_dim_; j++) { + keysize_max = + std::max(keysize_max, current_task_->device_dim_keys_[i][j].size()); + } + } + + auto dump_pool_to_cpu_func = [this](int i, int j) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); + auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; + auto& device_keys = this->current_task_->device_dim_keys_[i][j]; + size_t len = device_keys.size(); + int mf_dim = this->index_dim_vec_[j]; + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim; + size_t feature_value_size = + TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float))); + char* test_build_values = (char*)malloc(feature_value_size * len); + cudaMemcpy(test_build_values, hbm_pool->mem(), feature_value_size * len, + cudaMemcpyDeviceToHost); + CHECK(len == hbm_pool->capacity()); +#ifdef PADDLE_WITH_PSLIB + uint64_t unuse_key = std::numeric_limits::max(); + for (size_t i = 0; i < len; ++i) { + if (device_keys[i] == unuse_key) { + continue; + } + size_t offset = i * feature_value_size; + FeatureValue* gpu_val = (FeatureValue*)(test_build_values + offset); + auto* downpour_value = + (paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr); + int downpour_value_size = downpour_value->size(); + if (gpu_val->mf_size > 0 && downpour_value_size == 8) { + downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size); + } + float* cpu_val = downpour_value->data(); + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + delta_score_index()] = gpu_val->delta_score; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + show_index()] = gpu_val->show; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + click_index()] = gpu_val->clk; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + embed_w_index()] = gpu_val->lr; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + embed_g2sum_index()] = gpu_val->lr_g2sum; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + slot_index()] = gpu_val->slot; + if (gpu_val->mf_size > 0) { + for (int x = 0; x < gpu_val->mf_dim + 1; x++) { + cpu_val[x + 8] = gpu_val->mf[x]; + } + } + } +#endif + free(test_build_values); + }; + if (multi_mf_dim_) { + VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; + size_t device_num = heter_devices_.size(); + std::vector threads(device_num * multi_mf_dim_); + for (size_t i = 0; i < device_num; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + threads[i + j * device_num] = std::thread(dump_pool_to_cpu_func, i, j); + } + } + for (std::thread& t : threads) { + t.join(); + } } if (keysize_max != 0) { HeterPs_->end_pass(); } + for (size_t i = 0; i < hbm_pools_.size(); i++) { + delete hbm_pools_[i]; + } gpu_task_pool_.Push(current_task_); current_task_ = nullptr; gpu_free_channel_->Put(current_task_); timer.Pause(); - VLOG(0) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; + VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; } void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, @@ -936,8 +936,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, pull_gpups_timer.Start(); HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu, static_cast(total_length)); - // PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( - // "PullSparseGPU failed in GPUPS.")); pull_gpups_timer.Pause(); VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length @@ -945,6 +943,98 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len, static_cast(slot_lengths.size()), hidden_size, total_length); + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GpuPs: PullSparse Only Support CUDAPlace Now.")); + } + all_timer.Pause(); + VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec() + << " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec() + << " s"; + VLOG(3) << "End PullSparse"; +} + +void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, + const int table_id, + const std::vector& keys, + const std::vector& values, + const std::vector& slot_lengths, + const std::vector& slot_dim, + const int hidden_size) { + VLOG(3) << "Begine Gpu Ps PullSparse"; + platform::Timer all_timer; + platform::Timer pull_gpups_timer; + all_timer.Start(); + size_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + size_t feature_value_size = 0; + + feature_value_size = TYPEALIGN( + 8, sizeof(FeatureValue) + sizeof(float) * (index_dim_vec_.back() + 1)); + VLOG(0) << "yxf pull sparse feature_value_size: " << feature_value_size; + +#ifdef PADDLE_WITH_CUDA + VLOG(3) << "Begine Gpu Ps PullSparse"; + auto buf = memory::Alloc(place, total_length * feature_value_size); + FeatureValue* total_values_gpu = reinterpret_cast(buf->ptr()); +#endif +#ifdef PADDLE_WITH_XPU_KP + VLOG(3) << "Begine Xpu Ps PullSparse"; + FeatureValue* total_values_gpu = nullptr; + xpu_malloc(reinterpret_cast(&total_values_gpu), + total_length * feature_value_size); +#endif + if (platform::is_cpu_place(place)) { + PADDLE_THROW(platform::errors::Unimplemented( + "Warning:: CPUPlace is not supported in GpuPs now.")); + } else if (platform::is_gpu_place(place)) { + VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; + int device_id = place.GetDeviceId(); + int devid_2_index = HeterPs_->get_index_by_devid(device_id); + LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; + uint64_t* total_keys = + reinterpret_cast(total_keys_tensor.mutable_data( + {int64_t(total_length), 1}, place)); + + // construct slot_level lod info + auto slot_lengths_lod = slot_lengths; + for (size_t i = 1; i < slot_lengths_lod.size(); i++) { + slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + } + auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*)); + auto buf_length = + memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); + uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); + int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); + cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), + cudaMemcpyHostToDevice); + cudaMemcpy(gpu_len, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); + + auto buf_dim = memory::Alloc(place, slot_dim.size() * sizeof(int)); + int* gpu_dim = reinterpret_cast(buf_dim->ptr()); + cudaMemcpy(gpu_dim, slot_dim.data(), slot_dim.size() * sizeof(int), + cudaMemcpyHostToDevice); + + this->CopyKeys(place, gpu_keys, total_keys, gpu_len, + static_cast(slot_lengths.size()), + static_cast(total_length)); + VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index + << " len: " << total_length; + + pull_gpups_timer.Start(); + HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu, + total_length); + + VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length + << "]"; + + this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len, + static_cast(slot_lengths.size()), hidden_size, + total_length, gpu_dim); + + pull_gpups_timer.Pause(); + #endif } else if (platform::is_xpu_place(place)) { #ifdef PADDLE_WITH_XPU_KP @@ -1013,7 +1103,10 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); // #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begin GPUPS PushSparseGrad"; - auto buf = memory::Alloc(place, total_length * sizeof(FeaturePushValue)); + size_t grad_value_size = + TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + auto buf = memory::Alloc(place, total_length * grad_value_size); + VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_; FeaturePushValue* total_grad_values_gpu = reinterpret_cast(buf->ptr()); if (platform::is_cpu_place(place)) { @@ -1027,8 +1120,13 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, uint64_t* total_keys = reinterpret_cast(cached_total_keys_tensor.data()); VLOG(3) << "Begin copy grad tensor to gpups struct"; - this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, - hidden_size, total_length, batch_size); + if (!multi_mf_dim_) { + this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, + hidden_size, total_length, batch_size); + } else { + this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, + total_length, batch_size, grad_value_size); + } VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index << " len: " << total_length; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 3df5a4b473861e249521358243f93bea3b93a3c1..488a9ef8ce78ffe969b94dd3c283b927b2ec9a45 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -61,6 +61,45 @@ __global__ void PullCopy(float** dest, const FeatureValue* src, } } +__global__ void PullCopy(float** dest, const FeatureValue* src, + const int64_t* len, int slot_num, int total_len, + uint64_t** keys, uint64_t max_val_size, int* gpu_dim) { + CUDA_KERNEL_LOOP(i, total_len) { + int low = 0; + int high = slot_num - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < len[mid]) + high = mid; + else + low = mid + 1; + } + int x = low; + int y = i - (x ? len[x - 1] : 0); + FeatureValue* feature_value_ptr = + (FeatureValue*)((char*)src + uint64_t(i) * uint64_t(max_val_size)); + int mf_dim = gpu_dim[x] - 3; + if (*(keys[x] + y) == 0) { + *(dest[x] + y * (mf_dim + 3)) = 0; + *(dest[x] + y * (mf_dim + 3) + 1) = 0; + *(dest[x] + y * (mf_dim + 3) + 2) = 0; + } else { + *(dest[x] + y * (mf_dim + 3)) = feature_value_ptr->show; + *(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr->clk; + *(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr->lr; + } + if ((feature_value_ptr)->mf_size == 0 || *(keys[x] + y) == 0) { + for (int j = 0; j < mf_dim; j++) { + *(dest[x] + y * (mf_dim + 3) + 3 + j) = 0; + } + } else { + for (int j = 0; j < mf_dim; j++) { + *(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr->mf[1 + j]; + } + } + } +} + __global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys, const int64_t* len, int slot_num, int total_len) { @@ -105,6 +144,35 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len, } } +__global__ void PushCopyWithPool(FeaturePushValue* dest, float** src, + int64_t* len, int slot_num, uint64_t total_len, + int bs, int* slot_vector, int* mf_dim_vector, + size_t grad_value_size) { + CUDA_KERNEL_LOOP(i, total_len) { + int low = 0; + int high = slot_num - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < len[mid]) + high = mid; + else + low = mid + 1; + } + int x = low; + int y = i - (x ? len[low - 1] : 0); + FeaturePushValue* cur = + (FeaturePushValue*)((char*)dest + i * grad_value_size); + cur->slot = slot_vector[x]; + int mf_dim = mf_dim_vector[x]; + cur->mf_dim = mf_dim; + cur->show = *(src[x] + y * (mf_dim + 3)); + cur->clk = *(src[x] + y * (mf_dim + 3) + 1); + cur->lr_g = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; + for (int j = 0; j < cur->mf_dim; j++) { + cur->mf_g[j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; + } + } +} PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; } void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, @@ -128,6 +196,26 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, cudaStreamSynchronize(stream); } +void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, + uint64_t** gpu_keys, + const std::vector& values, + const FeatureValue* total_values_gpu, + const int64_t* gpu_len, const int slot_num, + const int hidden_size, + const int64_t total_length, int* gpu_dim) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + auto buf_value = memory::Alloc(place, values.size() * sizeof(float*)); + float** gpu_values = reinterpret_cast(buf_value->ptr()); + cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), + cudaMemcpyHostToDevice); + PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( + gpu_values, total_values_gpu, gpu_len, slot_num, total_length, gpu_keys, + val_type_size_, gpu_dim); + cudaStreamSynchronize(stream); +} + void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys, uint64_t* total_keys, const int64_t* gpu_len, int slot_num, @@ -177,6 +265,45 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, cudaStreamSynchronize(stream); } +void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, + const std::vector& grad_values, + FeaturePushValue* total_grad_values_gpu, + const std::vector& slot_lengths, + const uint64_t total_length, + const int batch_size, size_t grad_value_size) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + auto slot_lengths_lod = slot_lengths; + for (int i = 1; i < slot_lengths_lod.size(); i++) { + slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + } + auto buf_grad_value = + memory::Alloc(place, grad_values.size() * sizeof(float*)); + auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); + auto buf_slot_vector = + memory::Alloc(place, slot_lengths_lod.size() * sizeof(int)); + auto buf_mf_dim_vector = + memory::Alloc(place, slot_lengths_lod.size() * sizeof(int)); + float** gpu_values = reinterpret_cast(buf_grad_value->ptr()); + int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); + int* d_slot_vector = reinterpret_cast(buf_slot_vector->ptr()); + int* d_mf_dim_vector = reinterpret_cast(buf_mf_dim_vector->ptr()); + cudaMemcpy(gpu_values, grad_values.data(), + grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice); + cudaMemcpy(gpu_len, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_slot_vector, slot_vector_.data(), + slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpy(d_mf_dim_vector, slot_mf_dim_vector_.data(), + slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); + PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( + total_grad_values_gpu, gpu_values, gpu_len, slot_lengths.size(), + total_length, batch_size, d_slot_vector, d_mf_dim_vector, + grad_value_size); + cudaStreamSynchronize(stream); +} + void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, float min_bound, float max_bound, float learning_rate, float initial_g2sum, diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index c38b819822c28bd87909a64e7fb71451ae709862..824f6007198b031799d2e29f91ba9da5759773b5 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -27,6 +27,7 @@ limitations under the License. */ #include #ifdef PADDLE_WITH_GLOO #include +#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #endif #include "paddle/fluid/distributed/ps/thirdparty/round_robin.h" @@ -54,6 +55,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_PSLIB #include "afs_api.h" #endif +#ifdef PADDLE_WITH_PSLIB +#include "downpour_accessor.h" // NOLINT +#endif namespace paddle { namespace framework { @@ -95,12 +99,21 @@ class PSGPUWrapper { PSGPUWrapper() { HeterPs_ = NULL; sleep_seconds_before_fail_exit_ = 300; + pull_thread_pool_.resize(thread_keys_shard_num_); + for (size_t i = 0; i < pull_thread_pool_.size(); i++) { + pull_thread_pool_[i].reset(new ::ThreadPool(1)); + } hbm_thread_pool_.resize(thread_keys_shard_num_); for (size_t i = 0; i < hbm_thread_pool_.size(); i++) { hbm_thread_pool_[i].reset(new ::ThreadPool(1)); } } + void PullSparse(const paddle::platform::Place& place, const int table_id, + const std::vector& keys, + const std::vector& values, + const std::vector& slot_lengths, + const std::vector& slot_dim, const int hidden_size); void PullSparse(const paddle::platform::Place& place, const int table_id, const std::vector& keys, const std::vector& values, @@ -119,13 +132,23 @@ class PSGPUWrapper { const FeatureValue* total_values_gpu, const int64_t* gpu_len, const int slot_num, const int hidden_size, const int64_t total_length); - + void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, + const std::vector& values, + const FeatureValue* total_values_gpu, const int64_t* gpu_len, + const int slot_num, const int hidden_size, + const int64_t total_length, int* gpu_dim); void CopyForPush(const paddle::platform::Place& place, const std::vector& grad_values, FeaturePushValue* total_grad_values_gpu, const std::vector& slot_lengths, const int hidden_size, const int64_t total_length, const int batch_size); + void CopyForPush(const paddle::platform::Place& place, + const std::vector& grad_values, + FeaturePushValue* total_grad_values_gpu, + const std::vector& slot_lengths, + const uint64_t total_length, const int batch_size, + size_t grad_value_size); void BuildGPUTask(std::shared_ptr gpu_task); void PreBuildTask(std::shared_ptr gpu_task); @@ -428,6 +451,7 @@ class PSGPUWrapper { std::shared_ptr current_task_ = nullptr; std::thread pre_build_threads_; bool running_ = false; + std::vector> pull_thread_pool_; std::vector> hbm_thread_pool_; protected: diff --git a/paddle/fluid/operators/pull_gpups_sparse_op.h b/paddle/fluid/operators/pull_gpups_sparse_op.h index f721608cffb0826613013cc625209bae8844d486..abfdb62ec34ac3633c530f175da54fb17c9389c6 100644 --- a/paddle/fluid/operators/pull_gpups_sparse_op.h +++ b/paddle/fluid/operators/pull_gpups_sparse_op.h @@ -26,6 +26,7 @@ template static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) { auto inputs = ctx.MultiInput("Ids"); auto outputs = ctx.MultiOutput("Out"); + auto embedding_size_vec = ctx.Attr>("size"); const auto slot_size = inputs.size(); std::vector all_keys(slot_size); // GpuPSPS only supports float now @@ -44,7 +45,7 @@ static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) { #ifdef PADDLE_WITH_HETERPS auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance(); gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths, - 0); + embedding_size_vec, 0); #endif } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 799d93918f2efd6b67da1d34c8a94bb39697dcf5..97506ead5fad4db942208eacf680db34a2554fb6 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -737,7 +737,7 @@ def _pull_gpups_sparse(input, for i in range(len(inputs)) ] w = helper.create_parameter( - attr=helper.param_attr, shape=[11], dtype=dtype, is_bias=False) + attr=helper.param_attr, shape=[size[0]], dtype=dtype, is_bias=False) helper.append_op( type='pull_gpups_sparse', inputs={'Ids': inputs,