diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index 9f65a66708def030f7dfd9a9ac668c79ad60744f..13132740bb1dc708925f4cef11c54f802547e40d 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -302,7 +302,7 @@ std::future GraphBrpcClient::remove_graph_node( return fut; } // char* &buffer,int &actual_size -std::future GraphBrpcClient::batch_sample_neighboors( +std::future GraphBrpcClient::batch_sample_neighbors( uint32_t table_id, std::vector node_ids, int sample_size, std::vector>> &res, int server_index) { @@ -390,8 +390,8 @@ std::future GraphBrpcClient::batch_sample_neighboors( size_t fail_num = 0; for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) { - if (closure->check_response(request_idx, - PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) { + if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) != + 0) { ++fail_num; } else { auto &res_io_buffer = @@ -435,7 +435,7 @@ std::future GraphBrpcClient::batch_sample_neighboors( for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { int server_index = request2server[request_idx]; - closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS); + closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS); closure->request(request_idx)->set_table_id(table_id); closure->request(request_idx)->set_client_id(_client_id); size_t node_num = node_id_buckets[request_idx].size(); @@ -494,6 +494,47 @@ std::future GraphBrpcClient::random_sample_nodes( closure); return fut; } + +std::future GraphBrpcClient::use_neighbors_sample_cache( + uint32_t table_id, size_t total_size_limit, size_t ttl) { + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + server_size, [&, server_size = this->server_size ](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + size_t fail_num = 0; + for (size_t request_idx = 0; request_idx < server_size; ++request_idx) { + if (closure->check_response( + request_idx, PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE) != 0) { + ++fail_num; + break; + } + } + ret = fail_num == 0 ? 0 : -1; + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + size_t size_limit = total_size_limit / server_size + + (total_size_limit % server_size != 0 ? 1 : 0); + std::future fut = promise->get_future(); + for (size_t i = 0; i < server_size; i++) { + int server_index = i; + closure->request(server_index) + ->set_cmd_id(PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE); + closure->request(server_index)->set_table_id(table_id); + closure->request(server_index)->set_client_id(_client_id); + closure->request(server_index) + ->add_params((char *)&size_limit, sizeof(size_t)); + closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t)); + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(server_index), + closure->request(server_index), + closure->response(server_index), closure); + } + return fut; +} std::future GraphBrpcClient::pull_graph_list( uint32_t table_id, int server_index, int start, int size, int step, std::vector &res) { @@ -515,7 +556,7 @@ std::future GraphBrpcClient::pull_graph_list( index += node.get_size(false); res.push_back(node); } - delete buffer; + delete[] buffer; } closure->set_promise_value(ret); }); diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index 1fbb3fa9b0550e2e658c98b4ca41acd3e55a440a..c1083afb71abfb4432b749ce241630fbcff53782 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -61,8 +61,8 @@ class GraphBrpcClient : public BrpcPsClient { public: GraphBrpcClient() {} virtual ~GraphBrpcClient() {} - // given a batch of nodes, sample graph_neighboors for each of them - virtual std::future batch_sample_neighboors( + // given a batch of nodes, sample graph_neighbors for each of them + virtual std::future batch_sample_neighbors( uint32_t table_id, std::vector node_ids, int sample_size, std::vector>>& res, int server_index = -1); @@ -89,6 +89,9 @@ class GraphBrpcClient : public BrpcPsClient { virtual std::future add_graph_node( uint32_t table_id, std::vector& node_id_list, std::vector& is_weighted_list); + virtual std::future use_neighbors_sample_cache(uint32_t table_id, + size_t size_limit, + size_t ttl); virtual std::future remove_graph_node( uint32_t table_id, std::vector& node_id_list); virtual int32_t initialize(); diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 424cf281bf3974cfa5bea791d6e65e18290ee634..0aba2b9f44ae7cdd14cfd0bd2eb9cd3fafa1e783 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -187,8 +187,8 @@ int32_t GraphBrpcService::initialize() { _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler; _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list; - _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] = - &GraphBrpcService::graph_random_sample_neighboors; + _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] = + &GraphBrpcService::graph_random_sample_neighbors; _service_handler_map[PS_GRAPH_SAMPLE_NODES] = &GraphBrpcService::graph_random_sample_nodes; _service_handler_map[PS_GRAPH_GET_NODE_FEAT] = @@ -201,8 +201,9 @@ int32_t GraphBrpcService::initialize() { _service_handler_map[PS_GRAPH_SET_NODE_FEAT] = &GraphBrpcService::graph_set_node_feat; _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] = - &GraphBrpcService::sample_neighboors_across_multi_servers; - + &GraphBrpcService::sample_neighbors_across_multi_servers; + _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] = + &GraphBrpcService::use_neighbors_sample_cache; // shard初始化,server启动后才可从env获取到server_list的shard信息 initialize_shard_info(); @@ -373,7 +374,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, cntl->response_attachment().append(buffer.get(), actual_size); return 0; } -int32_t GraphBrpcService::graph_random_sample_neighboors( +int32_t GraphBrpcService::graph_random_sample_neighbors( Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) @@ -389,7 +390,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors( std::vector> buffers(node_num); std::vector actual_sizes(node_num, 0); ((GraphTable *)table) - ->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes); + ->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes); cntl->response_attachment().append(&node_num, sizeof(size_t)); cntl->response_attachment().append(actual_sizes.data(), @@ -448,7 +449,7 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, return 0; } -int32_t GraphBrpcService::sample_neighboors_across_multi_servers( +int32_t GraphBrpcService::sample_neighbors_across_multi_servers( Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl) { // sleep(5); @@ -456,7 +457,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( if (request.params_size() < 2) { set_response_code( response, -1, - "graph_random_sample request requires at least 2 arguments"); + "graph_random_neighbors_sample request requires at least 2 arguments"); return 0; } size_t node_num = request.params(0).size() / sizeof(uint64_t), @@ -519,7 +520,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( remote_call_num); size_t fail_num = 0; for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) { - if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) != + if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) != 0) { ++fail_num; failed[request2server[request_idx]] = true; @@ -570,7 +571,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) { int server_index = request2server[request_idx]; - closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS); + closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS); closure->request(request_idx)->set_table_id(request.table_id()); closure->request(request_idx)->set_client_id(rank); size_t node_num = node_id_buckets[request_idx].size(); @@ -590,8 +591,8 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( } if (server2request[rank] != -1) { ((GraphTable *)table) - ->random_sample_neighboors(node_id_buckets.back().data(), sample_size, - local_buffers, local_actual_sizes); + ->random_sample_neighbors(node_id_buckets.back().data(), sample_size, + local_buffers, local_actual_sizes); } local_promise.get()->set_value(0); if (remote_call_num == 0) func(closure); @@ -636,5 +637,20 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, return 0; } +int32_t GraphBrpcService::use_neighbors_sample_cache( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code(response, -1, + "use_neighbors_sample_cache request requires at least 2 " + "arguments[cache_size, ttl]"); + return 0; + } + size_t size_limit = *(size_t *)(request.params(0).c_str()); + size_t ttl = *(size_t *)(request.params(1).c_str()); + ((GraphTable *)table)->make_neighbor_sample_cache(size_limit, ttl); + return 0; +} } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h index 817fe08331165daf8953154f673eca14ea9d2e72..d1a6aa63604f36b3b3715a4293568a53c410bda7 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -78,10 +78,10 @@ class GraphBrpcService : public PsBaseService { int32_t initialize_shard_info(); int32_t pull_graph_list(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t graph_random_sample_neighboors(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl); + int32_t graph_random_sample_neighbors(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); int32_t graph_random_sample_nodes(Table *table, const PsRequestMessage &request, PsResponseMessage &response, @@ -116,9 +116,15 @@ class GraphBrpcService : public PsBaseService { int32_t print_table_stat(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t sample_neighboors_across_multi_servers( - Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); + int32_t sample_neighbors_across_multi_servers(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + + int32_t use_neighbors_sample_cache(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); private: bool _is_initialize_shard_info; diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index 498805136417f2f6520b6952465c0d90204efb0a..78f239f80d44599513066ba8bc985137a42097fc 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -290,19 +290,29 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { } } std::vector>> -GraphPyClient::batch_sample_neighboors(std::string name, - std::vector node_ids, - int sample_size) { +GraphPyClient::batch_sample_neighbors(std::string name, + std::vector node_ids, + int sample_size) { std::vector>> v; if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; auto status = - worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v); + worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v); status.wait(); } return v; } +void GraphPyClient::use_neighbors_sample_cache(std::string name, + size_t total_size_limit, + size_t ttl) { + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + worker_ptr->use_neighbors_sample_cache(table_id, total_size_limit, ttl); + status.wait(); + } +} std::vector GraphPyClient::random_sample_nodes(std::string name, int server_index, int sample_size) { diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index 8e03938801ce99e0fbbcafbe52174f8db86f8183..2d36edbf9c17d91bac742272c464fa2d2a39efe3 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -148,13 +148,15 @@ class GraphPyClient : public GraphPyService { int get_client_id() { return client_id; } void set_client_id(int client_id) { this->client_id = client_id; } void start_client(); - std::vector>> batch_sample_neighboors( + std::vector>> batch_sample_neighbors( std::string name, std::vector node_ids, int sample_size); std::vector random_sample_nodes(std::string name, int server_index, int sample_size); std::vector> get_node_feat( std::string node_type, std::vector node_ids, std::vector feature_names); + void use_neighbors_sample_cache(std::string name, size_t total_size_limit, + size_t ttl); void set_node_feat(std::string node_type, std::vector node_ids, std::vector feature_names, const std::vector> features); diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto index 42e25258ec3fe1eb8061129576e751284a9fa30d..8ee9b3590721a06e63b0059b8b1b5fea3a02ee75 100644 --- a/paddle/fluid/distributed/service/sendrecv.proto +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -49,7 +49,7 @@ enum PsCmdID { PS_STOP_PROFILER = 28; PS_PUSH_GLOBAL_STEP = 29; PS_PULL_GRAPH_LIST = 30; - PS_GRAPH_SAMPLE_NEIGHBOORS = 31; + PS_GRAPH_SAMPLE_NEIGHBORS = 31; PS_GRAPH_SAMPLE_NODES = 32; PS_GRAPH_GET_NODE_FEAT = 33; PS_GRAPH_CLEAR = 34; @@ -57,6 +57,7 @@ enum PsCmdID { PS_GRAPH_REMOVE_GRAPH_NODE = 36; PS_GRAPH_SET_NODE_FEAT = 37; PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38; + PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39; } message PsRequestMessage { diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 47b966182e6825c5032817bb37fa28d56f71a598..96ebf039aae773962d19c9b6b538cba802d81ae9 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -392,7 +392,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size, memcpy(pointer, res.data(), actual_size); return 0; } -int32_t GraphTable::random_sample_neighboors( +int32_t GraphTable::random_sample_neighbors( uint64_t *node_ids, int sample_size, std::vector> &buffers, std::vector &actual_sizes) { diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 0e2d09effeb48e94f9ef708b1716ae1c4752e36e..91f2b1c029d80505a6336aaf774a644191947523 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -89,12 +89,6 @@ struct SampleKey { } }; -struct SampleKeyHash { - size_t operator()(const SampleKey &s) const { - return s.node_key ^ s.sample_size; - } -}; - class SampleResult { public: size_t actual_size; @@ -121,13 +115,13 @@ class LRUNode { // time to live LRUNode *pre, *next; }; -template > +template class ScaledLRU; -template > +template class RandomSampleLRU { public: - RandomSampleLRU(ScaledLRU *_father) : father(_father) { + RandomSampleLRU(ScaledLRU *_father) : father(_father) { node_size = 0; node_head = node_end = NULL; global_ttl = father->ttl; @@ -229,15 +223,15 @@ class RandomSampleLRU { } private: - std::unordered_map *, Hash> key_map; - ScaledLRU *father; + std::unordered_map *> key_map; + ScaledLRU *father; size_t global_ttl; int node_size; LRUNode *node_head, *node_end; - friend class ScaledLRU; + friend class ScaledLRU; }; -template +template class ScaledLRU { public: ScaledLRU(size_t shard_num, size_t size_limit, size_t _ttl) @@ -246,8 +240,8 @@ class ScaledLRU { stop = false; thread_pool.reset(new ::ThreadPool(1)); global_count = 0; - lru_pool = std::vector>( - shard_num, RandomSampleLRU(this)); + lru_pool = std::vector>(shard_num, + RandomSampleLRU(this)); shrink_job = std::thread([this]() -> void { while (true) { { @@ -352,16 +346,16 @@ class ScaledLRU { size_t ttl; bool stop; std::thread shrink_job; - std::vector> lru_pool; + std::vector> lru_pool; mutable std::mutex mutex_; std::condition_variable cv_; struct RemovedNode { LRUNode *node; - RandomSampleLRU *lru_pointer; + RandomSampleLRU *lru_pointer; bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; } }; std::shared_ptr<::ThreadPool> thread_pool; - friend class RandomSampleLRU; + friend class RandomSampleLRU; }; class GraphTable : public SparseTable { @@ -373,7 +367,7 @@ class GraphTable : public SparseTable { int &actual_size, bool need_feature, int step); - virtual int32_t random_sample_neighboors( + virtual int32_t random_sample_neighbors( uint64_t *node_ids, int sample_size, std::vector> &buffers, std::vector &actual_sizes); @@ -433,11 +427,11 @@ class GraphTable : public SparseTable { size_t get_server_num() { return server_num; } - virtual int32_t make_neigh_sample_cache(size_t size_limit, size_t ttl) { + virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) { { std::unique_lock lock(mutex_); if (use_cache == false) { - scaled_lru.reset(new ScaledLRU( + scaled_lru.reset(new ScaledLRU( shard_end - shard_start, size_limit, ttl)); use_cache = true; } @@ -460,10 +454,20 @@ class GraphTable : public SparseTable { std::vector> _shards_task_pool; std::vector> _shards_task_rng_pool; - std::shared_ptr> scaled_lru; + std::shared_ptr> scaled_lru; bool use_cache; mutable std::mutex mutex_; }; } // namespace distributed }; // namespace paddle + +namespace std { + +template <> +struct hash { + size_t operator()(const paddle::distributed::SampleKey &s) const { + return s.node_key ^ s.sample_size; + } +}; +} diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 47dc72125756fed1ba69f49923b0c4fb76a11f47..c061fe0bb909d88eb28db3f537e9843d24266575 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -111,7 +111,7 @@ void testFeatureNodeSerializeFloat64() { void testSingleSampleNeighboor( std::shared_ptr& worker_ptr_) { std::vector>> vs; - auto pull_status = worker_ptr_->batch_sample_neighboors( + auto pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 37), 4, vs); pull_status.wait(); @@ -127,7 +127,7 @@ void testSingleSampleNeighboor( s.clear(); s1.clear(); vs.clear(); - pull_status = worker_ptr_->batch_sample_neighboors( + pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 96), 4, vs); pull_status.wait(); s1 = {111, 48, 247}; @@ -139,7 +139,7 @@ void testSingleSampleNeighboor( ASSERT_EQ(true, s1.find(g) != s1.end()); } vs.clear(); - pull_status = worker_ptr_->batch_sample_neighboors(0, {96, 37}, 4, vs, 0); + pull_status = worker_ptr_->batch_sample_neighbors(0, {96, 37}, 4, vs, 0); pull_status.wait(); ASSERT_EQ(vs.size(), 2); } @@ -199,7 +199,7 @@ void testBatchSampleNeighboor( std::shared_ptr& worker_ptr_) { std::vector>> vs; std::vector v = {37, 96}; - auto pull_status = worker_ptr_->batch_sample_neighboors(0, v, 4, vs); + auto pull_status = worker_ptr_->batch_sample_neighbors(0, v, 4, vs); pull_status.wait(); std::unordered_set s; std::unordered_set s1 = {112, 45, 145}; @@ -401,7 +401,6 @@ void RunClient( } void RunBrpcPushSparse() { - std::cout << "in test cache"; testCache(); setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); @@ -436,24 +435,24 @@ void RunBrpcPushSparse() { sleep(5); testSingleSampleNeighboor(worker_ptr_); testBatchSampleNeighboor(worker_ptr_); - pull_status = worker_ptr_->batch_sample_neighboors( + pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 10240001024), 4, vs); pull_status.wait(); ASSERT_EQ(0, vs[0].size()); paddle::distributed::GraphTable* g = (paddle::distributed::GraphTable*)pserver_ptr_->table(0); size_t ttl = 6; - g->make_neigh_sample_cache(4, ttl); + g->make_neighbor_sample_cache(4, ttl); int round = 5; while (round--) { vs.clear(); - pull_status = worker_ptr_->batch_sample_neighboors( + pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 37), 1, vs); pull_status.wait(); for (int i = 0; i < ttl; i++) { std::vector>> vs1; - pull_status = worker_ptr_->batch_sample_neighboors( + pull_status = worker_ptr_->batch_sample_neighbors( 0, std::vector(1, 37), 1, vs1); pull_status.wait(); ASSERT_EQ(vs[0].size(), vs1[0].size()); @@ -560,13 +559,13 @@ void RunBrpcPushSparse() { ASSERT_EQ(count_item_nodes.size(), 12); } - vs = client1.batch_sample_neighboors(std::string("user2item"), - std::vector(1, 96), 4); + vs = client1.batch_sample_neighbors(std::string("user2item"), + std::vector(1, 96), 4); ASSERT_EQ(vs[0].size(), 3); std::vector node_ids; node_ids.push_back(96); node_ids.push_back(37); - vs = client1.batch_sample_neighboors(std::string("user2item"), node_ids, 4); + vs = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4); ASSERT_EQ(vs.size(), 2); std::vector nodes_ids = client2.random_sample_nodes("user", 0, 6); @@ -635,8 +634,7 @@ void RunBrpcPushSparse() { void testCache() { ::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey, - ::paddle::distributed::SampleResult, - ::paddle::distributed::SampleKeyHash> + ::paddle::distributed::SampleResult> st(1, 2, 4); char* str = new char[7]; strcpy(str, "54321"); diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 0a39f529387a2581db370a53edeba7b74f6768fc..e6b8238010a35d79e99fba356345ef164b2cf341 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -205,7 +205,8 @@ void BindGraphPyClient(py::module* m) { .def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf) .def("pull_graph_list", &GraphPyClient::pull_graph_list) .def("start_client", &GraphPyClient::start_client) - .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors) + .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors) + .def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors) .def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes) .def("stop_server", &GraphPyClient::stop_server)