未验证 提交 ef78c9c2 编写于 作者: S seemingwang 提交者: GitHub

cherry pick recent updates in graph-engine to release2.3 (#42027)

* gpu_graph engine optimization+ (#41455)

* extract sub-graph

* graph-engine merging

* fix

* fix

* fix heter-ps config

* test performance

* test performance

* test performance

* test

* test

* update bfs

* change cmake

* test

* test gpu speed

* gpu_graph_engine optimization

* add ssd layer to graph_engine

* fix allocation

* fix syntax error

* fix syntax error

* fix pscore class

* fix

* recover test

* recover test

* fix spelling

* recover

* fix

* Cpu gpu graph engine (#41942)

* extract sub-graph

* graph-engine merging

* fix

* fix

* fix heter-ps config

* test performance

* test performance

* test performance

* test

* test

* update bfs

* change cmake

* test

* test gpu speed

* gpu_graph_engine optimization

* add ssd layer to graph_engine

* fix allocation

* fix syntax error

* fix syntax error

* fix pscore class

* fix

* recover test

* recover test

* fix spelling

* recover

* fix

* fix linking problem

* remove comment
上级 80992253
......@@ -28,7 +28,112 @@ namespace paddle {
namespace distributed {
#ifdef PADDLE_WITH_HETERPS
paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
std::vector<int64_t> ids) {
std::vector<std::vector<int64_t>> bags(task_pool_size_);
for (auto x : ids) {
int location = x % shard_num % task_pool_size_;
bags[location].push_back(x);
}
std::vector<std::future<int>> tasks;
std::vector<int64_t> edge_array[task_pool_size_];
std::vector<paddle::framework::GpuPsGraphNode> node_array[task_pool_size_];
for (int i = 0; i < (int)bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
paddle::framework::GpuPsGraphNode x;
for (int j = 0; j < (int)bags[i].size(); j++) {
Node *v = find_node(bags[i][j]);
x.node_id = bags[i][j];
if (v == NULL) {
x.neighbor_size = 0;
x.neighbor_offset = 0;
node_array[i].push_back(x);
} else {
x.neighbor_size = v->get_neighbor_size();
x.neighbor_offset = edge_array[i].size();
node_array[i].push_back(x);
for (int k = 0; k < x.neighbor_size; k++) {
edge_array[i].push_back(v->get_neighbor_id(k));
}
}
}
return 0;
}));
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
paddle::framework::GpuPsCommGraph res;
int tot_len = 0;
for (int i = 0; i < task_pool_size_; i++) {
tot_len += (int)edge_array[i].size();
}
res.neighbor_size = tot_len;
res.node_size = ids.size();
res.neighbor_list = new int64_t[tot_len];
res.node_list = new paddle::framework::GpuPsGraphNode[ids.size()];
int offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) {
for (int j = 0; j < (int)node_array[i].size(); j++) {
res.node_list[ind] = node_array[i][j];
res.node_list[ind++].neighbor_offset += offset;
}
for (int j = 0; j < (int)edge_array[i].size(); j++) {
res.neighbor_list[offset + j] = edge_array[i][j];
}
offset += edge_array[i].size();
}
return res;
}
int32_t GraphTable::add_node_to_ssd(int64_t src_id, char *data, int len) {
if (_db != NULL)
_db->put(src_id % shard_num % task_pool_size_, (char *)&src_id,
sizeof(uint64_t), (char *)data, sizeof(int64_t) * len);
return 0;
}
char *GraphTable::random_sample_neighbor_from_ssd(
int64_t id, int sample_size, const std::shared_ptr<std::mt19937_64> rng,
int &actual_size) {
if (_db == NULL) {
actual_size = 0;
return NULL;
}
std::string str;
if (_db->get(id % shard_num % task_pool_size_, (char *)&id, sizeof(uint64_t),
str) == 0) {
int64_t *data = ((int64_t *)str.c_str());
int n = str.size() / sizeof(int64_t);
std::unordered_map<int, int> m;
// std::vector<int64_t> res;
int sm_size = std::min(n, sample_size);
actual_size = sm_size * Node::id_size;
char *buff = new char[actual_size];
for (int i = 0; i < sm_size; i++) {
std::uniform_int_distribution<int> distrib(0, n - i - 1);
int t = distrib(*rng);
// int t = rand() % (n-i);
int pos = 0;
auto iter = m.find(t);
if (iter != m.end()) {
pos = iter->second;
} else {
pos = t;
}
auto iter2 = m.find(n - i - 1);
int key2 = iter2 == m.end() ? n - i - 1 : iter2->second;
m[t] = key2;
m.erase(n - i - 1);
memcpy(buff + i * Node::id_size, &data[pos], Node::id_size);
// res.push_back(data[pos]);
}
return buff;
}
actual_size = 0;
return NULL;
}
#endif
/*
int CompleteGraphSampler::run_graph_sampling() {
pthread_rwlock_t *rw_lock = graph_table->rw_lock.get();
pthread_rwlock_rdlock(rw_lock);
......@@ -136,7 +241,8 @@ int BasicBfsGraphSampler::run_graph_sampling() {
int task_size = 0;
std::vector<std::future<int>> tasks;
int init_size = 0;
std::function<int(int, int64_t)> bfs = [&, this](int i, int64_t id) -> int {
//__sync_fetch_and_add
std::function<int(int, int64_t)> bfs = [&, this](int i, int id) -> int {
if (this->status == GraphSamplerStatus::terminating) {
int task_left = __sync_sub_and_fetch(&task_size, 1);
if (task_left == 0) {
......@@ -289,6 +395,7 @@ int BasicBfsGraphSampler::run_graph_sampling() {
std::this_thread::sleep_for(std::chrono::seconds(1));
}
}
VLOG(0)<<"bfs returning";
}
return 0;
}
......@@ -304,7 +411,7 @@ void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
}
#endif
*/
std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0;
std::vector<Node *> res;
......@@ -316,6 +423,18 @@ std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
size_t GraphShard::get_size() { return bucket.size(); }
int32_t GraphTable::add_comm_edge(int64_t src_id, int64_t dst_id) {
size_t src_shard_id = src_id % shard_num;
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
return -1;
}
size_t index = src_shard_id - shard_start;
VLOG(0) << "index add edge " << src_id << " " << dst_id;
shards[index]->add_graph_node(src_id)->build_edges(false);
shards[index]->add_neighbor(src_id, dst_id, 1.0);
return 0;
}
int32_t GraphTable::add_graph_node(std::vector<int64_t> &id_list,
std::vector<bool> &is_weight_list) {
size_t node_size = id_list.size();
......@@ -554,9 +673,9 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
}
int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
#ifdef PADDLE_WITH_HETERPS
if (gpups_mode) pthread_rwlock_rdlock(rw_lock.get());
#endif
// #ifdef PADDLE_WITH_HETERPS
// if (gpups_mode) pthread_rwlock_rdlock(rw_lock.get());
// #endif
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
std::string sample_type = "random";
......@@ -633,9 +752,9 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
relocate the duplicate nodes to make them distributed evenly among threads.
*/
if (!use_duplicate_nodes) {
#ifdef PADDLE_WITH_HETERPS
if (gpups_mode) pthread_rwlock_unlock(rw_lock.get());
#endif
// #ifdef PADDLE_WITH_HETERPS
// if (gpups_mode) pthread_rwlock_unlock(rw_lock.get());
// #endif
return 0;
}
......@@ -712,9 +831,9 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
delete extra_shards[i];
extra_shards[i] = extra_shards_copy[i];
}
#ifdef PADDLE_WITH_HETERPS
if (gpups_mode) pthread_rwlock_unlock(rw_lock.get());
#endif
// #ifdef PADDLE_WITH_HETERPS
// if (gpups_mode) pthread_rwlock_unlock(rw_lock.get());
// #endif
return 0;
}
......@@ -878,6 +997,17 @@ int32_t GraphTable::random_sample_neighbors(
idx = seq_id[i][k];
int &actual_size = actual_sizes[idx];
if (node == nullptr) {
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) {
char *buffer_addr = random_sample_neighbor_from_ssd(
node_id, sample_size, rng, actual_size);
if (actual_size != 0) {
std::shared_ptr<char> &buffer = buffers[idx];
buffer.reset(buffer_addr, char_del);
}
continue;
}
#endif
actual_size = 0;
continue;
}
......@@ -1085,25 +1215,29 @@ int32_t GraphTable::Initialize(const TableParameter &config,
return Initialize(graph);
}
int32_t GraphTable::Initialize(const GraphParameter &graph) {
task_pool_size_ = graph.task_pool_size();
#ifdef PADDLE_WITH_HETERPS
if (graph.gpups_mode()) {
gpups_mode = true;
auto *sampler =
CREATE_PSCORE_CLASS(GraphSampler, graph.gpups_graph_sample_class());
auto slices =
string::split_string<std::string>(graph.gpups_graph_sample_args(), ",");
std::cout << "slices" << std::endl;
for (auto x : slices) std::cout << x << std::endl;
sampler->init(graph.gpu_num(), this, slices);
graph_sampler.reset(sampler);
}
_db = NULL;
search_level = graph.search_level();
if (search_level >= 2) {
_db = paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize("./temp_gpups_db", task_pool_size_);
}
// gpups_mode = true;
// auto *sampler =
// CREATE_PSCORE_CLASS(GraphSampler, graph.gpups_graph_sample_class());
// auto slices =
// string::split_string<std::string>(graph.gpups_graph_sample_args(), ",");
// std::cout << "slices" << std::endl;
// for (auto x : slices) std::cout << x << std::endl;
// sampler->init(graph.gpu_num(), this, slices);
// graph_sampler.reset(sampler);
#endif
if (shard_num == 0) {
server_num = 1;
_shard_idx = 0;
shard_num = graph.shard_num();
}
task_pool_size_ = graph.task_pool_size();
use_cache = graph.use_cache();
if (use_cache) {
cache_size_limit = graph.cache_size_limit();
......
......@@ -38,6 +38,7 @@
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h"
......@@ -351,6 +352,7 @@ class ScaledLRU {
friend class RandomSampleLRU<K, V>;
};
/*
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
......@@ -363,6 +365,9 @@ class GraphSampler {
return;
};
}
virtual int loadData(const std::string &path){
return 0;
}
virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) {
......@@ -403,15 +408,13 @@ class GraphSampler {
std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
*/
class GraphTable : public Table {
public:
GraphTable() {
use_cache = false;
shard_num = 0;
#ifdef PADDLE_WITH_HETERPS
gpups_mode = false;
#endif
rw_lock.reset(new pthread_rwlock_t());
}
virtual ~GraphTable();
......@@ -516,21 +519,28 @@ class GraphTable : public Table {
return 0;
}
#ifdef PADDLE_WITH_HETERPS
virtual int32_t start_graph_sampling() {
return this->graph_sampler->start_graph_sampling();
}
virtual int32_t end_graph_sampling() {
return this->graph_sampler->end_graph_sampling();
}
virtual int32_t set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
graph_sampler->set_graph_sample_callback(callback);
return 0;
}
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); }
// virtual int32_t start_graph_sampling() {
// return this->graph_sampler->start_graph_sampling();
// }
// virtual int32_t end_graph_sampling() {
// return this->graph_sampler->end_graph_sampling();
// }
// virtual int32_t set_graph_sample_callback(
// std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
// callback) {
// graph_sampler->set_graph_sample_callback(callback);
// return 0;
// }
virtual char *random_sample_neighbor_from_ssd(
int64_t id, int sample_size, const std::shared_ptr<std::mt19937_64> rng,
int &actual_size);
virtual int32_t add_node_to_ssd(int64_t id, char *data, int len);
virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
std::vector<int64_t> ids);
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); }
int search_level;
#endif
protected:
virtual int32_t add_comm_edge(int64_t src_id, int64_t dst_id);
std::vector<GraphShard *> shards, extra_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
int task_pool_size_ = 24;
......@@ -555,13 +565,14 @@ class GraphTable : public Table {
std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
bool gpups_mode;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
std::shared_ptr<GraphSampler> graph_sampler;
REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
paddle::distributed::RocksDBHandler *_db;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
// std::shared_ptr<GraphSampler> graph_sampler;
// REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
};
/*
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
......@@ -603,6 +614,7 @@ class BasicBfsGraphSampler : public GraphSampler {
sample_neighbors_map;
};
#endif
*/
} // namespace distributed
}; // namespace paddle
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include <glog/logging.h>
#include <rocksdb/db.h>
......
......@@ -31,10 +31,6 @@ namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(Table, GraphTable);
REGISTER_PSCORE_CLASS(Table, MemoryDenseTable);
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_CLASS(GraphSampler, CompleteGraphSampler);
REGISTER_PSCORE_CLASS(GraphSampler, BasicBfsGraphSampler);
#endif
REGISTER_PSCORE_CLASS(Table, BarrierTable);
REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
......
......@@ -25,7 +25,7 @@ set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${
cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_table_sample_test SRCS graph_table_sample_test.cc DEPS scope server communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
cc_test(graph_table_sample_test SRCS graph_table_sample_test.cc DEPS table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table)
......
......@@ -679,7 +679,7 @@ void testCache() {
st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 1);
char* p = (char*)r[0].second.buffer.get();
for (size_t j = 0; j < r[0].second.actual_size; j++)
for (int j = 0; j < (int)r[0].second.actual_size; j++)
ASSERT_EQ(p[j], str[j]);
r.clear();
}
......
......@@ -25,18 +25,7 @@
#include <chrono>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
......@@ -83,66 +72,11 @@ void prepare_file(char file_name[], std::vector<std::string> data) {
}
void testGraphSample() {
#ifdef PADDLE_WITH_HETERPS
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_shard_num(127);
table_proto.set_gpu_num(2);
// table_proto.set_gpu_num(2);
distributed::GraphTable graph_table, graph_table1;
graph_table.initialize(table_proto);
prepare_file(edge_file_name, edges);
graph_table.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res;
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_table.set_graph_sample_callback(
[&res, &prom](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res = res0;
prom.set_value(0);
});
graph_table.start_graph_sampling();
fut.get();
graph_table.end_graph_sampling();
ASSERT_EQ(2, res.size());
// 37 59 97
for (int i = 0; i < (int)res[1].node_size; i++) {
std::cout << res[1].node_list[i].node_id << std::endl;
}
ASSERT_EQ(3, res[1].node_size);
::paddle::distributed::GraphParameter table_proto1;
table_proto1.set_gpups_mode(true);
table_proto1.set_shard_num(127);
table_proto1.set_gpu_num(2);
table_proto1.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto1.set_gpups_graph_sample_args("5,5,1,1");
graph_table1.initialize(table_proto1);
graph_table1.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res1;
std::promise<int> prom1;
std::future<int> fut1 = prom1.get_future();
graph_table1.set_graph_sample_callback(
[&res1, &prom1](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res1 = res0;
prom1.set_value(0);
});
graph_table1.start_graph_sampling();
fut1.get();
graph_table1.end_graph_sampling();
// distributed::BasicBfsGraphSampler *sampler1 =
// (distributed::BasicBfsGraphSampler *)graph_table1.get_graph_sampler();
// sampler1->start_graph_sampling();
// std::this_thread::sleep_for (std::chrono::seconds(1));
// std::vector<paddle::framework::GpuPsCommGraph> res1;// =
// sampler1->fetch_sample_res();
ASSERT_EQ(2, res1.size());
// odd id:96 48 122 112
for (int i = 0; i < (int)res1[0].node_size; i++) {
std::cout << res1[0].node_list[i].node_id << std::endl;
}
ASSERT_EQ(4, res1[0].node_size);
#endif
distributed::GraphTable graph_table;
graph_table.Initialize(table_proto);
}
TEST(testGraphSample, Run) { testGraphSample(); }
......@@ -215,18 +215,16 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule
message GraphParameter {
optional int32 task_pool_size = 1 [ default = 24 ];
optional bool gpups_mode = 2 [ default = false ];
optional string gpups_graph_sample_class = 3
optional string gpups_graph_sample_class = 2
[ default = "CompleteGraphSampler" ];
optional string gpups_graph_sample_args = 4 [ default = "" ];
optional bool use_cache = 5 [ default = false ];
optional int32 cache_size_limit = 6 [ default = 100000 ];
optional int32 cache_ttl = 7 [ default = 5 ];
optional GraphFeature graph_feature = 8;
optional string table_name = 9 [ default = "" ];
optional string table_type = 10 [ default = "" ];
optional int32 shard_num = 11 [ default = 127 ];
optional int32 gpu_num = 12 [ default = 1 ];
optional bool use_cache = 3 [ default = false ];
optional int32 cache_size_limit = 4 [ default = 100000 ];
optional int32 cache_ttl = 5 [ default = 5 ];
optional GraphFeature graph_feature = 6;
optional string table_name = 7 [ default = "" ];
optional string table_type = 8 [ default = "" ];
optional int32 shard_num = 9 [ default = 127 ];
optional int32 search_level = 10 [ default = 1 ];
}
message GraphFeature {
......
......@@ -13,13 +13,17 @@ IF(WITH_GPU)
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
if(WITH_PSCORE)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table)
nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps)
nv_test(test_cpu_graph_sample SRCS test_cpu_graph_sample.cu DEPS graph_gpu_ps)
#nv_test(test_sample_rate SRCS test_sample_rate.cu DEPS graph_gpu_ps)
# ADD_EXECUTABLE(test_sample_rate test_sample_rate.cu)
# target_link_libraries(test_sample_rate graph_gpu_ps)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table hashtable_kernel)
nv_library(graph_sampler SRCS graph_sampler_inl.h DEPS graph_gpu_ps)
nv_test(test_cpu_query SRCS test_cpu_query.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS})
#ADD_EXECUTABLE(test_sample_rate test_sample_rate.cu)
#target_link_libraries(test_sample_rate heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS})
#nv_test(test_sample_rate SRCS test_sample_rate.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS})
#ADD_EXECUTABLE(test_cpu_query test_cpu_query.cu)
#target_link_libraries(test_cpu_query graph_gpu_ps)
endif()
ENDIF()
IF(WITH_XPU_KP)
SET(HETERPS_DEPS device_context)
......
......@@ -14,6 +14,12 @@
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include <iostream>
#include <memory>
#include <string>
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
namespace paddle {
namespace framework {
struct GpuPsGraphNode {
......@@ -36,6 +42,24 @@ struct GpuPsCommGraph {
node_list(node_list_),
neighbor_size(neighbor_size_),
node_size(node_size_) {}
void display_on_cpu() {
VLOG(0) << "neighbor_size = " << neighbor_size;
VLOG(0) << "node_size = " << node_size;
for (int i = 0; i < neighbor_size; i++) {
VLOG(0) << "neighbor " << i << " " << neighbor_list[i];
}
for (int i = 0; i < node_size; i++) {
VLOG(0) << "node i " << node_list[i].node_id
<< " neighbor_size = " << node_list[i].neighbor_size;
std::string str;
int offset = node_list[i].neighbor_offset;
for (int j = 0; j < node_list[i].neighbor_size; j++) {
if (j > 0) str += ",";
str += std::to_string(neighbor_list[j + offset]);
}
VLOG(0) << str;
}
}
};
/*
......@@ -94,16 +118,24 @@ struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
int *offset;
NeighborSampleResult(int _sample_size, int _key_size)
std::shared_ptr<memory::Allocation> val_mem, actual_sample_size_mem;
NeighborSampleResult(int _sample_size, int _key_size, int dev_id)
: sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL;
val = NULL;
platform::CUDADeviceGuard guard(dev_id);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
val_mem =
memory::AllocShared(place, _sample_size * _key_size * sizeof(int64_t));
val = (int64_t *)val_mem->ptr();
actual_sample_size_mem =
memory::AllocShared(place, _key_size * sizeof(int));
actual_sample_size = (int *)actual_sample_size_mem->ptr();
offset = NULL;
};
~NeighborSampleResult() {
if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size);
if (offset != NULL) cudaFree(offset);
// if (val != NULL) cudaFree(val);
// if (actual_sample_size != NULL) cudaFree(actual_sample_size);
// if (offset != NULL) cudaFree(offset);
}
};
......
......@@ -14,26 +14,73 @@
#pragma once
#include <thrust/host_vector.h>
#include <chrono>
#include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource)
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
: HeterComm<int64_t, int, int>(1, resource) {
load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t());
gpu_num = resource_->total_device();
cpu_table_status = -1;
if (topo_aware) {
int total_gpu = resource_->total_device();
std::map<int, int> device_map;
for (int i = 0; i < total_gpu; i++) {
device_map[resource_->dev_id(i)] = i;
VLOG(1) << " device " << resource_->dev_id(i) << " is stored on " << i;
}
path_.clear();
path_.resize(total_gpu);
VLOG(1) << "topo aware overide";
for (int i = 0; i < total_gpu; ++i) {
path_[i].resize(total_gpu);
for (int j = 0; j < total_gpu; ++j) {
auto &nodes = path_[i][j].nodes_;
nodes.clear();
int from = resource_->dev_id(i);
int to = resource_->dev_id(j);
int transfer_id = i;
if (need_transfer(from, to) &&
(device_map.find((from + 4) % 8) != device_map.end() ||
device_map.find((to + 4) % 8) != device_map.end())) {
transfer_id = (device_map.find((from + 4) % 8) != device_map.end())
? ((from + 4) % 8)
: ((to + 4) % 8);
transfer_id = device_map[transfer_id];
nodes.push_back(Node());
Node &node = nodes.back();
node.in_stream = resource_->comm_stream(i, transfer_id);
node.out_stream = resource_->comm_stream(transfer_id, i);
node.key_storage = NULL;
node.val_storage = NULL;
node.sync = 0;
node.dev_num = transfer_id;
}
nodes.push_back(Node());
Node &node = nodes.back();
node.in_stream = resource_->comm_stream(i, transfer_id);
node.out_stream = resource_->comm_stream(transfer_id, i);
node.key_storage = NULL;
node.val_storage = NULL;
node.sync = 0;
node.dev_num = j;
}
}
}
}
~GpuPsGraphTable() {
if (cpu_table_status != -1) {
end_graph_sampling();
}
// if (cpu_table_status != -1) {
// end_graph_sampling();
// }
}
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
NodeQueryResult *graph_node_sample(int gpu_id, int sample_size);
......@@ -41,21 +88,28 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int sample_size, int len);
NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
void clear_graph_info();
void move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int *h_left, int *h_right,
int64_t *src_sample_res, thrust::host_vector<int> &total_sample_size);
void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num,
int *h_left, int *h_right,
int *actual_sample_size,
int *total_sample_size);
void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num,
int sample_size, int *h_left,
int *h_right,
int64_t *src_sample_res,
int *actual_sample_size);
// void move_neighbor_sample_result_to_source_gpu(
// int gpu_id, int gpu_num, int *h_left, int *h_right,
// int64_t *src_sample_res, thrust::host_vector<int> &total_sample_size);
// void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num,
// int *h_left, int *h_right,
// int *actual_sample_size,
// int *total_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int load(const std::string &path, const std::string &param);
virtual int32_t end_graph_sampling() {
return cpu_graph_table->end_graph_sampling();
}
private:
// int load(const std::string &path, const std::string &param);
// virtual int32_t end_graph_sampling() {
// return cpu_graph_table->end_graph_sampling();
// }
int gpu_num;
std::vector<GpuPsCommGraph> gpu_graph_list;
std::vector<int *> sample_status;
const int parallel_sample_size = 1;
const int dim_y = 256;
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table;
std::shared_ptr<pthread_rwlock_t> rw_lock;
mutable std::mutex mutex_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
}
virtual int start_service(std::string path) {
load_from_ssd(path);
VLOG(0) << "load from ssd over";
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
VLOG(0) << " promise set ";
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
return 0;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
~GraphSampler() { end_graph_sampling(); }
virtual int load_from_ssd(std::string path) = 0;
;
virtual int run_graph_sampling() = 0;
;
virtual void init(GpuPsGraphTable *gpu_table,
std::vector<std::string> args_) = 0;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
};
class CommonGraphSampler : public GraphSampler {
public:
CommonGraphSampler() {}
virtual ~CommonGraphSampler() {}
GpuPsGraphTable *g_table;
virtual int load_from_ssd(std::string path);
virtual int run_graph_sampling();
virtual void init(GpuPsGraphTable *g, std::vector<std::string> args);
GpuPsGraphTable *gpu_table;
paddle::distributed::GraphTable *table;
std::vector<int64_t> gpu_edges_count;
int64_t cpu_edges_count;
int64_t gpu_edges_limit, cpu_edges_limit, gpu_edges_each_limit;
std::vector<std::unordered_set<int64_t>> gpu_set;
int gpu_num;
};
class AllInGpuGraphSampler : public GraphSampler {
public:
AllInGpuGraphSampler() {}
virtual ~AllInGpuGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual int load_from_ssd(std::string path);
virtual void init(GpuPsGraphTable *g, std::vector<std::string> args_);
protected:
paddle::distributed::GraphTable *graph_table;
GpuPsGraphTable *gpu_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
}
};
#include "paddle/fluid/framework/fleet/heter_ps/graph_sampler_inl.h"
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
int CommonGraphSampler::load_from_ssd(std::string path) {
std::ifstream file(path);
auto _db = table->_db;
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
std::cout << values.size();
if (values.size() < 2) continue;
auto neighbors = paddle::string::split_string<std::string>(values[1], ";");
std::vector<int64_t> neighbor_data;
for (auto x : neighbors) {
neighbor_data.push_back(std::stoll(x));
}
auto src_id = std::stoll(values[0]);
_db->put(0, (char *)&src_id, sizeof(uint64_t), (char *)neighbor_data.data(),
sizeof(int64_t) * neighbor_data.size());
int gpu_shard = src_id % gpu_num;
if (gpu_edges_count[gpu_shard] + neighbor_data.size() <=
gpu_edges_each_limit) {
gpu_edges_count[gpu_shard] += neighbor_data.size();
gpu_set[gpu_shard].insert(src_id);
}
if (cpu_edges_count + neighbor_data.size() <= cpu_edges_limit) {
cpu_edges_count += neighbor_data.size();
for (auto x : neighbor_data) {
// table->add_neighbor(src_id, x);
table->shards[src_id % table->shard_num]
->add_graph_node(src_id)
->build_edges(false);
table->shards[src_id % table->shard_num]->add_neighbor(src_id, x, 1.0);
}
}
std::vector<paddle::framework::GpuPsCommGraph> graph_list;
for (int i = 0; i < gpu_num; i++) {
std::vector<int64_t> ids(gpu_set[i].begin(), gpu_set[i].end());
graph_list.push_back(table->make_gpu_ps_graph(ids));
}
gpu_table->build_graph_from_cpu(graph_list);
for (int i = 0; i < graph_list.size(); i++) {
delete[] graph_list[i].node_list;
delete[] graph_list[i].neighbor_list;
}
}
}
int CommonGraphSampler::run_graph_sampling() { return 0; }
void CommonGraphSampler::init(GpuPsGraphTable *g,
std::vector<std::string> args) {
this->gpu_table = g;
gpu_num = g->gpu_num;
gpu_edges_limit = args.size() > 0 ? std::stoll(args[0]) : 1000000000LL;
cpu_edges_limit = args.size() > 1 ? std::stoll(args[1]) : 1000000000LL;
gpu_edges_each_limit = gpu_edges_limit / gpu_num;
if (gpu_edges_each_limit > INT_MAX) gpu_edges_each_limit = INT_MAX;
table = g->cpu_graph_table.get();
gpu_edges_count = std::vector<int64_t>(gpu_num, 0);
cpu_edges_count = 0;
gpu_set = std::vector<std::unordered_set<int64_t>>(gpu_num);
}
int AllInGpuGraphSampler::run_graph_sampling() { return 0; }
int AllInGpuGraphSampler::load_from_ssd(std::string path) {
graph_table->load_edges(path, false);
sample_nodes.clear();
sample_neighbors.clear();
sample_res.clear();
sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_);
for (int i = 0; i < graph_table->task_pool_size_; i++) {
sample_nodes_ex[i].resize(gpu_num);
sample_neighbors_ex[i].resize(gpu_num);
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < graph_table->shards.size(); ++i) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
paddle::framework::GpuPsGraphNode node;
std::vector<paddle::distributed::Node *> &v =
this->graph_table->shards[i]->get_bucket();
size_t ind = i % this->graph_table->task_pool_size_;
for (size_t j = 0; j < v.size(); j++) {
size_t location = v[j]->get_id() % this->gpu_num;
node.node_id = v[j]->get_id();
node.neighbor_size = v[j]->get_neighbor_size();
node.neighbor_offset =
(int)sample_neighbors_ex[ind][location].size();
sample_nodes_ex[ind][location].emplace_back(node);
for (int k = 0; k < node.neighbor_size; k++)
sample_neighbors_ex[ind][location].push_back(
v[j]->get_neighbor_id(k));
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
tasks.clear();
for (size_t i = 0; i < gpu_num; i++) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
int total_offset = 0;
size_t ind = i;
for (int j = 0; j < this->graph_table->task_pool_size_; j++) {
for (size_t k = 0; k < sample_nodes_ex[j][ind].size(); k++) {
sample_nodes[ind].push_back(sample_nodes_ex[j][ind][k]);
sample_nodes[ind].back().neighbor_offset += total_offset;
}
size_t neighbor_size = sample_neighbors_ex[j][ind].size();
total_offset += neighbor_size;
for (size_t k = 0; k < neighbor_size; k++) {
sample_neighbors[ind].push_back(
sample_neighbors_ex[j][ind][k]);
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
for (size_t i = 0; i < gpu_num; i++) {
sample_res[i].node_list = sample_nodes[i].data();
sample_res[i].neighbor_list = sample_neighbors[i].data();
sample_res[i].node_size = sample_nodes[i].size();
sample_res[i].neighbor_size = sample_neighbors[i].size();
}
gpu_table->build_graph_from_cpu(sample_res);
return 0;
}
void AllInGpuGraphSampler::init(GpuPsGraphTable *g,
std::vector<std::string> args_) {
this->gpu_table = g;
this->gpu_num = g->gpu_num;
graph_table = g->cpu_graph_table.get();
}
}
};
#endif
......@@ -297,12 +297,17 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
}
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template class HashTable<long, int>;
template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
cudaStream_t>(const unsigned long* d_keys,
paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream);
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
int* d_vals, size_t len,
cudaStream_t stream);
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
// const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
......@@ -313,6 +318,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
const paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream);
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
const int* d_vals,
size_t len,
cudaStream_t stream);
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::insert<
// cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
......
......@@ -211,11 +211,11 @@ class HeterComm {
std::vector<std::vector<Path>> path_;
float load_factor_{0.75};
int block_size_{256};
std::unique_ptr<HeterCommKernel> heter_comm_kernel_;
private:
std::unique_ptr<HeterCommKernel> heter_comm_kernel_;
std::vector<LocalStorage> storage_;
int topo_aware_{0};
std::vector<LocalStorage> storage_;
int feanum_{1800 * 2048};
int multi_node_{0};
int node_size_;
......
......@@ -218,6 +218,14 @@ template void HeterCommKernel::calc_shard_index<
int* shard_index, int total_devs,
const cudaStream_t& stream);
template void HeterCommKernel::calc_shard_index<long, int, cudaStream_t>(
long* d_keys, long long len, int* shard_index, int total_devs,
const cudaStream_t& stream);
template void HeterCommKernel::fill_shard_key<long, int, cudaStream_t>(
long* d_shard_keys, long* d_keys, int* idx, long long len,
const cudaStream_t& stream);
template void HeterCommKernel::fill_shard_key<unsigned long, int, cudaStream_t>(
unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len,
const cudaStream_t& stream);
......
......@@ -66,7 +66,6 @@ TEST(TEST_FLEET, graph_sample) {
1,4,7
gpu 2:
2,5,8
query(2,6) returns nodes [6,9,1,4,7,2]
*/
::paddle::distributed::GraphParameter table_proto;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
using namespace paddle::framework;
namespace platform = paddle::platform;
// paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph
// paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
// std::vector<int64_t> ids)
TEST(TEST_FLEET, test_cpu_cache) {
int gpu_num = 0;
int st = 0, u = 0;
std::vector<int> device_id_mapping;
for (int i = 0; i < 2; i++) device_id_mapping.push_back(i);
gpu_num = device_id_mapping.size();
::paddle::distributed::GraphParameter table_proto;
table_proto.set_shard_num(24);
std::shared_ptr<HeterPsResource> resource =
std::make_shared<HeterPsResource>(device_id_mapping);
resource->enable_p2p();
int use_nv = 1;
GpuPsGraphTable g(resource, use_nv);
g.init_cpu_table(table_proto);
std::vector<paddle::framework::GpuPsCommGraph> vec;
int n = 10;
std::vector<int64_t> ids0, ids1;
for (int i = 0; i < n; i++) {
g.cpu_graph_table->add_comm_edge(i, (i + 1) % n);
g.cpu_graph_table->add_comm_edge(i, (i - 1 + n) % n);
if (i % 2 == 0) ids0.push_back(i);
}
ids1.push_back(5);
vec.push_back(g.cpu_graph_table->make_gpu_ps_graph(ids0));
vec.push_back(g.cpu_graph_table->make_gpu_ps_graph(ids1));
vec[0].display_on_cpu();
vec[1].display_on_cpu();
g.build_graph_from_cpu(vec);
int64_t cpu_key[3] = {0, 1, 2};
void *key;
platform::CUDADeviceGuard guard(0);
cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 2, 3);
int64_t *res = new int64_t[7];
cudaMemcpy(res, neighbor_sample_res->val, 3 * 2 * sizeof(int64_t),
cudaMemcpyDeviceToHost);
int *actual_sample_size = new int[3];
cudaMemcpy(actual_sample_size, neighbor_sample_res->actual_sample_size,
3 * sizeof(int),
cudaMemcpyDeviceToHost); // 3, 1, 3
//{0,9} or {9,0} is expected for key 0
//{0,2} or {2,0} is expected for key 1
//{1,3} or {3,1} is expected for key 2
for (int i = 0; i < 3; i++) {
VLOG(0) << "actual sample size for " << i << " is "
<< actual_sample_size[i];
for (int j = 0; j < actual_sample_size[i]; j++) {
VLOG(0) << "sampled an neighbor for node" << i << " : " << res[i * 2 + j];
}
}
}
......@@ -40,6 +40,7 @@
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_sampler.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
......@@ -52,9 +53,13 @@ namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
std::string input_file;
int fixed_key_size = 100, sample_size = 100,
int exe_count = 100;
int use_nv = 1;
int fixed_key_size = 50000, sample_size = 32,
bfs_sample_nodes_in_each_shard = 10000, init_search_size = 1,
bfs_sample_edges = 20;
bfs_sample_edges = 20, gpu_num1 = 8, gpu_num = 8;
std::string gpu_str = "0,1,2,3,4,5,6,7";
int64_t *key[8];
std::vector<std::string> edges = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"),
......@@ -81,16 +86,17 @@ void testSampleRate() {
int start = 0;
pthread_rwlock_t rwlock;
pthread_rwlock_init(&rwlock, NULL);
{
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(false);
// table_proto.set_gpups_mode(false);
table_proto.set_shard_num(127);
table_proto.set_task_pool_size(24);
std::cerr << "initializing begin";
distributed::GraphTable graph_table;
graph_table.initialize(table_proto);
graph_table.Initialize(table_proto);
std::cerr << "initializing done";
graph_table.load(input_file, std::string("e>"));
graph_table.Load(input_file, std::string("e>"));
int sample_actual_size = -1;
int step = fixed_key_size, cur = 0;
while (sample_actual_size != 0) {
......@@ -163,25 +169,48 @@ void testSampleRate() {
std::chrono::duration_cast<std::chrono::microseconds>(end1 - start1);
std::cerr << "total time cost without cache is " << tt.count() << " us"
<< std::endl;
int64_t tot = 0;
for (int i = 0; i < 10; i++) {
for (auto x : sample_id[i]) tot += x;
}
VLOG(0) << "sum = " << tot;
}
const int gpu_num = 8;
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_shard_num(127);
table_proto.set_gpu_num(gpu_num);
table_proto.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto.set_gpups_graph_sample_args(std::to_string(init_search_size) +
",100000000,10000000,1,1");
std::vector<int> dev_ids;
for (int i = 0; i < gpu_num; i++) {
dev_ids.push_back(i);
gpu_num = 0;
int st = 0, u = 0;
std::vector<int> device_id_mapping;
while (u < gpu_str.size()) {
VLOG(0) << u << " " << gpu_str[u];
if (gpu_str[u] == ',') {
auto p = gpu_str.substr(st, u - st);
int id = std::stoi(p);
VLOG(0) << "got a new device id" << id;
device_id_mapping.push_back(id);
st = u + 1;
}
u++;
}
auto p = gpu_str.substr(st, gpu_str.size() - st);
int id = std::stoi(p);
VLOG(0) << "got a new device id" << id;
device_id_mapping.push_back(id);
gpu_num = device_id_mapping.size();
::paddle::distributed::GraphParameter table_proto;
table_proto.set_shard_num(24);
// table_proto.set_gpups_graph_sample_class("CompleteGraphSampler");
std::shared_ptr<HeterPsResource> resource =
std::make_shared<HeterPsResource>(dev_ids);
std::make_shared<HeterPsResource>(device_id_mapping);
resource->enable_p2p();
GpuPsGraphTable g(resource);
GpuPsGraphTable g(resource, use_nv);
g.init_cpu_table(table_proto);
g.load(std::string(input_file), std::string("e>"));
std::vector<std::string> arg;
AllInGpuGraphSampler sampler;
sampler.init(&g, arg);
// g.load(std::string(input_file), std::string("e>"));
// sampler.start(std::string(input_file));
// sampler.load_from_ssd(std::string(input_file));
sampler.start_service(input_file);
/*
NodeQueryResult *query_node_res;
query_node_res = g.query_node_list(0, 0, ids.size() + 10000);
......@@ -209,52 +238,65 @@ void testSampleRate() {
auto q = g.query_node_list(0, st, ids.size() / 20);
VLOG(0) << " the " << i << "th iteration size = " << q->actual_sample_size;
}
// NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
/*
void *key;
// NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
*/
for (int i = 0; i < gpu_num1; i++) {
platform::CUDADeviceGuard guard(device_id_mapping[i]);
cudaMalloc((void **)&key[i], ids.size() * sizeof(int64_t));
cudaMemcpy(key[i], ids.data(), ids.size() * sizeof(int64_t),
cudaMemcpyHostToDevice);
}
/*
cudaMalloc((void **)&key, ids.size() * sizeof(int64_t));
cudaMemcpy(key, ids.data(), ids.size() * sizeof(int64_t),
cudaMemcpyHostToDevice);
std::vector<NeighborSampleResult *> res[gpu_num];
*/
/*
std::vector<std::vector<NeighborSampleResult *>> res(gpu_num1);
for (int i = 0; i < gpu_num1; i++) {
int st = 0;
int size = ids.size();
NeighborSampleResult *result = new NeighborSampleResult(sample_size, size);
platform::CUDAPlace place = platform::CUDAPlace(device_id_mapping[i]);
platform::CUDADeviceGuard guard(device_id_mapping[i]);
cudaMalloc((void **)&result->val, size * sample_size * sizeof(int64_t));
cudaMalloc((void **)&result->actual_sample_size, size * sizeof(int));
res[i].push_back(result);
}
*/
start = 0;
auto func = [&rwlock, &g, &res, &start,
&gpu_num, &ids, &key](int i) {
while (true) {
int s, sn;
bool exit = false;
pthread_rwlock_wrlock(&rwlock);
if (start < ids.size()) {
s = start;
sn = ids.size() - start;
sn = min(sn, fixed_key_size);
start += sn;
} else {
exit = true;
auto func = [&rwlock, &g, &start, &ids](int i) {
int st = 0;
int size = ids.size();
for (int k = 0; k < exe_count; k++) {
st = 0;
while (st < size) {
int len = std::min(fixed_key_size, (int)ids.size() - st);
auto r = g.graph_neighbor_sample(i, (int64_t *)(key[i] + st),
sample_size, len);
st += len;
delete r;
}
pthread_rwlock_unlock(&rwlock);
if (exit) break;
auto r =
g.graph_neighbor_sample(i, (int64_t *)(key + s), sample_size, sn);
res[i].push_back(r);
}
};
auto start1 = std::chrono::steady_clock::now();
std::thread thr[gpu_num];
for (int i = 0; i < gpu_num; i++) {
std::thread thr[gpu_num1];
for (int i = 0; i < gpu_num1; i++) {
thr[i] = std::thread(func, i);
}
for (int i = 0; i < gpu_num; i++) thr[i].join();
for (int i = 0; i < gpu_num1; i++) thr[i].join();
auto end1 = std::chrono::steady_clock::now();
auto tt =
std::chrono::duration_cast<std::chrono::microseconds>(end1 - start1);
std::cerr << "total time cost without cache is " << tt.count() << " us"
<< std::endl;
*/
std::cerr << "total time cost without cache is "
<< tt.count() / exe_count / gpu_num1 << " us" << std::endl;
for (int i = 0; i < gpu_num1; i++) {
cudaFree(key[i]);
}
#endif
}
// TEST(testSampleRate, Run) { testSampleRate(); }
TEST(TEST_FLEET, sample_rate) { testSampleRate(); }
int main(int argc, char *argv[]) {
for (int i = 0; i < argc; i++)
......@@ -276,5 +318,14 @@ int main(int argc, char *argv[]) {
VLOG(0) << "sample_size neighbor_size is " << sample_size;
if (argc > 4) init_search_size = std::stoi(argv[4]);
VLOG(0) << " init_search_size " << init_search_size;
if (argc > 5) {
gpu_str = argv[5];
}
VLOG(0) << " gpu_str= " << gpu_str;
gpu_num = 0;
if (argc > 6) gpu_num1 = std::stoi(argv[6]);
VLOG(0) << " gpu_thread_num= " << gpu_num1;
if (argc > 7) use_nv = std::stoi(argv[7]);
VLOG(0) << " use_nv " << use_nv;
testSampleRate();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册