未验证 提交 9e32a387 编写于 作者: S seemingwang 提交者: GitHub

speed up random sample of graph engine (#34088)

上级 75fc32e2
...@@ -15,12 +15,15 @@ ...@@ -15,12 +15,15 @@
#include "paddle/fluid/distributed/table/common_graph_table.h" #include "paddle/fluid/distributed/table/common_graph_table.h"
#include <time.h> #include <time.h>
#include <algorithm> #include <algorithm>
#include <chrono>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h" #include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -399,15 +402,18 @@ int32_t GraphTable::random_sample_neighboors( ...@@ -399,15 +402,18 @@ int32_t GraphTable::random_sample_neighboors(
uint64_t &node_id = node_ids[idx]; uint64_t &node_id = node_ids[idx];
std::unique_ptr<char[]> &buffer = buffers[idx]; std::unique_ptr<char[]> &buffer = buffers[idx];
int &actual_size = actual_sizes[idx]; int &actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&]() -> int { int thread_pool_index = get_thread_pool_index(node_id);
auto rng = _shards_task_rng_pool[thread_pool_index];
tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int {
Node *node = find_node(node_id); Node *node = find_node(node_id);
if (node == nullptr) { if (node == nullptr) {
actual_size = 0; actual_size = 0;
return 0; return 0;
} }
std::vector<int> res = node->sample_k(sample_size); std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size); actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0; int offset = 0;
uint64_t id; uint64_t id;
...@@ -512,7 +518,6 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, ...@@ -512,7 +518,6 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int end = start + (count - 1) * step + 1; int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, end, step, size]() -> std::vector<Node *> { [this, i, start, end, step, size]() -> std::vector<Node *> {
return this->shards[i].get_batch(start - size, end - size, step); return this->shards[i].get_batch(start - size, end - size, step);
})); }));
start += count * step; start += count * step;
...@@ -546,6 +551,7 @@ int32_t GraphTable::initialize() { ...@@ -546,6 +551,7 @@ int32_t GraphTable::initialize() {
_shards_task_pool.resize(task_pool_size_); _shards_task_pool.resize(task_pool_size_);
for (size_t i = 0; i < _shards_task_pool.size(); ++i) { for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
_shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0));
} }
server_num = _shard_num; server_num = _shard_num;
// VLOG(0) << "in init graph table server num = " << server_num; // VLOG(0) << "in init graph table server num = " << server_num;
...@@ -586,5 +592,5 @@ int32_t GraphTable::initialize() { ...@@ -586,5 +592,5 @@ int32_t GraphTable::initialize() {
shards = std::vector<GraphShard>(shard_num_per_table, GraphShard(shard_num)); shards = std::vector<GraphShard>(shard_num_per_table, GraphShard(shard_num));
return 0; return 0;
} }
} } // namespace distributed
}; }; // namespace paddle
...@@ -136,6 +136,7 @@ class GraphTable : public SparseTable { ...@@ -136,6 +136,7 @@ class GraphTable : public SparseTable {
std::string table_type; std::string table_type;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -113,5 +113,5 @@ void FeatureNode::recover_from_buffer(char* buffer) { ...@@ -113,5 +113,5 @@ void FeatureNode::recover_from_buffer(char* buffer) {
feature.push_back(std::string(str)); feature.push_back(std::string(str));
} }
} }
} } // namespace distributed
} } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <memory>
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h" #include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
...@@ -33,7 +34,10 @@ class Node { ...@@ -33,7 +34,10 @@ class Node {
virtual void build_edges(bool is_weighted) {} virtual void build_edges(bool is_weighted) {}
virtual void build_sampler(std::string sample_type) {} virtual void build_sampler(std::string sample_type) {}
virtual void add_edge(uint64_t id, float weight) {} virtual void add_edge(uint64_t id, float weight) {}
virtual std::vector<int> sample_k(int k) { return std::vector<int>(); } virtual std::vector<int> sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) {
return std::vector<int>();
}
virtual uint64_t get_neighbor_id(int idx) { return 0; } virtual uint64_t get_neighbor_id(int idx) { return 0; }
virtual float get_neighbor_weight(int idx) { return 1.; } virtual float get_neighbor_weight(int idx) { return 1.; }
...@@ -59,7 +63,10 @@ class GraphNode : public Node { ...@@ -59,7 +63,10 @@ class GraphNode : public Node {
virtual void add_edge(uint64_t id, float weight) { virtual void add_edge(uint64_t id, float weight) {
edges->add_edge(id, weight); edges->add_edge(id, weight);
} }
virtual std::vector<int> sample_k(int k) { return sampler->sample_k(k); } virtual std::vector<int> sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) {
return sampler->sample_k(k, rng);
}
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); } virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); } virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
...@@ -123,5 +130,5 @@ class FeatureNode : public Node { ...@@ -123,5 +130,5 @@ class FeatureNode : public Node {
protected: protected:
std::vector<std::string> feature; std::vector<std::string> feature;
}; };
} } // namespace distributed
} } // namespace paddle
...@@ -14,24 +14,30 @@ ...@@ -14,24 +14,30 @@
#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h" #include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
#include <iostream> #include <iostream>
#include <memory>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/generator.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; } void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; }
std::vector<int> RandomSampler::sample_k(int k) { std::vector<int> RandomSampler::sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) {
int n = edges->size(); int n = edges->size();
if (k > n) { if (k >= n) {
k = n; k = n;
std::vector<int> sample_result;
for (int i = 0; i < k; i++) {
sample_result.push_back(i);
}
return sample_result;
} }
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::vector<int> sample_result; std::vector<int> sample_result;
std::unordered_map<int, int> replace_map; std::unordered_map<int, int> replace_map;
while (k--) { while (k--) {
int rand_int = rand() % n; std::uniform_int_distribution<int> distrib(0, n - 1);
int rand_int = distrib(*rng);
auto iter = replace_map.find(rand_int); auto iter = replace_map.find(rand_int);
if (iter == replace_map.end()) { if (iter == replace_map.end()) {
sample_result.push_back(rand_int); sample_result.push_back(rand_int);
...@@ -98,19 +104,23 @@ void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start, ...@@ -98,19 +104,23 @@ void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start,
count = left->count + right->count; count = left->count + right->count;
} }
} }
std::vector<int> WeightedSampler::sample_k(int k) { std::vector<int> WeightedSampler::sample_k(
if (k > count) { int k, const std::shared_ptr<std::mt19937_64> rng) {
if (k >= count) {
k = count; k = count;
std::vector<int> sample_result;
for (int i = 0; i < k; i++) {
sample_result.push_back(i);
}
return sample_result;
} }
std::vector<int> sample_result; std::vector<int> sample_result;
float subtract; float subtract;
std::unordered_map<WeightedSampler *, float> subtract_weight_map; std::unordered_map<WeightedSampler *, float> subtract_weight_map;
std::unordered_map<WeightedSampler *, int> subtract_count_map; std::unordered_map<WeightedSampler *, int> subtract_count_map;
struct timespec tn; std::uniform_real_distribution<float> distrib(0, 1.0);
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
while (k--) { while (k--) {
float query_weight = rand() % 100000 / 100000.0; float query_weight = distrib(*rng);
query_weight *= weight - subtract_weight_map[this]; query_weight *= weight - subtract_weight_map[this];
sample_result.push_back(sample(query_weight, subtract_weight_map, sample_result.push_back(sample(query_weight, subtract_weight_map,
subtract_count_map, subtract)); subtract_count_map, subtract));
...@@ -146,5 +156,5 @@ int WeightedSampler::sample( ...@@ -146,5 +156,5 @@ int WeightedSampler::sample(
subtract_count_map[this]++; subtract_count_map[this]++;
return return_idx; return return_idx;
} }
} } // namespace distributed
} } // namespace paddle
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <ctime> #include <ctime>
#include <memory>
#include <random>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/table/graph/graph_edge.h" #include "paddle/fluid/distributed/table/graph/graph_edge.h"
...@@ -24,14 +26,16 @@ class Sampler { ...@@ -24,14 +26,16 @@ class Sampler {
public: public:
virtual ~Sampler() {} virtual ~Sampler() {}
virtual void build(GraphEdgeBlob *edges) = 0; virtual void build(GraphEdgeBlob *edges) = 0;
virtual std::vector<int> sample_k(int k) = 0; virtual std::vector<int> sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) = 0;
}; };
class RandomSampler : public Sampler { class RandomSampler : public Sampler {
public: public:
virtual ~RandomSampler() {} virtual ~RandomSampler() {}
virtual void build(GraphEdgeBlob *edges); virtual void build(GraphEdgeBlob *edges);
virtual std::vector<int> sample_k(int k); virtual std::vector<int> sample_k(int k,
const std::shared_ptr<std::mt19937_64> rng);
GraphEdgeBlob *edges; GraphEdgeBlob *edges;
}; };
...@@ -46,7 +50,8 @@ class WeightedSampler : public Sampler { ...@@ -46,7 +50,8 @@ class WeightedSampler : public Sampler {
GraphEdgeBlob *edges; GraphEdgeBlob *edges;
virtual void build(GraphEdgeBlob *edges); virtual void build(GraphEdgeBlob *edges);
virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end); virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end);
virtual std::vector<int> sample_k(int k); virtual std::vector<int> sample_k(int k,
const std::shared_ptr<std::mt19937_64> rng);
private: private:
int sample(float query_weight, int sample(float query_weight,
...@@ -54,5 +59,5 @@ class WeightedSampler : public Sampler { ...@@ -54,5 +59,5 @@ class WeightedSampler : public Sampler {
std::unordered_map<WeightedSampler *, int> &subtract_count_map, std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract); float &subtract);
}; };
} } // namespace distributed
} } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册