未验证 提交 df96d1ed 编写于 作者: S Siming Dai 提交者: GitHub

fix sample error (#42595)

上级 be87caf2
......@@ -924,9 +924,11 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
{
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
thrust::device_ptr<int> t_actual_sample_size(actual_sample_size);
int total_sample_size =
thrust::reduce(t_actual_sample_size, t_actual_sample_size + len);
thrust::device_vector<int> t_actual_sample_size(len);
thrust::copy(actual_sample_size, actual_sample_size + len,
t_actual_sample_size.begin());
int total_sample_size = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
result.actual_val_mem =
memory::AllocShared(place, total_sample_size * sizeof(int64_t));
result.actual_val = (int64_t*)(result.actual_val_mem)->ptr();
......@@ -934,7 +936,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
result.set_total_sample_size(total_sample_size);
thrust::device_vector<int> cumsum_actual_sample_size(len);
thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len,
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(),
cumsum_actual_sample_size.begin(), 0);
fill_actual_vals<<<grid_size, block_size_, 0, stream>>>(
val, result.actual_val, actual_sample_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册