未验证 提交 c4a52b83 编写于 作者: Z zmxdream 提交者: GitHub

[GPUPS]fix merge_grad&push_sparse (#43840)

上级 40a77319
......@@ -112,20 +112,7 @@ __global__ void dy_mf_search_kernel(Table* table,
}
} else {
if (keys[i] != 0) {
printf("warning::pull miss key: %d", keys[i]);
}
FeatureValue* cur = (FeatureValue*)(vals + i * pull_feature_value_size);
cur->delta_score = 0;
cur->show = 0;
cur->clk = 0;
cur->slot = -1;
cur->lr = 0;
cur->lr_g2sum = 0;
cur->mf_size = 0;
cur->mf_dim = 8;
cur->cpu_ptr;
for (int j = 0; j < cur->mf_dim + 1; j++) {
cur->mf[j] = 0;
printf("warning::pull miss key: %llu", keys[i]);
}
}
}
......@@ -163,7 +150,7 @@ __global__ void dy_mf_update_kernel(Table* table,
sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
} else {
if (keys[i] != 0) {
printf("warning::push miss key: %d", keys[i]);
printf("warning::push miss key: %llu", keys[i]);
}
}
}
......
......@@ -1026,14 +1026,9 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
GradType* d_shard_grads_ptr;
if (!multi_mf_dim_) {
auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType));
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
} else {
auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
}
auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
GradType* d_shard_grads_ptr =
reinterpret_cast<GradType*>(d_shard_grads->ptr());
int uniq_len = len;
dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len);
......
......@@ -153,7 +153,6 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
size_t grad_value_size,
DynamicGradMerger& merger_) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
uint32_t start = offset[i];
uint32_t num = fea_num[i];
......@@ -164,8 +163,9 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
merger_.update_one(out, in);
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in);
FeaturePushValue& rhs =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, rhs);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册