未验证 提交 df96d1ed 编写于 作者: S Siming Dai 提交者: GitHub

fix sample error (#42595)

上级 be87caf2
...@@ -924,9 +924,11 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( ...@@ -924,9 +924,11 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
{ {
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
thrust::device_ptr<int> t_actual_sample_size(actual_sample_size); thrust::device_vector<int> t_actual_sample_size(len);
int total_sample_size = thrust::copy(actual_sample_size, actual_sample_size + len,
thrust::reduce(t_actual_sample_size, t_actual_sample_size + len); t_actual_sample_size.begin());
int total_sample_size = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
result.actual_val_mem = result.actual_val_mem =
memory::AllocShared(place, total_sample_size * sizeof(int64_t)); memory::AllocShared(place, total_sample_size * sizeof(int64_t));
result.actual_val = (int64_t*)(result.actual_val_mem)->ptr(); result.actual_val = (int64_t*)(result.actual_val_mem)->ptr();
...@@ -934,7 +936,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( ...@@ -934,7 +936,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
result.set_total_sample_size(total_sample_size); result.set_total_sample_size(total_sample_size);
thrust::device_vector<int> cumsum_actual_sample_size(len); thrust::device_vector<int> cumsum_actual_sample_size(len);
thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len, thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(),
cumsum_actual_sample_size.begin(), 0); cumsum_actual_sample_size.begin(), 0);
fill_actual_vals<<<grid_size, block_size_, 0, stream>>>( fill_actual_vals<<<grid_size, block_size_, 0, stream>>>(
val, result.actual_val, actual_sample_size, val, result.actual_val, actual_sample_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册