diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index 605019cb607fc41bd32bd9053128fd4791bb2c40..d28ae0ab5d93f484b8ebdfca4099d63eab080812 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -924,9 +924,11 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( { platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); - thrust::device_ptr t_actual_sample_size(actual_sample_size); - int total_sample_size = - thrust::reduce(t_actual_sample_size, t_actual_sample_size + len); + thrust::device_vector t_actual_sample_size(len); + thrust::copy(actual_sample_size, actual_sample_size + len, + t_actual_sample_size.begin()); + int total_sample_size = thrust::reduce(t_actual_sample_size.begin(), + t_actual_sample_size.end()); result.actual_val_mem = memory::AllocShared(place, total_sample_size * sizeof(int64_t)); result.actual_val = (int64_t*)(result.actual_val_mem)->ptr(); @@ -934,7 +936,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( result.set_total_sample_size(total_sample_size); thrust::device_vector cumsum_actual_sample_size(len); - thrust::exclusive_scan(t_actual_sample_size, t_actual_sample_size + len, + thrust::exclusive_scan(t_actual_sample_size.begin(), + t_actual_sample_size.end(), cumsum_actual_sample_size.begin(), 0); fill_actual_vals<<>>( val, result.actual_val, actual_sample_size,