未验证 提交 6becabaa 编写于 作者: Z zmxdream 提交者: GitHub

[XPUPS]add hashtable interface (#41987)

* add hashtable interface. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix optimizer config for xpups. test=develop

* fix. test=develop

* fix. test=develop
上级 ec995c59
...@@ -41,6 +41,10 @@ limitations under the License. */ ...@@ -41,6 +41,10 @@ limitations under the License. */
#include "xpu/kernel/simd.h" #include "xpu/kernel/simd.h"
#endif #endif
#if defined(PADDLE_WITH_XPU_KP)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -56,11 +60,10 @@ class TableContainer ...@@ -56,11 +60,10 @@ class TableContainer
capacity, ValType()) {} capacity, ValType()) {}
}; };
#elif defined(PADDLE_WITH_XPU_KP) #elif defined(PADDLE_WITH_XPU_KP)
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
class XPUCacheArray { class XPUCacheArray {
public: public:
explicit XPUCacheArray(size_t capacity) : capacity_(capacity), size_(0) { explicit XPUCacheArray(long long capacity) : capacity_(capacity), size_(0) {
xpu_malloc(reinterpret_cast<void**>(&keys), capacity_ * sizeof(KeyType)); xpu_malloc(reinterpret_cast<void**>(&keys), capacity_ * sizeof(KeyType));
xpu_malloc(reinterpret_cast<void**>(&vals), capacity_ * sizeof(ValType)); xpu_malloc(reinterpret_cast<void**>(&vals), capacity_ * sizeof(ValType));
} }
...@@ -71,8 +74,27 @@ class XPUCacheArray { ...@@ -71,8 +74,27 @@ class XPUCacheArray {
} }
void print() {} void print() {}
// ValType* find(const KeyType& key) { return NULL; }
// bool insert(const KeyType& key, const ValType& val) { return true; } #if defined(__xpu__)
__device__ ValType* find(const KeyType& key) {
for (int i = 0; i < size_; i++) {
if (keys[i] == key) return &vals[i];
}
return NULL;
}
__device__ bool insert(const KeyType& key, const ValType& val) {
// # NOTE(zhangminxu): we set the capacity larger than the feasign number of
// one batch
if (size_ == capacity_) {
return false;
} else {
keys[size_] = key;
vals[size_] = val;
size_++;
return true;
}
}
#endif
int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; } int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; }
size_t size() { return size_; } size_t size() { return size_; }
...@@ -110,6 +132,11 @@ class HashTable { ...@@ -110,6 +132,11 @@ class HashTable {
void show(); void show();
#if defined(PADDLE_WITH_XPU_KP)
void set_sparse_sgd(const OptimizerConfig& optimizer_config);
void set_embedx_sgd(const OptimizerConfig& optimizer_config);
#endif
template <typename StreamType> template <typename StreamType>
void dump_to_cpu(int devid, StreamType stream); void dump_to_cpu(int devid, StreamType stream);
...@@ -151,6 +178,8 @@ class HashTable { ...@@ -151,6 +178,8 @@ class HashTable {
TableContainer<KeyType, ValType>* container_; TableContainer<KeyType, ValType>* container_;
#elif defined(PADDLE_WITH_XPU_KP) #elif defined(PADDLE_WITH_XPU_KP)
XPUCacheArray<KeyType, ValType>* container_; XPUCacheArray<KeyType, ValType>* container_;
OptimizerConfig* xpu_optimizer_config_;
OptimizerConfig cpu_optimizer_config_;
#endif #endif
int BLOCK_SIZE_{256}; int BLOCK_SIZE_{256};
float LOAD_FACTOR{0.75f}; float LOAD_FACTOR{0.75f};
......
...@@ -14,41 +14,21 @@ limitations under the License. */ ...@@ -14,41 +14,21 @@ limitations under the License. */
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" #include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
namespace optimizer_config {
extern _global_ptr_ float* nonclk_coeff;
extern _global_ptr_ float* clk_coeff;
extern _global_ptr_ float* min_bound;
extern _global_ptr_ float* max_bound;
extern _global_ptr_ float* learning_rate;
extern _global_ptr_ float* initial_g2sum;
extern _global_ptr_ float* initial_range;
extern _global_ptr_ float* mf_create_thresholds;
extern _global_ptr_ float* mf_learning_rate;
extern _global_ptr_ float* mf_initial_g2sum;
extern _global_ptr_ float* mf_initial_range;
extern _global_ptr_ float* mf_min_bound;
extern _global_ptr_ float* mf_max_bound;
}
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#if defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU_KP)
__device__ void update_lr(float& w, float& g2sum, float g, // NOLINT __device__ void update_lr(OptimizerConfig& optimizer_config, float& w,
float& g2sum,
float g, // NOLINT
float scale) { float scale) {
__local__ float local_learning_rate; float local_learning_rate = optimizer_config.learning_rate;
__local__ float local_initial_g2sum; float local_initial_g2sum = optimizer_config.initial_g2sum;
__local__ float local_min_bound; float local_min_bound = optimizer_config.min_bound;
__local__ float local_max_bound; float local_max_bound = optimizer_config.max_bound;
GM2LM(optimizer_config::learning_rate, &local_learning_rate, sizeof(float));
GM2LM(optimizer_config::initial_g2sum, &local_initial_g2sum, sizeof(float));
GM2LM(optimizer_config::min_bound, &local_min_bound, sizeof(float));
GM2LM(optimizer_config::max_bound, &local_max_bound, sizeof(float));
double add_g2sum = 0; double add_g2sum = 0;
double ratio = local_learning_rate * double ratio = local_learning_rate *
...@@ -65,19 +45,12 @@ __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT ...@@ -65,19 +45,12 @@ __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT
g2sum += add_g2sum; g2sum += add_g2sum;
} }
__device__ void update_mf(int n, float* w, float& g2sum, const float* g, __device__ void update_mf(OptimizerConfig& optimizer_config, int n, float* w,
float scale) { float& g2sum, const float* g, float scale) {
__local__ float local_mf_learning_rate; float local_mf_learning_rate = optimizer_config.mf_learning_rate;
__local__ float local_mf_initial_g2sum; float local_mf_initial_g2sum = optimizer_config.mf_initial_g2sum;
__local__ float local_mf_min_bound; float local_mf_min_bound = optimizer_config.mf_min_bound;
__local__ float local_mf_max_bound; float local_mf_max_bound = optimizer_config.mf_max_bound;
GM2LM(optimizer_config::mf_learning_rate, &local_mf_learning_rate,
sizeof(float));
GM2LM(optimizer_config::mf_initial_g2sum, &local_mf_initial_g2sum,
sizeof(float));
GM2LM(optimizer_config::mf_min_bound, &local_mf_min_bound, sizeof(float));
GM2LM(optimizer_config::mf_max_bound, &local_mf_max_bound, sizeof(float));
double add_g2sum = 0; double add_g2sum = 0;
double ratio = double ratio =
...@@ -98,26 +71,22 @@ __device__ void update_mf(int n, float* w, float& g2sum, const float* g, ...@@ -98,26 +71,22 @@ __device__ void update_mf(int n, float* w, float& g2sum, const float* g,
__device__ float xpu_rand_uniform() { return 0.1; } __device__ float xpu_rand_uniform() { return 0.1; }
template <typename ValType, typename GradType> template <typename ValType, typename GradType>
__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT __device__ void update_value(OptimizerConfig& optimizer_config, ValType& val,
const GradType& grad) { // NOLINT
val.slot = grad.slot; val.slot = grad.slot;
val.show += grad.show; val.show += grad.show;
val.clk += grad.clk; val.clk += grad.clk;
__local__ float local_nonclk_coeff; float local_nonclk_coeff = optimizer_config.nonclk_coeff;
__local__ float local_clk_coeff; float local_clk_coeff = optimizer_config.clk_coeff;
__local__ float local_mf_create_thresholds; float local_mf_create_thresholds = optimizer_config.mf_create_thresholds;
__local__ float local_mf_initial_range; float local_mf_initial_range = optimizer_config.mf_initial_range;
GM2LM(optimizer_config::nonclk_coeff, &local_nonclk_coeff, sizeof(float));
GM2LM(optimizer_config::clk_coeff, &local_clk_coeff, sizeof(float));
GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds,
sizeof(float));
val.delta_score += val.delta_score +=
local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk; local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk;
update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show); update_lr(optimizer_config, val.lr, val.lr_g2sum, grad.lr_g, grad.show);
if (val.mf_size == 0) { if (val.mf_size == 0) {
if (local_mf_create_thresholds <= if (local_mf_create_thresholds <=
...@@ -130,12 +99,13 @@ __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT ...@@ -130,12 +99,13 @@ __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
} }
} }
} else { } else {
update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show); update_mf(optimizer_config, MF_DIM, &val.mf[1], val.mf[0], grad.mf_g,
grad.show);
} }
} }
template <typename KeyType, typename ValType, typename Table> template <typename KeyType, typename ValType, typename Table>
__global__ void insert_kernel(Table* table, const KeyType* const keys, __global__ void insert_kernel(Table& table, const KeyType* const keys,
const ValType* const vals, long long len) { const ValType* const vals, long long len) {
int cid = core_id(); int cid = core_id();
int ncores = core_num(); int ncores = core_num();
...@@ -156,14 +126,14 @@ __global__ void insert_kernel(Table* table, const KeyType* const keys, ...@@ -156,14 +126,14 @@ __global__ void insert_kernel(Table* table, const KeyType* const keys,
GM2LM(keys, local_keys, read_len * sizeof(KeyType)); GM2LM(keys, local_keys, read_len * sizeof(KeyType));
GM2LM(vals, local_vals, read_len * sizeof(ValType)); GM2LM(vals, local_vals, read_len * sizeof(ValType));
for (int k = 0; k < read_len; k++) { for (int k = 0; k < read_len; k++) {
// auto status = table->insert(local_keys[k], local_vals[k]); auto status = table.insert(local_keys[k], local_vals[k]);
// assert(status != false && "error: insert fails: table is full"); assert(status != false && "error: insert fails: table is full");
} }
} }
} }
template <typename KeyType, typename ValType, typename Table> template <typename KeyType, typename ValType, typename Table>
__global__ void search_kernel(Table* table, const KeyType* const keys, __global__ void search_kernel(Table& table, const KeyType* const keys,
ValType* const vals, long long len) { ValType* const vals, long long len) {
int cid = core_id(); int cid = core_id();
int ncores = core_num(); int ncores = core_num();
...@@ -183,17 +153,18 @@ __global__ void search_kernel(Table* table, const KeyType* const keys, ...@@ -183,17 +153,18 @@ __global__ void search_kernel(Table* table, const KeyType* const keys,
int read_len = min(len_per_loop, len - i); int read_len = min(len_per_loop, len - i);
GM2LM(keys, local_keys, read_len * sizeof(KeyType)); GM2LM(keys, local_keys, read_len * sizeof(KeyType));
for (int k = 0; k < read_len; k++) { for (int k = 0; k < read_len; k++) {
// ValType* val = table->find(local_keys[k]); ValType* val = table.find(local_keys[k]);
// if (val != NULL) { if (val != NULL) {
// local_vals[k] = *val; local_vals[k] = *val;
// } }
} }
LM2GM(local_vals, vals + i, read_len * sizeof(ValType)); LM2GM(local_vals, vals + i, read_len * sizeof(ValType));
} }
} }
template <typename KeyType, typename ValType, typename Table, typename GradType> template <typename KeyType, typename ValType, typename Table, typename GradType>
__global__ void update_kernel(Table* table, const KeyType* const keys, __global__ void update_kernel(OptimizerConfig& optimizer_config, Table& table,
const KeyType* const keys,
const GradType* const grads, long long len) { const GradType* const grads, long long len) {
int cid = core_id(); int cid = core_id();
int ncores = core_num(); int ncores = core_num();
...@@ -216,10 +187,10 @@ __global__ void update_kernel(Table* table, const KeyType* const keys, ...@@ -216,10 +187,10 @@ __global__ void update_kernel(Table* table, const KeyType* const keys,
GM2LM(grads, local_grads, read_len * sizeof(GradType)); GM2LM(grads, local_grads, read_len * sizeof(GradType));
for (int k = 0; k < read_len; k++) { for (int k = 0; k < read_len; k++) {
// ValType* val = table->find(local_keys[k]); ValType* val = table.find(local_keys[k]);
// if (val != NULL) { if (val != NULL) {
// update_value(*val, grads[i]); update_value(optimizer_config, *val, local_grads[i]);
//} }
} }
} }
} }
...@@ -229,14 +200,23 @@ HashTable<KeyType, ValType>::HashTable(size_t capacity) { ...@@ -229,14 +200,23 @@ HashTable<KeyType, ValType>::HashTable(size_t capacity) {
auto tmp_container = XPUCacheArray<KeyType, ValType>(capacity); auto tmp_container = XPUCacheArray<KeyType, ValType>(capacity);
xpu_malloc(reinterpret_cast<void**>(&container_), xpu_malloc(reinterpret_cast<void**>(&container_),
sizeof(XPUCacheArray<KeyType, ValType>)); sizeof(XPUCacheArray<KeyType, ValType>));
xpu_memcpy(container_, &tmp_container, xpu_memcpy((void*)container_, &tmp_container,
sizeof(XPUCacheArray<KeyType, ValType>), XPU_HOST_TO_DEVICE); sizeof(XPUCacheArray<KeyType, ValType>), XPU_HOST_TO_DEVICE);
OptimizerConfig tmp_opt_config;
xpu_malloc(reinterpret_cast<void**>(&xpu_optimizer_config_),
sizeof(OptimizerConfig));
xpu_memcpy((void*)xpu_optimizer_config_, &tmp_opt_config,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
rwlock_.reset(new phi::RWLock); rwlock_.reset(new phi::RWLock);
} }
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() { HashTable<KeyType, ValType>::~HashTable() {
xpu_free((void*)container_); xpu_free((void*)container_);
xpu_free((void*)xpu_optimizer_config_);
} }
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
...@@ -244,6 +224,34 @@ void HashTable<KeyType, ValType>::show() { ...@@ -244,6 +224,34 @@ void HashTable<KeyType, ValType>::show() {
container_->print(); container_->print();
} }
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
const OptimizerConfig& optimizer_config) {
cpu_optimizer_config_.nonclk_coeff = optimizer_config.nonclk_coeff;
cpu_optimizer_config_.clk_coeff = optimizer_config.clk_coeff;
cpu_optimizer_config_.min_bound = optimizer_config.min_bound;
cpu_optimizer_config_.max_bound = optimizer_config.max_bound;
cpu_optimizer_config_.learning_rate = optimizer_config.learning_rate;
cpu_optimizer_config_.initial_g2sum = optimizer_config.initial_g2sum;
cpu_optimizer_config_.initial_range = optimizer_config.initial_range;
xpu_memcpy((void*)xpu_optimizer_config_, &cpu_optimizer_config_,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
}
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
const OptimizerConfig& optimizer_config) {
cpu_optimizer_config_.mf_create_thresholds =
optimizer_config.mf_create_thresholds;
cpu_optimizer_config_.mf_learning_rate = optimizer_config.mf_learning_rate;
cpu_optimizer_config_.mf_initial_g2sum = optimizer_config.mf_initial_g2sum;
cpu_optimizer_config_.mf_initial_range = optimizer_config.mf_initial_range;
cpu_optimizer_config_.mf_min_bound = optimizer_config.mf_min_bound;
cpu_optimizer_config_.mf_max_bound = optimizer_config.mf_max_bound;
xpu_memcpy((void*)xpu_optimizer_config_, &cpu_optimizer_config_,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
}
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
template <typename StreamType> template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals, void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
...@@ -254,7 +262,7 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals, ...@@ -254,7 +262,7 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
long long c_len = (long long)len; long long c_len = (long long)len;
search_kernel<KeyType, ValType, search_kernel<KeyType, ValType,
XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>( XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>(
container_, d_keys, d_vals, c_len); *container_, d_keys, d_vals, c_len);
} }
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
...@@ -278,7 +286,7 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, ...@@ -278,7 +286,7 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
long long c_len = (long long)len; long long c_len = (long long)len;
insert_kernel<KeyType, ValType, insert_kernel<KeyType, ValType,
XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>( XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>(
container_, d_keys, d_vals, c_len); *container_, d_keys, d_vals, c_len);
} }
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
...@@ -297,8 +305,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys, ...@@ -297,8 +305,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
} }
long long c_len = (long long)len; long long c_len = (long long)len;
update_kernel<KeyType, ValType, XPUCacheArray<KeyType, ValType>, update_kernel<KeyType, ValType, XPUCacheArray<KeyType, ValType>,
GradType><<<4, 64, stream>>>(container_, d_keys, d_grads, GradType><<<4, 64, stream>>>(
c_len); *xpu_optimizer_config_, *container_, d_keys, d_grads, c_len);
} }
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "thrust/pair.h" #include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP) #elif defined(PADDLE_WITH_XPU_KP)
// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include <xpu/runtime.h> #include <xpu/runtime.h>
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif #endif
...@@ -64,6 +65,11 @@ class HeterComm { ...@@ -64,6 +65,11 @@ class HeterComm {
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len); void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len);
#endif #endif
#if defined(PADDLE_WITH_XPU_KP)
void set_sparse_sgd(const OptimizerConfig& optimizer_config);
void set_embedx_sgd(const OptimizerConfig& optimizer_config);
#endif
int log2i(int x); int log2i(int x);
template <typename DstPlace, typename SrcPlace, typename StreamType> template <typename DstPlace, typename SrcPlace, typename StreamType>
......
...@@ -338,6 +338,24 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) { ...@@ -338,6 +338,24 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
return resource_->get_index_by_devid(devid); return resource_->get_index_by_devid(devid);
} }
#if defined(PADDLE_WITH_XPU_KP)
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::set_sparse_sgd(
const OptimizerConfig& optimizer_config) {
for (auto& table : tables_) {
table->set_sparse_sgd(optimizer_config);
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::set_embedx_sgd(
const OptimizerConfig& optimizer_config) {
for (auto& table : tables_) {
table->set_embedx_sgd(optimizer_config);
}
}
#endif
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::build_ps( void HeterComm<KeyType, ValType, GradType>::build_ps(
int dev_num, KeyType* h_keys, ValType* h_vals, size_t len, int dev_num, KeyType* h_keys, ValType* h_vals, size_t len,
......
...@@ -50,6 +50,16 @@ int HeterPs::get_index_by_devid(int devid) { ...@@ -50,6 +50,16 @@ int HeterPs::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid); return comm_->get_index_by_devid(devid);
} }
#if defined(PADDLE_WITH_XPU_KP)
void HeterPs::set_sparse_sgd(const OptimizerConfig& optimizer_config) {
comm_->set_sparse_sgd(optimizer_config);
}
void HeterPs::set_embedx_sgd(const OptimizerConfig& optimizer_config) {
comm_->set_embedx_sgd(optimizer_config);
}
#endif
void HeterPs::end_pass() { comm_->end_pass(); } void HeterPs::end_pass() { comm_->end_pass(); }
void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }
......
...@@ -33,22 +33,27 @@ class HeterPs : public HeterPsBase { ...@@ -33,22 +33,27 @@ class HeterPs : public HeterPsBase {
HeterPs(const HeterPs&) = delete; HeterPs(const HeterPs&) = delete;
HeterPs& operator=(const HeterPs&) = delete; HeterPs& operator=(const HeterPs&) = delete;
virtual void pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals, void pull_sparse(int num, FeatureKey* d_keys, FeatureValue* d_vals,
size_t len) override; size_t len) override;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len,
size_t len, size_t chunk_size, int stream_num) override; size_t chunk_size, int stream_num) override;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
virtual void set_nccl_comm_and_size( void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inner_comms, const std::vector<ncclComm_t>& inter_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) override; int comm_size) override;
#endif #endif
virtual void end_pass() override; #if defined(PADDLE_WITH_XPU_KP)
virtual int get_index_by_devid(int devid) override; void set_sparse_sgd(const OptimizerConfig& optimizer_config) override;
virtual void show_one_table(int gpu_num) override; void set_embedx_sgd(const OptimizerConfig& optimizer_config) override;
virtual void push_sparse(int num, FeatureKey* d_keys, #endif
FeaturePushValue* d_grads, size_t len) override;
void end_pass() override;
int get_index_by_devid(int devid) override;
void show_one_table(int gpu_num) override;
void push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads,
size_t len) override;
private: private:
std::shared_ptr<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>> comm_; std::shared_ptr<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>> comm_;
......
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#if defined(PADDLE_WITH_XPU_KP)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#endif
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
...@@ -24,9 +27,9 @@ namespace framework { ...@@ -24,9 +27,9 @@ namespace framework {
class HeterPsBase { class HeterPsBase {
public: public:
HeterPsBase(){}; HeterPsBase() {}
HeterPsBase(size_t capacity, std::shared_ptr<HeterPsResource> resource){}; HeterPsBase(size_t capacity, std::shared_ptr<HeterPsResource> resource) {}
virtual ~HeterPsBase(){}; virtual ~HeterPsBase() {}
HeterPsBase(const HeterPsBase&) = delete; HeterPsBase(const HeterPsBase&) = delete;
HeterPsBase& operator=(const HeterPsBase&) = delete; HeterPsBase& operator=(const HeterPsBase&) = delete;
...@@ -44,6 +47,12 @@ class HeterPsBase { ...@@ -44,6 +47,12 @@ class HeterPsBase {
virtual void show_one_table(int gpu_num) = 0; virtual void show_one_table(int gpu_num) = 0;
virtual void push_sparse(int num, FeatureKey* d_keys, virtual void push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) = 0; FeaturePushValue* d_grads, size_t len) = 0;
#if defined(PADDLE_WITH_XPU_KP)
virtual void set_sparse_sgd(const OptimizerConfig& optimizer_config) {}
virtual void set_embedx_sgd(const OptimizerConfig& optimizer_config) {}
#endif
static HeterPsBase* get_instance(size_t capacity, static HeterPsBase* get_instance(size_t capacity,
std::shared_ptr<HeterPsResource> resource); std::shared_ptr<HeterPsResource> resource);
}; };
......
...@@ -14,16 +14,10 @@ limitations under the License. */ ...@@ -14,16 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#if defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_CUDA)
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#endif
namespace optimizer_config { namespace optimizer_config {
#if defined(PADDLE_WITH_CUDA)
__constant__ float nonclk_coeff = 0.1; __constant__ float nonclk_coeff = 0.1;
__constant__ float clk_coeff = 1; __constant__ float clk_coeff = 1;
...@@ -39,24 +33,31 @@ __constant__ float mf_initial_g2sum = 3.0; ...@@ -39,24 +33,31 @@ __constant__ float mf_initial_g2sum = 3.0;
__constant__ float mf_initial_range = 1e-4; __constant__ float mf_initial_range = 1e-4;
__constant__ float mf_min_bound = -10; __constant__ float mf_min_bound = -10;
__constant__ float mf_max_bound = 10; __constant__ float mf_max_bound = 10;
} // namespace optimizer_config
#elif defined(PADDLE_WITH_XPU_KP) #elif defined(PADDLE_WITH_XPU_KP)
namespace paddle {
_global_ptr_ float* nonclk_coeff; namespace framework {
_global_ptr_ float* clk_coeff;
class OptimizerConfig {
_global_ptr_ float* min_bound; public:
_global_ptr_ float* max_bound; float nonclk_coeff;
_global_ptr_ float* learning_rate; float clk_coeff;
_global_ptr_ float* initial_g2sum;
_global_ptr_ float* initial_range; float min_bound;
float max_bound;
_global_ptr_ float* mf_create_thresholds; float learning_rate;
_global_ptr_ float* mf_learning_rate; float initial_g2sum;
_global_ptr_ float* mf_initial_g2sum; float initial_range;
_global_ptr_ float* mf_initial_range;
_global_ptr_ float* mf_min_bound; float mf_create_thresholds;
_global_ptr_ float* mf_max_bound; float mf_learning_rate;
float mf_initial_g2sum;
float mf_initial_range;
float mf_min_bound;
float mf_max_bound;
};
} // namespace framework
} // namespace paddle
#endif #endif
} // namespace optimizer_config
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include <ctime> #include <ctime>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "xpu/kernel/cluster_header.h" // NOLINT #include "xpu/kernel/cluster_header.h" // NOLINT
...@@ -162,23 +161,7 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, long long* len, ...@@ -162,23 +161,7 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, long long* len,
} }
} }
PSGPUWrapper::~PSGPUWrapper() { PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }
delete HeterPs_;
xpu_free((void*)optimizer_config::nonclk_coeff);
xpu_free((void*)optimizer_config::clk_coeff);
xpu_free((void*)optimizer_config::min_bound);
xpu_free((void*)optimizer_config::max_bound);
xpu_free((void*)optimizer_config::learning_rate);
xpu_free((void*)optimizer_config::initial_g2sum);
xpu_free((void*)optimizer_config::initial_range);
xpu_free((void*)optimizer_config::mf_create_thresholds);
xpu_free((void*)optimizer_config::mf_learning_rate);
xpu_free((void*)optimizer_config::mf_initial_g2sum);
xpu_free((void*)optimizer_config::mf_initial_range);
xpu_free((void*)optimizer_config::mf_min_bound);
xpu_free((void*)optimizer_config::mf_max_bound);
}
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys, uint64_t** gpu_keys,
...@@ -272,66 +255,29 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, ...@@ -272,66 +255,29 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
float min_bound, float max_bound, float min_bound, float max_bound,
float learning_rate, float initial_g2sum, float learning_rate, float initial_g2sum,
float initial_range) { float initial_range) {
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::nonclk_coeff), OptimizerConfig optimizer_config;
sizeof(float)); optimizer_config.nonclk_coeff = nonclk_coeff;
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::clk_coeff), optimizer_config.clk_coeff = clk_coeff;
sizeof(float)); optimizer_config.min_bound = min_bound;
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::min_bound), optimizer_config.max_bound = max_bound;
sizeof(float)); optimizer_config.learning_rate = learning_rate;
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::max_bound), optimizer_config.initial_g2sum = initial_g2sum;
sizeof(float)); optimizer_config.initial_range = initial_range;
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::learning_rate), HeterPs_->set_sparse_sgd(optimizer_config);
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::initial_g2sum),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::initial_range),
sizeof(float));
xpu_memcpy((void*)optimizer_config::nonclk_coeff, &nonclk_coeff,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::clk_coeff, &clk_coeff, sizeof(float),
XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::min_bound, &min_bound, sizeof(float),
XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::max_bound, &max_bound, sizeof(float),
XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::learning_rate, &learning_rate,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::initial_g2sum, &initial_g2sum,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::initial_range, &initial_range,
sizeof(float), XPU_HOST_TO_DEVICE);
} }
void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
float mf_learning_rate, float mf_initial_g2sum, float mf_learning_rate, float mf_initial_g2sum,
float mf_initial_range, float mf_min_bound, float mf_initial_range, float mf_min_bound,
float mf_max_bound) { float mf_max_bound) {
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_create_thresholds), OptimizerConfig optimizer_config;
sizeof(float)); optimizer_config.mf_create_thresholds = mf_create_thresholds;
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_learning_rate), optimizer_config.mf_learning_rate = mf_learning_rate;
sizeof(float)); optimizer_config.mf_initial_g2sum = mf_initial_g2sum;
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_initial_g2sum), optimizer_config.mf_initial_range = mf_initial_range;
sizeof(float)); optimizer_config.mf_min_bound = mf_min_bound;
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_initial_range), optimizer_config.mf_max_bound = mf_max_bound;
sizeof(float)); HeterPs_->set_embedx_sgd(optimizer_config);
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_min_bound),
sizeof(float));
xpu_malloc(reinterpret_cast<void**>(&optimizer_config::mf_max_bound),
sizeof(float));
xpu_memcpy((void*)optimizer_config::mf_create_thresholds,
&mf_create_thresholds, sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_initial_g2sum, &mf_initial_g2sum,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_initial_range, &mf_initial_range,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_min_bound, &mf_min_bound,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_max_bound, &mf_max_bound,
sizeof(float), XPU_HOST_TO_DEVICE);
xpu_memcpy((void*)optimizer_config::mf_learning_rate, &mf_learning_rate,
sizeof(float), XPU_HOST_TO_DEVICE);
} }
} // end namespace framework } // end namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册