diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 4f6cc1143e96b65fb7244041b6fc6b88a429160f..bdd926278b624b9e9bfdf19a4f293784bef6e28f 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -19,7 +19,6 @@ #include "butil/endpoint.h" #include "iomanip" #include "paddle/fluid/distributed/service/brpc_ps_client.h" -#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -265,7 +264,8 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, int step = *(int *)(request.params(2).c_str()); std::unique_ptr buffer; int actual_size; - table->pull_graph_list(start, size, buffer, actual_size, false, step); + ((GraphTable *)table) + ->pull_graph_list(start, size, buffer, actual_size, false, step); cntl->response_attachment().append(buffer.get(), actual_size); return 0; } @@ -284,8 +284,8 @@ int32_t GraphBrpcService::graph_random_sample_neighboors( int sample_size = *(uint64_t *)(request.params(1).c_str()); std::vector> buffers(node_num); std::vector actual_sizes(node_num, 0); - table->random_sample_neighboors(node_data, sample_size, buffers, - actual_sizes); + ((GraphTable *)table) + ->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes); cntl->response_attachment().append(&node_num, sizeof(size_t)); cntl->response_attachment().append(actual_sizes.data(), @@ -301,7 +301,8 @@ int32_t GraphBrpcService::graph_random_sample_nodes( size_t size = *(uint64_t *)(request.params(0).c_str()); std::unique_ptr buffer; int actual_size; - if (table->random_sample_nodes(size, buffer, actual_size) == 0) { + if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) == + 0) { cntl->response_attachment().append(buffer.get(), actual_size); } else cntl->response_attachment().append(NULL, 0); @@ -330,7 +331,7 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, std::vector> feature( feature_names.size(), std::vector(node_num)); - table->get_node_feat(node_ids, feature_names, feature); + ((GraphTable *)table)->get_node_feat(node_ids, feature_names, feature); for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h index af63bf5d99ef2e03cd7264f1cf95f3c77e8a5ec0..32c572f9e6c2bf759c59190679bcf7570a807f2d 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -22,7 +22,8 @@ #include #include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/server.h" - +#include "paddle/fluid/distributed/table/common_graph_table.h" +#include "paddle/fluid/distributed/table/table.h" namespace paddle { namespace distributed { class GraphBrpcServer : public PSServer { diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index 3ff4b9d063f33a35418d6393edb010923caae838..1c8abc6c2e8dcd18ed64788f132f9e5ccfd83f12 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -24,7 +24,7 @@ #include "paddle/fluid/distributed/service/env.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/table/accessor.h" -#include "paddle/fluid/distributed/table/graph_node.h" +#include "paddle/fluid/distributed/table/graph/graph_node.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt index 33873abc5f7f51e3c4fdf2619c51190bbd82085a..dde1f5ae8ee3a1d683c805896a470612de6e2aba 100644 --- a/paddle/fluid/distributed/table/CMakeLists.txt +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -1,12 +1,12 @@ set_property(GLOBAL PROPERTY TABLE_DEPS string_helper) - +set(graphDir graph) get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS) -set_source_files_properties(graph_edge.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_library(graph_edge SRCS graph_edge.cc) -set_source_files_properties(graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_library(WeightedSampler SRCS graph_weighted_sampler.cc DEPS graph_edge) -set_source_files_properties(graph_node.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_library(graph_node SRCS graph_node.cc DEPS WeightedSampler) +set_source_files_properties(${graphDir}/graph_edge.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(graph_edge SRCS ${graphDir}/graph_edge.cc) +set_source_files_properties(${graphDir}/graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(WeightedSampler SRCS ${graphDir}/graph_weighted_sampler.cc DEPS graph_edge) +set_source_files_properties(${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(graph_node SRCS ${graphDir}/graph_node.cc DEPS WeightedSampler) set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 995a39a654312fc76677373b04a8896d85703b7d..020bcdcc52ef4b023fbb7b263517f67ef4abaf0b 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -18,7 +18,7 @@ #include #include #include "paddle/fluid/distributed/common/utils.h" -#include "paddle/fluid/distributed/table/graph_node.h" +#include "paddle/fluid/distributed/table/graph/graph_node.h" #include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/string_helper.h" namespace paddle { diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index ab28961846297457187b92346f304f68c6bd514c..8ddf3c8f904a6cab0e5826118ce9650bf8f6e2af 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -26,7 +26,7 @@ #include #include "paddle/fluid/distributed/table/accessor.h" #include "paddle/fluid/distributed/table/common_table.h" -#include "paddle/fluid/distributed/table/graph_node.h" +#include "paddle/fluid/distributed/table/graph/graph_node.h" #include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/string/string_helper.h" namespace paddle { diff --git a/paddle/fluid/distributed/table/graph/graph_edge.cc b/paddle/fluid/distributed/table/graph/graph_edge.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ab0d5a76d6715401dd55ce7487634b72d452ddf --- /dev/null +++ b/paddle/fluid/distributed/table/graph/graph_edge.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/graph/graph_edge.h" +#include +namespace paddle { +namespace distributed { + +void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { + id_arr.push_back(id); +} + +void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { + id_arr.push_back(id); + weight_arr.push_back(weight); +} +} +} diff --git a/paddle/fluid/distributed/table/graph/graph_edge.h b/paddle/fluid/distributed/table/graph/graph_edge.h new file mode 100644 index 0000000000000000000000000000000000000000..3dfe5a6f357a7cd7d79834a20b6411995665f4fa --- /dev/null +++ b/paddle/fluid/distributed/table/graph/graph_edge.h @@ -0,0 +1,46 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +namespace paddle { +namespace distributed { + +class GraphEdgeBlob { + public: + GraphEdgeBlob() {} + virtual ~GraphEdgeBlob() {} + size_t size() { return id_arr.size(); } + virtual void add_edge(uint64_t id, float weight); + uint64_t get_id(int idx) { return id_arr[idx]; } + virtual float get_weight(int idx) { return 1; } + + protected: + std::vector id_arr; +}; + +class WeightedGraphEdgeBlob : public GraphEdgeBlob { + public: + WeightedGraphEdgeBlob() {} + virtual ~WeightedGraphEdgeBlob() {} + virtual void add_edge(uint64_t id, float weight); + virtual float get_weight(int idx) { return weight_arr[idx]; } + + protected: + std::vector weight_arr; +}; +} +} diff --git a/paddle/fluid/distributed/table/graph/graph_node.cc b/paddle/fluid/distributed/table/graph/graph_node.cc new file mode 100644 index 0000000000000000000000000000000000000000..816d31b979072c3f1679df1ea75cd9dc75c55b0a --- /dev/null +++ b/paddle/fluid/distributed/table/graph/graph_node.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/graph/graph_node.h" +#include +namespace paddle { +namespace distributed { + +GraphNode::~GraphNode() { + if (sampler != nullptr) { + delete sampler; + sampler = nullptr; + } + if (edges != nullptr) { + delete edges; + edges = nullptr; + } +} + +int Node::weight_size = sizeof(float); +int Node::id_size = sizeof(uint64_t); +int Node::int_size = sizeof(int); + +int Node::get_size(bool need_feature) { return id_size + int_size; } + +void Node::to_buffer(char* buffer, bool need_feature) { + memcpy(buffer, &id, id_size); + buffer += id_size; + + int feat_num = 0; + memcpy(buffer, &feat_num, sizeof(int)); +} + +void Node::recover_from_buffer(char* buffer) { memcpy(&id, buffer, id_size); } + +int FeatureNode::get_size(bool need_feature) { + int size = id_size + int_size; // id, feat_num + if (need_feature) { + size += feature.size() * int_size; + for (const std::string& fea : feature) { + size += fea.size(); + } + } + return size; +} + +void GraphNode::build_edges(bool is_weighted) { + if (edges == nullptr) { + if (is_weighted == true) { + edges = new WeightedGraphEdgeBlob(); + } else { + edges = new GraphEdgeBlob(); + } + } +} +void GraphNode::build_sampler(std::string sample_type) { + if (sample_type == "random") { + sampler = new RandomSampler(); + } else if (sample_type == "weighted") { + sampler = new WeightedSampler(); + } + sampler->build(edges); +} +void FeatureNode::to_buffer(char* buffer, bool need_feature) { + memcpy(buffer, &id, id_size); + buffer += id_size; + + int feat_num = 0; + int feat_len; + if (need_feature) { + feat_num += feature.size(); + memcpy(buffer, &feat_num, sizeof(int)); + buffer += sizeof(int); + for (int i = 0; i < feat_num; ++i) { + feat_len = feature[i].size(); + memcpy(buffer, &feat_len, sizeof(int)); + buffer += sizeof(int); + memcpy(buffer, feature[i].c_str(), feature[i].size()); + buffer += feature[i].size(); + } + } else { + memcpy(buffer, &feat_num, sizeof(int)); + } +} +void FeatureNode::recover_from_buffer(char* buffer) { + int feat_num, feat_len; + memcpy(&id, buffer, id_size); + buffer += id_size; + + memcpy(&feat_num, buffer, sizeof(int)); + buffer += sizeof(int); + + feature.clear(); + for (int i = 0; i < feat_num; ++i) { + memcpy(&feat_len, buffer, sizeof(int)); + buffer += sizeof(int); + + char str[feat_len + 1]; + memcpy(str, buffer, feat_len); + buffer += feat_len; + str[feat_len] = '\0'; + feature.push_back(std::string(str)); + } +} +} +} diff --git a/paddle/fluid/distributed/table/graph/graph_node.h b/paddle/fluid/distributed/table/graph/graph_node.h new file mode 100644 index 0000000000000000000000000000000000000000..8ad795ac97b5499c7b10361760f7ac16494c154b --- /dev/null +++ b/paddle/fluid/distributed/table/graph/graph_node.h @@ -0,0 +1,127 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h" +namespace paddle { +namespace distributed { + +class Node { + public: + Node() {} + Node(uint64_t id) : id(id) {} + virtual ~Node() {} + static int id_size, int_size, weight_size; + uint64_t get_id() { return id; } + void set_id(uint64_t id) { this->id = id; } + + virtual void build_edges(bool is_weighted) {} + virtual void build_sampler(std::string sample_type) {} + virtual void add_edge(uint64_t id, float weight) {} + virtual std::vector sample_k(int k) { return std::vector(); } + virtual uint64_t get_neighbor_id(int idx) { return 0; } + virtual float get_neighbor_weight(int idx) { return 1.; } + + virtual int get_size(bool need_feature); + virtual void to_buffer(char *buffer, bool need_feature); + virtual void recover_from_buffer(char *buffer); + virtual std::string get_feature(int idx) { return std::string(""); } + virtual void set_feature(int idx, std::string str) {} + virtual void set_feature_size(int size) {} + virtual int get_feature_size() { return 0; } + + protected: + uint64_t id; +}; + +class GraphNode : public Node { + public: + GraphNode() : Node(), sampler(nullptr), edges(nullptr) {} + GraphNode(uint64_t id) : Node(id), sampler(nullptr), edges(nullptr) {} + virtual ~GraphNode(); + virtual void build_edges(bool is_weighted); + virtual void build_sampler(std::string sample_type); + virtual void add_edge(uint64_t id, float weight) { + edges->add_edge(id, weight); + } + virtual std::vector sample_k(int k) { return sampler->sample_k(k); } + virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); } + virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); } + + protected: + Sampler *sampler; + GraphEdgeBlob *edges; +}; + +class FeatureNode : public Node { + public: + FeatureNode() : Node() {} + FeatureNode(uint64_t id) : Node(id) {} + virtual ~FeatureNode() {} + virtual int get_size(bool need_feature); + virtual void to_buffer(char *buffer, bool need_feature); + virtual void recover_from_buffer(char *buffer); + virtual std::string get_feature(int idx) { + if (idx < (int)this->feature.size()) { + return this->feature[idx]; + } else { + return std::string(""); + } + } + + virtual void set_feature(int idx, std::string str) { + if (idx >= (int)this->feature.size()) { + this->feature.resize(idx + 1); + } + this->feature[idx] = str; + } + virtual void set_feature_size(int size) { this->feature.resize(size); } + virtual int get_feature_size() { return this->feature.size(); } + + template + static std::string parse_value_to_bytes(std::vector feat_str) { + T v; + size_t Tsize = sizeof(T) * feat_str.size(); + char buffer[Tsize]; + for (size_t i = 0; i < feat_str.size(); i++) { + std::stringstream ss(feat_str[i]); + ss >> v; + std::memcpy(buffer + sizeof(T) * i, (char *)&v, sizeof(T)); + } + return std::string(buffer, Tsize); + } + + template + static std::vector parse_bytes_to_array(std::string feat_str) { + T v; + std::vector out; + size_t start = 0; + const char *buffer = feat_str.data(); + while (start < feat_str.size()) { + std::memcpy((char *)&v, buffer + start, sizeof(T)); + start += sizeof(T); + out.push_back(v); + } + return out; + } + + protected: + std::vector feature; +}; +} +} diff --git a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc new file mode 100644 index 0000000000000000000000000000000000000000..3a680875e3df4a9cd60f8fe1921b877dbb23c8a2 --- /dev/null +++ b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/graph/graph_weighted_sampler.h" +#include +#include +namespace paddle { +namespace distributed { + +void RandomSampler::build(GraphEdgeBlob *edges) { this->edges = edges; } + +std::vector RandomSampler::sample_k(int k) { + int n = edges->size(); + if (k > n) { + k = n; + } + struct timespec tn; + clock_gettime(CLOCK_REALTIME, &tn); + srand(tn.tv_nsec); + std::vector sample_result; + std::unordered_map replace_map; + while (k--) { + int rand_int = rand() % n; + auto iter = replace_map.find(rand_int); + if (iter == replace_map.end()) { + sample_result.push_back(rand_int); + } else { + sample_result.push_back(iter->second); + } + + iter = replace_map.find(n - 1); + if (iter == replace_map.end()) { + replace_map[rand_int] = n - 1; + } else { + replace_map[rand_int] = iter->second; + } + --n; + } + return sample_result; +} + +WeightedSampler::WeightedSampler() { + left = nullptr; + right = nullptr; + edges = nullptr; +} + +WeightedSampler::~WeightedSampler() { + if (left != nullptr) { + delete left; + left = nullptr; + } + if (right != nullptr) { + delete right; + right = nullptr; + } +} + +void WeightedSampler::build(GraphEdgeBlob *edges) { + if (left != nullptr) { + delete left; + left = nullptr; + } + if (right != nullptr) { + delete right; + right = nullptr; + } + return build_one((WeightedGraphEdgeBlob *)edges, 0, edges->size()); +} + +void WeightedSampler::build_one(WeightedGraphEdgeBlob *edges, int start, + int end) { + count = 0; + this->edges = edges; + if (start + 1 == end) { + left = right = nullptr; + idx = start; + count = 1; + weight = edges->get_weight(idx); + + } else { + left = new WeightedSampler(); + right = new WeightedSampler(); + left->build_one(edges, start, start + (end - start) / 2); + right->build_one(edges, start + (end - start) / 2, end); + weight = left->weight + right->weight; + count = left->count + right->count; + } +} +std::vector WeightedSampler::sample_k(int k) { + if (k > count) { + k = count; + } + std::vector sample_result; + float subtract; + std::unordered_map subtract_weight_map; + std::unordered_map subtract_count_map; + struct timespec tn; + clock_gettime(CLOCK_REALTIME, &tn); + srand(tn.tv_nsec); + while (k--) { + float query_weight = rand() % 100000 / 100000.0; + query_weight *= weight - subtract_weight_map[this]; + sample_result.push_back(sample(query_weight, subtract_weight_map, + subtract_count_map, subtract)); + } + return sample_result; +} + +int WeightedSampler::sample( + float query_weight, + std::unordered_map &subtract_weight_map, + std::unordered_map &subtract_count_map, + float &subtract) { + if (left == nullptr) { + subtract_weight_map[this] = weight; + subtract = weight; + subtract_count_map[this] = 1; + return idx; + } + int left_count = left->count - subtract_count_map[left]; + int right_count = right->count - subtract_count_map[right]; + float left_subtract = subtract_weight_map[left]; + int return_idx; + if (right_count == 0 || + left_count > 0 && left->weight - left_subtract >= query_weight) { + return_idx = left->sample(query_weight, subtract_weight_map, + subtract_count_map, subtract); + } else { + return_idx = + right->sample(query_weight - (left->weight - left_subtract), + subtract_weight_map, subtract_count_map, subtract); + } + subtract_weight_map[this] += subtract; + subtract_count_map[this]++; + return return_idx; +} +} +} diff --git a/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h new file mode 100644 index 0000000000000000000000000000000000000000..1787ab23b04316de9ad0622ff5524bc88bd51fe1 --- /dev/null +++ b/paddle/fluid/distributed/table/graph/graph_weighted_sampler.h @@ -0,0 +1,58 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/distributed/table/graph/graph_edge.h" +namespace paddle { +namespace distributed { + +class Sampler { + public: + virtual ~Sampler() {} + virtual void build(GraphEdgeBlob *edges) = 0; + virtual std::vector sample_k(int k) = 0; +}; + +class RandomSampler : public Sampler { + public: + virtual ~RandomSampler() {} + virtual void build(GraphEdgeBlob *edges); + virtual std::vector sample_k(int k); + GraphEdgeBlob *edges; +}; + +class WeightedSampler : public Sampler { + public: + WeightedSampler(); + virtual ~WeightedSampler(); + WeightedSampler *left, *right; + float weight; + int count; + int idx; + GraphEdgeBlob *edges; + virtual void build(GraphEdgeBlob *edges); + virtual void build_one(WeightedGraphEdgeBlob *edges, int start, int end); + virtual std::vector sample_k(int k); + + private: + int sample(float query_weight, + std::unordered_map &subtract_weight_map, + std::unordered_map &subtract_count_map, + float &subtract); +}; +} +} diff --git a/paddle/fluid/distributed/table/table.h b/paddle/fluid/distributed/table/table.h index 8f014ac98ba4bbcfe0b90d774733406b68d94ee6..5bc818ff4741fd3ad6ee181fd464f71103be2600 100644 --- a/paddle/fluid/distributed/table/table.h +++ b/paddle/fluid/distributed/table/table.h @@ -22,7 +22,7 @@ #include #include "paddle/fluid/distributed/table/accessor.h" #include "paddle/fluid/distributed/table/depends/sparse_utils.h" -#include "paddle/fluid/distributed/table/graph_node.h" +#include "paddle/fluid/distributed/table/graph/graph_node.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/device_context.h" @@ -88,31 +88,6 @@ class Table { return 0; } - // only for graph table - virtual int32_t pull_graph_list(int start, int total_size, - std::unique_ptr &buffer, - int &actual_size, bool need_feature, - int step = 1) { - return 0; - } - // only for graph table - virtual int32_t random_sample_neighboors( - uint64_t *node_ids, int sample_size, - std::vector> &buffers, - std::vector &actual_sizes) { - return 0; - } - - virtual int32_t random_sample_nodes(int sample_size, - std::unique_ptr &buffers, - int &actual_sizes) { - return 0; - } - virtual int32_t get_node_feat(const std::vector &node_ids, - const std::vector &feature_names, - std::vector> &res) { - return 0; - } virtual int32_t pour() { return 0; } virtual void clear() = 0; diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 79ab27959638445a46342599592575c324dce135..b268bb449e14619048e89c8933dbae7daf66537b 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -33,7 +33,7 @@ limitations under the License. */ #include "paddle/fluid/distributed/service/ps_client.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/service/service.h" -#include "paddle/fluid/distributed/table/graph_node.h" +#include "paddle/fluid/distributed/table/graph/graph_node.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h"