From 9b15efce771346d53b113ba182edb601a5926c7c Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Tue, 17 May 2022 17:48:38 +0800 Subject: [PATCH] refine cpu query (#42803) --- .../fleet/heter_ps/graph_gpu_ps_table_inl.cu | 206 ++++++++++-------- 1 file changed, 117 insertions(+), 89 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 4cf579ce004..631ca962fae 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -32,10 +32,11 @@ sample_result is to save the neighbor sampling result, its size is len * sample_size; */ -__global__ void get_cpu_id_index(int64_t* key, int64_t* val, int64_t* cpu_key, - int* sum, int* index, int len) { +__global__ void get_cpu_id_index(int64_t* key, int* actual_sample_size, + int64_t* cpu_key, int* sum, int* index, + int len) { CUDA_KERNEL_LOOP(i, len) { - if (val[i] == -1) { + if (actual_sample_size[i] == -1) { int old = atomicAdd(sum, 1); cpu_key[old] = key[i]; index[old] = i; @@ -44,11 +45,35 @@ __global__ void get_cpu_id_index(int64_t* key, int64_t* val, int64_t* cpu_key, } } +__global__ void get_actual_gpu_ac(int* gpu_ac, int number_on_cpu) { + CUDA_KERNEL_LOOP(i, number_on_cpu) { gpu_ac[i] /= sizeof(int64_t); } +} + +template +__global__ void copy_buffer_ac_to_final_place( + int64_t* gpu_buffer, int* gpu_ac, int64_t* val, int* actual_sample_size, + int* index, int* cumsum_gpu_ac, int number_on_cpu, int sample_size) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int i = blockIdx.x * TILE_SIZE + threadIdx.y; + const int last_idx = + min(static_cast(blockIdx.x + 1) * TILE_SIZE, number_on_cpu); + while (i < last_idx) { + actual_sample_size[index[i]] = gpu_ac[i]; + for (int j = threadIdx.x; j < gpu_ac[i]; j += WARP_SIZE) { + val[index[i] * sample_size + j] = gpu_buffer[cumsum_gpu_ac[i] + j]; + } + i += BLOCK_WARPS; + } +} + template __global__ void neighbor_sample_example_v2(GpuPsCommGraph graph, int64_t* node_index, int* actual_size, int64_t* res, - int sample_len, int n) { + int sample_len, int n, + int default_value) { assert(blockDim.x == WARP_SIZE); assert(blockDim.y == BLOCK_WARPS); @@ -59,7 +84,7 @@ __global__ void neighbor_sample_example_v2(GpuPsCommGraph graph, while (i < last_idx) { if (node_index[i] == -1) { - actual_size[i] = 0; + actual_size[i] = default_value; i += BLOCK_WARPS; continue; } @@ -762,6 +787,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( auto d_right = memory::Alloc(place, total_gpu * sizeof(int)); int* d_left_ptr = reinterpret_cast(d_left->ptr()); int* d_right_ptr = reinterpret_cast(d_right->ptr()); + int default_value = 0; + if (cpu_query_switch) { + default_value = -1; + } cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); @@ -796,14 +825,9 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( sizeof(int) * (shard_len + shard_len % 2)); } walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); - // For cpu_query_switch, we need global items. - std::vector> cpu_keys_list; - std::vector> cpu_index_list; - thrust::device_vector tmp1; - thrust::device_vector tmp2; + for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1) { - // Insert empty object continue; } int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; @@ -832,92 +856,16 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( WARP_SIZE, BLOCK_WARPS, TILE_SIZE><<remote_stream(i, gpu_id)>>>( graph, id_array, actual_size_array, sample_array, sample_size, - shard_len); - // cpu_graph_table->random_sample_neighbors - // if (cpu_query_switch) { - //} + shard_len, default_value); } for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1) { - if (cpu_query_switch) { - cpu_keys_list.emplace_back(tmp1); - cpu_index_list.emplace_back(tmp2); - } continue; } cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); - if (cpu_query_switch) { - platform::CUDADeviceGuard guard(resource_->dev_id(i)); - int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; - auto& node = path_[gpu_id][i].nodes_.back(); - int64_t* id_array = reinterpret_cast(node.val_storage); - int* actual_size_array = (int*)(id_array + shard_len); - int64_t* sample_array = - (int64_t*)(actual_size_array + shard_len + shard_len % 2); - thrust::device_vector cpu_keys_ptr(shard_len); - thrust::device_vector index_ptr(shard_len + 1, 0); - int64_t* node_id_array = reinterpret_cast(node.key_storage); - int grid_size2 = (shard_len - 1) / block_size_ + 1; - get_cpu_id_index<<remote_stream(i, gpu_id)>>>( - node_id_array, id_array, - thrust::raw_pointer_cast(cpu_keys_ptr.data()), - thrust::raw_pointer_cast(index_ptr.data()), - thrust::raw_pointer_cast(index_ptr.data()) + 1, shard_len); - cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); - cpu_keys_list.emplace_back(cpu_keys_ptr); - cpu_index_list.emplace_back(index_ptr); - } - } - if (cpu_query_switch) { - for (int i = 0; i < total_gpu; ++i) { - if (h_left[i] == -1) { - continue; - } - platform::CUDADeviceGuard guard(resource_->dev_id(i)); - auto shard_len = h_right[i] - h_left[i] + 1; - int* cpu_index = new int[shard_len + 1]; - cudaMemcpy(cpu_index, thrust::raw_pointer_cast(cpu_index_list[i].data()), - (shard_len + 1) * sizeof(int), cudaMemcpyDeviceToHost); - if (cpu_index[0] > 0) { - int number_on_cpu = cpu_index[0]; - int64_t* cpu_keys = new int64_t[number_on_cpu]; - cudaMemcpy(cpu_keys, thrust::raw_pointer_cast(cpu_keys_list[i].data()), - number_on_cpu * sizeof(int64_t), cudaMemcpyDeviceToHost); - std::vector> buffers(number_on_cpu); - std::vector ac(number_on_cpu); - auto status = cpu_graph_table->random_sample_neighbors( - 0, cpu_keys, sample_size, buffers, ac, false); - - auto& node = path_[gpu_id][i].nodes_.back(); - // display_sample_res(node.key_storage,node.val_storage,shard_len,sample_size); - int64_t* id_array = reinterpret_cast(node.val_storage); - int* actual_size_array = (int*)(id_array + shard_len); - int64_t* sample_array = - (int64_t*)(actual_size_array + shard_len + shard_len % 2); - for (int j = 0; j < number_on_cpu; j++) { - int offset = cpu_index[j + 1] * sample_size; - ac[j] = ac[j] / sizeof(int64_t); - /* - std::cerr<<"for cpu key "<