From d9c174d135707307e312ce8cb7c121e09031568b Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Wed, 29 Dec 2021 18:15:16 +0800 Subject: [PATCH] add hashtable dynamic mf support (#38493) add hashtable dynamic mf support --- .../framework/fleet/heter_ps/hashtable.h | 20 +++++ .../framework/fleet/heter_ps/hashtable_inl.h | 88 +++++++++++++++++++ .../framework/fleet/heter_ps/optimizer.cuh.h | 34 +++++++ 3 files changed, 142 insertions(+) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index e7f098320c..509b43431b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -27,6 +27,8 @@ limitations under the License. */ #include "thrust/pair.h" // #include "cudf/concurrent_unordered_map.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" +#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h" #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/platform/device/gpu/gpu_types.h" @@ -53,8 +55,11 @@ class HashTable { HashTable& operator=(const HashTable&) = delete; void insert(const KeyType* d_keys, const ValType* d_vals, size_t len, gpuStream_t stream); + void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index, + gpuStream_t stream); void get(const KeyType* d_keys, ValType* d_vals, size_t len, gpuStream_t stream); + void get(const KeyType* d_keys, char* d_vals, size_t len, gpuStream_t stream); void show(); void dump_to_cpu(int devid, cudaStream_t stream); @@ -62,8 +67,20 @@ class HashTable { void update(const KeyType* d_keys, const GradType* d_grads, size_t len, Sgd sgd, gpuStream_t stream); + template + void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd, + gpuStream_t stream); + int size() { return container_->size(); } + void set_feature_value_size(size_t pull_feature_value_size, + size_t push_grad_value_size) { + pull_feature_value_size_ = pull_feature_value_size; + push_grad_value_size_ = push_grad_value_size; + VLOG(3) << "hashtable set pull value size: " << pull_feature_value_size_ + << " push value size: " << push_grad_value_size_; + } + std::unique_ptr rwlock_{nullptr}; private: @@ -71,6 +88,9 @@ class HashTable { int BLOCK_SIZE_{256}; float LOAD_FACTOR{0.75f}; size_t capacity_; + size_t max_mf_dim_ = 8; + size_t pull_feature_value_size_; + size_t push_grad_value_size_; }; } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h index 9f3d1a7adc..dec7357468 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h @@ -42,6 +42,23 @@ __global__ void insert_kernel(Table* table, } } +template +__global__ void insert_kernel(Table* table, + const typename Table::key_type* const keys, + size_t len, char* pool, int start_index) { + ReplaceOp op; + thrust::pair kv; + + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < len) { + kv.first = keys[i]; + kv.second = (Table::mapped_type)(pool + (start_index + i) * 80); + auto it = table->insert(kv, op); + assert(it != table->end() && "error: insert fails: table is full"); + } +} + template __global__ void search_kernel(Table* table, const typename Table::key_type* const keys, @@ -56,6 +73,20 @@ __global__ void search_kernel(Table* table, } } +template +__global__ void dy_mf_search_kernel(Table* table, + const typename Table::key_type* const keys, + char* const vals, size_t len, + size_t pull_feature_value_size) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + auto it = table->find(keys[i]); + + if (it != table->end()) { + *(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second); + } + } +} template __global__ void update_kernel(Table* table, const typename Table::key_type* const keys, @@ -70,6 +101,23 @@ __global__ void update_kernel(Table* table, } } +template +__global__ void dy_mf_update_kernel(Table* table, + const typename Table::key_type* const keys, + const char* const grads, size_t len, + Sgd sgd, size_t grad_value_size) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + auto it = table->find(keys[i]); + if (it != table->end()) { + FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size); + sgd.dy_mf_update_value((it.getter())->second, *cur); + } else { + printf("yxf::push miss key: %d", keys[i]); + } + } +} + template HashTable::HashTable(size_t capacity) { container_ = new TableContainer(capacity); @@ -97,6 +145,17 @@ void HashTable::get(const KeyType* d_keys, ValType* d_vals, d_vals, len); } +template +void HashTable::get(const KeyType* d_keys, char* d_vals, + size_t len, gpuStream_t stream) { + if (len == 0) { + return; + } + const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; + dy_mf_search_kernel<<>>( + container_, d_keys, d_vals, len, pull_feature_value_size_); +} + template void HashTable::insert(const KeyType* d_keys, const ValType* d_vals, size_t len, @@ -109,6 +168,21 @@ void HashTable::insert(const KeyType* d_keys, d_vals, len); } +template +void HashTable::insert(const KeyType* d_keys, size_t len, + char* pool, size_t start_index, + gpuStream_t stream) { + if (len == 0) { + return; + } + const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; + if (pool == NULL) { + return; + } + insert_kernel<<>>(container_, d_keys, len, + pool, start_index); +} + template void HashTable::dump_to_cpu(int devid, cudaStream_t stream) { container_->prefetch(cudaCpuDeviceId, stream); @@ -166,6 +240,20 @@ void HashTable::update(const KeyType* d_keys, d_grads, len, sgd); } +template +template +void HashTable::update(const KeyType* d_keys, + const char* d_grads, size_t len, + Sgd sgd, gpuStream_t stream) { + if (len == 0) { + return; + } + const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; + + dy_mf_update_kernel<<>>( + container_, d_keys, d_grads, len, sgd, push_grad_value_size_); +} + } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index 374984ecdb..ff9976db5d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -96,6 +96,40 @@ class Optimizer { update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show); } } + + __device__ void dy_mf_update_value(ValType* ptr, const GradType& grad) { + ptr->slot = grad.slot; + ptr->show += grad.show; + ptr->clk += grad.clk; + ptr->delta_score += + optimizer_config::nonclk_coeff * (grad.show - grad.clk) + + optimizer_config::clk_coeff * grad.clk; + + update_lr(ptr->lr, ptr->lr_g2sum, grad.lr_g, grad.show); + // use MF_DIM temporarily + // ptr->mf_dim = grad.mf_dim; + + if (ptr->mf_size == 0) { + if (optimizer_config::mf_create_thresholds <= + optimizer_config::nonclk_coeff * (ptr->show - ptr->clk) + + optimizer_config::clk_coeff * ptr->clk) { + // ptr->mf_size = ptr->mf_dim + 1; + + ptr->mf_size = MF_DIM + 1; + ptr->mf[0] = 0; + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + curandState state; + curand_init(clock64(), tid_x, 0, &state); + for (int i = 0; i < MF_DIM; ++i) { + ptr->mf[i + 1] = + (curand_uniform(&state)) * optimizer_config::mf_initial_range; + } + } + } else { + update_mf(MF_DIM, &(ptr->mf[1]), ptr->mf[0], grad.mf_g, + grad.show); // for local test + } + } }; } // end namespace framework -- GitLab