未验证 提交 9e32a387 编写于 作者: S seemingwang 提交者: GitHub

speed up random sample of graph engine (#34088)

上级 75fc32e2
......@@ -15,12 +15,15 @@
#include "paddle/fluid/distributed/table/common_graph_table.h"
#include <time.h>
#include <algorithm>
#include <chrono>
#include <set>
#include <sstream>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
......@@ -399,15 +402,18 @@ int32_t GraphTable::random_sample_neighboors(
uint64_t &node_id = node_ids[idx];
std::unique_ptr<char[]> &buffer = buffers[idx];
int &actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&]() -> int {
int thread_pool_index = get_thread_pool_index(node_id);
auto rng = _shards_task_rng_pool[thread_pool_index];
tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int {
Node *node = find_node(node_id);
if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size);
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
......@@ -512,7 +518,6 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, end, step, size]() -> std::vector<Node *> {
return this->shards[i].get_batch(start - size, end - size, step);
}));
start += count * step;
......@@ -546,6 +551,7 @@ int32_t GraphTable::initialize() {
_shards_task_pool.resize(task_pool_size_);
for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
_shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0));
}
server_num = _shard_num;
// VLOG(0) << "in init graph table server num = " << server_num;
......@@ -586,5 +592,5 @@ int32_t GraphTable::initialize() {
shards = std::vector<GraphShard>(shard_num_per_table, GraphShard(shard_num));
return 0;
}
}
};
} // namespace distributed
}; // namespace paddle
......@@ -136,6 +136,7 @@ class GraphTable : public SparseTable {
std::string table_type;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
};
} // namespace distributed
......
......@@ -113,5 +113,5 @@ void FeatureNode::recover_from_buffer(char* buffer) {
feature.push_back(std::string(str));
}
}
}
}
} // namespace distributed
} // namespace paddle
......@@ -15,6 +15,7 @@
#pragma once
#include <cstring>
#include <iostream>
#include <memory>
#include <sstream>
#include <vector>
#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
......@@ -33,7 +34,10 @@ class Node {
virtual void build_edges(bool is_weighted) {}
virtual void build_sampler(std::string sample_type) {}
virtual void add_edge(uint64_t id, float weight) {}
virtual std::vector<int> sample_k(int k) { return std::vector<int>(); }
virtual std::vector<int> sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) {
return std::vector<int>();
}
virtual uint64_t get_neighbor_id(int idx) { return 0; }
virtual float get_neighbor_weight(int idx) { return 1.; }
......@@ -59,7 +63,10 @@ class GraphNode : public Node {
virtual void add_edge(uint64_t id, float weight) {
edges->add_edge(id, weight);
}
virtual std::vector<int> sample_k(int k) { return sampler->sample_k(k); }
virtual std::vector<int> sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) {
return sampler->sample_k(k, rng);
}
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
......@@ -123,5 +130,5 @@ class FeatureNode : public Node {
protected:
std::vector<std::string> feature;
};
}
}
} // namespace distributed
} // namespace paddle
......@@ -14,24 +14,30 @@
#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
#include <iostream>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/generator.h"
namespace paddle {
namespace distributed {
void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; }
std::vector<int> RandomSampler::sample_k(int k) {
std::vector<int> RandomSampler::sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) {
int n = edges->size();
if (k > n) {
if (k >= n) {
k = n;
std::vector<int> sample_result;
for (int i = 0; i < k; i++) {
sample_result.push_back(i);
}
return sample_result;
}
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::vector<int> sample_result;
std::unordered_map<int, int> replace_map;
while (k--) {
int rand_int = rand() % n;
std::uniform_int_distribution<int> distrib(0, n - 1);
int rand_int = distrib(*rng);
auto iter = replace_map.find(rand_int);
if (iter == replace_map.end()) {
sample_result.push_back(rand_int);
......@@ -98,19 +104,23 @@ void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start,
count = left->count + right->count;
}
}
std::vector<int> WeightedSampler::sample_k(int k) {
if (k > count) {
std::vector<int> WeightedSampler::sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) {
if (k >= count) {
k = count;
std::vector<int> sample_result;
for (int i = 0; i < k; i++) {
sample_result.push_back(i);
}
return sample_result;
}
std::vector<int> sample_result;
float subtract;
std::unordered_map<WeightedSampler *, float> subtract_weight_map;
std::unordered_map<WeightedSampler *, int> subtract_count_map;
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::uniform_real_distribution<float> distrib(0, 1.0);
while (k--) {
float query_weight = rand() % 100000 / 100000.0;
float query_weight = distrib(*rng);
query_weight *= weight - subtract_weight_map[this];
sample_result.push_back(sample(query_weight, subtract_weight_map,
subtract_count_map, subtract));
......@@ -146,5 +156,5 @@ int WeightedSampler::sample(
subtract_count_map[this]++;
return return_idx;
}
}
}
} // namespace distributed
} // namespace paddle
......@@ -14,6 +14,8 @@
#pragma once
#include <ctime>
#include <memory>
#include <random>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/table/graph/graph_edge.h"
......@@ -24,14 +26,16 @@ class Sampler {
public:
virtual ~Sampler() {}
virtual void build(GraphEdgeBlob *edges) = 0;
virtual std::vector<int> sample_k(int k) = 0;
virtual std::vector<int> sample_k(
int k, const std::shared_ptr<std::mt19937_64> rng) = 0;
};
class RandomSampler : public Sampler {
public:
virtual ~RandomSampler() {}
virtual void build(GraphEdgeBlob *edges);
virtual std::vector<int> sample_k(int k);
virtual std::vector<int> sample_k(int k,
const std::shared_ptr<std::mt19937_64> rng);
GraphEdgeBlob *edges;
};
......@@ -46,7 +50,8 @@ class WeightedSampler : public Sampler {
GraphEdgeBlob *edges;
virtual void build(GraphEdgeBlob *edges);
virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end);
virtual std::vector<int> sample_k(int k);
virtual std::vector<int> sample_k(int k,
const std::shared_ptr<std::mt19937_64> rng);
private:
int sample(float query_weight,
......@@ -54,5 +59,5 @@ class WeightedSampler : public Sampler {
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract);
};
}
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册