From df96d1ed937ed52c8b8af7a9afc39ac221c782ba Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Tue, 10 May 2022 10:56:57 +0800 Subject: [PATCH] fix sample error (#42595) --- .../framework/fleet/heter_ps/graph_gpu_ps_table_inl.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index 605019cb60..d28ae0ab5d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -924,9 +924,11 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( { platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); - thrust::device_ptr t_actual_sample_size(actual_sample_size); - int total_sample_size = - thrust::reduce(t_actual_sample_size, t_actual_sample_size + len); + thrust::device_vector t_actual_sample_size(len); + thrust::copy(actual_sample_size, actual_sample_size + len, + t_actual_sample_size.begin()); + int total_sample_size = thrust::reduce(t_actual_sample_size.begin(), + t_actual_sample_size.end()); result.actual_val_mem = memory::AllocShared(place, total_sample_size * sizeof(int64_t)); result.actual_val = (int64_t*)(result.actual_val_mem)->ptr(); @@ -934,7 +936,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( result.set_total_sample_size(total_sample_size); thrust::device_vector cumsum_actual_sample_size(len); - thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len, + thrust::exclusive_scan(t_actual_sample_size.begin(), + t_actual_sample_size.end(), cumsum_actual_sample_size.begin(), 0); fill_actual_vals<<>>( val, result.actual_val, actual_sample_size, -- GitLab