diff --git a/paddle/fluid/distributed/ps/service/CMakeLists.txt b/paddle/fluid/distributed/ps/service/CMakeLists.txt index f0ac7bc6a06359b952881af1200b88ff042367cc..e7519ef4998b13024b47d148c357eadd87943b95 100755 --- a/paddle/fluid/distributed/ps/service/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/service/CMakeLists.txt @@ -1,10 +1,15 @@ set(BRPC_SRCS ps_client.cc server.cc) set_source_files_properties(${BRPC_SRCS}) + if(WITH_HETERPS) + set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context rocksdb) + else() + set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context) + endif() brpc_library(sendrecv_rpc SRCS diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index c4b4064e0299e4b2d4e72b8bfd5c106dcd0433db..a8fde3f36bc6d892e564f2308802ef79a64681a6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -64,11 +64,9 @@ struct GpuPsCommGraph { /* suppose we have a graph like this - 0----3-----5----7 \ |\ |\ 17 8 9 1 2 - we save the nodes in arbitrary order, in this example,the order is [0,5,1,2,7,3,8,9,17] @@ -83,7 +81,6 @@ we record each node's neighbors: 8:3 9:3 17:0 - by concatenating each node's neighbor_list in the order we save the node id. we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0] this is the neighbor_list of GpuPsCommGraph @@ -114,6 +111,32 @@ node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13 node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14 node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15 */ +struct NeighborSampleQuery { + int gpu_id; + int64_t *key; + int sample_size; + int len; + void initialize(int gpu_id, int64_t key, int sample_size, int len) { + this->gpu_id = gpu_id; + this->key = (int64_t *)key; + this->sample_size = sample_size; + this->len = len; + } + void display() { + int64_t *sample_keys = new int64_t[len]; + VLOG(0) << "device_id " << gpu_id << " sample_size = " << sample_size; + VLOG(0) << "there are " << len << " keys "; + std::string key_str; + cudaMemcpy(sample_keys, key, len * sizeof(int64_t), cudaMemcpyDeviceToHost); + + for (int i = 0; i < len; i++) { + if (key_str.size() > 0) key_str += ";"; + key_str += std::to_string(sample_keys[i]); + } + VLOG(0) << key_str; + delete[] sample_keys; + } +}; struct NeighborSampleResult { int64_t *val; int *actual_sample_size, sample_size, key_size; @@ -134,6 +157,29 @@ struct NeighborSampleResult { memory::AllocShared(place, _key_size * sizeof(int)); actual_sample_size = (int *)actual_sample_size_mem->ptr(); } + void display() { + VLOG(0) << "in node sample result display ------------------"; + int64_t *res = new int64_t[sample_size * key_size]; + cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t), + cudaMemcpyDeviceToHost); + int *ac_size = new int[key_size]; + cudaMemcpy(ac_size, actual_sample_size, key_size * sizeof(int), + cudaMemcpyDeviceToHost); // 3, 1, 3 + + for (int i = 0; i < key_size; i++) { + VLOG(0) << "actual sample size for " << i << "th key is " << ac_size[i]; + VLOG(0) << "sampled neighbors are "; + std::string neighbor; + for (int j = 0; j < ac_size[i]; j++) { + if (neighbor.size() > 0) neighbor += ";"; + neighbor += std::to_string(res[i * sample_size + j]); + } + VLOG(0) << neighbor; + } + delete[] res; + delete[] ac_size; + VLOG(0) << " ------------------"; + } NeighborSampleResult(){}; ~NeighborSampleResult() { // if (val != NULL) cudaFree(val); @@ -145,13 +191,39 @@ struct NeighborSampleResult { struct NodeQueryResult { int64_t *val; int actual_sample_size; + int64_t get_val() { return (int64_t)val; } + int get_len() { return actual_sample_size; } + std::shared_ptr val_mem; + void initialize(int query_size, int dev_id) { + platform::CUDADeviceGuard guard(dev_id); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + val_mem = memory::AllocShared(place, query_size * sizeof(int64_t)); + val = (int64_t *)val_mem->ptr(); + + // cudaMalloc((void **)&val, query_size * sizeof(int64_t)); + actual_sample_size = 0; + } + void display() { + VLOG(0) << "in node query result display ------------------"; + int64_t *res = new int64_t[actual_sample_size]; + cudaMemcpy(res, val, actual_sample_size * sizeof(int64_t), + cudaMemcpyDeviceToHost); + + VLOG(0) << "actual_sample_size =" << actual_sample_size; + std::string str; + for (int i = 0; i < actual_sample_size; i++) { + if (str.size() > 0) str += ";"; + str += std::to_string(res[i]); + } + VLOG(0) << str; + delete[] res; + VLOG(0) << " ------------------"; + } NodeQueryResult() { val = NULL; actual_sample_size = 0; }; - ~NodeQueryResult() { - if (val != NULL) cudaFree(val); - } + ~NodeQueryResult() {} }; } }; diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index ff36b38b5089fcc99127333324ca92cfb9660d0d..7e5aa402677674bb5fc31aed1953ec40b8db484d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -83,13 +83,15 @@ class GpuPsGraphTable : public HeterComm { // } } void build_graph_from_cpu(std::vector &cpu_node_list); - NodeQueryResult *graph_node_sample(int gpu_id, int sample_size); - NeighborSampleResult *graph_neighbor_sample(int gpu_id, int64_t *key, - int sample_size, int len); - NeighborSampleResult *graph_neighbor_sample_v2(int gpu_id, int64_t *key, - int sample_size, int len, - bool cpu_query_switch); - NodeQueryResult *query_node_list(int gpu_id, int start, int query_size); + NodeQueryResult graph_node_sample(int gpu_id, int sample_size); + NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q, + bool cpu_switch); + NeighborSampleResult graph_neighbor_sample(int gpu_id, int64_t *key, + int sample_size, int len); + NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int64_t *key, + int sample_size, int len, + bool cpu_query_switch); + NodeQueryResult query_node_list(int gpu_id, int start, int query_size); void clear_graph_info(); void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num, int sample_size, int *h_left, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index b119724e695da6419497702750d8b09a7de29c1d..1c59f318517d0ded14336f2095335ad493592a8d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -13,7 +13,7 @@ // limitations under the License. #include - +#include #pragma once #ifdef PADDLE_WITH_HETERPS //#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" @@ -28,7 +28,6 @@ actual_size[0,len) is to save the sample size of each node. for ith node in index, actual_size[i] = min(node i's neighbor size, sample size) sample_result is to save the neighbor sampling result, its size is len * sample_size; - */ __global__ void get_cpu_id_index(int64_t* key, int* val, int64_t* cpu_key, @@ -198,7 +197,6 @@ int GpuPsGraphTable::init_cpu_table( // } /* comment 1 - gpu i triggers a neighbor_sample task, when this task is done, this function is called to move the sample result on other gpu back @@ -211,13 +209,11 @@ int GpuPsGraphTable::init_cpu_table( smaller than sample_size, is saved on src_sample_res [x*sample_size, x*sample_size + actual_sample_size[x]) - since before each gpu runs the neighbor_sample task,the key array is shuffled, but we have the idx array to save the original order. when the gpu i gets all the sample results from other gpus, it relies on idx array to recover the original order. that's what fill_dvals does. - */ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( @@ -404,10 +400,8 @@ void GpuPsGraphTable::clear_graph_info() { /* the parameter std::vector cpu_graph_list is generated by cpu. it saves the graph to be saved on each gpu. - for the ith GpuPsCommGraph, any the node's key satisfies that key % gpu_number == i - In this function, memory is allocated on each gpu to save the graphs, gpu i saves the ith graph from cpu_graph_list */ @@ -468,10 +462,15 @@ void GpuPsGraphTable::build_graph_from_cpu( cudaDeviceSynchronize(); } -NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, - int64_t* key, - int sample_size, - int len) { +NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3( + NeighborSampleQuery q, bool cpu_switch) { + return graph_neighbor_sample_v2(q.gpu_id, q.key, q.sample_size, q.len, + cpu_switch); +} +NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, + int64_t* key, + int sample_size, + int len) { /* comment 2 this function shares some kernels with heter_comm_inl.h @@ -479,7 +478,6 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, gpu_id:the id of gpu. len:how many keys are used,(the length of array key) sample_size:how many neighbors should be sampled for each node in key. - the code below shuffle the key array to make the keys that belong to a gpu-card stay together, the shuffled result is saved on d_shard_keys, @@ -489,18 +487,16 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, if keys in range [a,b] belong to ith-gpu, then h_left[i] = a, h_right[i] = b, if no keys are allocated for ith-gpu, then h_left[i] == h_right[i] == -1 - for example, suppose key = [0,1,2,3,4,5,6,7,8], gpu_num = 2 when we run this neighbor_sample function, the key is shuffled to [0,2,4,6,8,1,3,5,7] the first part (0,2,4,6,8) % 2 == 0,thus should be handled by gpu 0, the rest part should be handled by gpu1, because (1,3,5,7) % 2 == 1, h_left = [0,5],h_right = [4,8] - */ - NeighborSampleResult* result = new NeighborSampleResult(); - result->initialize(sample_size, len, resource_->dev_id(gpu_id)); + NeighborSampleResult result; + result.initialize(sample_size, len, resource_->dev_id(gpu_id)); if (len == 0) { return result; } @@ -508,8 +504,8 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); // cudaMalloc((void**)&result->val, len * sample_size * sizeof(int64_t)); // cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int)); - int* actual_sample_size = result->actual_sample_size; - int64_t* val = result->val; + int* actual_sample_size = result.actual_sample_size; + int64_t* val = result.val; int total_gpu = resource_->total_device(); // int dev_id = resource_->dev_id(gpu_id); auto stream = resource_->local_stream(gpu_id, 0); @@ -686,10 +682,10 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, return result; } -NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample_v2( +NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( int gpu_id, int64_t* key, int sample_size, int len, bool cpu_query_switch) { - NeighborSampleResult* result = new NeighborSampleResult(); - result->initialize(sample_size, len, resource_->dev_id(gpu_id)); + NeighborSampleResult result; + result.initialize(sample_size, len, resource_->dev_id(gpu_id)); if (len == 0) { return result; @@ -697,8 +693,8 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample_v2( platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); - int* actual_sample_size = result->actual_sample_size; - int64_t* val = result->val; + int* actual_sample_size = result.actual_sample_size; + int64_t* val = result.val; int total_gpu = resource_->total_device(); auto stream = resource_->local_stream(gpu_id, 0); @@ -861,17 +857,19 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample_v2( return result; } -NodeQueryResult* GpuPsGraphTable::graph_node_sample(int gpu_id, - int sample_size) {} +NodeQueryResult GpuPsGraphTable::graph_node_sample(int gpu_id, + int sample_size) { + return NodeQueryResult(); +} -NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start, - int query_size) { - NodeQueryResult* result = new NodeQueryResult(); +NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int start, + int query_size) { + NodeQueryResult result; if (query_size <= 0) return result; - int& actual_size = result->actual_sample_size; + int& actual_size = result.actual_sample_size; actual_size = 0; - cudaMalloc((void**)&result->val, query_size * sizeof(int64_t)); - int64_t* val = result->val; + result.initialize(query_size, resource_->dev_id(gpu_id)); + int64_t* val = result.val; // int dev_id = resource_->dev_id(gpu_id); // platform::CUDADeviceGuard guard(dev_id); platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); @@ -883,7 +881,6 @@ NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start, sample_size[i] = s; then on gpu a, the nodes of positions [p1,p1 + s) should be returned and saved from the p2 position on the sample_result array - for example: suppose gpu 0 saves [0,2,4,6,8], gpu1 saves [1,3,5,7] @@ -893,23 +890,29 @@ NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start, gpu_begin_pos = [3,0] local_begin_pos = [0,3] sample_size = [2,3] - */ + std::function range_check = []( + int x, int y, int x1, int y1, int& x2, int& y2) { + if (y <= x1 || x >= y1) return 0; + y2 = min(y, y1); + x2 = max(x1, x); + return y2 - x2; + }; for (int i = 0; i < gpu_graph_list.size() && query_size != 0; i++) { auto graph = gpu_graph_list[i]; if (graph.node_size == 0) { continue; } - if (graph.node_size + size > start) { - int cur_size = min(query_size, graph.node_size + size - start); - query_size -= cur_size; - idx.emplace_back(i); - gpu_begin_pos.emplace_back(start - size); + int x2, y2; + int len = range_check(start, start + query_size, size, + size + graph.node_size, x2, y2); + if (len > 0) { + idx.push_back(i); + gpu_begin_pos.emplace_back(x2 - size); local_begin_pos.emplace_back(actual_size); - start += cur_size; - actual_size += cur_size; - sample_size.emplace_back(cur_size); - create_storage(gpu_id, i, 1, cur_size * sizeof(int64_t)); + sample_size.push_back(len); + actual_size += len; + create_storage(gpu_id, i, 1, len * sizeof(int64_t)); } size += graph.node_size; } @@ -936,6 +939,9 @@ NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start, auto& node = path_[gpu_id][idx[i]].nodes_.front(); cudaStreamSynchronize(node.out_stream); } + for (auto x : idx) { + destroy_storage(gpu_id, x); + } return result; } } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 2f099d09397d5aeb0b81a223ffdb86d4bdb99a8d..e99a0f4fe11c173f40f2764f49a5cbab695ea476 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -158,14 +158,16 @@ void GraphGpuWrapper::init_service() { graph_table = (char *)g; } -void GraphGpuWrapper::upload_batch(std::vector> &ids) { +void GraphGpuWrapper::upload_batch(int idx, + std::vector> &ids) { GpuPsGraphTable *g = (GpuPsGraphTable *)graph_table; std::vector vec; for (int i = 0; i < ids.size(); i++) { - vec.push_back(g->cpu_graph_table->make_gpu_ps_graph(0, ids[i])); + vec.push_back(g->cpu_graph_table->make_gpu_ps_graph(idx, ids[i])); } g->build_graph_from_cpu(vec); } + void GraphGpuWrapper::initialize() { std::vector device_id_mapping; for (int i = 0; i < 2; i++) device_id_mapping.push_back(i); @@ -238,10 +240,10 @@ void GraphGpuWrapper::test() { ((GpuPsGraphTable *)graph_table) ->graph_neighbor_sample(0, (int64_t *)key, 2, 3); int64_t *res = new int64_t[7]; - cudaMemcpy(res, neighbor_sample_res->val, 3 * 2 * sizeof(int64_t), + cudaMemcpy(res, neighbor_sample_res.val, 3 * 2 * sizeof(int64_t), cudaMemcpyDeviceToHost); int *actual_sample_size = new int[3]; - cudaMemcpy(actual_sample_size, neighbor_sample_res->actual_sample_size, + cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size, 3 * sizeof(int), cudaMemcpyDeviceToHost); // 3, 1, 3 @@ -256,12 +258,60 @@ void GraphGpuWrapper::test() { } } } -NeighborSampleResult *GraphGpuWrapper::graph_neighbor_sample(int gpu_id, - int64_t *key, - int sample_size, - int len) { +NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample_v3( + NeighborSampleQuery q, bool cpu_switch) { + return ((GpuPsGraphTable *)graph_table) + ->graph_neighbor_sample_v3(q, cpu_switch); +} + +// this function is contributed by Liwb5 +std::vector GraphGpuWrapper::graph_neighbor_sample( + int gpu_id, std::vector &key, int sample_size) { + int64_t *cuda_key; + platform::CUDADeviceGuard guard(gpu_id); + + cudaMalloc(&cuda_key, key.size() * sizeof(int64_t)); + cudaMemcpy(cuda_key, key.data(), key.size() * sizeof(int64_t), + cudaMemcpyHostToDevice); + + auto neighbor_sample_res = + ((GpuPsGraphTable *)graph_table) + ->graph_neighbor_sample(gpu_id, cuda_key, sample_size, key.size()); + + int *actual_sample_size = new int[key.size()]; + cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size, + key.size() * sizeof(int), + cudaMemcpyDeviceToHost); // 3, 1, 3 + int cumsum = 0; + for (int i = 0; i < key.size(); i++) { + cumsum += actual_sample_size[i]; + } + /* VLOG(0) << "cumsum " << cumsum; */ + + std::vector res; + res.resize(cumsum * 2); + int count = 0; + for (int i = 0; i < key.size(); i++) { + for (int j = 0; j < actual_sample_size[i]; j++) { + res[count] = key[i]; + count += 1; + } + } + + cudaMemcpy(res.data() + cumsum, neighbor_sample_res.val, + cumsum * sizeof(int64_t), cudaMemcpyDeviceToHost); + /* for(int i = 0;i < res.size();i ++) { */ + /* VLOG(0) << i << " " << res[i]; */ + /* } */ + + cudaFree(cuda_key); + return res; +} + +NodeQueryResult GraphGpuWrapper::query_node_list(int gpu_id, int start, + int query_size) { return ((GpuPsGraphTable *)graph_table) - ->graph_neighbor_sample(gpu_id, key, sample_size, len); + ->query_node_list(gpu_id, start, query_size); } #endif } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index 26ce4c8adce2108db2a760cb52410c949d30c4cb..6972551b896edab4445ff9b8d783e8f9dbd913db 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -29,13 +29,17 @@ class GraphGpuWrapper { void init_service(); void set_up_types(std::vector& edge_type, std::vector& node_type); - void upload_batch(std::vector>& ids); + void upload_batch(int idx, std::vector>& ids); void add_table_feat_conf(std::string table_name, std::string feat_name, std::string feat_dtype, int feat_shape); void load_edge_file(std::string name, std::string filepath, bool reverse); void load_node_file(std::string name, std::string filepath); - NeighborSampleResult* graph_neighbor_sample(int gpu_id, int64_t* key, - int sample_size, int len); + NodeQueryResult query_node_list(int gpu_id, int start, int query_size); + NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q, + bool cpu_switch); + std::vector graph_neighbor_sample(int gpu_id, + std::vector& key, + int sample_size); std::unordered_map edge_to_id, feature_to_id; std::vector id_to_feature, id_to_edge; std::vector> table_feat_mapping; diff --git a/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu b/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu index 2e94a7f4059abc5a805368dcf64d2695a81933d4..f35a1c41bbe1d0903a1d5dfe7ee5e4e3cdc95f1f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu @@ -139,23 +139,17 @@ TEST(TEST_FLEET, test_cpu_cache) { platform::CUDADeviceGuard guard(0); cudaMalloc((void **)&key, 3 * sizeof(int64_t)); cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice); - auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 2, 3); - int64_t *res = new int64_t[7]; - cudaMemcpy(res, neighbor_sample_res->val, 3 * 2 * sizeof(int64_t), - cudaMemcpyDeviceToHost); - int *actual_sample_size = new int[3]; - cudaMemcpy(actual_sample_size, neighbor_sample_res->actual_sample_size, - 3 * sizeof(int), - cudaMemcpyDeviceToHost); // 3, 1, 3 - - //{0,9} or {9,0} is expected for key 0 + auto neighbor_sample_res = + g.graph_neighbor_sample_v2(0, (int64_t *)key, 2, 3, true); + neighbor_sample_res.display(); + //{1,9} or {9,1} is expected for key 0 //{0,2} or {2,0} is expected for key 1 //{1,3} or {3,1} is expected for key 2 - for (int i = 0; i < 3; i++) { - VLOG(0) << "actual sample size for " << i << " is " - << actual_sample_size[i]; - for (int j = 0; j < actual_sample_size[i]; j++) { - VLOG(0) << "sampled an neighbor for node" << i << " : " << res[i * 2 + j]; - } - } + auto node_query_res = g.query_node_list(0, 0, 4); + node_query_res.display(); + NeighborSampleQuery query; + query.initialize(0, node_query_res.get_val(), 2, node_query_res.get_len()); + query.display(); + auto c = g.graph_neighbor_sample_v3(query, false); + c.display(); } diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 00ceaf252dc8e4d85fd18942d823778b4fa4e6aa..4df43dc1a3a52c74a3b80d862f6c7764978ce1c9 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -255,6 +252,8 @@ using paddle::distributed::IndexNode; #ifdef PADDLE_WITH_HETERPS using paddle::framework::GraphGpuWrapper; using paddle::framework::NeighborSampleResult; +using paddle::framework::NeighborSampleQuery; +using paddle::framework::NodeQueryResult; #endif void BindIndexNode(py::module* m) { @@ -307,21 +306,39 @@ void BindIndexWrapper(py::module* m) { } #ifdef PADDLE_WITH_HETERPS +void BindNodeQueryResult(py::module* m) { + py::class_(*m, "NodeQueryResult") + .def(py::init<>()) + .def("initialize", &NodeQueryResult::initialize) + .def("display", &NodeQueryResult::display) + .def("get_val", &NodeQueryResult::get_val) + .def("get_len", &NodeQueryResult::get_len); +} +void BindNeighborSampleQuery(py::module* m) { + py::class_(*m, "NeighborSampleQuery") + .def(py::init<>()) + .def("initialize", &NeighborSampleQuery::initialize) + .def("display", &NeighborSampleQuery::display); +} + void BindNeighborSampleResult(py::module* m) { py::class_(*m, "NeighborSampleResult") .def(py::init<>()) - .def("initialize", &NeighborSampleResult::initialize); + .def("initialize", &NeighborSampleResult::initialize) + .def("display", &NeighborSampleResult::display); } void BindGraphGpuWrapper(py::module* m) { py::class_(*m, "GraphGpuWrapper") .def(py::init<>()) - .def("test", &GraphGpuWrapper::test) + //.def("test", &GraphGpuWrapper::test) .def("initialize", &GraphGpuWrapper::initialize) + .def("neighbor_sample", &GraphGpuWrapper::graph_neighbor_sample_v3) .def("graph_neighbor_sample", &GraphGpuWrapper::graph_neighbor_sample) .def("set_device", &GraphGpuWrapper::set_device) .def("init_service", &GraphGpuWrapper::init_service) .def("set_up_types", &GraphGpuWrapper::set_up_types) + .def("query_node_list", &GraphGpuWrapper::query_node_list) .def("add_table_feat_conf", &GraphGpuWrapper::add_table_feat_conf) .def("load_edge_file", &GraphGpuWrapper::load_edge_file) .def("upload_batch", &GraphGpuWrapper::upload_batch) diff --git a/paddle/fluid/pybind/fleet_py.h b/paddle/fluid/pybind/fleet_py.h index 81ed25913ba1a9eb965deff6d0f053a1f15ca236..a47aec749bda56f8422712230d59d52e5e0fd1f5 100644 --- a/paddle/fluid/pybind/fleet_py.h +++ b/paddle/fluid/pybind/fleet_py.h @@ -39,6 +39,8 @@ void BindIndexSampler(py::module* m); #ifdef PADDLE_WITH_HETERPS void BindNeighborSampleResult(py::module* m); void BindGraphGpuWrapper(py::module* m); +void BindNodeQueryResult(py::module* m); +void BindNeighborSampleQuery(py::module* m); #endif } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d5ee0c2a47b00224a0eecbbd8a229ed3b9327afa..843083fa0ad48e404ae0c3ffb665a4f5ca575f19 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -4578,6 +4578,8 @@ All parameter, weight, gradient are variables in Paddle. BindIndexWrapper(&m); BindIndexSampler(&m); #ifdef PADDLE_WITH_HETERPS + BindNodeQueryResult(&m); + BindNeighborSampleQuery(&m); BindNeighborSampleResult(&m); BindGraphGpuWrapper(&m); #endif