未验证 提交 c4a52b83 编写于 作者: Z zmxdream 提交者: GitHub

[GPUPS]fix merge_grad&push_sparse (#43840)

上级 40a77319
...@@ -112,20 +112,7 @@ __global__ void dy_mf_search_kernel(Table* table, ...@@ -112,20 +112,7 @@ __global__ void dy_mf_search_kernel(Table* table,
} }
} else { } else {
if (keys[i] != 0) { if (keys[i] != 0) {
printf("warning::pull miss key: %d", keys[i]); printf("warning::pull miss key: %llu", keys[i]);
}
FeatureValue* cur = (FeatureValue*)(vals + i * pull_feature_value_size);
cur->delta_score = 0;
cur->show = 0;
cur->clk = 0;
cur->slot = -1;
cur->lr = 0;
cur->lr_g2sum = 0;
cur->mf_size = 0;
cur->mf_dim = 8;
cur->cpu_ptr;
for (int j = 0; j < cur->mf_dim + 1; j++) {
cur->mf[j] = 0;
} }
} }
} }
...@@ -163,7 +150,7 @@ __global__ void dy_mf_update_kernel(Table* table, ...@@ -163,7 +150,7 @@ __global__ void dy_mf_update_kernel(Table* table,
sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur); sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
} else { } else {
if (keys[i] != 0) { if (keys[i] != 0) {
printf("warning::push miss key: %d", keys[i]); printf("warning::push miss key: %llu", keys[i]);
} }
} }
} }
......
...@@ -1026,14 +1026,9 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num, ...@@ -1026,14 +1026,9 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr()); KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
GradType* d_shard_grads_ptr;
if (!multi_mf_dim_) {
auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType));
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
} else {
auto d_shard_grads = memory::Alloc(place, len * grad_value_size); auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr()); GradType* d_shard_grads_ptr =
} reinterpret_cast<GradType*>(d_shard_grads->ptr());
int uniq_len = len; int uniq_len = len;
dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len); dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len);
......
...@@ -153,7 +153,6 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, ...@@ -153,7 +153,6 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
size_t grad_value_size, size_t grad_value_size,
DynamicGradMerger& merger_) { DynamicGradMerger& merger_) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) { if (i < n) {
uint32_t start = offset[i]; uint32_t start = offset[i];
uint32_t num = fea_num[i]; uint32_t num = fea_num[i];
...@@ -164,8 +163,9 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, ...@@ -164,8 +163,9 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
merger_.update_one(out, in); merger_.update_one(out, in);
for (int j = 1; j < num; ++j) { for (int j = 1; j < num; ++j) {
ori_index = index[start + j]; ori_index = index[start + j];
in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); FeaturePushValue& rhs =
merger_.merge_one(out, in); *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, rhs);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册