diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index b8f9f0bfec9b2a0bf6b6fb1e122e40b3eaa90fa8..3d1599a76e8ebcf8d379e6d44d6cc475ab4b0b33 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include "heter_comm.h" #include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" @@ -40,11 +41,13 @@ class GpuPsGraphTable : public HeterComm { int sample_size, int len); NodeQueryResult *query_node_list(int gpu_id, int start, int query_size); void clear_graph_info(); - void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num, - int sample_size, int *h_left, - int *h_right, - int64_t *src_sample_res, - int *actual_sample_size); + void move_neighbor_sample_result_to_source_gpu( + int gpu_id, int gpu_num, int *h_left, int *h_right, + int64_t *src_sample_res, thrust::host_vector &total_sample_size); + void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num, + int *h_left, int *h_right, + int *actual_sample_size, + int *total_sample_size); int init_cpu_table(const paddle::distributed::GraphParameter &graph); int load(const std::string &path, const std::string ¶m); virtual int32_t end_graph_sampling() { diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index 16a6857ae96eecaaa06b92b9912387f22612f53e..acd3f0a290d0b1b40ef71dd11b2741452f41e773 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -13,10 +13,23 @@ // limitations under the License. #pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + #ifdef PADDLE_WITH_HETERPS //#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" namespace paddle { namespace framework { + +constexpr int WARP_SIZE = 32; + /* comment 0 this kernel just serves as an example of how to sample nodes' neighbors. @@ -29,20 +42,79 @@ sample_size; */ -__global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index, - int* actual_size, - int64_t* sample_result, int sample_size, - int len) { - const size_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len) { +struct MaxFunctor { + int sample_size; + HOSTDEVICE explicit inline MaxFunctor(int sample_size) { + this->sample_size = sample_size; + } + HOSTDEVICE inline int operator()(int x) const { + if (x > sample_size) { + return sample_size; + } + return x; + } +}; + +struct DegreeFunctor { + GpuPsCommGraph graph; + HOSTDEVICE explicit inline DegreeFunctor(GpuPsCommGraph graph) { + this->graph = graph; + } + HOSTDEVICE inline int operator()(int i) const { + return graph.node_list[i].neighbor_size; + } +}; + +template +__global__ void neighbor_sample(const uint64_t rand_seed, GpuPsCommGraph graph, + int sample_size, int* index, int len, + int64_t* sample_result, int* output_idx, + int* output_offset) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int i = blockIdx.x * TILE_SIZE + threadIdx.y; + const int last_idx = min(static_cast(blockIdx.x + 1) * TILE_SIZE, len); + curandState rng; + curand_init(rand_seed * gridDim.x + blockIdx.x, + threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng); + + while (i < last_idx) { auto node_index = index[i]; - actual_size[i] = graph.node_list[node_index].neighbor_size < sample_size - ? graph.node_list[node_index].neighbor_size - : sample_size; - int offset = graph.node_list[node_index].neighbor_offset; - for (int j = 0; j < actual_size[i]; j++) { - sample_result[sample_size * i + j] = graph.neighbor_list[offset + j]; + int degree = graph.node_list[node_index].neighbor_size; + const int offset = graph.node_list[node_index].neighbor_offset; + int output_start = output_offset[i]; + + if (degree <= sample_size) { + // Just copy + for (int j = threadIdx.x; j < degree; j += WARP_SIZE) { + sample_result[output_start + j] = graph.neighbor_list[offset + j]; + } + } else { + for (int j = threadIdx.x; j < degree; j += WARP_SIZE) { + output_idx[output_start + j] = j; + } + + __syncwarp(); + + for (int j = sample_size + threadIdx.x; j < degree; j += WARP_SIZE) { + const int num = curand(&rng) % (j + 1); + if (num < sample_size) { + atomicMax( + reinterpret_cast(output_idx + output_start + num), + static_cast(j)); + } + } + + __syncwarp(); + + for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) { + const int perm_idx = output_idx[output_start + j] + offset; + sample_result[output_start + j] = graph.neighbor_list[perm_idx]; + } } + + i += BLOCK_WARPS; } } @@ -79,7 +151,7 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) { gpu i triggers a neighbor_sample task, when this task is done, this function is called to move the sample result on other gpu back - to gup i and aggragate the result. + to gpu i and aggragate the result. the sample_result is saved on src_sample_res and the actual sample size for each node is saved on actual_sample_size. the number of actual sample_result for @@ -96,10 +168,50 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) { that's what fill_dvals does. */ +void GpuPsGraphTable::move_neighbor_sample_size_to_source_gpu( + int gpu_id, int gpu_num, int* h_left, int* h_right, int* actual_sample_size, + int* total_sample_size) { + // This function copyed actual_sample_size to source_gpu, + // and calculate total_sample_size of each gpu sample number. + for (int i = 0; i < gpu_num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + auto shard_len = h_right[i] - h_left[i] + 1; + auto& node = path_[gpu_id][i].nodes_.front(); + cudaMemcpyAsync(reinterpret_cast(actual_sample_size + h_left[i]), + node.val_storage + sizeof(int) * shard_len, + sizeof(int) * shard_len, cudaMemcpyDefault, + node.out_stream); + } + for (int i = 0; i < gpu_num; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + total_sample_size[i] = 0; + continue; + } + auto& node = path_[gpu_id][i].nodes_.front(); + cudaStreamSynchronize(node.out_stream); + + auto shard_len = h_right[i] - h_left[i] + 1; + thrust::device_vector t_actual_sample_size(shard_len); + thrust::copy(actual_sample_size + h_left[i], + actual_sample_size + h_left[i] + shard_len, + t_actual_sample_size.begin()); + total_sample_size[i] = thrust::reduce(t_actual_sample_size.begin(), + t_actual_sample_size.end()); + } +} void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( - int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right, - int64_t* src_sample_res, int* actual_sample_size) { + int gpu_id, int gpu_num, int* h_left, int* h_right, int64_t* src_sample_res, + thrust::host_vector& total_sample_size) { + /* + if total_sample_size is [4, 5, 1, 6], + then cumsum_total_sample_size is [0, 4, 9, 10]; + */ + thrust::host_vector cumsum_total_sample_size(gpu_num, 0); + thrust::exclusive_scan(total_sample_size.begin(), total_sample_size.end(), + cumsum_total_sample_size.begin(), 0); for (int i = 0; i < gpu_num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { continue; @@ -109,14 +221,10 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( // auto& node = path_[gpu_id][i].nodes_[cur_step]; auto& node = path_[gpu_id][i].nodes_.front(); cudaMemcpyAsync( - reinterpret_cast(src_sample_res + h_left[i] * sample_size), + reinterpret_cast(src_sample_res + cumsum_total_sample_size[i]), node.val_storage + sizeof(int64_t) * shard_len, - node.val_bytes_len - sizeof(int64_t) * shard_len, cudaMemcpyDefault, + sizeof(int64_t) * total_sample_size[i], cudaMemcpyDefault, node.out_stream); - cudaMemcpyAsync(reinterpret_cast(actual_sample_size + h_left[i]), - node.val_storage + sizeof(int) * shard_len, - sizeof(int) * shard_len, cudaMemcpyDefault, - node.out_stream); } for (int i = 0; i < gpu_num; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { @@ -131,17 +239,35 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( TODO: how to optimize it to eliminate the for loop */ -__global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals, - int* d_shard_actual_sample_size, - int* d_actual_sample_size, int* idx, - int sample_size, int len) { +__global__ void fill_dvalues_actual_sample_size(int* d_shard_actual_sample_size, + int* d_actual_sample_size, + int* idx, int len) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i]; - // d_vals[idx[i]] = d_shard_vals[i]; - for (int j = 0; j < sample_size; j++) { - d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j]; + } +} + +template +__global__ void fill_dvalues_sample_result(int64_t* d_shard_vals, + int64_t* d_vals, + int* d_actual_sample_size, int* idx, + int* offset, int* d_offset, + int len) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int i = blockIdx.x * TILE_SIZE + threadIdx.y; + const int last_idx = min(static_cast(blockIdx.x + 1) * TILE_SIZE, len); + while (i < last_idx) { + const int sample_size = d_actual_sample_size[idx[i]]; + for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) { + d_vals[offset[idx[i]] + j] = d_shard_vals[d_offset[i] + j]; } +#ifdef PADDLE_WITH_CUDA + __syncwarp(); +#endif + i += BLOCK_WARPS; } } @@ -255,14 +381,12 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, h_left = [0,5],h_right = [4,8] */ + NeighborSampleResult* result = new NeighborSampleResult(sample_size, len); if (len == 0) { return result; } - cudaMalloc((void**)&result->val, len * sample_size * sizeof(int64_t)); - cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int)); - int* actual_sample_size = result->actual_sample_size; - int64_t* val = result->val; + int total_gpu = resource_->total_gpu(); int dev_id = resource_->dev_id(gpu_id); platform::CUDAPlace place = platform::CUDAPlace(dev_id); @@ -287,11 +411,6 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t)); int64_t* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); - auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t)); - int64_t* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); - auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int)); - int* d_shard_actual_sample_size_ptr = - reinterpret_cast(d_shard_actual_sample_size->ptr()); split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); @@ -331,6 +450,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, of alloc_mem_i, actual_sample_size_of_x equals ((int *)alloc_mem_i)[shard_len + x] */ + create_storage(gpu_id, i, shard_len * sizeof(int64_t), shard_len * (1 + sample_size) * sizeof(int64_t)); } @@ -351,6 +471,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, h_right[i] - h_left[i] + 1, resource_->remote_stream(i, gpu_id)); } + for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1) { continue; @@ -364,10 +485,42 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, int* res_array = reinterpret_cast(node.val_storage); int* actual_size_array = res_array + shard_len; int64_t* sample_array = (int64_t*)(res_array + shard_len * 2); - neighbor_sample_example<<remote_stream(i, gpu_id)>>>( - graph, res_array, actual_size_array, sample_array, sample_size, - shard_len); + + // 1. get actual_size_array. + // 2. get sum of actual_size. + // 3. get offset ptr + thrust::device_vector t_res_array(shard_len); + thrust::copy(res_array, res_array + shard_len, t_res_array.begin()); + thrust::device_vector t_actual_size_array(shard_len); + thrust::transform(t_res_array.begin(), t_res_array.end(), + t_actual_size_array.begin(), DegreeFunctor(graph)); + + if (sample_size >= 0) { + thrust::transform(t_actual_size_array.begin(), t_actual_size_array.end(), + t_actual_size_array.begin(), MaxFunctor(sample_size)); + } + + thrust::copy(t_actual_size_array.begin(), t_actual_size_array.end(), + actual_size_array); + + int total_sample_sum = + thrust::reduce(t_actual_size_array.begin(), t_actual_size_array.end()); + + thrust::device_vector output_idx(total_sample_sum); + thrust::device_vector output_offset(shard_len); + thrust::exclusive_scan(t_actual_size_array.begin(), + t_actual_size_array.end(), output_offset.begin(), 0); + + constexpr int BLOCK_WARPS = 128 / WARP_SIZE; + constexpr int TILE_SIZE = BLOCK_WARPS * 16; + const dim3 block_(WARP_SIZE, BLOCK_WARPS); + const dim3 grid_((shard_len + TILE_SIZE - 1) / TILE_SIZE); + neighbor_sample< + BLOCK_WARPS, + TILE_SIZE><<remote_stream(i, gpu_id)>>>( + 0, graph, sample_size, res_array, shard_len, sample_array, + thrust::raw_pointer_cast(output_idx.data()), + thrust::raw_pointer_cast(output_offset.data())); } for (int i = 0; i < total_gpu; ++i) { @@ -378,13 +531,56 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, tables_[i]->rwlock_->UNLock(); } // walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr); - move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size, - h_left, h_right, d_shard_vals_ptr, - d_shard_actual_sample_size_ptr); - fill_dvalues<<>>( - d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size, - d_idx_ptr, sample_size, len); + auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int)); + int* d_shard_actual_sample_size_ptr = + reinterpret_cast(d_shard_actual_sample_size->ptr()); + // Store total sample number of each gpu. + thrust::host_vector d_shard_total_sample_size(total_gpu, 0); + move_neighbor_sample_size_to_source_gpu( + gpu_id, total_gpu, h_left, h_right, d_shard_actual_sample_size_ptr, + thrust::raw_pointer_cast(d_shard_total_sample_size.data())); + int allocate_sample_num = 0; + for (int i = 0; i < total_gpu; ++i) { + allocate_sample_num += d_shard_total_sample_size[i]; + } + auto d_shard_vals = + memory::Alloc(place, allocate_sample_num * sizeof(int64_t)); + int64_t* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); + move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, h_left, h_right, + d_shard_vals_ptr, + d_shard_total_sample_size); + + cudaMalloc((void**)&result->val, allocate_sample_num * sizeof(int64_t)); + cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int)); + cudaMalloc((void**)&result->offset, len * sizeof(int)); + int64_t* val = result->val; + int* actual_sample_size = result->actual_sample_size; + int* offset = result->offset; + + fill_dvalues_actual_sample_size<<>>( + d_shard_actual_sample_size_ptr, actual_sample_size, d_idx_ptr, len); + thrust::device_vector t_actual_sample_size(len); + thrust::copy(actual_sample_size, actual_sample_size + len, + t_actual_sample_size.begin()); + thrust::exclusive_scan(t_actual_sample_size.begin(), + t_actual_sample_size.end(), offset, 0); + int* d_offset; + cudaMalloc(&d_offset, len * sizeof(int)); + thrust::copy(d_shard_actual_sample_size_ptr, + d_shard_actual_sample_size_ptr + len, + t_actual_sample_size.begin()); + thrust::exclusive_scan(t_actual_sample_size.begin(), + t_actual_sample_size.end(), d_offset, 0); + constexpr int BLOCK_WARPS_ = 128 / WARP_SIZE; + constexpr int TILE_SIZE_ = BLOCK_WARPS_ * 16; + const dim3 block__(WARP_SIZE, BLOCK_WARPS_); + const dim3 grid__((len + TILE_SIZE_ - 1) / TILE_SIZE_); + fill_dvalues_sample_result<<>>( + d_shard_vals_ptr, val, actual_sample_size, d_idx_ptr, offset, d_offset, + len); + cudaStreamSynchronize(stream); for (int i = 0; i < total_gpu; ++i) { int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; @@ -393,6 +589,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, } destroy_storage(gpu_id, i); } + cudaFree(d_offset); return result; } diff --git a/paddle/fluid/framework/fleet/heter_ps/test_graph.cu b/paddle/fluid/framework/fleet/heter_ps/test_graph.cu index 697e0ba2cdf3475d1e7ad48105bc55959461900f..06c7026eb51ca8ed808d528391ab6723fd83831c 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_graph.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_graph.cu @@ -94,19 +94,44 @@ TEST(TEST_FLEET, graph_comm) { 0 --index--->0 7 --index-->2 */ + int64_t cpu_key[3] = {7, 0, 6}; void *key; cudaMalloc((void **)&key, 3 * sizeof(int64_t)); cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice); auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3); - res = new int64_t[9]; - cudaMemcpy(res, neighbor_sample_res->val, 72, cudaMemcpyDeviceToHost); - int64_t expected_sample_val[] = {28, 29, 30, 0, -1, -1, 21, 22, 23}; - for (int i = 0; i < 9; i++) { - if (expected_sample_val[i] != -1) { - ASSERT_EQ(res[i], expected_sample_val[i]); + res = new int64_t[7]; + cudaMemcpy(res, neighbor_sample_res->val, 56, cudaMemcpyDeviceToHost); + int *actual_sample_size = new int[3]; + cudaMemcpy(actual_sample_size, neighbor_sample_res->actual_sample_size, 12, + cudaMemcpyDeviceToHost); // 3, 1, 3 + int *cumsum_sample_size = new int[3]; + cudaMemcpy(cumsum_sample_size, neighbor_sample_res->offset, 12, + cudaMemcpyDeviceToHost); // 0, 3, 4 + + std::vector> neighbors_; + std::vector neighbors_7 = {28, 29, 30, 31, 32, 33, 34, 35}; + std::vector neighbors_0 = {0}; + std::vector neighbors_6 = {21, 22, 23, 24, 25, 26, 27}; + neighbors_.push_back(neighbors_7); + neighbors_.push_back(neighbors_0); + neighbors_.push_back(neighbors_6); + for (int i = 0; i < 3; i++) { + for (int j = cumsum_sample_size[i]; + j < cumsum_sample_size[i] + actual_sample_size[i]; j++) { + bool flag = false; + for (int k = 0; k < neighbors_[i].size(); k++) { + if (res[j] == neighbors_[i][k]) { + flag = true; + break; + } + } + ASSERT_EQ(flag, true); } } + delete[] res; + delete[] actual_sample_size; + delete[] cumsum_sample_size; delete neighbor_sample_res; }