未验证 提交 4935b8e7 编写于 作者: S seemingwang 提交者: GitHub

move graph files (#32103)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 e09f4db9
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "butil/endpoint.h" #include "butil/endpoint.h"
#include "iomanip" #include "iomanip"
#include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -265,7 +264,8 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, ...@@ -265,7 +264,8 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
int step = *(int *)(request.params(2).c_str()); int step = *(int *)(request.params(2).c_str());
std::unique_ptr<char[]> buffer; std::unique_ptr<char[]> buffer;
int actual_size; int actual_size;
table->pull_graph_list(start, size, buffer, actual_size, false, step); ((GraphTable *)table)
->pull_graph_list(start, size, buffer, actual_size, false, step);
cntl->response_attachment().append(buffer.get(), actual_size); cntl->response_attachment().append(buffer.get(), actual_size);
return 0; return 0;
} }
...@@ -284,8 +284,8 @@ int32_t GraphBrpcService::graph_random_sample_neighboors( ...@@ -284,8 +284,8 @@ int32_t GraphBrpcService::graph_random_sample_neighboors(
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(uint64_t *)(request.params(1).c_str());
std::vector<std::unique_ptr<char[]>> buffers(node_num); std::vector<std::unique_ptr<char[]>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0); std::vector<int> actual_sizes(node_num, 0);
table->random_sample_neighboors(node_data, sample_size, buffers, ((GraphTable *)table)
actual_sizes); ->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes);
cntl->response_attachment().append(&node_num, sizeof(size_t)); cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(), cntl->response_attachment().append(actual_sizes.data(),
...@@ -301,7 +301,8 @@ int32_t GraphBrpcService::graph_random_sample_nodes( ...@@ -301,7 +301,8 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
size_t size = *(uint64_t *)(request.params(0).c_str()); size_t size = *(uint64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer; std::unique_ptr<char[]> buffer;
int actual_size; int actual_size;
if (table->random_sample_nodes(size, buffer, actual_size) == 0) { if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) ==
0) {
cntl->response_attachment().append(buffer.get(), actual_size); cntl->response_attachment().append(buffer.get(), actual_size);
} else } else
cntl->response_attachment().append(NULL, 0); cntl->response_attachment().append(NULL, 0);
...@@ -330,7 +331,7 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, ...@@ -330,7 +331,7 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
std::vector<std::vector<std::string>> feature( std::vector<std::vector<std::string>> feature(
feature_names.size(), std::vector<std::string>(node_num)); feature_names.size(), std::vector<std::string>(node_num));
table->get_node_feat(node_ids, feature_names, feature); ((GraphTable *)table)->get_node_feat(node_ids, feature_names, feature);
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
#include <vector> #include <vector>
#include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/server.h" #include "paddle/fluid/distributed/service/server.h"
#include "paddle/fluid/distributed/table/common_graph_table.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class GraphBrpcServer : public PSServer { class GraphBrpcServer : public PSServer {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "paddle/fluid/distributed/service/env.h" #include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/table/accessor.h" #include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/graph_node.h" #include "paddle/fluid/distributed/table/graph/graph_node.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
set_property(GLOBAL PROPERTY TABLE_DEPS string_helper) set_property(GLOBAL PROPERTY TABLE_DEPS string_helper)
set(graphDir graph)
get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS) get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS)
set_source_files_properties(graph_edge.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(${graphDir}/graph_edge.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_edge SRCS graph_edge.cc) cc_library(graph_edge SRCS ${graphDir}/graph_edge.cc)
set_source_files_properties(graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(${graphDir}/graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(WeightedSampler SRCS graph_weighted_sampler.cc DEPS graph_edge) cc_library(WeightedSampler SRCS ${graphDir}/graph_weighted_sampler.cc DEPS graph_edge)
set_source_files_properties(graph_node.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_node SRCS graph_node.cc DEPS WeightedSampler) cc_library(graph_node SRCS ${graphDir}/graph_node.cc DEPS WeightedSampler)
set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <set> #include <set>
#include <sstream> #include <sstream>
#include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/graph_node.h" #include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/distributed/table/accessor.h" #include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/common_table.h" #include "paddle/fluid/distributed/table/common_table.h"
#include "paddle/fluid/distributed/table/graph_node.h" #include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph/graph_edge.h"
#include <cstring>
namespace paddle {
namespace distributed {
void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) {
id_arr.push_back(id);
}
void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) {
id_arr.push_back(id);
weight_arr.push_back(weight);
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <cstdint>
#include <vector>
namespace paddle {
namespace distributed {
class GraphEdgeBlob {
public:
GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {}
size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight);
uint64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; }
protected:
std::vector<uint64_t> id_arr;
};
class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public:
WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight);
virtual float get_weight(int idx) { return weight_arr[idx]; }
protected:
std::vector<float> weight_arr;
};
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include <cstring>
namespace paddle {
namespace distributed {
GraphNode::~GraphNode() {
if (sampler != nullptr) {
delete sampler;
sampler = nullptr;
}
if (edges != nullptr) {
delete edges;
edges = nullptr;
}
}
int Node::weight_size = sizeof(float);
int Node::id_size = sizeof(uint64_t);
int Node::int_size = sizeof(int);
int Node::get_size(bool need_feature) { return id_size + int_size; }
void Node::to_buffer(char* buffer, bool need_feature) {
memcpy(buffer, &id, id_size);
buffer += id_size;
int feat_num = 0;
memcpy(buffer, &feat_num, sizeof(int));
}
void Node::recover_from_buffer(char* buffer) { memcpy(&id, buffer, id_size); }
int FeatureNode::get_size(bool need_feature) {
int size = id_size + int_size; // id, feat_num
if (need_feature) {
size += feature.size() * int_size;
for (const std::string& fea : feature) {
size += fea.size();
}
}
return size;
}
void GraphNode::build_edges(bool is_weighted) {
if (edges == nullptr) {
if (is_weighted == true) {
edges = new WeightedGraphEdgeBlob();
} else {
edges = new GraphEdgeBlob();
}
}
}
void GraphNode::build_sampler(std::string sample_type) {
if (sample_type == "random") {
sampler = new RandomSampler();
} else if (sample_type == "weighted") {
sampler = new WeightedSampler();
}
sampler->build(edges);
}
void FeatureNode::to_buffer(char* buffer, bool need_feature) {
memcpy(buffer, &id, id_size);
buffer += id_size;
int feat_num = 0;
int feat_len;
if (need_feature) {
feat_num += feature.size();
memcpy(buffer, &feat_num, sizeof(int));
buffer += sizeof(int);
for (int i = 0; i < feat_num; ++i) {
feat_len = feature[i].size();
memcpy(buffer, &feat_len, sizeof(int));
buffer += sizeof(int);
memcpy(buffer, feature[i].c_str(), feature[i].size());
buffer += feature[i].size();
}
} else {
memcpy(buffer, &feat_num, sizeof(int));
}
}
void FeatureNode::recover_from_buffer(char* buffer) {
int feat_num, feat_len;
memcpy(&id, buffer, id_size);
buffer += id_size;
memcpy(&feat_num, buffer, sizeof(int));
buffer += sizeof(int);
feature.clear();
for (int i = 0; i < feat_num; ++i) {
memcpy(&feat_len, buffer, sizeof(int));
buffer += sizeof(int);
char str[feat_len + 1];
memcpy(str, buffer, feat_len);
buffer += feat_len;
str[feat_len] = '\0';
feature.push_back(std::string(str));
}
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <iostream>
#include <sstream>
#include <vector>
#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
namespace paddle {
namespace distributed {
class Node {
public:
Node() {}
Node(uint64_t id) : id(id) {}
virtual ~Node() {}
static int id_size, int_size, weight_size;
uint64_t get_id() { return id; }
void set_id(uint64_t id) { this->id = id; }
virtual void build_edges(bool is_weighted) {}
virtual void build_sampler(std::string sample_type) {}
virtual void add_edge(uint64_t id, float weight) {}
virtual std::vector<int> sample_k(int k) { return std::vector<int>(); }
virtual uint64_t get_neighbor_id(int idx) { return 0; }
virtual float get_neighbor_weight(int idx) { return 1.; }
virtual int get_size(bool need_feature);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual std::string get_feature(int idx) { return std::string(""); }
virtual void set_feature(int idx, std::string str) {}
virtual void set_feature_size(int size) {}
virtual int get_feature_size() { return 0; }
protected:
uint64_t id;
};
class GraphNode : public Node {
public:
GraphNode() : Node(), sampler(nullptr), edges(nullptr) {}
GraphNode(uint64_t id) : Node(id), sampler(nullptr), edges(nullptr) {}
virtual ~GraphNode();
virtual void build_edges(bool is_weighted);
virtual void build_sampler(std::string sample_type);
virtual void add_edge(uint64_t id, float weight) {
edges->add_edge(id, weight);
}
virtual std::vector<int> sample_k(int k) { return sampler->sample_k(k); }
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
protected:
Sampler *sampler;
GraphEdgeBlob *edges;
};
class FeatureNode : public Node {
public:
FeatureNode() : Node() {}
FeatureNode(uint64_t id) : Node(id) {}
virtual ~FeatureNode() {}
virtual int get_size(bool need_feature);
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual std::string get_feature(int idx) {
if (idx < (int)this->feature.size()) {
return this->feature[idx];
} else {
return std::string("");
}
}
virtual void set_feature(int idx, std::string str) {
if (idx >= (int)this->feature.size()) {
this->feature.resize(idx + 1);
}
this->feature[idx] = str;
}
virtual void set_feature_size(int size) { this->feature.resize(size); }
virtual int get_feature_size() { return this->feature.size(); }
template <typename T>
static std::string parse_value_to_bytes(std::vector<std::string> feat_str) {
T v;
size_t Tsize = sizeof(T) * feat_str.size();
char buffer[Tsize];
for (size_t i = 0; i < feat_str.size(); i++) {
std::stringstream ss(feat_str[i]);
ss >> v;
std::memcpy(buffer + sizeof(T) * i, (char *)&v, sizeof(T));
}
return std::string(buffer, Tsize);
}
template <typename T>
static std::vector<T> parse_bytes_to_array(std::string feat_str) {
T v;
std::vector<T> out;
size_t start = 0;
const char *buffer = feat_str.data();
while (start < feat_str.size()) {
std::memcpy((char *)&v, buffer + start, sizeof(T));
start += sizeof(T);
out.push_back(v);
}
return out;
}
protected:
std::vector<std::string> feature;
};
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h"
#include <iostream>
#include <unordered_map>
namespace paddle {
namespace distributed {
void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; }
std::vector<int> RandomSampler::sample_k(int k) {
int n = edges->size();
if (k > n) {
k = n;
}
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
std::vector<int> sample_result;
std::unordered_map<int, int> replace_map;
while (k--) {
int rand_int = rand() % n;
auto iter = replace_map.find(rand_int);
if (iter == replace_map.end()) {
sample_result.push_back(rand_int);
} else {
sample_result.push_back(iter->second);
}
iter = replace_map.find(n - 1);
if (iter == replace_map.end()) {
replace_map[rand_int] = n - 1;
} else {
replace_map[rand_int] = iter->second;
}
--n;
}
return sample_result;
}
WeightedSampler::WeightedSampler() {
left = nullptr;
right = nullptr;
edges = nullptr;
}
WeightedSampler::~WeightedSampler() {
if (left != nullptr) {
delete left;
left = nullptr;
}
if (right != nullptr) {
delete right;
right = nullptr;
}
}
void WeightedSampler::build(GraphEdgeBlob *edges) {
if (left != nullptr) {
delete left;
left = nullptr;
}
if (right != nullptr) {
delete right;
right = nullptr;
}
return build_one((WeightedGraphEdgeBlob *)edges, 0, edges->size());
}
void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start,
int end) {
count = 0;
this->edges = edges;
if (start + 1 == end) {
left = right = nullptr;
idx = start;
count = 1;
weight = edges->get_weight(idx);
} else {
left = new WeightedSampler();
right = new WeightedSampler();
left->build_one(edges, start, start + (end - start) / 2);
right->build_one(edges, start + (end - start) / 2, end);
weight = left->weight + right->weight;
count = left->count + right->count;
}
}
std::vector<int> WeightedSampler::sample_k(int k) {
if (k > count) {
k = count;
}
std::vector<int> sample_result;
float subtract;
std::unordered_map<WeightedSampler *, float> subtract_weight_map;
std::unordered_map<WeightedSampler *, int> subtract_count_map;
struct timespec tn;
clock_gettime(CLOCK_REALTIME, &tn);
srand(tn.tv_nsec);
while (k--) {
float query_weight = rand() % 100000 / 100000.0;
query_weight *= weight - subtract_weight_map[this];
sample_result.push_back(sample(query_weight, subtract_weight_map,
subtract_count_map, subtract));
}
return sample_result;
}
int WeightedSampler::sample(
float query_weight,
std::unordered_map<WeightedSampler *, float> &subtract_weight_map,
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract) {
if (left == nullptr) {
subtract_weight_map[this] = weight;
subtract = weight;
subtract_count_map[this] = 1;
return idx;
}
int left_count = left->count - subtract_count_map[left];
int right_count = right->count - subtract_count_map[right];
float left_subtract = subtract_weight_map[left];
int return_idx;
if (right_count == 0 ||
left_count > 0 && left->weight - left_subtract >= query_weight) {
return_idx = left->sample(query_weight, subtract_weight_map,
subtract_count_map, subtract);
} else {
return_idx =
right->sample(query_weight - (left->weight - left_subtract),
subtract_weight_map, subtract_count_map, subtract);
}
subtract_weight_map[this] += subtract;
subtract_count_map[this]++;
return return_idx;
}
}
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ctime>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/table/graph/graph_edge.h"
namespace paddle {
namespace distributed {
class Sampler {
public:
virtual ~Sampler() {}
virtual void build(GraphEdgeBlob *edges) = 0;
virtual std::vector<int> sample_k(int k) = 0;
};
class RandomSampler : public Sampler {
public:
virtual ~RandomSampler() {}
virtual void build(GraphEdgeBlob *edges);
virtual std::vector<int> sample_k(int k);
GraphEdgeBlob *edges;
};
class WeightedSampler : public Sampler {
public:
WeightedSampler();
virtual ~WeightedSampler();
WeightedSampler *left, *right;
float weight;
int count;
int idx;
GraphEdgeBlob *edges;
virtual void build(GraphEdgeBlob *edges);
virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end);
virtual std::vector<int> sample_k(int k);
private:
int sample(float query_weight,
std::unordered_map<WeightedSampler *, float> &subtract_weight_map,
std::unordered_map<WeightedSampler *, int> &subtract_count_map,
float &subtract);
};
}
}
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <utility> #include <utility>
#include "paddle/fluid/distributed/table/accessor.h" #include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/depends/sparse_utils.h" #include "paddle/fluid/distributed/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/table/graph_node.h" #include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -88,31 +88,6 @@ class Table { ...@@ -88,31 +88,6 @@ class Table {
return 0; return 0;
} }
// only for graph table
virtual int32_t pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature,
int step = 1) {
return 0;
}
// only for graph table
virtual int32_t random_sample_neighboors(
uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers,
std::vector<int> &actual_sizes) {
return 0;
}
virtual int32_t random_sample_nodes(int sample_size,
std::unique_ptr<char[]> &buffers,
int &actual_sizes) {
return 0;
}
virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
return 0;
}
virtual int32_t pour() { return 0; } virtual int32_t pour() { return 0; }
virtual void clear() = 0; virtual void clear() = 0;
......
...@@ -33,7 +33,7 @@ limitations under the License. */ ...@@ -33,7 +33,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/service/ps_client.h" #include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/service/service.h" #include "paddle/fluid/distributed/service/service.h"
#include "paddle/fluid/distributed/table/graph_node.h" #include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册