diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 92f8304a8bf62178988fef447fcf8c309d8589ea..29bcc04d9c1dfb3f3a5d32040162c4f5c6371672 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -15,12 +15,15 @@ #include "paddle/fluid/distributed/table/common_graph_table.h" #include #include +#include #include #include #include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/table/graph/graph_node.h" +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/string_helper.h" + namespace paddle { namespace distributed { @@ -399,31 +402,34 @@ int32_t GraphTable::random_sample_neighboors( uint64_t &node_id = node_ids[idx]; std::unique_ptr &buffer = buffers[idx]; int &actual_size = actual_sizes[idx]; - tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue( - [&]() -> int { - Node *node = find_node(node_id); - if (node == nullptr) { - actual_size = 0; - return 0; - } - std::vector res = node->sample_k(sample_size); - actual_size = res.size() * (Node::id_size + Node::weight_size); - int offset = 0; - uint64_t id; - float weight; - char *buffer_addr = new char[actual_size]; - buffer.reset(buffer_addr); - for (int &x : res) { - id = node->get_neighbor_id(x); - weight = node->get_neighbor_weight(x); - memcpy(buffer_addr + offset, &id, Node::id_size); - offset += Node::id_size; - memcpy(buffer_addr + offset, &weight, Node::weight_size); - offset += Node::weight_size; - } - return 0; - })); + int thread_pool_index = get_thread_pool_index(node_id); + auto rng = _shards_task_rng_pool[thread_pool_index]; + + tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int { + Node *node = find_node(node_id); + + if (node == nullptr) { + actual_size = 0; + return 0; + } + std::vector res = node->sample_k(sample_size, rng); + actual_size = res.size() * (Node::id_size + Node::weight_size); + int offset = 0; + uint64_t id; + float weight; + char *buffer_addr = new char[actual_size]; + buffer.reset(buffer_addr); + for (int &x : res) { + id = node->get_neighbor_id(x); + weight = node->get_neighbor_weight(x); + memcpy(buffer_addr + offset, &id, Node::id_size); + offset += Node::id_size; + memcpy(buffer_addr + offset, &weight, Node::weight_size); + offset += Node::weight_size; + } + return 0; + })); } for (size_t idx = 0; idx < node_num; ++idx) { tasks[idx].get(); @@ -512,7 +518,6 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, int end = start + (count - 1) * step + 1; tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( [this, i, start, end, step, size]() -> std::vector { - return this->shards[i].get_batch(start - size, end - size, step); })); start += count * step; @@ -546,6 +551,7 @@ int32_t GraphTable::initialize() { _shards_task_pool.resize(task_pool_size_); for (size_t i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); + _shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0)); } server_num = _shard_num; // VLOG(0) << "in init graph table server num = " << server_num; @@ -586,5 +592,5 @@ int32_t GraphTable::initialize() { shards = std::vector(shard_num_per_table, GraphShard(shard_num)); return 0; } -} -}; +} // namespace distributed +}; // namespace paddle diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 5eeb3915f5b1f251dd5edf1a5199621a7cd0069b..6ccce44c7ead6983efb57718999f1b36499b34e8 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -136,6 +136,7 @@ class GraphTable : public SparseTable { std::string table_type; std::vector> _shards_task_pool; + std::vector> _shards_task_rng_pool; }; } // namespace distributed diff --git a/paddle/fluid/distributed/table/graph/graph_node.cc b/paddle/fluid/distributed/table/graph/graph_node.cc index 816d31b979072c3f1679df1ea75cd9dc75c55b0a..e2311cc307b6057937408c94c0093f3af1f0882e 100644 --- a/paddle/fluid/distributed/table/graph/graph_node.cc +++ b/paddle/fluid/distributed/table/graph/graph_node.cc @@ -113,5 +113,5 @@ void FeatureNode::recover_from_buffer(char* buffer) { feature.push_back(std::string(str)); } } -} -} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/graph/graph_node.h b/paddle/fluid/distributed/table/graph/graph_node.h index 8ad795ac97b5499c7b10361760f7ac16494c154b..62c101ec02a935b4f29948c1e8c53823592e8fdf 100644 --- a/paddle/fluid/distributed/table/graph/graph_node.h +++ b/paddle/fluid/distributed/table/graph/graph_node.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include #include #include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h" @@ -33,7 +34,10 @@ class Node { virtual void build_edges(bool is_weighted) {} virtual void build_sampler(std::string sample_type) {} virtual void add_edge(uint64_t id, float weight) {} - virtual std::vector sample_k(int k) { return std::vector(); } + virtual std::vector sample_k( + int k, const std::shared_ptr rng) { + return std::vector(); + } virtual uint64_t get_neighbor_id(int idx) { return 0; } virtual float get_neighbor_weight(int idx) { return 1.; } @@ -59,7 +63,10 @@ class GraphNode : public Node { virtual void add_edge(uint64_t id, float weight) { edges->add_edge(id, weight); } - virtual std::vector sample_k(int k) { return sampler->sample_k(k); } + virtual std::vector sample_k( + int k, const std::shared_ptr rng) { + return sampler->sample_k(k, rng); + } virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); } virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); } @@ -123,5 +130,5 @@ class FeatureNode : public Node { protected: std::vector feature; }; -} -} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc index 3a680875e3df4a9cd60f8fe1921b877dbb23c8a2..7a46433e3defbd51b68ed9f25e9e92f64b6d1afa 100644 --- a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc +++ b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc @@ -14,24 +14,30 @@ #include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h" #include +#include #include +#include "paddle/fluid/framework/generator.h" namespace paddle { namespace distributed { void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; } -std::vector RandomSampler::sample_k(int k) { +std::vector RandomSampler::sample_k( + int k, const std::shared_ptr rng) { int n = edges->size(); - if (k > n) { + if (k >= n) { k = n; + std::vector sample_result; + for (int i = 0; i < k; i++) { + sample_result.push_back(i); + } + return sample_result; } - struct timespec tn; - clock_gettime(CLOCK_REALTIME, &tn); - srand(tn.tv_nsec); std::vector sample_result; std::unordered_map replace_map; while (k--) { - int rand_int = rand() % n; + std::uniform_int_distribution distrib(0, n - 1); + int rand_int = distrib(*rng); auto iter = replace_map.find(rand_int); if (iter == replace_map.end()) { sample_result.push_back(rand_int); @@ -98,19 +104,23 @@ void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start, count = left->count + right->count; } } -std::vector WeightedSampler::sample_k(int k) { - if (k > count) { +std::vector WeightedSampler::sample_k( + int k, const std::shared_ptr rng) { + if (k >= count) { k = count; + std::vector sample_result; + for (int i = 0; i < k; i++) { + sample_result.push_back(i); + } + return sample_result; } std::vector sample_result; float subtract; std::unordered_map subtract_weight_map; std::unordered_map subtract_count_map; - struct timespec tn; - clock_gettime(CLOCK_REALTIME, &tn); - srand(tn.tv_nsec); + std::uniform_real_distribution distrib(0, 1.0); while (k--) { - float query_weight = rand() % 100000 / 100000.0; + float query_weight = distrib(*rng); query_weight *= weight - subtract_weight_map[this]; sample_result.push_back(sample(query_weight, subtract_weight_map, subtract_count_map, subtract)); @@ -146,5 +156,5 @@ int WeightedSampler::sample( subtract_count_map[this]++; return return_idx; } -} -} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h index 1787ab23b04316de9ad0622ff5524bc88bd51fe1..4a75a112697d322a2eb49a57d379889d34b6009f 100644 --- a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h +++ b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h @@ -14,6 +14,8 @@ #pragma once #include +#include +#include #include #include #include "paddle/fluid/distributed/table/graph/graph_edge.h" @@ -24,14 +26,16 @@ class Sampler { public: virtual ~Sampler() {} virtual void build(GraphEdgeBlob *edges) = 0; - virtual std::vector sample_k(int k) = 0; + virtual std::vector sample_k( + int k, const std::shared_ptr rng) = 0; }; class RandomSampler : public Sampler { public: virtual ~RandomSampler() {} virtual void build(GraphEdgeBlob *edges); - virtual std::vector sample_k(int k); + virtual std::vector sample_k(int k, + const std::shared_ptr rng); GraphEdgeBlob *edges; }; @@ -46,7 +50,8 @@ class WeightedSampler : public Sampler { GraphEdgeBlob *edges; virtual void build(GraphEdgeBlob *edges); virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end); - virtual std::vector sample_k(int k); + virtual std::vector sample_k(int k, + const std::shared_ptr rng); private: int sample(float query_weight, @@ -54,5 +59,5 @@ class WeightedSampler : public Sampler { std::unordered_map &subtract_count_map, float &subtract); }; -} -} +} // namespace distributed +} // namespace paddle