未验证 提交 73d0b0e9 编写于 作者: S seemingwang 提交者: GitHub

fix count problem (#32415)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 f4d9adc7
...@@ -54,19 +54,7 @@ class GraphPyService { ...@@ -54,19 +54,7 @@ class GraphPyService {
std::vector<std::string> table_feat_conf_feat_dtype; std::vector<std::string> table_feat_conf_feat_dtype;
std::vector<int32_t> table_feat_conf_feat_shape; std::vector<int32_t> table_feat_conf_feat_shape;
// std::thread *server_thread, *client_thread;
// std::shared_ptr<paddle::distributed::PSServer> pserver_ptr;
// std::shared_ptr<paddle::distributed::PSClient> worker_ptr;
public: public:
// std::shared_ptr<paddle::distributed::PSServer> get_ps_server() {
// return pserver_ptr;
// }
// std::shared_ptr<paddle::distributed::PSClient> get_ps_client() {
// return worker_ptr;
// }
int get_shard_num() { return shard_num; } int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; }
void GetDownpourSparseTableProto( void GetDownpourSparseTableProto(
......
...@@ -171,7 +171,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { ...@@ -171,7 +171,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
auto paths = paddle::string::split_string<std::string>(path, ";"); auto paths = paddle::string::split_string<std::string>(path, ";");
int count = 0; int64_t count = 0;
std::string sample_type = "random"; std::string sample_type = "random";
bool is_weighted = false; bool is_weighted = false;
int valid_count = 0; int valid_count = 0;
......
...@@ -33,26 +33,11 @@ namespace paddle { ...@@ -33,26 +33,11 @@ namespace paddle {
namespace distributed { namespace distributed {
class GraphShard { class GraphShard {
public: public:
// static int bucket_low_bound;
// static int gcd(int s, int t) {
// if (s % t == 0) return t;
// return gcd(t, s % t);
// }
size_t get_size(); size_t get_size();
GraphShard() {} GraphShard() {}
GraphShard(int shard_num) { GraphShard(int shard_num) { this->shard_num = shard_num; }
this->shard_num = shard_num;
// bucket_size = init_bucket_size(shard_num);
// bucket.resize(bucket_size);
}
std::vector<Node *> &get_bucket() { return bucket; } std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step); std::vector<Node *> get_batch(int start, int end, int step);
// int init_bucket_size(int shard_num) {
// for (int i = bucket_low_bound;; i++) {
// if (gcd(i, shard_num) == 1) return i;
// }
// return -1;
// }
std::vector<uint64_t> get_ids_by_range(int start, int end) { std::vector<uint64_t> get_ids_by_range(int start, int end) {
std::vector<uint64_t> res; std::vector<uint64_t> res;
for (int i = start; i < end && i < bucket.size(); i++) { for (int i = start; i < end && i < bucket.size(); i++) {
...@@ -64,7 +49,6 @@ class GraphShard { ...@@ -64,7 +49,6 @@ class GraphShard {
FeatureNode *add_feature_node(uint64_t id); FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id); Node *find_node(uint64_t id);
void add_neighboor(uint64_t id, uint64_t dst_id, float weight); void add_neighboor(uint64_t id, uint64_t dst_id, float weight);
// std::unordered_map<uint64_t, std::list<GraphNode *>::iterator>
std::unordered_map<uint64_t, int> get_node_location() { std::unordered_map<uint64_t, int> get_node_location() {
return node_location; return node_location;
} }
...@@ -131,7 +115,7 @@ class GraphTable : public SparseTable { ...@@ -131,7 +115,7 @@ class GraphTable : public SparseTable {
protected: protected:
std::vector<GraphShard> shards; std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num; size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
const int task_pool_size_ = 11; const int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3; const int random_sample_nodes_ranges = 3;
std::vector<std::string> feat_name; std::vector<std::string> feat_name;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册