From b61a6e71216cbbdd813ee04b2fe51741bf7c0886 Mon Sep 17 00:00:00 2001 From: seemingwang Date: Mon, 16 May 2022 09:40:03 +0800 Subject: [PATCH] fix node transfer problem (#42674) * enable graph-engine to return all id * change vector's dimension * change vector's dimension * enlarge returned ids dimensions * add actual_val * change vlog * fix bug * bug fix * bug fix * fix display test * singleton of gpu_graph_wrapper * change sample result's structure to fit training * recover sample code * fix * secondary sample * add graph partition * fix pybind * optimize buffer allocation * fix node transfer problem * remove log * support 32G+ graph on single gpu * remove logs * fix * fix * fix cpu query * display info * remove log * remove empyt file Co-authored-by: DesmonDay <908660116@qq.com> --- .../ps/table/common_graph_table.cc | 52 ++-- .../framework/fleet/heter_ps/CMakeLists.txt | 7 +- .../framework/fleet/heter_ps/gpu_graph_node.h | 49 ++- .../fleet/heter_ps/graph_gpu_ps_table.h | 23 +- ..._table_inl.h => graph_gpu_ps_table_inl.cu} | 287 +++++++++--------- .../fleet/heter_ps/graph_gpu_wrapper.cu | 163 +++------- .../fleet/heter_ps/graph_gpu_wrapper.h | 14 +- .../fleet/heter_ps/hashtable_kernel.cu | 8 + .../framework/fleet/heter_ps/heter_comm_inl.h | 7 +- .../fleet/heter_ps/test_cpu_query.cu | 6 + paddle/fluid/pybind/fleet_py.cc | 13 +- 11 files changed, 301 insertions(+), 328 deletions(-) rename paddle/fluid/framework/fleet/heter_ps/{graph_gpu_ps_table_inl.h => graph_gpu_ps_table_inl.cu} (85%) diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index a3fa80b386..b53044b749 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -80,7 +80,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( } for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get(); paddle::framework::GpuPsCommGraph res; - unsigned int tot_len = 0; + int64_t tot_len = 0; for (int i = 0; i < task_pool_size_; i++) { tot_len += edge_array[i].size(); } @@ -88,8 +88,8 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( // res.node_size = ids.size(); // res.neighbor_list = new int64_t[tot_len]; // res.node_list = new paddle::framework::GpuPsGraphNode[ids.size()]; - res.init_on_cpu(tot_len, (unsigned int)ids.size()); - unsigned int offset = 0, ind = 0; + res.init_on_cpu(tot_len, ids.size()); + int64_t offset = 0, ind = 0; for (int i = 0; i < task_pool_size_; i++) { for (int j = 0; j < (int)node_array[i].size(); j++) { res.node_list[ind] = node_array[i][j]; @@ -126,8 +126,8 @@ int32_t GraphTable::add_node_to_ssd(int type_id, int idx, int64_t src_id, _db->put(src_id % shard_num % task_pool_size_, ch, sizeof(int) * 2 + sizeof(int64_t), (char *)data, len); } - _db->flush(src_id % shard_num % task_pool_size_); - std::string x; + // _db->flush(src_id % shard_num % task_pool_size_); + // std::string x; // if (_db->get(src_id % shard_num % task_pool_size_, ch, sizeof(int64_t) + // 2 * sizeof(int), x) ==0){ // VLOG(0)<<"put result"; @@ -135,6 +135,18 @@ int32_t GraphTable::add_node_to_ssd(int type_id, int idx, int64_t src_id, // VLOG(0)<<"get an id "<<*((int64_t *)(x.c_str() + i)); // } //} + // if(src_id == 429){ + // str = ""; + // _db->get(src_id % shard_num % task_pool_size_, ch, + // sizeof(int) * 2 + sizeof(int64_t), str); + // int64_t *stored_data = ((int64_t *)str.c_str()); + // int n = str.size() / sizeof(int64_t); + // VLOG(0)<<"429 has "< edge_array[task_pool_size_]; std::vector> count(task_pool_size_); @@ -387,9 +403,9 @@ int32_t GraphTable::dump_edges_to_ssd(int idx) { [&, i, this]() -> int64_t { int64_t cost = 0; std::vector &v = shards[i]->get_bucket(); - std::vector s; size_t ind = i % this->task_pool_size_; for (size_t j = 0; j < v.size(); j++) { + std::vector s; for (int k = 0; k < v[j]->get_neighbor_size(); k++) { s.push_back(v[j]->get_neighbor_id(k)); } @@ -405,7 +421,7 @@ int32_t GraphTable::dump_edges_to_ssd(int idx) { } int32_t GraphTable::make_complementary_graph(int idx, int64_t byte_size) { VLOG(0) << "make_complementary_graph"; - const int64_t fixed_size = 10000; + const int64_t fixed_size = byte_size / 8; // std::vector edge_array[task_pool_size_]; std::vector> count(task_pool_size_); std::vector> tasks; @@ -416,7 +432,7 @@ int32_t GraphTable::make_complementary_graph(int idx, int64_t byte_size) { std::vector &v = shards[i]->get_bucket(); size_t ind = i % this->task_pool_size_; for (size_t j = 0; j < v.size(); j++) { - size_t location = v[j]->get_id(); + // size_t location = v[j]->get_id(); for (int k = 0; k < v[j]->get_neighbor_size(); k++) { count[ind][v[j]->get_neighbor_id(k)]++; } @@ -424,19 +440,12 @@ int32_t GraphTable::make_complementary_graph(int idx, int64_t byte_size) { return 0; })); } - + for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); std::unordered_map final_count; std::map> count_to_id; std::vector buffer; - for (auto p : edge_shards[idx]) { - delete p; - } + clear_graph(idx); - edge_shards[idx].clear(); - for (size_t i = 0; i < shard_num_per_server; i++) { - edge_shards[idx].push_back(new GraphShard()); - } - for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); for (int i = 0; i < task_pool_size_; i++) { for (auto &p : count[i]) { final_count[p.first] = final_count[p.first] + p.second; @@ -447,13 +456,13 @@ int32_t GraphTable::make_complementary_graph(int idx, int64_t byte_size) { count_to_id[p.second].push_back(p.first); VLOG(2) << p.first << " appear " << p.second << " times"; } - // std::map>::iterator iter= count_to_id.rbegin(); auto iter = count_to_id.rbegin(); while (iter != count_to_id.rend() && byte_size > 0) { for (auto x : iter->second) { buffer.push_back(x); if (buffer.size() >= fixed_size) { int64_t res = load_graph_to_memory_from_ssd(idx, buffer); + buffer.clear(); byte_size -= res; } if (byte_size <= 0) break; @@ -1265,13 +1274,14 @@ int32_t GraphTable::random_sample_neighbors( if (node == nullptr) { #ifdef PADDLE_WITH_HETERPS if (search_level == 2) { - VLOG(2) << "enter sample from ssd"; + VLOG(2) << "enter sample from ssd for node_id " << node_id; char *buffer_addr = random_sample_neighbor_from_ssd( idx, node_id, sample_size, rng, actual_size); if (actual_size != 0) { - std::shared_ptr &buffer = buffers[idx]; + std::shared_ptr &buffer = buffers[idy]; buffer.reset(buffer_addr, char_del); } + VLOG(2) << "actual sampled size from ssd = " << actual_sizes[idy]; continue; } #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt index 51456457d0..d62fc1c084 100644 --- a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt @@ -13,11 +13,10 @@ IF(WITH_GPU) nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) if(WITH_PSCORE) - nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table hashtable_kernel) + nv_library(graph_gpu_ps SRCS graph_gpu_ps_table_inl.cu DEPS heter_comm table hashtable_kernel) nv_library(graph_sampler SRCS graph_sampler_inl.h DEPS graph_gpu_ps) - - nv_test(test_cpu_query SRCS test_cpu_query.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS}) - nv_library(graph_gpu_wrapper SRCS graph_gpu_wrapper.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS}) + nv_library(graph_gpu_wrapper SRCS graph_gpu_wrapper.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS} graph_gpu_ps) + nv_test(test_cpu_query SRCS test_cpu_query.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS} graph_gpu_ps graph_gpu_wrapper) #ADD_EXECUTABLE(test_sample_rate test_sample_rate.cu) #target_link_libraries(test_sample_rate heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS}) #nv_test(test_sample_rate SRCS test_sample_rate.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS}) diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index e7601edb0c..19c355c671 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -24,7 +24,7 @@ namespace paddle { namespace framework { struct GpuPsGraphNode { int64_t node_id; - unsigned int neighbor_size, neighbor_offset; + int64_t neighbor_size, neighbor_offset; // this node's neighbor is stored on [neighbor_offset,neighbor_offset + // neighbor_size) of int64_t *neighbor_list; }; @@ -32,17 +32,17 @@ struct GpuPsGraphNode { struct GpuPsCommGraph { int64_t *neighbor_list; GpuPsGraphNode *node_list; - unsigned int neighbor_size, node_size; + int64_t neighbor_size, node_size; // the size of neighbor array and graph_node_list array GpuPsCommGraph() : neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {} GpuPsCommGraph(int64_t *neighbor_list_, GpuPsGraphNode *node_list_, - unsigned int neighbor_size_, unsigned int node_size_) + int64_t neighbor_size_, int64_t node_size_) : neighbor_list(neighbor_list_), node_list(node_list_), neighbor_size(neighbor_size_), node_size(node_size_) {} - void init_on_cpu(unsigned int neighbor_size, unsigned int node_size) { + void init_on_cpu(int64_t neighbor_size, int64_t node_size) { this->neighbor_size = neighbor_size; this->node_size = node_size; this->neighbor_list = new int64_t[neighbor_size]; @@ -208,12 +208,43 @@ struct NeighborSampleResult { delete[] ac_size; VLOG(0) << " ------------------"; } - NeighborSampleResult(){}; - ~NeighborSampleResult() { - // if (val != NULL) cudaFree(val); - // if (actual_sample_size != NULL) cudaFree(actual_sample_size); - // if (offset != NULL) cudaFree(offset); + std::vector get_sampled_graph(NeighborSampleQuery q) { + std::vector graph; + int64_t *sample_keys = new int64_t[q.len]; + std::string key_str; + cudaMemcpy(sample_keys, q.key, q.len * sizeof(int64_t), + cudaMemcpyDeviceToHost); + int64_t *res = new int64_t[sample_size * key_size]; + cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t), + cudaMemcpyDeviceToHost); + int *ac_size = new int[key_size]; + cudaMemcpy(ac_size, actual_sample_size, key_size * sizeof(int), + cudaMemcpyDeviceToHost); // 3, 1, 3 + int total_sample_size = 0; + for (int i = 0; i < key_size; i++) { + total_sample_size += ac_size[i]; + } + int64_t *res2 = new int64_t[total_sample_size]; // r + cudaMemcpy(res2, actual_val, total_sample_size * sizeof(int64_t), + cudaMemcpyDeviceToHost); // r + + int start = 0; + for (int i = 0; i < key_size; i++) { + graph.push_back(sample_keys[i]); + graph.push_back(ac_size[i]); + for (int j = 0; j < ac_size[i]; j++) { + graph.push_back(res2[start + j]); + } + start += ac_size[i]; // r + } + delete[] res; + delete[] res2; // r + delete[] ac_size; + delete[] sample_keys; + return graph; } + NeighborSampleResult(){}; + ~NeighborSampleResult() {} }; struct NodeQueryResult { diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index 8a0088114e..9e7ee80edc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -23,15 +23,17 @@ #ifdef PADDLE_WITH_HETERPS namespace paddle { namespace framework { -class GpuPsGraphTable : public HeterComm { +class GpuPsGraphTable : public HeterComm { public: GpuPsGraphTable(std::shared_ptr resource, int topo_aware) - : HeterComm(1, resource) { + : HeterComm(1, resource) { load_factor_ = 0.25; rw_lock.reset(new pthread_rwlock_t()); gpu_num = resource_->total_device(); + memset(global_device_map, -1, sizeof(global_device_map)); for (int i = 0; i < gpu_num; i++) { gpu_graph_list.push_back(GpuPsCommGraph()); + global_device_map[resource_->dev_id(i)] = i; sample_status.push_back(NULL); tables_.push_back(NULL); } @@ -98,27 +100,20 @@ class GpuPsGraphTable : public HeterComm { NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int64_t *key, int sample_size, int len, bool cpu_query_switch); + void init_sample_status(); + void free_sample_status(); NodeQueryResult query_node_list(int gpu_id, int start, int query_size); void clear_graph_info(); + void display_sample_res(void *key, void *val, int len, int sample_len); void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num, int sample_size, int *h_left, int *h_right, int64_t *src_sample_res, int *actual_sample_size); - // void move_neighbor_sample_result_to_source_gpu( - // int gpu_id, int gpu_num, int *h_left, int *h_right, - // int64_t *src_sample_res, thrust::host_vector &total_sample_size); - // void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num, - // int *h_left, int *h_right, - // int *actual_sample_size, - // int *total_sample_size); int init_cpu_table(const paddle::distributed::GraphParameter &graph); - // int load(const std::string &path, const std::string ¶m); - // virtual int32_t end_graph_sampling() { - // return cpu_graph_table->end_graph_sampling(); - // } int gpu_num; std::vector gpu_graph_list; + int global_device_map[32]; std::vector sample_status; const int parallel_sample_size = 1; const int dim_y = 256; @@ -130,5 +125,5 @@ class GpuPsGraphTable : public HeterComm { }; } }; -#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h" +//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h" #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu similarity index 85% rename from paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h rename to paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index d28ae0ab5d..4cf579ce00 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -18,7 +18,7 @@ #include #pragma once #ifdef PADDLE_WITH_HETERPS -//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" +#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" namespace paddle { namespace framework { /* @@ -32,21 +32,21 @@ sample_result is to save the neighbor sampling result, its size is len * sample_size; */ -__global__ void get_cpu_id_index(int64_t* key, unsigned int* val, - int64_t* cpu_key, int* sum, int* index, - int len) { +__global__ void get_cpu_id_index(int64_t* key, int64_t* val, int64_t* cpu_key, + int* sum, int* index, int len) { CUDA_KERNEL_LOOP(i, len) { - if (val[i] == ((unsigned int)-1)) { + if (val[i] == -1) { int old = atomicAdd(sum, 1); cpu_key[old] = key[i]; index[old] = i; + // printf("old %d i-%d key:%lld\n",old,i,key[i]); } } } template __global__ void neighbor_sample_example_v2(GpuPsCommGraph graph, - unsigned int* node_index, + int64_t* node_index, int* actual_size, int64_t* res, int sample_len, int n) { assert(blockDim.x == WARP_SIZE); @@ -58,13 +58,13 @@ __global__ void neighbor_sample_example_v2(GpuPsCommGraph graph, curand_init(blockIdx.x, threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng); while (i < last_idx) { - if (node_index[i] == (unsigned int)(-1)) { + if (node_index[i] == -1) { actual_size[i] = 0; i += BLOCK_WARPS; continue; } - int neighbor_len = graph.node_list[node_index[i]].neighbor_size; - int data_offset = graph.node_list[node_index[i]].neighbor_offset; + int neighbor_len = (int)graph.node_list[node_index[i]].neighbor_size; + int64_t data_offset = graph.node_list[node_index[i]].neighbor_offset; int offset = i * sample_len; int64_t* data = graph.neighbor_list; if (neighbor_len <= sample_len) { @@ -86,7 +86,7 @@ __global__ void neighbor_sample_example_v2(GpuPsCommGraph graph, } __syncwarp(); for (int j = threadIdx.x; j < sample_len; j += WARP_SIZE) { - const int perm_idx = res[offset + j] + data_offset; + const int64_t perm_idx = res[offset + j] + data_offset; res[offset + j] = data[perm_idx]; } actual_size[i] = sample_len; @@ -96,23 +96,22 @@ __global__ void neighbor_sample_example_v2(GpuPsCommGraph graph, } __global__ void neighbor_sample_example(GpuPsCommGraph graph, - unsigned int* node_index, - int* actual_size, int64_t* res, - int sample_len, int* sample_status, - int n, int from) { + int64_t* node_index, int* actual_size, + int64_t* res, int sample_len, + int* sample_status, int n, int from) { int id = blockIdx.x * blockDim.y + threadIdx.y; if (id < n) { - if (node_index[id] == (unsigned int)(-1)) { + if (node_index[id] == -1) { actual_size[id] = 0; return; } curandState rng; curand_init(blockIdx.x, threadIdx.x, threadIdx.y, &rng); - int index = threadIdx.x; - int offset = id * sample_len; + int64_t index = threadIdx.x; + int64_t offset = id * sample_len; int64_t* data = graph.neighbor_list; - int data_offset = graph.node_list[node_index[id]].neighbor_offset; - int neighbor_len = graph.node_list[node_index[id]].neighbor_size; + int64_t data_offset = graph.node_list[node_index[id]].neighbor_offset; + int64_t neighbor_len = graph.node_list[node_index[id]].neighbor_size; int ac_len; if (sample_len > neighbor_len) ac_len = neighbor_len; @@ -220,6 +219,29 @@ int GpuPsGraphTable::init_cpu_table( that's what fill_dvals does. */ +void GpuPsGraphTable::display_sample_res(void* key, void* val, int len, + int sample_len) { + char key_buffer[len * sizeof(int64_t)]; + char val_buffer[sample_len * sizeof(int64_t) * len + + (len + len % 2) * sizeof(int) + len * sizeof(int64_t)]; + cudaMemcpy(key_buffer, key, sizeof(int64_t) * len, cudaMemcpyDeviceToHost); + cudaMemcpy(val_buffer, val, + sample_len * sizeof(int64_t) * len + + (len + len % 2) * sizeof(int) + len * sizeof(int64_t), + cudaMemcpyDeviceToHost); + int64_t* sample_val = (int64_t*)(val_buffer + (len + len % 2) * sizeof(int) + + len * sizeof(int64_t)); + for (int i = 0; i < len; i++) { + printf("key %lld\n", *(int64_t*)(key_buffer + i * sizeof(int64_t))); + printf("index %lld\n", *(int64_t*)(val_buffer + i * sizeof(int64_t))); + int ac_size = *(int*)(val_buffer + i * sizeof(int) + len * sizeof(int64_t)); + printf("sampled %d neigbhors\n", ac_size); + for (int j = 0; j < ac_size; j++) { + printf("%lld ", sample_val[i * sample_len + j]); + } + printf("\n"); + } +} void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( int start_index, int gpu_num, int sample_size, int* h_left, int* h_right, int64_t* src_sample_res, int* actual_sample_size) { @@ -229,7 +251,7 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( continue; } shard_len[i] = h_right[i] - h_left[i] + 1; - int cur_step = path_[start_index][i].nodes_.size() - 1; + int cur_step = (int)path_[start_index][i].nodes_.size() - 1; for (int j = cur_step; j > 0; j--) { cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage, path_[start_index][i].nodes_[j].val_storage, @@ -240,12 +262,12 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( auto& node = path_[start_index][i].nodes_.front(); cudaMemcpyAsync( reinterpret_cast(src_sample_res + h_left[i] * sample_size), - node.val_storage + sizeof(int64_t) * shard_len[i], - node.val_bytes_len - sizeof(int64_t) * shard_len[i], cudaMemcpyDefault, + node.val_storage + sizeof(int64_t) * shard_len[i] + + sizeof(int) * (shard_len[i] + shard_len[i] % 2), + sizeof(int64_t) * shard_len[i] * sample_size, cudaMemcpyDefault, node.out_stream); - // resource_->remote_stream(i, start_index)); cudaMemcpyAsync(reinterpret_cast(actual_sample_size + h_left[i]), - node.val_storage + sizeof(int) * shard_len[i], + node.val_storage + sizeof(int64_t) * shard_len[i], sizeof(int) * shard_len[i], cudaMemcpyDefault, node.out_stream); } @@ -440,15 +462,15 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) { // platform::CUDADeviceGuard guard(i); gpu_graph_list[i] = GpuPsCommGraph(); sample_status[i] = NULL; - tables_[i] = new Table(std::max((unsigned int)1, g.node_size) / load_factor_); + tables_[i] = new Table(std::max((int64_t)1, g.node_size) / load_factor_); if (g.node_size > 0) { std::vector keys; - std::vector offset; + std::vector offset; cudaMalloc((void**)&gpu_graph_list[i].node_list, g.node_size * sizeof(GpuPsGraphNode)); cudaMemcpy(gpu_graph_list[i].node_list, g.node_list, g.node_size * sizeof(GpuPsGraphNode), cudaMemcpyHostToDevice); - for (unsigned int j = 0; j < g.node_size; j++) { + for (int64_t j = 0; j < g.node_size; j++) { keys.push_back(g.node_list[j].node_id); offset.push_back(j); } @@ -460,12 +482,15 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) { gpu_graph_list[i].node_size = 0; } if (g.neighbor_size) { - int* addr; - cudaMalloc((void**)&addr, g.neighbor_size * sizeof(int)); - cudaMemset(addr, 0, g.neighbor_size * sizeof(int)); - sample_status[i] = addr; - cudaMalloc((void**)&gpu_graph_list[i].neighbor_list, - g.neighbor_size * sizeof(int64_t)); + cudaError_t cudaStatus = + cudaMalloc((void**)&gpu_graph_list[i].neighbor_list, + g.neighbor_size * sizeof(int64_t)); + PADDLE_ENFORCE_EQ(cudaStatus, cudaSuccess, + platform::errors::InvalidArgument( + "ailed to allocate memory for graph on gpu ")); + VLOG(0) << "sucessfully allocate " << g.neighbor_size * sizeof(int64_t) + << " bytes of memory for graph-edges on gpu " + << resource_->dev_id(i); cudaMemcpy(gpu_graph_list[i].neighbor_list, g.neighbor_list, g.neighbor_size * sizeof(int64_t), cudaMemcpyHostToDevice); gpu_graph_list[i].neighbor_size = g.neighbor_size; @@ -474,6 +499,27 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) { gpu_graph_list[i].neighbor_size = 0; } } + +void GpuPsGraphTable::init_sample_status() { + for (int i = 0; i < gpu_num; i++) { + if (gpu_graph_list[i].neighbor_size) { + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + int* addr; + cudaMalloc((void**)&addr, gpu_graph_list[i].neighbor_size * sizeof(int)); + cudaMemset(addr, 0, gpu_graph_list[i].neighbor_size * sizeof(int)); + sample_status[i] = addr; + } + } +} + +void GpuPsGraphTable::free_sample_status() { + for (int i = 0; i < gpu_num; i++) { + if (sample_status[i] != NULL) { + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + cudaFree(sample_status[i]); + } + } +} void GpuPsGraphTable::build_graph_from_cpu( std::vector& cpu_graph_list) { VLOG(0) << "in build_graph_from_cpu cpu_graph_list size = " @@ -485,22 +531,19 @@ void GpuPsGraphTable::build_graph_from_cpu( clear_graph_info(); for (int i = 0; i < cpu_graph_list.size(); i++) { platform::CUDADeviceGuard guard(resource_->dev_id(i)); - // platform::CUDADeviceGuard guard(i); gpu_graph_list[i] = GpuPsCommGraph(); sample_status[i] = NULL; - // auto table = - // new Table(std::max(1, cpu_graph_list[i].node_size) / load_factor_); - tables_[i] = new Table( - std::max((unsigned int)1, cpu_graph_list[i].node_size) / load_factor_); + tables_[i] = new Table(std::max((int64_t)1, cpu_graph_list[i].node_size) / + load_factor_); if (cpu_graph_list[i].node_size > 0) { std::vector keys; - std::vector offset; + std::vector offset; cudaMalloc((void**)&gpu_graph_list[i].node_list, cpu_graph_list[i].node_size * sizeof(GpuPsGraphNode)); cudaMemcpy(gpu_graph_list[i].node_list, cpu_graph_list[i].node_list, cpu_graph_list[i].node_size * sizeof(GpuPsGraphNode), cudaMemcpyHostToDevice); - for (unsigned int j = 0; j < cpu_graph_list[i].node_size; j++) { + for (int64_t j = 0; j < cpu_graph_list[i].node_size; j++) { keys.push_back(cpu_graph_list[i].node_list[j].node_id); offset.push_back(j); } @@ -512,12 +555,9 @@ void GpuPsGraphTable::build_graph_from_cpu( gpu_graph_list[i].node_size = 0; } if (cpu_graph_list[i].neighbor_size) { - int* addr; - cudaMalloc((void**)&addr, cpu_graph_list[i].neighbor_size * sizeof(int)); - cudaMemset(addr, 0, cpu_graph_list[i].neighbor_size * sizeof(int)); - sample_status[i] = addr; cudaMalloc((void**)&gpu_graph_list[i].neighbor_list, cpu_graph_list[i].neighbor_size * sizeof(int64_t)); + cudaMemcpy(gpu_graph_list[i].neighbor_list, cpu_graph_list[i].neighbor_list, cpu_graph_list[i].neighbor_size * sizeof(int64_t), @@ -533,8 +573,8 @@ void GpuPsGraphTable::build_graph_from_cpu( NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3( NeighborSampleQuery q, bool cpu_switch) { - return graph_neighbor_sample_v2(q.gpu_id, q.key, q.sample_size, q.len, - cpu_switch); + return graph_neighbor_sample_v2(global_device_map[q.gpu_id], q.key, + q.sample_size, q.len, cpu_switch); } NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, int64_t* key, @@ -571,12 +611,9 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, } platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); - // cudaMalloc((void**)&result->val, len * sample_size * sizeof(int64_t)); - // cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int)); int* actual_sample_size = result.actual_sample_size; int64_t* val = result.val; int total_gpu = resource_->total_device(); - // int dev_id = resource_->dev_id(gpu_id); auto stream = resource_->local_stream(gpu_id, 0); int grid_size = (len - 1) / block_size_ + 1; @@ -605,9 +642,6 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); - // fill_shard_key<<>>(d_shard_keys_ptr, - // key, - // d_idx_ptr, len); heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len, stream); cudaStreamSynchronize(stream); @@ -643,95 +677,47 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, of alloc_mem_i, actual_sample_size_of_x equals ((int *)alloc_mem_i)[shard_len + x] */ + create_storage(gpu_id, i, shard_len * sizeof(int64_t), - shard_len * (1 + sample_size) * sizeof(int64_t)); - auto& node = path_[gpu_id][i].nodes_[0]; - cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(int), - node.in_stream); + shard_len * (1 + sample_size) * sizeof(int64_t) + + sizeof(int) * (shard_len + shard_len % 2)); + // auto& node = path_[gpu_id][i].nodes_[0]; } - // auto end1 = std::chrono::steady_clock::now(); - // auto tt = std::chrono::duration_cast(end1 - - // start1); - // VLOG(0)<< "create storage time " << tt.count() << " us"; walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1) { continue; } + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; auto& node = path_[gpu_id][i].nodes_.back(); + cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(int64_t), + node.in_stream); cudaStreamSynchronize(node.in_stream); platform::CUDADeviceGuard guard(resource_->dev_id(i)); - // platform::CUDADeviceGuard guard(i); - // use the key-value map to update alloc_mem_i[0,shard_len) - // tables_[i]->rwlock_->RDLock(); tables_[i]->get(reinterpret_cast(node.key_storage), - reinterpret_cast(node.val_storage), + reinterpret_cast(node.val_storage), h_right[i] - h_left[i] + 1, resource_->remote_stream(i, gpu_id)); // node.in_stream); - int shard_len = h_right[i] - h_left[i] + 1; auto graph = gpu_graph_list[i]; - unsigned int* id_array = reinterpret_cast(node.val_storage); + int64_t* id_array = reinterpret_cast(node.val_storage); int* actual_size_array = (int*)(id_array + shard_len); - int64_t* sample_array = (int64_t*)(actual_size_array + shard_len); - int sample_grid_size = (shard_len - 1) / dim_y + 1; - dim3 block(parallel_sample_size, dim_y); - dim3 grid(sample_grid_size); - // int sample_grid_size = shard_len / block_size_ + 1; - // VLOG(0)<<"in sample grid_size = "< user_feature_name = {"a", "b", "c", "d"}; -std::vector item_feature_name = {"a"}; -std::vector user_feature_dtype = {"float32", "int32", "string", - "string"}; -std::vector item_feature_dtype = {"float32"}; -std::vector user_feature_shape = {1, 2, 1, 1}; -std::vector item_feature_shape = {1}; -void prepare_file(char file_name[]) { - std::ofstream ofile; - ofile.open(file_name); - - for (auto x : nodes) { - ofile << x << std::endl; - } - ofile.close(); -} +std::shared_ptr GraphGpuWrapper::s_instance_(nullptr); void GraphGpuWrapper::set_device(std::vector ids) { for (auto device_id : ids) { device_id_mapping.push_back(device_id); @@ -205,96 +172,35 @@ void GraphGpuWrapper::upload_batch(int idx, // g->build_graph_from_cpu(vec); } -void GraphGpuWrapper::initialize() { - std::vector device_id_mapping; - for (int i = 0; i < 2; i++) device_id_mapping.push_back(i); - int gpu_num = device_id_mapping.size(); - ::paddle::distributed::GraphParameter table_proto; - table_proto.add_edge_types("u2u"); - table_proto.add_node_types("user"); - table_proto.add_node_types("item"); - ::paddle::distributed::GraphFeature *g_f = table_proto.add_graph_feature(); - - for (int i = 0; i < user_feature_name.size(); i++) { - g_f->add_name(user_feature_name[i]); - g_f->add_dtype(user_feature_dtype[i]); - g_f->add_shape(user_feature_shape[i]); - } - ::paddle::distributed::GraphFeature *g_f1 = table_proto.add_graph_feature(); - for (int i = 0; i < item_feature_name.size(); i++) { - g_f1->add_name(item_feature_name[i]); - g_f1->add_dtype(item_feature_dtype[i]); - g_f1->add_shape(item_feature_shape[i]); - } - prepare_file(node_file_name); - table_proto.set_shard_num(24); +// void GraphGpuWrapper::test() { +// int64_t cpu_key[3] = {0, 1, 2}; +// void *key; +// platform::CUDADeviceGuard guard(0); +// cudaMalloc((void **)&key, 3 * sizeof(int64_t)); +// cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice); +// auto neighbor_sample_res = +// ((GpuPsGraphTable *)graph_table) +// ->graph_neighbor_sample(0, (int64_t *)key, 2, 3); +// int64_t *res = new int64_t[7]; +// cudaMemcpy(res, neighbor_sample_res.val, 3 * 2 * sizeof(int64_t), +// cudaMemcpyDeviceToHost); +// int *actual_sample_size = new int[3]; +// cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size, +// 3 * sizeof(int), +// cudaMemcpyDeviceToHost); // 3, 1, 3 - std::shared_ptr resource = - std::make_shared(device_id_mapping); - resource->enable_p2p(); - GpuPsGraphTable *g = new GpuPsGraphTable(resource, 1); - g->init_cpu_table(table_proto); - graph_table = (char *)g; - g->cpu_graph_table->Load(node_file_name, "nuser"); - g->cpu_graph_table->Load(node_file_name, "nitem"); - std::remove(node_file_name); - std::vector vec; - std::vector node_ids; - node_ids.push_back(37); - node_ids.push_back(96); - std::vector> node_feat(2, - std::vector(2)); - std::vector feature_names; - feature_names.push_back(std::string("c")); - feature_names.push_back(std::string("d")); - g->cpu_graph_table->get_node_feat(0, node_ids, feature_names, node_feat); - VLOG(0) << "get_node_feat: " << node_feat[0][0]; - VLOG(0) << "get_node_feat: " << node_feat[0][1]; - VLOG(0) << "get_node_feat: " << node_feat[1][0]; - VLOG(0) << "get_node_feat: " << node_feat[1][1]; - int n = 10; - std::vector ids0, ids1; - for (int i = 0; i < n; i++) { - g->cpu_graph_table->add_comm_edge(0, i, (i + 1) % n); - g->cpu_graph_table->add_comm_edge(0, i, (i - 1 + n) % n); - if (i % 2 == 0) ids0.push_back(i); - } - g->cpu_graph_table->build_sampler(0); - ids1.push_back(5); - vec.push_back(g->cpu_graph_table->make_gpu_ps_graph(0, ids0)); - vec.push_back(g->cpu_graph_table->make_gpu_ps_graph(0, ids1)); - vec[0].display_on_cpu(); - vec[1].display_on_cpu(); - g->build_graph_from_cpu(vec); -} -void GraphGpuWrapper::test() { - int64_t cpu_key[3] = {0, 1, 2}; - void *key; - platform::CUDADeviceGuard guard(0); - cudaMalloc((void **)&key, 3 * sizeof(int64_t)); - cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice); - auto neighbor_sample_res = - ((GpuPsGraphTable *)graph_table) - ->graph_neighbor_sample(0, (int64_t *)key, 2, 3); - int64_t *res = new int64_t[7]; - cudaMemcpy(res, neighbor_sample_res.val, 3 * 2 * sizeof(int64_t), - cudaMemcpyDeviceToHost); - int *actual_sample_size = new int[3]; - cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size, - 3 * sizeof(int), - cudaMemcpyDeviceToHost); // 3, 1, 3 - - //{0,9} or {9,0} is expected for key 0 - //{0,2} or {2,0} is expected for key 1 - //{1,3} or {3,1} is expected for key 2 - for (int i = 0; i < 3; i++) { - VLOG(0) << "actual sample size for " << i << " is " - << actual_sample_size[i]; - for (int j = 0; j < actual_sample_size[i]; j++) { - VLOG(0) << "sampled an neighbor for node" << i << " : " << res[i * 2 + j]; - } - } -} +// //{0,9} or {9,0} is expected for key 0 +// //{0,2} or {2,0} is expected for key 1 +// //{1,3} or {3,1} is expected for key 2 +// for (int i = 0; i < 3; i++) { +// VLOG(0) << "actual sample size for " << i << " is " +// << actual_sample_size[i]; +// for (int j = 0; j < actual_sample_size[i]; j++) { +// VLOG(0) << "sampled an neighbor for node" << i << " : " << res[i * 2 + +// j]; +// } +// } +// } NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample_v3( NeighborSampleQuery q, bool cpu_switch) { return ((GpuPsGraphTable *)graph_table) @@ -314,7 +220,6 @@ std::vector GraphGpuWrapper::graph_neighbor_sample( auto neighbor_sample_res = ((GpuPsGraphTable *)graph_table) ->graph_neighbor_sample(gpu_id, cuda_key, sample_size, key.size()); - int *actual_sample_size = new int[key.size()]; cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size, key.size() * sizeof(int), @@ -323,7 +228,6 @@ std::vector GraphGpuWrapper::graph_neighbor_sample( for (int i = 0; i < key.size(); i++) { cumsum += actual_sample_size[i]; } - /* VLOG(0) << "cumsum " << cumsum; */ std::vector cpu_key, res; cpu_key.resize(key.size() * sample_size); @@ -340,11 +244,18 @@ std::vector GraphGpuWrapper::graph_neighbor_sample( /* for(int i = 0;i < res.size();i ++) { */ /* VLOG(0) << i << " " << res[i]; */ /* } */ - + delete[] actual_sample_size; cudaFree(cuda_key); return res; } +void GraphGpuWrapper::init_sample_status() { + ((GpuPsGraphTable *)graph_table)->init_sample_status(); +} + +void GraphGpuWrapper::free_sample_status() { + ((GpuPsGraphTable *)graph_table)->free_sample_status(); +} NodeQueryResult GraphGpuWrapper::query_node_list(int gpu_id, int start, int query_size) { return ((GpuPsGraphTable *)graph_table) diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index b638311304..d8b11682bc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once #include #include #include @@ -22,10 +23,13 @@ namespace framework { #ifdef PADDLE_WITH_HETERPS class GraphGpuWrapper { public: - static GraphGpuWrapper* GetInstance() { - static GraphGpuWrapper wrapper; - return &wrapper; + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::framework::GraphGpuWrapper()); + } + return s_instance_; } + static std::shared_ptr s_instance_; void initialize(); void test(); void set_device(std::vector ids); @@ -53,6 +57,8 @@ class GraphGpuWrapper { std::vector& key, int sample_size); + void init_sample_status(); + void free_sample_status(); std::unordered_map edge_to_id, feature_to_id; std::vector id_to_feature, id_to_edge; std::vector> table_feat_mapping; @@ -62,7 +68,7 @@ class GraphGpuWrapper { ::paddle::distributed::GraphParameter table_proto; std::vector device_id_mapping; int search_level = 1; - char* graph_table; + void* graph_table; }; #endif } diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index df93f05691..5a29159aa1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -320,6 +320,7 @@ void HashTable::update(const KeyType* d_keys, template class HashTable; template class HashTable; +template class HashTable; template class HashTable; template class HashTable; @@ -334,6 +335,9 @@ template void HashTable::get(const long* d_keys, template void HashTable::get( const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream); +template void HashTable::get(const long* d_keys, + long* d_vals, size_t len, + cudaStream_t stream); template void HashTable::get( const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream); // template void @@ -350,6 +354,10 @@ template void HashTable::insert(const long* d_keys, const int* d_vals, size_t len, cudaStream_t stream); +template void HashTable::insert(const long* d_keys, + const long* d_vals, + size_t len, + cudaStream_t stream); template void HashTable::insert( const long* d_keys, const unsigned long* d_vals, size_t len, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 2a4f535ef7..d23719ea9e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -193,9 +193,10 @@ void HeterComm::walk_to_dest(int start_index, memory_copy(dst_place, node.key_storage, src_place, reinterpret_cast(src_key + h_left[i]), node.key_bytes_len, node.in_stream); -#if defined(PADDLE_WITH_CUDA) // adapt for gpu-graph - cudaMemsetAsync(node.val_storage, -1, node.val_bytes_len, node.in_stream); -#endif + // #if defined(PADDLE_WITH_CUDA) // adapt for gpu-graph + // cudaMemsetAsync(node.val_storage, -1, node.val_bytes_len, + // node.in_stream); + // #endif if (need_copy_val) { memory_copy(dst_place, node.val_storage, src_place, diff --git a/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu b/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu index b3a38a6dfd..ff3cd9d2d0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu @@ -17,6 +17,7 @@ #include #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" +#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" @@ -235,4 +236,9 @@ TEST(TEST_FLEET, test_cpu_cache) { } index++; } + auto iter = paddle::framework::GraphGpuWrapper::GetInstance(); + std::vector device; + device.push_back(0); + device.push_back(1); + iter->set_device(device); } diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index bcf55e46ed..2549240aa1 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -327,16 +327,15 @@ void BindNeighborSampleResult(py::module* m) { .def("initialize", &NeighborSampleResult::initialize) .def("get_len", &NeighborSampleResult::get_len) .def("get_val", &NeighborSampleResult::get_actual_val) + .def("get_sampled_graph", &NeighborSampleResult::get_sampled_graph) .def("display", &NeighborSampleResult::display); } void BindGraphGpuWrapper(py::module* m) { - py::class_(*m, "GraphGpuWrapper") - // nit<>()) - //.def("test", &GraphGpuWrapper::test) - //.def(py::init([]() { return framework::GraphGpuWrapper::GetInstance(); - //})) - .def(py::init<>()) + py::class_>( + *m, "GraphGpuWrapper") + .def(py::init([]() { return GraphGpuWrapper::GetInstance(); })) + // .def(py::init<>()) .def("neighbor_sample", &GraphGpuWrapper::graph_neighbor_sample_v3) .def("graph_neighbor_sample", &GraphGpuWrapper::graph_neighbor_sample) .def("set_device", &GraphGpuWrapper::set_device) @@ -347,6 +346,8 @@ void BindGraphGpuWrapper(py::module* m) { .def("load_edge_file", &GraphGpuWrapper::load_edge_file) .def("upload_batch", &GraphGpuWrapper::upload_batch) .def("get_all_id", &GraphGpuWrapper::get_all_id) + .def("init_sample_status", &GraphGpuWrapper::init_sample_status) + .def("free_sample_status", &GraphGpuWrapper::free_sample_status) .def("load_next_partition", &GraphGpuWrapper::load_next_partition) .def("make_partitions", &GraphGpuWrapper::make_partitions) .def("make_complementary_graph", -- GitLab