未验证 提交 ef6ff4ef 编写于 作者: Z zmxdream 提交者: GitHub

[XPUPS]fix hashtable_kernel.kps (#41790)

* refactor heter comm kernel

* update. test=develop

* update calc_shard_offset. test=develop

* update xpu kernel. test=develop

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* update. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* add optimizer kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update hashtable. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* template init. test=develop

* hashtable template init. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix hashtable_kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop
Co-authored-by: NWorgenZhang <frank08081993@gmail.com>
上级 d7224482
...@@ -74,7 +74,7 @@ class XPUCacheArray { ...@@ -74,7 +74,7 @@ class XPUCacheArray {
// ValType* find(const KeyType& key) { return NULL; } // ValType* find(const KeyType& key) { return NULL; }
// bool insert(const KeyType& key, const ValType& val) { return true; } // bool insert(const KeyType& key, const ValType& val) { return true; }
int prefetch(const int dev_id, XPUStream stream = NULL) {} int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; }
size_t size() { return size_; } size_t size() { return size_; }
private: private:
......
...@@ -38,7 +38,7 @@ namespace framework { ...@@ -38,7 +38,7 @@ namespace framework {
#if defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU_KP)
__device__ void update_lr(float* w, float* g2sum, float g, // NOLINT __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT
float scale) { float scale) {
__local__ float local_learning_rate; __local__ float local_learning_rate;
__local__ float local_initial_g2sum; __local__ float local_initial_g2sum;
...@@ -55,17 +55,17 @@ __device__ void update_lr(float* w, float* g2sum, float g, // NOLINT ...@@ -55,17 +55,17 @@ __device__ void update_lr(float* w, float* g2sum, float g, // NOLINT
sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum)); sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum));
double scaled_grad = g / scale; double scaled_grad = g / scale;
(*w) += scaled_grad * ratio; w += scaled_grad * ratio;
if (w < local_min_bound) w = local_min_bound; if (w < local_min_bound) w = local_min_bound;
if (w > local_max_bound) w = local_max_bound; if (w > local_max_bound) w = local_max_bound;
add_g2sum += scaled_grad * scaled_grad; add_g2sum += scaled_grad * scaled_grad;
(*g2sum) += add_g2sum; g2sum += add_g2sum;
} }
__device__ void update_mf(int n, float* w, float* g2sum, const float* g, __device__ void update_mf(int n, float* w, float& g2sum, const float* g,
float scale) { float scale) {
__local__ float local_mf_learning_rate; __local__ float local_mf_learning_rate;
__local__ float local_mf_initial_g2sum; __local__ float local_mf_initial_g2sum;
...@@ -92,16 +92,16 @@ __device__ void update_mf(int n, float* w, float* g2sum, const float* g, ...@@ -92,16 +92,16 @@ __device__ void update_mf(int n, float* w, float* g2sum, const float* g,
add_g2sum += scaled_grad * scaled_grad; add_g2sum += scaled_grad * scaled_grad;
} }
(*g2sum) += add_g2sum / n; g2sum += add_g2sum / n;
} }
__device__ float xpu_rand_uniform() { return 0.1; } __device__ float xpu_rand_uniform() { return 0.1; }
template <typename ValType, typename GradType> template <typename ValType, typename GradType>
__device__ void update_value(ValType* val, const GradType* grad) { // NOLINT __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
(*val).slot = (*grad).slot; val.slot = grad.slot;
(*val).show += (*grad).show; val.show += grad.show;
(*val).clk += (*grad).clk; val.clk += grad.clk;
__local__ float local_nonclk_coeff; __local__ float local_nonclk_coeff;
__local__ float local_clk_coeff; __local__ float local_clk_coeff;
...@@ -114,25 +114,23 @@ __device__ void update_value(ValType* val, const GradType* grad) { // NOLINT ...@@ -114,25 +114,23 @@ __device__ void update_value(ValType* val, const GradType* grad) { // NOLINT
GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds, GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds,
sizeof(float)); sizeof(float));
val.delta_score += local_nonclk_coeff * ((*grad).show - (*grad).clk) + val.delta_score +=
local_clk_coeff * (*grad).clk; local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk;
update_lr(&(*val).lr, &(*val).lr_g2sum, (*grad).lr_g, (*grad).show); update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show);
if (val.mf_size == 0) { if (val.mf_size == 0) {
if (local_mf_create_thresholds <= if (local_mf_create_thresholds <=
local_nonclk_coeff * ((*val).show - (*val).clk) + local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) {
local_clk_coeff * (*val).clk) {
val.mf_size = MF_DIM + 1; val.mf_size = MF_DIM + 1;
val.mf[0] = 0; val.mf[0] = 0;
xpu_rand_uniform(&);
for (int i = 0; i < MF_DIM; ++i) { for (int i = 0; i < MF_DIM; ++i) {
(*val).mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range; val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range;
} }
} }
} else { } else {
update_mf(MF_DIM, &val.mf[1], &val.mf[0], (*grad).mf_g, (*grad).show); update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
} }
} }
......
...@@ -92,6 +92,7 @@ class HeterComm { ...@@ -92,6 +92,7 @@ class HeterComm {
nccl_inter_comms_ = inter_comms; nccl_inter_comms_ = inter_comms;
node_size_ = comm_size; node_size_ = comm_size;
} }
#endif
bool need_transfer(int send_id, int receive_id) { bool need_transfer(int send_id, int receive_id) {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id); return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
...@@ -101,8 +102,6 @@ class HeterComm { ...@@ -101,8 +102,6 @@ class HeterComm {
int get_transfer_devid(int send_id) { return (send_id + 4) % 8; } int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }
#endif
void end_pass(); void end_pass();
struct Node { struct Node {
......
...@@ -161,8 +161,8 @@ void HeterComm<KeyType, ValType, GradType>::destroy_storage(int start_index, ...@@ -161,8 +161,8 @@ void HeterComm<KeyType, ValType, GradType>::destroy_storage(int start_index,
nodes[i].key_storage); nodes[i].key_storage);
allocator->DeviceFree(resource_->dev_id(nodes[i].dev_num), allocator->DeviceFree(resource_->dev_id(nodes[i].dev_num),
nodes[i].val_storage); nodes[i].val_storage);
#endif
} }
#endif
} }
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
...@@ -804,9 +804,9 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num, ...@@ -804,9 +804,9 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto dst_place = platform::CPUPlace(); auto dst_place = platform::CPUPlace();
auto src_place = place; auto src_place = place;
memory_copy(dst_place, h_left, src_place, d_left_ptr, memory_copy(dst_place, h_left, src_place, d_left_ptr,
total_device * sizeof(int)); total_device * sizeof(int), stream);
memory_copy(dst_place, h_right, src_place, d_right_ptr, memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int)); total_device * sizeof(int), stream);
for (int i = 0; i < total_device; ++i) { for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1; int shard_len = h_right[i] - h_left[i] + 1;
......
...@@ -236,55 +236,62 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, ...@@ -236,55 +236,62 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
// xpu implementation of heter_comm_kernel.h // xpu implementation of heter_comm_kernel.h
template <typename T, typename StreamType> template <typename T, typename StreamType>
void fill_idx(T* idx, long long len, const StreamType& stream) { void HeterCommKernel::fill_idx(T* idx, long long len,
const StreamType& stream) {
fill_idx_kernel<T><<<4, 64, stream>>>(idx, len); fill_idx_kernel<T><<<4, 64, stream>>>(idx, len);
} }
template <typename T, typename StreamType> template <typename T, typename StreamType>
void calc_shard_offset(T* idx, T* left, T* right, long long len, int total_devs, void HeterCommKernel::calc_shard_offset(T* idx, T* left, T* right,
const StreamType& stream) { long long len, int total_devs,
const StreamType& stream) {
calc_shard_offset_kernel<T><<<4, 64, stream>>>(idx, left, right, len, calc_shard_offset_kernel<T><<<4, 64, stream>>>(idx, left, right, len,
total_devs); total_devs);
} }
template <typename KeyType, typename T, typename StreamType> template <typename KeyType, typename T, typename StreamType>
void calc_shard_index(KeyType* d_keys, long long len, T* shard_index, void HeterCommKernel::calc_shard_index(KeyType* d_keys, long long len,
int total_devs, const StreamType& stream) { T* shard_index, int total_devs,
const StreamType& stream) {
calc_shard_index_kernel<KeyType, T><<<4, 64, stream>>>( calc_shard_index_kernel<KeyType, T><<<4, 64, stream>>>(
d_keys, len, shard_index, total_devs); d_keys, len, shard_index, total_devs);
} }
template <typename KeyType, typename T, typename StreamType> template <typename KeyType, typename T, typename StreamType>
void fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, T* idx, void HeterCommKernel::fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys,
long long len, const StreamType& stream) { T* idx, long long len,
const StreamType& stream) {
fill_shard_key_kernel<KeyType, T><<<4, 64, stream>>>(d_shard_keys, d_keys, fill_shard_key_kernel<KeyType, T><<<4, 64, stream>>>(d_shard_keys, d_keys,
idx, len); idx, len);
} }
template <typename KeyType, typename GradType, typename T, typename StreamType> template <typename KeyType, typename GradType, typename T, typename StreamType>
void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, void HeterCommKernel::fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads, GradType* d_grads, T* idx, GradType* d_shard_grads,
long long len, const StreamType& stream) { GradType* d_grads, T* idx, long long len,
const StreamType& stream) {
fill_shard_grads_kernel<KeyType, GradType, T><<<4, 64, stream>>>( fill_shard_grads_kernel<KeyType, GradType, T><<<4, 64, stream>>>(
d_shard_keys, d_keys, d_shard_grads, d_grads, idx, len); d_shard_keys, d_keys, d_shard_grads, d_grads, idx, len);
} }
template <typename ValType, typename T, typename StreamType> template <typename ValType, typename T, typename StreamType>
void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, long long len, void HeterCommKernel::fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx,
const StreamType& stream) { long long len, const StreamType& stream) {
fill_dvals_kernel<ValType, T><<<4, 64, stream>>>(d_shard_vals, d_vals, idx, fill_dvals_kernel<ValType, T><<<4, 64, stream>>>(d_shard_vals, d_vals, idx,
len); len);
} }
template <typename KeyT, typename ValueT, typename StreamType> template <typename KeyT, typename ValueT, typename StreamType>
void sort_pairs(void* d_temp_storage, size_t& temp_storage_bytes, // NOLINT void HeterCommKernel::sort_pairs(void* d_temp_storage,
const KeyT* d_keys_in, // NOLINT size_t& temp_storage_bytes, // NOLINT
KeyT* d_keys_out, const ValueT* d_values_in, const KeyT* d_keys_in, // NOLINT
ValueT* d_values_out, int num_items, int begin_bit, int end_bit, KeyT* d_keys_out, const ValueT* d_values_in,
StreamType stream, bool debug_synchronous) {} ValueT* d_values_out, int num_items,
int begin_bit, int end_bit, StreamType stream,
bool debug_synchronous) {}
template <typename KeysInputIteratorT, typename UniqueOutputIteratorT, template <typename KeysInputIteratorT, typename UniqueOutputIteratorT,
void reduce_by_key( void HeterCommKernel::reduce_by_key(
void* d_temp_storage, void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT size_t& temp_storage_bytes, // NOLINT
KeysInputIteratorT d_keys_in, UniqueOutputIteratorT d_unique_out, KeysInputIteratorT d_keys_in, UniqueOutputIteratorT d_unique_out,
...@@ -293,31 +300,32 @@ template <typename KeysInputIteratorT, typename UniqueOutputIteratorT, ...@@ -293,31 +300,32 @@ template <typename KeysInputIteratorT, typename UniqueOutputIteratorT,
NumRunsOutputIteratorT d_num_runs_out, int num_items, NumRunsOutputIteratorT d_num_runs_out, int num_items,
StreamType stream, bool debug_synchronous) {} StreamType stream, bool debug_synchronous) {}
template void fill_idx<int, XPUStream>(int* idx, long long len, template void HeterCommKernel::fill_idx<int, XPUStream>(
const XPUStream& stream); int* idx, long long len, const XPUStream& stream);
template void calc_shard_offset<int, XPUStream>(int* idx, int* left, int* right, template void HeterCommKernel::calc_shard_offset<int, XPUStream>(
long long len, int total_devs, int* idx, int* left, int* right, long long len, int total_devs,
const XPUStream& stream); const XPUStream& stream);
template void calc_shard_index<unsigned long, int, XPUStream>( template void HeterCommKernel::calc_shard_index<unsigned long, int, XPUStream>(
unsigned long* d_keys, long long len, int* shard_index, int total_devs, unsigned long* d_keys, long long len, int* shard_index, int total_devs,
const XPUStream& stream); const XPUStream& stream);
template void fill_shard_key<unsigned long, int, XPUStream>( template void HeterCommKernel::fill_shard_key<unsigned long, int, XPUStream>(
unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len, unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len,
const XPUStream& stream); const XPUStream& stream);
template void HeterCommKernel::fill_shard_grads<
unsigned long, paddle::framework::FeaturePushValue, int, XPUStream>(
unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads, int* idx, long long len,
const XPUStream& stream);
template void template void
fill_shard_grads<unsigned long, paddle::framework::FeaturePushValue, int, HeterCommKernel::fill_dvals<paddle::framework::FeatureValue, int, XPUStream>(
XPUStream>(unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads,
int* idx, long long len, const XPUStream& stream);
template void fill_dvals<paddle::framework::FeatureValue, int, XPUStream>(
paddle::framework::FeatureValue* d_shard_vals, paddle::framework::FeatureValue* d_shard_vals,
paddle::framework::FeatureValue* d_vals, int* idx, long long len, paddle::framework::FeatureValue* d_vals, int* idx, long long len,
const XPUStream& stream); const XPUStream& stream);
template void template void HeterCommKernel::sort_pairs<
sort_pairs<unsigned long, paddle::framework::FeaturePushValue, XPUStream>( unsigned long, paddle::framework::FeaturePushValue, XPUStream>(
void* d_temp_storage, void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT size_t& temp_storage_bytes, // NOLINT
const unsigned long* d_keys_in, // NOLINT const unsigned long* d_keys_in, // NOLINT
...@@ -326,14 +334,14 @@ sort_pairs<unsigned long, paddle::framework::FeaturePushValue, XPUStream>( ...@@ -326,14 +334,14 @@ sort_pairs<unsigned long, paddle::framework::FeaturePushValue, XPUStream>(
paddle::framework::FeaturePushValue* d_values_out, int num_items, paddle::framework::FeaturePushValue* d_values_out, int num_items,
int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous); int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous);
template void sort_pairs<int, int, XPUStream>( template void HeterCommKernel::sort_pairs<int, int, XPUStream>(
void* d_temp_storage, void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT size_t& temp_storage_bytes, // NOLINT
const int* d_keys_in, // NOLINT const int* d_keys_in, // NOLINT
int* d_keys_out, const int* d_values_in, int* d_values_out, int num_items, int* d_keys_out, const int* d_values_in, int* d_values_out, int num_items,
int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous); int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous);
template void reduce_by_key< template void HeterCommKernel::reduce_by_key<
unsigned long*, unsigned long*, paddle::framework::FeaturePushValue*, unsigned long*, unsigned long*, paddle::framework::FeaturePushValue*,
paddle::framework::FeaturePushValue*, int*, XPUStream>( paddle::framework::FeaturePushValue*, int*, XPUStream>(
void* d_temp_storage, void* d_temp_storage,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册