未验证 提交 a7a4843c 编写于 作者: Z zmxdream 提交者: GitHub

[GPUPS]Optimize dymf kernel (#43911)

上级 aa45f931
......@@ -89,31 +89,42 @@ __global__ void dy_mf_search_kernel(Table* table,
char* vals,
size_t len,
size_t pull_feature_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
// return;
const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
const size_t k = threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
uint64_t offset = i * pull_feature_value_size;
FeatureValue* cur = (FeatureValue*)(vals + offset);
FeatureValue& input = *(FeatureValue*)(it->second);
cur->slot = input.slot;
cur->show = input.show;
cur->clk = input.clk;
cur->mf_dim = input.mf_dim;
cur->lr = input.lr;
cur->mf_size = input.mf_size;
cur->cpu_ptr = input.cpu_ptr;
cur->delta_score = input.delta_score;
cur->lr_g2sum = input.lr_g2sum;
for (int j = 0; j < cur->mf_dim + 1; ++j) {
cur->mf[j] = input.mf[j];
char* cur_p = (char*)cur;
char* input_p = (char*)(&input);
int len = 9 + input.mf_dim + 1;
if (k == 3 || k == 6 || k == 7)
*(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
else if (k < 8)
*(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
else if (k == 8) {
*(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4);
} else {
int len_per_thread = (len - 9) / (blockDim.y - 9);
int remain = (len - 9) % (blockDim.y - 9);
int real_len = len_per_thread;
if ((k - 9) < remain) real_len++;
int left = -1, right = -1;
if ((k - 9) < remain) {
left = 9 + (k - 9) * (len_per_thread + 1);
right = left + real_len;
} else {
left = 9 + remain * (len_per_thread + 1) +
(k - 9 - remain) * len_per_thread;
right = left + real_len;
}
for (int j = left; j < right; j++)
*(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4);
}
} else {
if (keys[i] != 0) {
printf("warning::pull miss key: %llu", keys[i]);
}
if (keys[i] != 0) printf("pull miss key: %llu", keys[i]);
}
}
}
......@@ -220,8 +231,10 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
dim3 block_dims(32, 32);
const int grid_size = (len - 1) / 32 + 1;
dim3 grid_dims(grid_size);
dy_mf_search_kernel<<<grid_dims, block_dims, 0, stream>>>(
container_, d_keys, d_vals, len, pull_feature_value_size_);
}
......
......@@ -760,6 +760,7 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
(char*)d_grads,
(char*)d_merge_grads_ptr,
uniq_len,
max_mf_dim_,
grad_value_size,
merger_,
stream);
......
......@@ -144,28 +144,106 @@ __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys,
}
}
__global__ void merge_gradients_kernel(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index,
const char* input,
char* output,
int n,
size_t grad_value_size,
DynamicGradMerger& merger_) {
// optimized version
template <>
__global__ void
dy_mf_fill_shard_grads_kernel<FeatureKey, FeaturePushValue, int>(
FeatureKey* d_shard_keys,
FeatureKey* d_keys,
FeaturePushValue* d_shard_grads,
FeaturePushValue* d_grads,
int* idx,
size_t len,
size_t grad_value_size) {
const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
const size_t k = threadIdx.x;
if (i < len) {
if (k == 0) {
d_shard_keys[i] = d_keys[idx[i]];
}
FeaturePushValue* cur =
(FeaturePushValue*)((char*)d_shard_grads + i * grad_value_size);
FeaturePushValue& input = *(
FeaturePushValue*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size);
char* cur_p = (char*)cur;
char* input_p = (char*)(&input);
int len = 5 + input.mf_dim;
if (k == 2 || k == 4)
*(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
else if (k < 5)
*(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
else {
int len_per_thread = (len - 5) / (blockDim.y - 5);
int remain = (len - 5) % (blockDim.y - 5);
int real_len = len_per_thread;
if ((k - 5) < remain) real_len++;
int left = -1, right = -1;
if ((k - 5) < remain) {
left = 5 + (k - 5) * (len_per_thread + 1);
right = left + real_len;
} else {
left = 5 + remain * (len_per_thread + 1) +
(k - 5 - remain) * len_per_thread;
right = left + real_len;
}
for (int j = left; j < right; j++)
*(float*)(cur_p + j * 4) = *(float*)(input_p + j * 4);
}
}
}
__global__ void merge_gradients_basic_kernel(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index,
const char* input,
char* output,
int n,
size_t grad_value_size,
DynamicGradMerger& merger) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
uint32_t start = offset[i];
uint32_t num = fea_num[i];
int ori_index = index[start];
FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size);
FeaturePushValue& lhs = *(FeaturePushValue*)(output + i * grad_value_size);
FeaturePushValue& in =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.update_one(out, in);
merger.update_basic(lhs, in);
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
FeaturePushValue& rhs =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, rhs);
merger.merge_basic(lhs, rhs);
}
}
}
__global__ void merge_gradients_embedx_kernel(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index,
const char* input,
char* output,
int n,
size_t grad_dim,
size_t grad_value_size,
DynamicGradMerger& merger) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
size_t value_idx = i / grad_dim;
size_t field_idx = i % grad_dim;
uint32_t start = offset[value_idx];
uint32_t num = fea_num[value_idx];
int ori_index = index[start];
FeaturePushValue& in =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
FeaturePushValue& lhs =
*(FeaturePushValue*)(output + value_idx * grad_value_size);
merger.update_embedx(lhs, in, field_idx);
for (int j = 1; j < num; ++j) {
int ori_index = index[start + j];
FeaturePushValue& rhs =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger.merge_embedx(lhs, rhs, field_idx);
}
}
}
......@@ -184,6 +262,49 @@ __global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals,
}
}
// optimized version
template <>
__global__ void dy_mf_fill_dvals_kernel<FeatureValue, int>(
FeatureValue* d_shard_vals,
FeatureValue* d_vals,
int* idx,
size_t len,
size_t val_size) {
const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
const size_t k = threadIdx.x;
if (i < len) {
uint64_t new_offset = uint64_t(idx[i]) * val_size;
FeatureValue* cur = (FeatureValue*)((char*)d_vals + new_offset);
FeatureValue& input = *(FeatureValue*)((char*)d_shard_vals + i * val_size);
char* cur_p = (char*)cur;
char* input_p = (char*)(&input);
int len = 9 + input.mf_dim + 1;
if (k == 3 || k == 6 || k == 7)
*(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
else if (k < 8)
*(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
else if (k == 8) {
*(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4);
} else {
int len_per_thread = (len - 9) / (blockDim.x - 9);
int remain = (len - 9) % (blockDim.y - 9);
int real_len = len_per_thread;
if ((k - 9) < remain) real_len++;
int left = -1, right = -1;
if ((k - 9) < remain) {
left = 9 + (k - 9) * (len_per_thread + 1);
right = left + real_len;
} else {
left = 9 + remain * (len_per_thread + 1) +
(k - 9 - remain) * len_per_thread;
right = left + real_len;
}
for (int j = left; j < right; j++)
*(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4);
}
}
}
// cuda implemention of heter_comm_kernel.h
template <typename T, typename StreamType>
void HeterCommKernel::fill_idx(T* idx,
......@@ -321,9 +442,12 @@ void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys,
long long len,
size_t grad_value_size,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
// int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_shard_grads_kernel<<<grid_size, block_size_, 0, stream>>>(
dim3 block_dims(32, 32);
const size_t grid_size = (len - 1) / 32 + 1;
dim3 grid_dims(grid_size);
dy_mf_fill_shard_grads_kernel<<<grid_dims, block_dims, 0, stream>>>(
d_shard_keys,
d_keys,
d_shard_grads,
......@@ -340,12 +464,26 @@ void HeterCommKernel::merge_gradient(const uint32_t* offset,
const char* input,
char* output,
int n,
size_t grad_dim,
size_t grad_value_size,
DynamicGradMerger& merger_,
const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
merge_gradients_basic_kernel<<<grid_size, block_size_, 0, stream>>>(
offset, fea_num, index, input, output, n, grad_value_size, merger_);
if (grad_dim > 0) {
int grid_size2 = (n * grad_dim - 1) / block_size_ + 1;
merge_gradients_embedx_kernel<<<grid_size2, block_size_, 0, stream>>>(
offset,
fea_num,
index,
input,
output,
n * grad_dim,
grad_dim,
grad_value_size,
merger_);
}
}
template <typename ValType, typename T, typename StreamType>
......@@ -355,9 +493,12 @@ void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals,
long long len,
size_t val_size,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
// int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_dvals_kernel<<<grid_size, block_size_, 0, stream>>>(
dim3 block_dims(32, 32);
const size_t grid_size_ = (len - 1) / 32 + 1;
dim3 grid_dims(grid_size_);
dy_mf_fill_dvals_kernel<<<grid_dims, block_dims, 0, stream>>>(
d_shard_vals, d_vals, idx, c_len, val_size);
}
......@@ -487,6 +628,7 @@ template void HeterCommKernel::merge_gradient<cudaStream_t>(
const char* input,
char* output,
int n,
size_t grad_dim,
size_t grad_value_size,
DynamicGradMerger& merger_,
const cudaStream_t& stream);
......
......@@ -42,23 +42,41 @@ struct DynamicGradMerger {
}
template <typename T>
__device__ __forceinline__ void update_one(T& output, const T& input) {
__device__ __forceinline__ void update_basic(T& output, const T& input) {
output.slot = input.slot;
output.show = input.show;
output.clk = input.clk;
output.mf_dim = input.mf_dim;
output.lr_g = input.lr_g;
for (int i = 0; i < output.mf_dim; ++i) {
output.mf_g[i] = input.mf_g[i];
}
// for (int i = 0; i < output.mf_dim; ++i) {
// output.mf_g[i] = input.mf_g[i];
//}
}
template <typename T>
__device__ __forceinline__ void merge_one(T& output, const T& input) {
__device__ __forceinline__ void merge_basic(T& output, const T& input) {
output.show += input.show;
output.clk += input.clk;
output.lr_g += input.lr_g;
for (int i = 0; i < input.mf_dim; ++i) {
output.mf_g[i] += input.mf_g[i];
// for (int i = 0; i < input.mf_dim; ++i) {
// output.mf_g[i] += input.mf_g[i];
//}
}
template <typename T>
__device__ __forceinline__ void update_embedx(T& output,
const T& input,
size_t embedx_id) {
if (embedx_id < output.mf_dim) {
output.mf_g[embedx_id] = input.mf_g[embedx_id];
}
}
template <typename T>
__device__ __forceinline__ void merge_embedx(T& output,
const T& input,
size_t embedx_id) {
if (embedx_id < output.mf_dim) {
output.mf_g[embedx_id] += input.mf_g[embedx_id];
}
}
};
......@@ -165,6 +183,7 @@ class HeterCommKernel {
const char* input,
char* output,
int n,
size_t grad_dim,
size_t grad_value_size,
DynamicGradMerger& merger_,
const StreamType& stream);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册