From 32f369a86a333970709afd47a030a4064746d603 Mon Sep 17 00:00:00 2001 From: pangengzheng <117730991+pangengzheng@users.noreply.github.com> Date: Thu, 9 Mar 2023 16:49:55 +0800 Subject: [PATCH] Enable gpups run on rec model (#51115) * support run haokanctr model in heterps-models * polish setup.py * polish JVM_LIB in evn_dict --- cmake/external/pslib.cmake | 1 + .../fleet/heter_ps/hashtable_kernel.cu | 7 + .../framework/fleet/heter_ps/heter_ps.cu | 25 +++- .../framework/fleet/heter_ps/optimizer.cuh.h | 136 +++++++++++++++++- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 41 +++--- python/env_dict.py.in | 1 + .../fleet/parameter_server/pslib/node.py | 2 +- .../incubate/distributed/fleet/role_maker.py | 2 +- python/setup.py.in | 2 + setup.py | 2 + 10 files changed, 186 insertions(+), 33 deletions(-) diff --git a/cmake/external/pslib.cmake b/cmake/external/pslib.cmake index 45d75165379..d7de1aae860 100644 --- a/cmake/external/pslib.cmake +++ b/cmake/external/pslib.cmake @@ -43,6 +43,7 @@ set(PSLIB_ROOT ${PSLIB_INSTALL_DIR}) set(PSLIB_INC_DIR ${PSLIB_ROOT}/include) set(PSLIB_LIB_DIR ${PSLIB_ROOT}/lib) set(PSLIB_LIB ${PSLIB_LIB_DIR}/libps.so) +set(JVM_LIB ${PSLIB_LIB_DIR}/libjvm.so) set(PSLIB_VERSION_PY ${PSLIB_DOWNLOAD_DIR}/pslib/version.py) set(PSLIB_IOMP_LIB ${PSLIB_LIB_DIR}/libiomp5.so) #todo what is this set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_ROOT}/lib") diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 752c1123944..c67ed130796 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -524,6 +524,13 @@ template void HashTable::update< size_t len, SparseAdagradOptimizer sgd, cudaStream_t stream); +template void HashTable::update< + StdAdagradOptimizer, + cudaStream_t>(const uint64_t* d_keys, + const char* d_grads, + size_t len, + StdAdagradOptimizer sgd, + cudaStream_t stream); template void HashTable::update< SparseAdamOptimizer, cudaStream_t>(const uint64_t* d_keys, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index 01e8a6212f9..3c794238207 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -27,12 +27,12 @@ HeterPsBase* HeterPsBase::get_instance( std::unordered_map fleet_config, std::string accessor_type, int optimizer_type) { + auto* accessor_wrapper_ptr = + GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); + CommonFeatureValueAccessor* gpu_accessor = + ((AccessorWrapper*)accessor_wrapper_ptr) + ->AccessorPtr(); if (accessor_type == "CtrDymfAccessor") { - auto* accessor_wrapper_ptr = - GlobalAccessorFactory::GetInstance().GetAccessorWrapper(); - CommonFeatureValueAccessor* gpu_accessor = - ((AccessorWrapper*)accessor_wrapper_ptr) - ->AccessorPtr(); if (optimizer_type == 1) { return new HeterPs( capacity, resource, *gpu_accessor); @@ -43,9 +43,20 @@ HeterPsBase* HeterPsBase::get_instance( return new HeterPs( capacity, resource, *gpu_accessor); } + } else if (accessor_type == "DownpourCtrDymfAccessor" || + accessor_type == "DownpourCtrDoubleDymfAccessor") { + if (optimizer_type == 1) { // adagrad + return new HeterPs( + capacity, resource, *gpu_accessor); + } else if (optimizer_type == 2) { // std_adagrad + return new HeterPs( + capacity, resource, *gpu_accessor); + } } else { - VLOG(0) << " HeterPsBase get_instance Warning: now only support " - "CtrDymfAccessor, but get " + VLOG(0) << "HeterPsBase get_instance Warning: now only support " + "CtrDymfAccessor, DownpourCtrDymfAccessor or " + "DownpourCtrDoubleDymfAccessor, " + "but get " << accessor_type; } } diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index ef633528af1..24d9516b6f0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -31,7 +31,7 @@ template class SparseAdagradOptimizer { public: SparseAdagradOptimizer() {} - SparseAdagradOptimizer(GPUAccessor gpu_accessor) { + explicit SparseAdagradOptimizer(const GPUAccessor& gpu_accessor) { gpu_accessor_ = gpu_accessor; _lr_embedding_dim = 1; _embedding_dim = gpu_accessor_.common_feature_value.EmbedWDim(); @@ -102,7 +102,8 @@ class SparseAdagradOptimizer { scale, slot); - int mf_dim = int(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]); + int mf_dim = + static_cast(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]); if (ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * @@ -145,11 +146,132 @@ class SparseAdagradOptimizer { size_t _lr_embedding_dim; }; +template +class StdAdagradOptimizer { + public: + StdAdagradOptimizer() {} + explicit StdAdagradOptimizer(const GPUAccessor& gpu_accessor) { + gpu_accessor_ = gpu_accessor; + } + + ~StdAdagradOptimizer() {} + + __device__ void update_lr(const OptimizerConfig& optimizer_config, + float* w, + float* g2sum, + float g, + float scale) { + double ratio = optimizer_config.learning_rate * + sqrt(optimizer_config.initial_g2sum / + (optimizer_config.initial_g2sum + *g2sum)); + double scaled_grad = g / scale; + + *w += scaled_grad * ratio; + + if (*w < optimizer_config.min_bound) *w = optimizer_config.min_bound; + if (*w > optimizer_config.max_bound) *w = optimizer_config.max_bound; + + *g2sum += scaled_grad * scaled_grad; + } + + __device__ int g2sum_index() { return 0; } + + __device__ void update_mf(const OptimizerConfig& optimizer_config, + int n, + float* w, + float* sgd, + const float* g, + float scale) { + for (int i = 0; i < n; ++i) { + float& g2sum = sgd[g2sum_index() + i]; + double scaled_grad = g[i] / scale; + + double ratio = optimizer_config.mf_learning_rate * + sqrt(optimizer_config.mf_initial_g2sum / + (optimizer_config.mf_initial_g2sum + g2sum)); + + w[i] += scaled_grad * ratio; + + if (w[i] < optimizer_config.mf_min_bound) + w[i] = optimizer_config.mf_min_bound; + if (w[i] > optimizer_config.mf_max_bound) + w[i] = optimizer_config.mf_max_bound; + + g2sum += scaled_grad * scaled_grad; + } + } + + __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, + float* ptr, + const float* grad) { + float grad_show = grad[gpu_accessor_.common_push_value.ShowIndex()]; + float grad_clk = grad[gpu_accessor_.common_push_value.ClickIndex()]; + + ptr[gpu_accessor_.common_feature_value.SlotIndex()] = + grad[gpu_accessor_.common_push_value.SlotIndex()]; + + ptr[gpu_accessor_.common_feature_value.ShowIndex()] += grad_show; + ptr[gpu_accessor_.common_feature_value.ClickIndex()] += grad_clk; + + ptr[gpu_accessor_.common_feature_value.DeltaScoreIndex()] += + optimizer_config.nonclk_coeff * (grad_show - grad_clk) + + optimizer_config.clk_coeff * grad_clk; + + float ptr_show = ptr[gpu_accessor_.common_feature_value.ShowIndex()]; + float ptr_clk = ptr[gpu_accessor_.common_feature_value.ClickIndex()]; + float grad_lr_g = grad[gpu_accessor_.common_push_value.EmbedGIndex()]; + + float ptr_mf_size = ptr[gpu_accessor_.common_feature_value.MfSizeIndex()]; + int ptr_mf_dim = + static_cast(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]); + + update_lr(optimizer_config, + ptr + gpu_accessor_.common_feature_value.EmbedWIndex(), + ptr + gpu_accessor_.common_feature_value.EmbedG2SumIndex(), + grad_lr_g, + grad_show); + + if (ptr_mf_size == 0.0) { + if (optimizer_config.mf_create_thresholds <= + optimizer_config.nonclk_coeff * (ptr_show - ptr_clk) + + optimizer_config.clk_coeff * ptr_clk) { + ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] = + gpu_accessor_.common_feature_value.MFSize(ptr_mf_dim) / + sizeof(float); + + // get embedxw index + int embedx_w_index = + gpu_accessor_.common_feature_value.EmbedxWOffsetIndex(ptr); + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + curandState state; + curand_init(clock64(), tid_x, 0, &state); + for (int i = 0; i < ptr_mf_dim; ++i) { + ptr[embedx_w_index + i] = + (curand_uniform(&state)) * optimizer_config.mf_initial_range; + ptr[gpu_accessor_.common_feature_value.EmbedxG2SumIndex() + i] = 0; + } + } + } else { + int embedx_w_index = + gpu_accessor_.common_feature_value.EmbedxWOffsetIndex(ptr); + update_mf(optimizer_config, + ptr_mf_dim, + &ptr[embedx_w_index], + &ptr[gpu_accessor_.common_feature_value.EmbedxG2SumIndex()], + &grad[gpu_accessor_.common_push_value.EmbedxGIndex()], + grad_show); + } + } + + private: + GPUAccessor gpu_accessor_; +}; + template class SparseAdamOptimizer { public: SparseAdamOptimizer() {} - SparseAdamOptimizer(GPUAccessor gpu_accessor) { + explicit SparseAdamOptimizer(const GPUAccessor& gpu_accessor) { gpu_accessor_ = gpu_accessor; _lr_embedding_dim = 1; _embedding_dim = gpu_accessor_.common_feature_value.EmbedWDim(); @@ -263,7 +385,8 @@ class SparseAdamOptimizer { ptr + gpu_accessor_.common_feature_value.EmbedG2SumIndex(), grad + gpu_accessor_.common_push_value.EmbedGIndex(), g_show); - int mf_dim = int(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]); + int mf_dim = + static_cast(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]); if (ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * @@ -331,7 +454,7 @@ template class SparseAdamSharedOptimizer { public: SparseAdamSharedOptimizer() {} - SparseAdamSharedOptimizer(GPUAccessor gpu_accessor) { + explicit SparseAdamSharedOptimizer(const GPUAccessor& gpu_accessor) { gpu_accessor_ = gpu_accessor; _lr_embedding_dim = 1; _embedding_dim = gpu_accessor_.common_feature_value.EmbedWDim(); @@ -414,7 +537,8 @@ class SparseAdamSharedOptimizer { ptr + gpu_accessor_.common_feature_value.EmbedG2SumIndex(), grad + gpu_accessor_.common_push_value.EmbedGIndex(), g_show); - int mf_dim = int(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]); + int mf_dim = + static_cast(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]); if (ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 84076aad26a..0d4a6c4871d 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -439,24 +439,6 @@ class PSGPUWrapper { google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param); auto sparse_table = ps_param.server_param().downpour_server_param().downpour_table_param(0); - // set build thread_num and shard_num - thread_keys_thread_num_ = sparse_table.shard_num(); - thread_keys_shard_num_ = sparse_table.shard_num(); - VLOG(1) << "ps_gpu build phase thread_num:" << thread_keys_thread_num_ - << " shard_num:" << thread_keys_shard_num_; - - pull_thread_pool_.resize(thread_keys_shard_num_); - for (size_t i = 0; i < pull_thread_pool_.size(); i++) { - pull_thread_pool_[i].reset(new ::ThreadPool(1)); - } - hbm_thread_pool_.resize(device_num_); - for (size_t i = 0; i < hbm_thread_pool_.size(); i++) { - hbm_thread_pool_[i].reset(new ::ThreadPool(1)); - } - cpu_work_pool_.resize(device_num_); - for (size_t i = 0; i < cpu_work_pool_.size(); i++) { - cpu_work_pool_[i].reset(new ::ThreadPool(cpu_device_thread_num_)); - } auto sparse_table_accessor = sparse_table.accessor(); auto sparse_table_accessor_parameter = @@ -480,6 +462,7 @@ class PSGPUWrapper { add_sparse_optimizer( config, sparse_table_accessor.embedx_sgd_param(), "mf_"); } + config["sparse_shard_num"] = sparse_table.shard_num(); fleet_config_ = config; GlobalAccessorFactory::GetInstance().Init(accessor_class_); @@ -660,6 +643,28 @@ class PSGPUWrapper { #endif void InitializeGPUServer(std::unordered_map config) { + // set build thread_num and shard_num + int sparse_shard_num = (config.find("sparse_shard_num") == config.end()) + ? 37 + : config["sparse_shard_num"]; + thread_keys_thread_num_ = sparse_shard_num; + thread_keys_shard_num_ = sparse_shard_num; + VLOG(1) << "ps_gpu build phase thread_num:" << thread_keys_thread_num_ + << " shard_num:" << thread_keys_shard_num_; + + pull_thread_pool_.resize(thread_keys_shard_num_); + for (size_t i = 0; i < pull_thread_pool_.size(); i++) { + pull_thread_pool_[i].reset(new ::ThreadPool(1)); + } + hbm_thread_pool_.resize(device_num_); + for (size_t i = 0; i < hbm_thread_pool_.size(); i++) { + hbm_thread_pool_[i].reset(new ::ThreadPool(1)); + } + cpu_work_pool_.resize(device_num_); + for (size_t i = 0; i < cpu_work_pool_.size(); i++) { + cpu_work_pool_[i].reset(new ::ThreadPool(cpu_device_thread_num_)); + } + float nonclk_coeff = (config.find("nonclk_coeff") == config.end()) ? 1.0 : config["nonclk_coeff"]; diff --git a/python/env_dict.py.in b/python/env_dict.py.in index d06a9d1af3a..ade6b5610e1 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -37,6 +37,7 @@ env_dict={ 'CINN_INCLUDE_DIR':'@CINN_INCLUDE_DIR@', 'CMAKE_BUILD_TYPE':'@CMAKE_BUILD_TYPE@', 'PSLIB_LIB':'@PSLIB_LIB@', + 'JVM_LIB':'@JVM_LIB@', 'PSLIB_VERSION_PY':'@PSLIB_VERSION_PY@', 'WITH_MKLDNN':'@WITH_MKLDNN@', 'MKLDNN_SHARED_LIB':'@MKLDNN_SHARED_LIB@', diff --git a/python/paddle/incubate/distributed/fleet/parameter_server/pslib/node.py b/python/paddle/incubate/distributed/fleet/parameter_server/pslib/node.py index 94b4373f549..8dbd73abbdd 100644 --- a/python/paddle/incubate/distributed/fleet/parameter_server/pslib/node.py +++ b/python/paddle/incubate/distributed/fleet/parameter_server/pslib/node.py @@ -197,7 +197,6 @@ class DownpourServer(Server): if ( accessor_class == 'DownpourFeatureValueAccessor' or accessor_class == 'DownpourCtrAccessor' - or accessor_class == 'DownpourCtrDymfAccessor' or accessor_class == 'DownpourCtrDoubleAccessor' ): table.accessor.sparse_sgd_param.learning_rate = strategy.get( @@ -350,6 +349,7 @@ class DownpourServer(Server): elif ( accessor_class == 'DownpourUnitAccessor' or accessor_class == 'DownpourDoubleUnitAccessor' + or accessor_class == 'DownpourCtrDymfAccessor' ): self.add_sparse_table_common_config(table, strategy) self.add_sparse_optimizer( diff --git a/python/paddle/incubate/distributed/fleet/role_maker.py b/python/paddle/incubate/distributed/fleet/role_maker.py index 323d150443d..1fec0f50a1f 100644 --- a/python/paddle/incubate/distributed/fleet/role_maker.py +++ b/python/paddle/incubate/distributed/fleet/role_maker.py @@ -1029,7 +1029,7 @@ class GeneralRoleMaker(RoleMakerBase): return "lo" def __start_kv_server(self, http_server_d, size_d): - from paddle.distributed.launch.utils.kv_server import KVServer + from paddle.distributed.fleet.utils.http_server import KVServer http_server = KVServer(int(self._http_ip_port[1]), size_d) http_server.start() diff --git a/python/setup.py.in b/python/setup.py.in index 452ca01e2b6..d4df12b1f46 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -593,9 +593,11 @@ if '${WITH_CINN}' == 'ON': if '${WITH_PSLIB}' == 'ON': shutil.copy('${PSLIB_LIB}', libs_path) + shutil.copy('${JVM_LIB}', libs_path) if os.path.exists('${PSLIB_VERSION_PY}'): shutil.copy('${PSLIB_VERSION_PY}', '${PADDLE_BINARY_DIR}/python/paddle/incubate/distributed/fleet/parameter_server/pslib/') package_data['paddle.libs'] += ['libps' + ext_name] + package_data['paddle.libs'] += ['libjvm' + ext_name] if '${WITH_MKLDNN}' == 'ON': if '${CMAKE_BUILD_TYPE}' == 'Release' and os.name != 'nt': diff --git a/setup.py b/setup.py index 07c96b720fa..75969c5f41c 100644 --- a/setup.py +++ b/setup.py @@ -975,6 +975,7 @@ def get_package_data_and_package_dir(): ) if env_dict.get("WITH_PSLIB") == 'ON': shutil.copy(env_dict.get("PSLIB_LIB"), libs_path) + shutil.copy(env_dict.get("JVM_LIB"), libs_path) if os.path.exists(env_dict.get("PSLIB_VERSION_PY")): shutil.copy( env_dict.get("PSLIB_VERSION_PY"), @@ -982,6 +983,7 @@ def get_package_data_and_package_dir(): + '/python/paddle/incubate/distributed/fleet/parameter_server/pslib/', ) package_data['paddle.libs'] += ['libps' + ext_suffix] + package_data['paddle.libs'] += ['libjvm' + ext_suffix] if env_dict.get("WITH_MKLDNN") == 'ON': if env_dict.get("CMAKE_BUILD_TYPE") == 'Release' and os.name != 'nt': # only change rpath in Release mode. -- GitLab