未验证 提交 ef6ff4ef 编写于 作者: Z zmxdream 提交者: GitHub

[XPUPS]fix hashtable_kernel.kps (#41790)

* refactor heter comm kernel

* update. test=develop

* update calc_shard_offset. test=develop

* update xpu kernel. test=develop

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* update. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* add optimizer kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update hashtable. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* template init. test=develop

* hashtable template init. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix hashtable_kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop
Co-authored-by: NWorgenZhang <frank08081993@gmail.com>
上级 d7224482
......@@ -74,7 +74,7 @@ class XPUCacheArray {
// ValType* find(const KeyType& key) { return NULL; }
// bool insert(const KeyType& key, const ValType& val) { return true; }
int prefetch(const int dev_id, XPUStream stream = NULL) {}
int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; }
size_t size() { return size_; }
private:
......
......@@ -38,7 +38,7 @@ namespace framework {
#if defined(PADDLE_WITH_XPU_KP)
__device__ void update_lr(float* w, float* g2sum, float g, // NOLINT
__device__ void update_lr(float& w, float& g2sum, float g, // NOLINT
float scale) {
__local__ float local_learning_rate;
__local__ float local_initial_g2sum;
......@@ -55,17 +55,17 @@ __device__ void update_lr(float* w, float* g2sum, float g, // NOLINT
sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum));
double scaled_grad = g / scale;
(*w) += scaled_grad * ratio;
w += scaled_grad * ratio;
if (w < local_min_bound) w = local_min_bound;
if (w > local_max_bound) w = local_max_bound;
add_g2sum += scaled_grad * scaled_grad;
(*g2sum) += add_g2sum;
g2sum += add_g2sum;
}
__device__ void update_mf(int n, float* w, float* g2sum, const float* g,
__device__ void update_mf(int n, float* w, float& g2sum, const float* g,
float scale) {
__local__ float local_mf_learning_rate;
__local__ float local_mf_initial_g2sum;
......@@ -92,16 +92,16 @@ __device__ void update_mf(int n, float* w, float* g2sum, const float* g,
add_g2sum += scaled_grad * scaled_grad;
}
(*g2sum) += add_g2sum / n;
g2sum += add_g2sum / n;
}
__device__ float xpu_rand_uniform() { return 0.1; }
template <typename ValType, typename GradType>
__device__ void update_value(ValType* val, const GradType* grad) { // NOLINT
(*val).slot = (*grad).slot;
(*val).show += (*grad).show;
(*val).clk += (*grad).clk;
__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
val.slot = grad.slot;
val.show += grad.show;
val.clk += grad.clk;
__local__ float local_nonclk_coeff;
__local__ float local_clk_coeff;
......@@ -114,25 +114,23 @@ __device__ void update_value(ValType* val, const GradType* grad) { // NOLINT
GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds,
sizeof(float));
val.delta_score += local_nonclk_coeff * ((*grad).show - (*grad).clk) +
local_clk_coeff * (*grad).clk;
val.delta_score +=
local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk;
update_lr(&(*val).lr, &(*val).lr_g2sum, (*grad).lr_g, (*grad).show);
update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show);
if (val.mf_size == 0) {
if (local_mf_create_thresholds <=
local_nonclk_coeff * ((*val).show - (*val).clk) +
local_clk_coeff * (*val).clk) {
local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) {
val.mf_size = MF_DIM + 1;
val.mf[0] = 0;
xpu_rand_uniform(&);
for (int i = 0; i < MF_DIM; ++i) {
(*val).mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range;
val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range;
}
}
} else {
update_mf(MF_DIM, &val.mf[1], &val.mf[0], (*grad).mf_g, (*grad).show);
update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
}
}
......
......@@ -92,6 +92,7 @@ class HeterComm {
nccl_inter_comms_ = inter_comms;
node_size_ = comm_size;
}
#endif
bool need_transfer(int send_id, int receive_id) {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
......@@ -101,8 +102,6 @@ class HeterComm {
int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }
#endif
void end_pass();
struct Node {
......
......@@ -161,8 +161,8 @@ void HeterComm<KeyType, ValType, GradType>::destroy_storage(int start_index,
nodes[i].key_storage);
allocator->DeviceFree(resource_->dev_id(nodes[i].dev_num),
nodes[i].val_storage);
#endif
}
#endif
}
template <typename KeyType, typename ValType, typename GradType>
......@@ -804,9 +804,9 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto dst_place = platform::CPUPlace();
auto src_place = place;
memory_copy(dst_place, h_left, src_place, d_left_ptr,
total_device * sizeof(int));
total_device * sizeof(int), stream);
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int));
total_device * sizeof(int), stream);
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
......
......@@ -236,55 +236,62 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
// xpu implementation of heter_comm_kernel.h
template <typename T, typename StreamType>
void fill_idx(T* idx, long long len, const StreamType& stream) {
void HeterCommKernel::fill_idx(T* idx, long long len,
const StreamType& stream) {
fill_idx_kernel<T><<<4, 64, stream>>>(idx, len);
}
template <typename T, typename StreamType>
void calc_shard_offset(T* idx, T* left, T* right, long long len, int total_devs,
const StreamType& stream) {
void HeterCommKernel::calc_shard_offset(T* idx, T* left, T* right,
long long len, int total_devs,
const StreamType& stream) {
calc_shard_offset_kernel<T><<<4, 64, stream>>>(idx, left, right, len,
total_devs);
}
template <typename KeyType, typename T, typename StreamType>
void calc_shard_index(KeyType* d_keys, long long len, T* shard_index,
int total_devs, const StreamType& stream) {
void HeterCommKernel::calc_shard_index(KeyType* d_keys, long long len,
T* shard_index, int total_devs,
const StreamType& stream) {
calc_shard_index_kernel<KeyType, T><<<4, 64, stream>>>(
d_keys, len, shard_index, total_devs);
}
template <typename KeyType, typename T, typename StreamType>
void fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, T* idx,
long long len, const StreamType& stream) {
void HeterCommKernel::fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys,
T* idx, long long len,
const StreamType& stream) {
fill_shard_key_kernel<KeyType, T><<<4, 64, stream>>>(d_shard_keys, d_keys,
idx, len);
}
template <typename KeyType, typename GradType, typename T, typename StreamType>
void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads, GradType* d_grads, T* idx,
long long len, const StreamType& stream) {
void HeterCommKernel::fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads,
GradType* d_grads, T* idx, long long len,
const StreamType& stream) {
fill_shard_grads_kernel<KeyType, GradType, T><<<4, 64, stream>>>(
d_shard_keys, d_keys, d_shard_grads, d_grads, idx, len);
}
template <typename ValType, typename T, typename StreamType>
void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, long long len,
const StreamType& stream) {
void HeterCommKernel::fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx,
long long len, const StreamType& stream) {
fill_dvals_kernel<ValType, T><<<4, 64, stream>>>(d_shard_vals, d_vals, idx,
len);
}
template <typename KeyT, typename ValueT, typename StreamType>
void sort_pairs(void* d_temp_storage, size_t& temp_storage_bytes, // NOLINT
const KeyT* d_keys_in, // NOLINT
KeyT* d_keys_out, const ValueT* d_values_in,
ValueT* d_values_out, int num_items, int begin_bit, int end_bit,
StreamType stream, bool debug_synchronous) {}
void HeterCommKernel::sort_pairs(void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const KeyT* d_keys_in, // NOLINT
KeyT* d_keys_out, const ValueT* d_values_in,
ValueT* d_values_out, int num_items,
int begin_bit, int end_bit, StreamType stream,
bool debug_synchronous) {}
template <typename KeysInputIteratorT, typename UniqueOutputIteratorT,
void reduce_by_key(
void HeterCommKernel::reduce_by_key(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
KeysInputIteratorT d_keys_in, UniqueOutputIteratorT d_unique_out,
......@@ -293,31 +300,32 @@ template <typename KeysInputIteratorT, typename UniqueOutputIteratorT,
NumRunsOutputIteratorT d_num_runs_out, int num_items,
StreamType stream, bool debug_synchronous) {}
template void fill_idx<int, XPUStream>(int* idx, long long len,
const XPUStream& stream);
template void calc_shard_offset<int, XPUStream>(int* idx, int* left, int* right,
long long len, int total_devs,
const XPUStream& stream);
template void calc_shard_index<unsigned long, int, XPUStream>(
template void HeterCommKernel::fill_idx<int, XPUStream>(
int* idx, long long len, const XPUStream& stream);
template void HeterCommKernel::calc_shard_offset<int, XPUStream>(
int* idx, int* left, int* right, long long len, int total_devs,
const XPUStream& stream);
template void HeterCommKernel::calc_shard_index<unsigned long, int, XPUStream>(
unsigned long* d_keys, long long len, int* shard_index, int total_devs,
const XPUStream& stream);
template void fill_shard_key<unsigned long, int, XPUStream>(
template void HeterCommKernel::fill_shard_key<unsigned long, int, XPUStream>(
unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len,
const XPUStream& stream);
template void HeterCommKernel::fill_shard_grads<
unsigned long, paddle::framework::FeaturePushValue, int, XPUStream>(
unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads, int* idx, long long len,
const XPUStream& stream);
template void
fill_shard_grads<unsigned long, paddle::framework::FeaturePushValue, int,
XPUStream>(unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads,
int* idx, long long len, const XPUStream& stream);
template void fill_dvals<paddle::framework::FeatureValue, int, XPUStream>(
HeterCommKernel::fill_dvals<paddle::framework::FeatureValue, int, XPUStream>(
paddle::framework::FeatureValue* d_shard_vals,
paddle::framework::FeatureValue* d_vals, int* idx, long long len,
const XPUStream& stream);
template void
sort_pairs<unsigned long, paddle::framework::FeaturePushValue, XPUStream>(
template void HeterCommKernel::sort_pairs<
unsigned long, paddle::framework::FeaturePushValue, XPUStream>(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const unsigned long* d_keys_in, // NOLINT
......@@ -326,14 +334,14 @@ sort_pairs<unsigned long, paddle::framework::FeaturePushValue, XPUStream>(
paddle::framework::FeaturePushValue* d_values_out, int num_items,
int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous);
template void sort_pairs<int, int, XPUStream>(
template void HeterCommKernel::sort_pairs<int, int, XPUStream>(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const int* d_keys_in, // NOLINT
int* d_keys_out, const int* d_values_in, int* d_values_out, int num_items,
int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous);
template void reduce_by_key<
template void HeterCommKernel::reduce_by_key<
unsigned long*, unsigned long*, paddle::framework::FeaturePushValue*,
paddle::framework::FeaturePushValue*, int*, XPUStream>(
void* d_temp_storage,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册