diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index b821ccecf0a29c725e91fa827441b7c226dcebec..b860ea5d39cb54551aa1122ee129a93cf2b26a62 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -41,6 +41,10 @@ limitations under the License. */ #include "xpu/kernel/simd.h" #endif +#if defined(PADDLE_WITH_XPU_KP) +#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" +#endif + namespace paddle { namespace framework { @@ -56,11 +60,10 @@ class TableContainer capacity, ValType()) {} }; #elif defined(PADDLE_WITH_XPU_KP) - template class XPUCacheArray { public: - explicit XPUCacheArray(size_t capacity) : capacity_(capacity), size_(0) { + explicit XPUCacheArray(long long capacity) : capacity_(capacity), size_(0) { xpu_malloc(reinterpret_cast(&keys), capacity_ * sizeof(KeyType)); xpu_malloc(reinterpret_cast(&vals), capacity_ * sizeof(ValType)); } @@ -71,8 +74,27 @@ class XPUCacheArray { } void print() {} - // ValType* find(const KeyType& key) { return NULL; } - // bool insert(const KeyType& key, const ValType& val) { return true; } + +#if defined(__xpu__) + __device__ ValType* find(const KeyType& key) { + for (int i = 0; i < size_; i++) { + if (keys[i] == key) return &vals[i]; + } + return NULL; + } + __device__ bool insert(const KeyType& key, const ValType& val) { + // # NOTE(zhangminxu): we set the capacity larger than the feasign number of + // one batch + if (size_ == capacity_) { + return false; + } else { + keys[size_] = key; + vals[size_] = val; + size_++; + return true; + } + } +#endif int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; } size_t size() { return size_; } @@ -110,6 +132,11 @@ class HashTable { void show(); +#if defined(PADDLE_WITH_XPU_KP) + void set_sparse_sgd(const OptimizerConfig& optimizer_config); + void set_embedx_sgd(const OptimizerConfig& optimizer_config); +#endif + template void dump_to_cpu(int devid, StreamType stream); @@ -151,6 +178,8 @@ class HashTable { TableContainer* container_; #elif defined(PADDLE_WITH_XPU_KP) XPUCacheArray* container_; + OptimizerConfig* xpu_optimizer_config_; + OptimizerConfig cpu_optimizer_config_; #endif int BLOCK_SIZE_{256}; float LOAD_FACTOR{0.75f}; diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps index e879d817b14dd1345c2f865c50bafe2581d130b4..cd43a73b44ec302902bbfc951ffd4c9dacd8c616 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps @@ -14,41 +14,21 @@ limitations under the License. */ #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" - -namespace optimizer_config { -extern _global_ptr_ float* nonclk_coeff; -extern _global_ptr_ float* clk_coeff; - -extern _global_ptr_ float* min_bound; -extern _global_ptr_ float* max_bound; -extern _global_ptr_ float* learning_rate; -extern _global_ptr_ float* initial_g2sum; -extern _global_ptr_ float* initial_range; - -extern _global_ptr_ float* mf_create_thresholds; -extern _global_ptr_ float* mf_learning_rate; -extern _global_ptr_ float* mf_initial_g2sum; -extern _global_ptr_ float* mf_initial_range; -extern _global_ptr_ float* mf_min_bound; -extern _global_ptr_ float* mf_max_bound; -} +#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" namespace paddle { namespace framework { #if defined(PADDLE_WITH_XPU_KP) -__device__ void update_lr(float& w, float& g2sum, float g, // NOLINT +__device__ void update_lr(OptimizerConfig& optimizer_config, float& w, + float& g2sum, + float g, // NOLINT float scale) { - __local__ float local_learning_rate; - __local__ float local_initial_g2sum; - __local__ float local_min_bound; - __local__ float local_max_bound; - - GM2LM(optimizer_config::learning_rate, &local_learning_rate, sizeof(float)); - GM2LM(optimizer_config::initial_g2sum, &local_initial_g2sum, sizeof(float)); - GM2LM(optimizer_config::min_bound, &local_min_bound, sizeof(float)); - GM2LM(optimizer_config::max_bound, &local_max_bound, sizeof(float)); + float local_learning_rate = optimizer_config.learning_rate; + float local_initial_g2sum = optimizer_config.initial_g2sum; + float local_min_bound = optimizer_config.min_bound; + float local_max_bound = optimizer_config.max_bound; double add_g2sum = 0; double ratio = local_learning_rate * @@ -65,19 +45,12 @@ __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT g2sum += add_g2sum; } -__device__ void update_mf(int n, float* w, float& g2sum, const float* g, - float scale) { - __local__ float local_mf_learning_rate; - __local__ float local_mf_initial_g2sum; - __local__ float local_mf_min_bound; - __local__ float local_mf_max_bound; - - GM2LM(optimizer_config::mf_learning_rate, &local_mf_learning_rate, - sizeof(float)); - GM2LM(optimizer_config::mf_initial_g2sum, &local_mf_initial_g2sum, - sizeof(float)); - GM2LM(optimizer_config::mf_min_bound, &local_mf_min_bound, sizeof(float)); - GM2LM(optimizer_config::mf_max_bound, &local_mf_max_bound, sizeof(float)); +__device__ void update_mf(OptimizerConfig& optimizer_config, int n, float* w, + float& g2sum, const float* g, float scale) { + float local_mf_learning_rate = optimizer_config.mf_learning_rate; + float local_mf_initial_g2sum = optimizer_config.mf_initial_g2sum; + float local_mf_min_bound = optimizer_config.mf_min_bound; + float local_mf_max_bound = optimizer_config.mf_max_bound; double add_g2sum = 0; double ratio = @@ -98,26 +71,22 @@ __device__ void update_mf(int n, float* w, float& g2sum, const float* g, __device__ float xpu_rand_uniform() { return 0.1; } template -__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT +__device__ void update_value(OptimizerConfig& optimizer_config, ValType& val, + const GradType& grad) { // NOLINT val.slot = grad.slot; val.show += grad.show; val.clk += grad.clk; - __local__ float local_nonclk_coeff; - __local__ float local_clk_coeff; + float local_nonclk_coeff = optimizer_config.nonclk_coeff; + float local_clk_coeff = optimizer_config.clk_coeff; - __local__ float local_mf_create_thresholds; - __local__ float local_mf_initial_range; - - GM2LM(optimizer_config::nonclk_coeff, &local_nonclk_coeff, sizeof(float)); - GM2LM(optimizer_config::clk_coeff, &local_clk_coeff, sizeof(float)); - GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds, - sizeof(float)); + float local_mf_create_thresholds = optimizer_config.mf_create_thresholds; + float local_mf_initial_range = optimizer_config.mf_initial_range; val.delta_score += local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk; - update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show); + update_lr(optimizer_config, val.lr, val.lr_g2sum, grad.lr_g, grad.show); if (val.mf_size == 0) { if (local_mf_create_thresholds <= @@ -130,12 +99,13 @@ __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT } } } else { - update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show); + update_mf(optimizer_config, MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, + grad.show); } } template -__global__ void insert_kernel(Table* table, const KeyType* const keys, +__global__ void insert_kernel(Table& table, const KeyType* const keys, const ValType* const vals, long long len) { int cid = core_id(); int ncores = core_num(); @@ -156,14 +126,14 @@ __global__ void insert_kernel(Table* table, const KeyType* const keys, GM2LM(keys, local_keys, read_len * sizeof(KeyType)); GM2LM(vals, local_vals, read_len * sizeof(ValType)); for (int k = 0; k < read_len; k++) { - // auto status = table->insert(local_keys[k], local_vals[k]); - // assert(status != false && "error: insert fails: table is full"); + auto status = table.insert(local_keys[k], local_vals[k]); + assert(status != false && "error: insert fails: table is full"); } } } template -__global__ void search_kernel(Table* table, const KeyType* const keys, +__global__ void search_kernel(Table& table, const KeyType* const keys, ValType* const vals, long long len) { int cid = core_id(); int ncores = core_num(); @@ -183,17 +153,18 @@ __global__ void search_kernel(Table* table, const KeyType* const keys, int read_len = min(len_per_loop, len - i); GM2LM(keys, local_keys, read_len * sizeof(KeyType)); for (int k = 0; k < read_len; k++) { - // ValType* val = table->find(local_keys[k]); - // if (val != NULL) { - // local_vals[k] = *val; - // } + ValType* val = table.find(local_keys[k]); + if (val != NULL) { + local_vals[k] = *val; + } } LM2GM(local_vals, vals + i, read_len * sizeof(ValType)); } } template -__global__ void update_kernel(Table* table, const KeyType* const keys, +__global__ void update_kernel(OptimizerConfig& optimizer_config, Table& table, + const KeyType* const keys, const GradType* const grads, long long len) { int cid = core_id(); int ncores = core_num(); @@ -216,10 +187,10 @@ __global__ void update_kernel(Table* table, const KeyType* const keys, GM2LM(grads, local_grads, read_len * sizeof(GradType)); for (int k = 0; k < read_len; k++) { - // ValType* val = table->find(local_keys[k]); - // if (val != NULL) { - // update_value(*val, grads[i]); - //} + ValType* val = table.find(local_keys[k]); + if (val != NULL) { + update_value(optimizer_config, *val, local_grads[i]); + } } } } @@ -229,14 +200,23 @@ HashTable::HashTable(size_t capacity) { auto tmp_container = XPUCacheArray(capacity); xpu_malloc(reinterpret_cast(&container_), sizeof(XPUCacheArray)); - xpu_memcpy(container_, &tmp_container, + xpu_memcpy((void*)container_, &tmp_container, sizeof(XPUCacheArray), XPU_HOST_TO_DEVICE); + + OptimizerConfig tmp_opt_config; + xpu_malloc(reinterpret_cast(&xpu_optimizer_config_), + sizeof(OptimizerConfig)); + + xpu_memcpy((void*)xpu_optimizer_config_, &tmp_opt_config, + sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE); + rwlock_.reset(new phi::RWLock); } template HashTable::~HashTable() { xpu_free((void*)container_); + xpu_free((void*)xpu_optimizer_config_); } template @@ -244,6 +224,34 @@ void HashTable::show() { container_->print(); } +template +void HashTable::set_sparse_sgd( + const OptimizerConfig& optimizer_config) { + cpu_optimizer_config_.nonclk_coeff = optimizer_config.nonclk_coeff; + cpu_optimizer_config_.clk_coeff = optimizer_config.clk_coeff; + cpu_optimizer_config_.min_bound = optimizer_config.min_bound; + cpu_optimizer_config_.max_bound = optimizer_config.max_bound; + cpu_optimizer_config_.learning_rate = optimizer_config.learning_rate; + cpu_optimizer_config_.initial_g2sum = optimizer_config.initial_g2sum; + cpu_optimizer_config_.initial_range = optimizer_config.initial_range; + xpu_memcpy((void*)xpu_optimizer_config_, &cpu_optimizer_config_, + sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE); +} + +template +void HashTable::set_embedx_sgd( + const OptimizerConfig& optimizer_config) { + cpu_optimizer_config_.mf_create_thresholds = + optimizer_config.mf_create_thresholds; + cpu_optimizer_config_.mf_learning_rate = optimizer_config.mf_learning_rate; + cpu_optimizer_config_.mf_initial_g2sum = optimizer_config.mf_initial_g2sum; + cpu_optimizer_config_.mf_initial_range = optimizer_config.mf_initial_range; + cpu_optimizer_config_.mf_min_bound = optimizer_config.mf_min_bound; + cpu_optimizer_config_.mf_max_bound = optimizer_config.mf_max_bound; + xpu_memcpy((void*)xpu_optimizer_config_, &cpu_optimizer_config_, + sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE); +} + template template void HashTable::get(const KeyType* d_keys, ValType* d_vals, @@ -254,7 +262,7 @@ void HashTable::get(const KeyType* d_keys, ValType* d_vals, long long c_len = (long long)len; search_kernel><<<4, 64, stream>>>( - container_, d_keys, d_vals, c_len); + *container_, d_keys, d_vals, c_len); } template @@ -278,7 +286,7 @@ void HashTable::insert(const KeyType* d_keys, long long c_len = (long long)len; insert_kernel><<<4, 64, stream>>>( - container_, d_keys, d_vals, c_len); + *container_, d_keys, d_vals, c_len); } template @@ -297,8 +305,8 @@ void HashTable::update(const KeyType* d_keys, } long long c_len = (long long)len; update_kernel, - GradType><<<4, 64, stream>>>(container_, d_keys, d_grads, - c_len); + GradType><<<4, 64, stream>>>( + *xpu_optimizer_config_, *container_, d_keys, d_grads, c_len); } template diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 73387f652de6dad6347feee93c5fb29e29c71061..f14a4d648b1d3272144151b321a6d21c75fe9e09 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/nccl.h" #include "thrust/pair.h" #elif defined(PADDLE_WITH_XPU_KP) +// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" #include #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #endif @@ -64,6 +65,11 @@ class HeterComm { void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len); #endif +#if defined(PADDLE_WITH_XPU_KP) + void set_sparse_sgd(const OptimizerConfig& optimizer_config); + void set_embedx_sgd(const OptimizerConfig& optimizer_config); +#endif + int log2i(int x); template diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index e0cef022d3145b48d9e5561dd0de3f708912081e..c39806f88444f5eb3db98f5df6e69ffe01ea4a2d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -338,6 +338,24 @@ int HeterComm::get_index_by_devid(int devid) { return resource_->get_index_by_devid(devid); } +#if defined(PADDLE_WITH_XPU_KP) +template +void HeterComm::set_sparse_sgd( + const OptimizerConfig& optimizer_config) { + for (auto& table : tables_) { + table->set_sparse_sgd(optimizer_config); + } +} + +template +void HeterComm::set_embedx_sgd( + const OptimizerConfig& optimizer_config) { + for (auto& table : tables_) { + table->set_embedx_sgd(optimizer_config); + } +} +#endif + template void HeterComm::build_ps( int dev_num, KeyType* h_keys, ValType* h_vals, size_t len, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index 583eb926a26a513e57cf1e779de41e2548969a6b..8a877f85076efa7e7bfd3330d7d1489b5be89de7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -50,6 +50,16 @@ int HeterPs::get_index_by_devid(int devid) { return comm_->get_index_by_devid(devid); } +#if defined(PADDLE_WITH_XPU_KP) +void HeterPs::set_sparse_sgd(const OptimizerConfig& optimizer_config) { + comm_->set_sparse_sgd(optimizer_config); +} + +void HeterPs::set_embedx_sgd(const OptimizerConfig& optimizer_config) { + comm_->set_embedx_sgd(optimizer_config); +} +#endif + void HeterPs::end_pass() { comm_->end_pass(); } void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 7fb50f4da1fce3876efceff5da86e325d70f18a8..7060817be91ebf1f7c7cc337feae9e848669cff6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -33,22 +33,27 @@ class HeterPs : public HeterPsBase { HeterPs(const HeterPs&) = delete; HeterPs& operator=(const HeterPs&) = delete; - virtual void pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals, - size_t len) override; - virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, - size_t len, size_t chunk_size, int stream_num) override; + void pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals, + size_t len) override; + void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, + size_t chunk_size, int stream_num) override; #if defined(PADDLE_WITH_CUDA) - virtual void set_nccl_comm_and_size( - const std::vector& inner_comms, - const std::vector& inter_comms, int comm_size) override; + void set_nccl_comm_and_size(const std::vector& inner_comms, + const std::vector& inter_comms, + int comm_size) override; #endif - virtual void end_pass() override; - virtual int get_index_by_devid(int devid) override; - virtual void show_one_table(int gpu_num) override; - virtual void push_sparse(int num, FeatureKey* d_keys, - FeaturePushValue* d_grads, size_t len) override; +#if defined(PADDLE_WITH_XPU_KP) + void set_sparse_sgd(const OptimizerConfig& optimizer_config) override; + void set_embedx_sgd(const OptimizerConfig& optimizer_config) override; +#endif + + void end_pass() override; + int get_index_by_devid(int devid) override; + void show_one_table(int gpu_num) override; + void push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads, + size_t len) override; private: std::shared_ptr> comm_; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index ddbf02df6c578904d8fa79934f4704ad00c4d121..79061ab66af1cb654360d0904199369e91071073 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -16,6 +16,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" +#if defined(PADDLE_WITH_XPU_KP) +#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" +#endif #ifdef PADDLE_WITH_HETERPS @@ -24,9 +27,9 @@ namespace framework { class HeterPsBase { public: - HeterPsBase(){}; - HeterPsBase(size_t capacity, std::shared_ptr resource){}; - virtual ~HeterPsBase(){}; + HeterPsBase() {} + HeterPsBase(size_t capacity, std::shared_ptr resource) {} + virtual ~HeterPsBase() {} HeterPsBase(const HeterPsBase&) = delete; HeterPsBase& operator=(const HeterPsBase&) = delete; @@ -44,6 +47,12 @@ class HeterPsBase { virtual void show_one_table(int gpu_num) = 0; virtual void push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads, size_t len) = 0; + +#if defined(PADDLE_WITH_XPU_KP) + virtual void set_sparse_sgd(const OptimizerConfig& optimizer_config) {} + virtual void set_embedx_sgd(const OptimizerConfig& optimizer_config) {} +#endif + static HeterPsBase* get_instance(size_t capacity, std::shared_ptr resource); }; diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h index 6d924a395e19ac063236a352c1145f29c84ded67..2a80aa4b52d91165447a9f556c036f555f032dbd 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h @@ -14,16 +14,10 @@ limitations under the License. */ #pragma once -#if defined(PADDLE_WITH_XPU_KP) -#include "xpu/kernel/cluster_header.h" -#include "xpu/kernel/debug.h" -#include "xpu/kernel/math.h" -#endif +#if defined(PADDLE_WITH_CUDA) namespace optimizer_config { -#if defined(PADDLE_WITH_CUDA) - __constant__ float nonclk_coeff = 0.1; __constant__ float clk_coeff = 1; @@ -39,24 +33,31 @@ __constant__ float mf_initial_g2sum = 3.0; __constant__ float mf_initial_range = 1e-4; __constant__ float mf_min_bound = -10; __constant__ float mf_max_bound = 10; +} // namespace optimizer_config #elif defined(PADDLE_WITH_XPU_KP) - -_global_ptr_ float* nonclk_coeff; -_global_ptr_ float* clk_coeff; - -_global_ptr_ float* min_bound; -_global_ptr_ float* max_bound; -_global_ptr_ float* learning_rate; -_global_ptr_ float* initial_g2sum; -_global_ptr_ float* initial_range; - -_global_ptr_ float* mf_create_thresholds; -_global_ptr_ float* mf_learning_rate; -_global_ptr_ float* mf_initial_g2sum; -_global_ptr_ float* mf_initial_range; -_global_ptr_ float* mf_min_bound; -_global_ptr_ float* mf_max_bound; +namespace paddle { +namespace framework { + +class OptimizerConfig { + public: + float nonclk_coeff; + float clk_coeff; + + float min_bound; + float max_bound; + float learning_rate; + float initial_g2sum; + float initial_range; + + float mf_create_thresholds; + float mf_learning_rate; + float mf_initial_g2sum; + float mf_initial_range; + float mf_min_bound; + float mf_max_bound; +}; +} // namespace framework +} // namespace paddle #endif -} // namespace optimizer_config diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps b/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps index 6d69ae0136d68e4c42026d50ddc24bf45350c194..571a090b9b4a6a408e577eabcc8b673b9bc72a36 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps @@ -18,7 +18,6 @@ limitations under the License. */ #include #include #include -#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/lod_tensor.h" #include "xpu/kernel/cluster_header.h" // NOLINT @@ -162,23 +161,7 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, long long* len, } } -PSGPUWrapper::~PSGPUWrapper() { - delete HeterPs_; - xpu_free((void*)optimizer_config::nonclk_coeff); - xpu_free((void*)optimizer_config::clk_coeff); - xpu_free((void*)optimizer_config::min_bound); - xpu_free((void*)optimizer_config::max_bound); - xpu_free((void*)optimizer_config::learning_rate); - xpu_free((void*)optimizer_config::initial_g2sum); - xpu_free((void*)optimizer_config::initial_range); - - xpu_free((void*)optimizer_config::mf_create_thresholds); - xpu_free((void*)optimizer_config::mf_learning_rate); - xpu_free((void*)optimizer_config::mf_initial_g2sum); - xpu_free((void*)optimizer_config::mf_initial_range); - xpu_free((void*)optimizer_config::mf_min_bound); - xpu_free((void*)optimizer_config::mf_max_bound); -} +PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; } void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, @@ -272,66 +255,29 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, float min_bound, float max_bound, float learning_rate, float initial_g2sum, float initial_range) { - xpu_malloc(reinterpret_cast(&optimizer_config::nonclk_coeff), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::clk_coeff), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::min_bound), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::max_bound), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::learning_rate), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::initial_g2sum), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::initial_range), - sizeof(float)); - - xpu_memcpy((void*)optimizer_config::nonclk_coeff, &nonclk_coeff, - sizeof(float), XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::clk_coeff, &clk_coeff, sizeof(float), - XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::min_bound, &min_bound, sizeof(float), - XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::max_bound, &max_bound, sizeof(float), - XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::learning_rate, &learning_rate, - sizeof(float), XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::initial_g2sum, &initial_g2sum, - sizeof(float), XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::initial_range, &initial_range, - sizeof(float), XPU_HOST_TO_DEVICE); + OptimizerConfig optimizer_config; + optimizer_config.nonclk_coeff = nonclk_coeff; + optimizer_config.clk_coeff = clk_coeff; + optimizer_config.min_bound = min_bound; + optimizer_config.max_bound = max_bound; + optimizer_config.learning_rate = learning_rate; + optimizer_config.initial_g2sum = initial_g2sum; + optimizer_config.initial_range = initial_range; + HeterPs_->set_sparse_sgd(optimizer_config); } void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, float mf_learning_rate, float mf_initial_g2sum, float mf_initial_range, float mf_min_bound, float mf_max_bound) { - xpu_malloc(reinterpret_cast(&optimizer_config::mf_create_thresholds), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::mf_learning_rate), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::mf_initial_g2sum), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::mf_initial_range), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::mf_min_bound), - sizeof(float)); - xpu_malloc(reinterpret_cast(&optimizer_config::mf_max_bound), - sizeof(float)); - - xpu_memcpy((void*)optimizer_config::mf_create_thresholds, - &mf_create_thresholds, sizeof(float), XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::mf_initial_g2sum, &mf_initial_g2sum, - sizeof(float), XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::mf_initial_range, &mf_initial_range, - sizeof(float), XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::mf_min_bound, &mf_min_bound, - sizeof(float), XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::mf_max_bound, &mf_max_bound, - sizeof(float), XPU_HOST_TO_DEVICE); - xpu_memcpy((void*)optimizer_config::mf_learning_rate, &mf_learning_rate, - sizeof(float), XPU_HOST_TO_DEVICE); + OptimizerConfig optimizer_config; + optimizer_config.mf_create_thresholds = mf_create_thresholds; + optimizer_config.mf_learning_rate = mf_learning_rate; + optimizer_config.mf_initial_g2sum = mf_initial_g2sum; + optimizer_config.mf_initial_range = mf_initial_range; + optimizer_config.mf_min_bound = mf_min_bound; + optimizer_config.mf_max_bound = mf_max_bound; + HeterPs_->set_embedx_sgd(optimizer_config); } } // end namespace framework