未验证 提交 ce72690c 编写于 作者: S seemingwang 提交者: GitHub

gpu_graph engine optimization+ (#41455)

* extract sub-graph

* graph-engine merging

* fix

* fix

* fix heter-ps config

* test performance

* test performance

* test performance

* test

* test

* update bfs

* change cmake

* test

* test gpu speed

* gpu_graph_engine optimization

* add ssd layer to graph_engine

* fix allocation

* fix syntax error

* fix syntax error

* fix pscore class

* fix

* recover test

* recover test

* fix spelling

* recover

* fix
上级 c37af19c
...@@ -28,7 +28,112 @@ namespace paddle { ...@@ -28,7 +28,112 @@ namespace paddle {
namespace distributed { namespace distributed {
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
std::vector<int64_t> ids) {
std::vector<std::vector<int64_t>> bags(task_pool_size_);
for (auto x : ids) {
int location = x % shard_num % task_pool_size_;
bags[location].push_back(x);
}
std::vector<std::future<int>> tasks;
std::vector<int64_t> edge_array[task_pool_size_];
std::vector<paddle::framework::GpuPsGraphNode> node_array[task_pool_size_];
for (int i = 0; i < (int)bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
paddle::framework::GpuPsGraphNode x;
for (int j = 0; j < (int)bags[i].size(); j++) {
Node *v = find_node(bags[i][j]);
x.node_id = bags[i][j];
if (v == NULL) {
x.neighbor_size = 0;
x.neighbor_offset = 0;
node_array[i].push_back(x);
} else {
x.neighbor_size = v->get_neighbor_size();
x.neighbor_offset = edge_array[i].size();
node_array[i].push_back(x);
for (int k = 0; k < x.neighbor_size; k++) {
edge_array[i].push_back(v->get_neighbor_id(k));
}
}
}
return 0;
}));
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
paddle::framework::GpuPsCommGraph res;
int tot_len = 0;
for (int i = 0; i < task_pool_size_; i++) {
tot_len += (int)edge_array[i].size();
}
res.neighbor_size = tot_len;
res.node_size = ids.size();
res.neighbor_list = new int64_t[tot_len];
res.node_list = new paddle::framework::GpuPsGraphNode[ids.size()];
int offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) {
for (int j = 0; j < (int)node_array[i].size(); j++) {
res.node_list[ind] = node_array[i][j];
res.node_list[ind++].neighbor_offset += offset;
}
for (int j = 0; j < (int)edge_array[i].size(); j++) {
res.neighbor_list[offset + j] = edge_array[i][j];
}
offset += edge_array[i].size();
}
return res;
}
int32_t GraphTable::add_node_to_ssd(int64_t src_id, char *data, int len) {
if (_db != NULL)
_db->put(src_id % shard_num % task_pool_size_, (char *)&src_id,
sizeof(uint64_t), (char *)data, sizeof(int64_t) * len);
return 0;
}
char *GraphTable::random_sample_neighbor_from_ssd(
int64_t id, int sample_size, const std::shared_ptr<std::mt19937_64> rng,
int &actual_size) {
if (_db == NULL) {
actual_size = 0;
return NULL;
}
std::string str;
if (_db->get(id % shard_num % task_pool_size_, (char *)&id, sizeof(uint64_t),
str) == 0) {
int64_t *data = ((int64_t *)str.c_str());
int n = str.size() / sizeof(int64_t);
std::unordered_map<int, int> m;
// std::vector<int64_t> res;
int sm_size = std::min(n, sample_size);
actual_size = sm_size * Node::id_size;
char *buff = new char[actual_size];
for (int i = 0; i < sm_size; i++) {
std::uniform_int_distribution<int> distrib(0, n - i - 1);
int t = distrib(*rng);
// int t = rand() % (n-i);
int pos = 0;
auto iter = m.find(t);
if (iter != m.end()) {
pos = iter->second;
} else {
pos = t;
}
auto iter2 = m.find(n - i - 1);
int key2 = iter2 == m.end() ? n - i - 1 : iter2->second;
m[t] = key2;
m.erase(n - i - 1);
memcpy(buff + i * Node::id_size, &data[pos], Node::id_size);
// res.push_back(data[pos]);
}
return buff;
}
actual_size = 0;
return NULL;
}
#endif
/*
int CompleteGraphSampler::run_graph_sampling() { int CompleteGraphSampler::run_graph_sampling() {
pthread_rwlock_t *rw_lock = graph_table->rw_lock.get(); pthread_rwlock_t *rw_lock = graph_table->rw_lock.get();
pthread_rwlock_rdlock(rw_lock); pthread_rwlock_rdlock(rw_lock);
...@@ -136,7 +241,8 @@ int BasicBfsGraphSampler::run_graph_sampling() { ...@@ -136,7 +241,8 @@ int BasicBfsGraphSampler::run_graph_sampling() {
int task_size = 0; int task_size = 0;
std::vector<std::future<int>> tasks; std::vector<std::future<int>> tasks;
int init_size = 0; int init_size = 0;
std::function<int(int, int64_t)> bfs = [&, this](int i, int64_t id) -> int { //__sync_fetch_and_add
std::function<int(int, int64_t)> bfs = [&, this](int i, int id) -> int {
if (this->status == GraphSamplerStatus::terminating) { if (this->status == GraphSamplerStatus::terminating) {
int task_left = __sync_sub_and_fetch(&task_size, 1); int task_left = __sync_sub_and_fetch(&task_size, 1);
if (task_left == 0) { if (task_left == 0) {
...@@ -289,6 +395,7 @@ int BasicBfsGraphSampler::run_graph_sampling() { ...@@ -289,6 +395,7 @@ int BasicBfsGraphSampler::run_graph_sampling() {
std::this_thread::sleep_for(std::chrono::seconds(1)); std::this_thread::sleep_for(std::chrono::seconds(1));
} }
} }
VLOG(0)<<"bfs returning";
} }
return 0; return 0;
} }
...@@ -304,7 +411,7 @@ void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table, ...@@ -304,7 +411,7 @@ void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
} }
#endif #endif
*/
std::vector<Node *> GraphShard::get_batch(int start, int end, int step) { std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0; if (start < 0) start = 0;
std::vector<Node *> res; std::vector<Node *> res;
...@@ -316,6 +423,17 @@ std::vector<Node *> GraphShard::get_batch(int start, int end, int step) { ...@@ -316,6 +423,17 @@ std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
size_t GraphShard::get_size() { return bucket.size(); } size_t GraphShard::get_size() { return bucket.size(); }
int32_t GraphTable::add_comm_edge(int64_t src_id, int64_t dst_id) {
size_t src_shard_id = src_id % shard_num;
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
return -1;
}
size_t index = src_shard_id - shard_start;
extra_shards[index]->add_graph_node(src_id)->build_edges(false);
extra_shards[index]->add_neighbor(src_id, dst_id, 1.0);
return 0;
}
int32_t GraphTable::add_graph_node(std::vector<int64_t> &id_list, int32_t GraphTable::add_graph_node(std::vector<int64_t> &id_list,
std::vector<bool> &is_weight_list) { std::vector<bool> &is_weight_list) {
size_t node_size = id_list.size(); size_t node_size = id_list.size();
...@@ -554,9 +672,9 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { ...@@ -554,9 +672,9 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
} }
int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
#ifdef PADDLE_WITH_HETERPS // #ifdef PADDLE_WITH_HETERPS
if (gpups_mode) pthread_rwlock_rdlock(rw_lock.get()); // if (gpups_mode) pthread_rwlock_rdlock(rw_lock.get());
#endif // #endif
auto paths = paddle::string::split_string<std::string>(path, ";"); auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0; int64_t count = 0;
std::string sample_type = "random"; std::string sample_type = "random";
...@@ -633,9 +751,9 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { ...@@ -633,9 +751,9 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
relocate the duplicate nodes to make them distributed evenly among threads. relocate the duplicate nodes to make them distributed evenly among threads.
*/ */
if (!use_duplicate_nodes) { if (!use_duplicate_nodes) {
#ifdef PADDLE_WITH_HETERPS // #ifdef PADDLE_WITH_HETERPS
if (gpups_mode) pthread_rwlock_unlock(rw_lock.get()); // if (gpups_mode) pthread_rwlock_unlock(rw_lock.get());
#endif // #endif
return 0; return 0;
} }
...@@ -712,9 +830,9 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { ...@@ -712,9 +830,9 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
delete extra_shards[i]; delete extra_shards[i];
extra_shards[i] = extra_shards_copy[i]; extra_shards[i] = extra_shards_copy[i];
} }
#ifdef PADDLE_WITH_HETERPS // #ifdef PADDLE_WITH_HETERPS
if (gpups_mode) pthread_rwlock_unlock(rw_lock.get()); // if (gpups_mode) pthread_rwlock_unlock(rw_lock.get());
#endif // #endif
return 0; return 0;
} }
...@@ -878,6 +996,17 @@ int32_t GraphTable::random_sample_neighbors( ...@@ -878,6 +996,17 @@ int32_t GraphTable::random_sample_neighbors(
idx = seq_id[i][k]; idx = seq_id[i][k];
int &actual_size = actual_sizes[idx]; int &actual_size = actual_sizes[idx];
if (node == nullptr) { if (node == nullptr) {
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) {
char *buffer_addr = random_sample_neighbor_from_ssd(
node_id, sample_size, rng, actual_size);
if (actual_size != 0) {
std::shared_ptr<char> &buffer = buffers[idx];
buffer.reset(buffer_addr, char_del);
}
continue;
}
#endif
actual_size = 0; actual_size = 0;
continue; continue;
} }
...@@ -1085,25 +1214,29 @@ int32_t GraphTable::Initialize(const TableParameter &config, ...@@ -1085,25 +1214,29 @@ int32_t GraphTable::Initialize(const TableParameter &config,
return Initialize(graph); return Initialize(graph);
} }
int32_t GraphTable::Initialize(const GraphParameter &graph) { int32_t GraphTable::Initialize(const GraphParameter &graph) {
task_pool_size_ = graph.task_pool_size();
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
if (graph.gpups_mode()) { _db = NULL;
gpups_mode = true; search_level = graph.search_level();
auto *sampler = if (search_level >= 2) {
CREATE_PSCORE_CLASS(GraphSampler, graph.gpups_graph_sample_class()); _db = paddle::distributed::RocksDBHandler::GetInstance();
auto slices = _db->initialize("./temp_gpups_db", task_pool_size_);
string::split_string<std::string>(graph.gpups_graph_sample_args(), ","); }
std::cout << "slices" << std::endl; // gpups_mode = true;
for (auto x : slices) std::cout << x << std::endl; // auto *sampler =
sampler->init(graph.gpu_num(), this, slices); // CREATE_PSCORE_CLASS(GraphSampler, graph.gpups_graph_sample_class());
graph_sampler.reset(sampler); // auto slices =
} // string::split_string<std::string>(graph.gpups_graph_sample_args(), ",");
// std::cout << "slices" << std::endl;
// for (auto x : slices) std::cout << x << std::endl;
// sampler->init(graph.gpu_num(), this, slices);
// graph_sampler.reset(sampler);
#endif #endif
if (shard_num == 0) { if (shard_num == 0) {
server_num = 1; server_num = 1;
_shard_idx = 0; _shard_idx = 0;
shard_num = graph.shard_num(); shard_num = graph.shard_num();
} }
task_pool_size_ = graph.task_pool_size();
use_cache = graph.use_cache(); use_cache = graph.use_cache();
if (use_cache) { if (use_cache) {
cache_size_limit = graph.cache_size_limit(); cache_size_limit = graph.cache_size_limit();
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h" #include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h" #include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -351,6 +352,7 @@ class ScaledLRU { ...@@ -351,6 +352,7 @@ class ScaledLRU {
friend class RandomSampleLRU<K, V>; friend class RandomSampleLRU<K, V>;
}; };
/*
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 }; enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable; class GraphTable;
...@@ -363,6 +365,9 @@ class GraphSampler { ...@@ -363,6 +365,9 @@ class GraphSampler {
return; return;
}; };
} }
virtual int loadData(const std::string &path){
return 0;
}
virtual int run_graph_sampling() = 0; virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() { virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) { if (status != GraphSamplerStatus::waiting) {
...@@ -403,15 +408,13 @@ class GraphSampler { ...@@ -403,15 +408,13 @@ class GraphSampler {
std::vector<paddle::framework::GpuPsCommGraph> sample_res; std::vector<paddle::framework::GpuPsCommGraph> sample_res;
}; };
#endif #endif
*/
class GraphTable : public Table { class GraphTable : public Table {
public: public:
GraphTable() { GraphTable() {
use_cache = false; use_cache = false;
shard_num = 0; shard_num = 0;
#ifdef PADDLE_WITH_HETERPS
gpups_mode = false;
#endif
rw_lock.reset(new pthread_rwlock_t()); rw_lock.reset(new pthread_rwlock_t());
} }
virtual ~GraphTable(); virtual ~GraphTable();
...@@ -516,21 +519,28 @@ class GraphTable : public Table { ...@@ -516,21 +519,28 @@ class GraphTable : public Table {
return 0; return 0;
} }
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
virtual int32_t start_graph_sampling() { // virtual int32_t start_graph_sampling() {
return this->graph_sampler->start_graph_sampling(); // return this->graph_sampler->start_graph_sampling();
} // }
virtual int32_t end_graph_sampling() { // virtual int32_t end_graph_sampling() {
return this->graph_sampler->end_graph_sampling(); // return this->graph_sampler->end_graph_sampling();
} // }
virtual int32_t set_graph_sample_callback( // virtual int32_t set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)> // std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) { // callback) {
graph_sampler->set_graph_sample_callback(callback); // graph_sampler->set_graph_sample_callback(callback);
return 0; // return 0;
} // }
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); } virtual char *random_sample_neighbor_from_ssd(
int64_t id, int sample_size, const std::shared_ptr<std::mt19937_64> rng,
int &actual_size);
virtual int32_t add_node_to_ssd(int64_t id, char *data, int len);
virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
std::vector<int64_t> ids);
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); }
int search_level;
#endif #endif
protected: virtual int32_t add_comm_edge(int64_t src_id, int64_t dst_id);
std::vector<GraphShard *> shards, extra_shards; std::vector<GraphShard *> shards, extra_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num; size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
int task_pool_size_ = 24; int task_pool_size_ = 24;
...@@ -555,13 +565,14 @@ class GraphTable : public Table { ...@@ -555,13 +565,14 @@ class GraphTable : public Table {
std::shared_ptr<pthread_rwlock_t> rw_lock; std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table; // paddle::framework::GpuPsGraphTable gpu_graph_table;
bool gpups_mode; paddle::distributed::RocksDBHandler *_db;
// std::shared_ptr<::ThreadPool> graph_sample_pool; // std::shared_ptr<::ThreadPool> graph_sample_pool;
std::shared_ptr<GraphSampler> graph_sampler; // std::shared_ptr<GraphSampler> graph_sampler;
REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler) // REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif #endif
}; };
/*
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler); REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler { class CompleteGraphSampler : public GraphSampler {
...@@ -603,6 +614,7 @@ class BasicBfsGraphSampler : public GraphSampler { ...@@ -603,6 +614,7 @@ class BasicBfsGraphSampler : public GraphSampler {
sample_neighbors_map; sample_neighbors_map;
}; };
#endif #endif
*/
} // namespace distributed } // namespace distributed
}; // namespace paddle }; // namespace paddle
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
#include <glog/logging.h> #include <glog/logging.h>
#include <rocksdb/db.h> #include <rocksdb/db.h>
......
...@@ -31,10 +31,6 @@ namespace paddle { ...@@ -31,10 +31,6 @@ namespace paddle {
namespace distributed { namespace distributed {
REGISTER_PSCORE_CLASS(Table, GraphTable); REGISTER_PSCORE_CLASS(Table, GraphTable);
REGISTER_PSCORE_CLASS(Table, MemoryDenseTable); REGISTER_PSCORE_CLASS(Table, MemoryDenseTable);
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_CLASS(GraphSampler, CompleteGraphSampler);
REGISTER_PSCORE_CLASS(GraphSampler, BasicBfsGraphSampler);
#endif
REGISTER_PSCORE_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, BarrierTable);
REGISTER_PSCORE_CLASS(Table, TensorTable); REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS(Table, DenseTensorTable); REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
......
...@@ -25,7 +25,7 @@ set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${ ...@@ -25,7 +25,7 @@ set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${
cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_table_sample_test SRCS graph_table_sample_test.cc DEPS scope server communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) cc_test(graph_table_sample_test SRCS graph_table_sample_test.cc DEPS table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table) cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table)
......
...@@ -679,7 +679,7 @@ void testCache() { ...@@ -679,7 +679,7 @@ void testCache() {
st.query(0, &skey, 1, r); st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 1); ASSERT_EQ((int)r.size(), 1);
char* p = (char*)r[0].second.buffer.get(); char* p = (char*)r[0].second.buffer.get();
for (size_t j = 0; j < r[0].second.actual_size; j++) for (int j = 0; j < (int)r[0].second.actual_size; j++)
ASSERT_EQ(p[j], str[j]); ASSERT_EQ(p[j], str[j]);
r.clear(); r.clear();
} }
......
...@@ -25,18 +25,7 @@ ...@@ -25,18 +25,7 @@
#include <chrono> #include <chrono>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h" #include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace operators = paddle::operators; namespace operators = paddle::operators;
...@@ -83,66 +72,11 @@ void prepare_file(char file_name[], std::vector<std::string> data) { ...@@ -83,66 +72,11 @@ void prepare_file(char file_name[], std::vector<std::string> data) {
} }
void testGraphSample() { void testGraphSample() {
#ifdef PADDLE_WITH_HETERPS
::paddle::distributed::GraphParameter table_proto; ::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true); // table_proto.set_gpu_num(2);
table_proto.set_shard_num(127);
table_proto.set_gpu_num(2);
distributed::GraphTable graph_table, graph_table1; distributed::GraphTable graph_table;
graph_table.initialize(table_proto); graph_table.Initialize(table_proto);
prepare_file(edge_file_name, edges);
graph_table.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res;
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_table.set_graph_sample_callback(
[&res, &prom](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res = res0;
prom.set_value(0);
});
graph_table.start_graph_sampling();
fut.get();
graph_table.end_graph_sampling();
ASSERT_EQ(2, res.size());
// 37 59 97
for (int i = 0; i < (int)res[1].node_size; i++) {
std::cout << res[1].node_list[i].node_id << std::endl;
}
ASSERT_EQ(3, res[1].node_size);
::paddle::distributed::GraphParameter table_proto1;
table_proto1.set_gpups_mode(true);
table_proto1.set_shard_num(127);
table_proto1.set_gpu_num(2);
table_proto1.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto1.set_gpups_graph_sample_args("5,5,1,1");
graph_table1.initialize(table_proto1);
graph_table1.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res1;
std::promise<int> prom1;
std::future<int> fut1 = prom1.get_future();
graph_table1.set_graph_sample_callback(
[&res1, &prom1](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res1 = res0;
prom1.set_value(0);
});
graph_table1.start_graph_sampling();
fut1.get();
graph_table1.end_graph_sampling();
// distributed::BasicBfsGraphSampler *sampler1 =
// (distributed::BasicBfsGraphSampler *)graph_table1.get_graph_sampler();
// sampler1->start_graph_sampling();
// std::this_thread::sleep_for (std::chrono::seconds(1));
// std::vector<paddle::framework::GpuPsCommGraph> res1;// =
// sampler1->fetch_sample_res();
ASSERT_EQ(2, res1.size());
// odd id:96 48 122 112
for (int i = 0; i < (int)res1[0].node_size; i++) {
std::cout << res1[0].node_list[i].node_id << std::endl;
}
ASSERT_EQ(4, res1[0].node_size);
#endif
} }
TEST(testGraphSample, Run) { testGraphSample(); } TEST(testGraphSample, Run) { testGraphSample(); }
...@@ -215,18 +215,16 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule ...@@ -215,18 +215,16 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule
message GraphParameter { message GraphParameter {
optional int32 task_pool_size = 1 [ default = 24 ]; optional int32 task_pool_size = 1 [ default = 24 ];
optional bool gpups_mode = 2 [ default = false ]; optional string gpups_graph_sample_class = 2
optional string gpups_graph_sample_class = 3
[ default = "CompleteGraphSampler" ]; [ default = "CompleteGraphSampler" ];
optional string gpups_graph_sample_args = 4 [ default = "" ]; optional bool use_cache = 3 [ default = false ];
optional bool use_cache = 5 [ default = false ]; optional int32 cache_size_limit = 4 [ default = 100000 ];
optional int32 cache_size_limit = 6 [ default = 100000 ]; optional int32 cache_ttl = 5 [ default = 5 ];
optional int32 cache_ttl = 7 [ default = 5 ]; optional GraphFeature graph_feature = 6;
optional GraphFeature graph_feature = 8; optional string table_name = 7 [ default = "" ];
optional string table_name = 9 [ default = "" ]; optional string table_type = 8 [ default = "" ];
optional string table_type = 10 [ default = "" ]; optional int32 shard_num = 9 [ default = 127 ];
optional int32 shard_num = 11 [ default = 127 ]; optional int32 search_level = 10 [ default = 1 ];
optional int32 gpu_num = 12 [ default = 1 ];
} }
message GraphFeature { message GraphFeature {
......
...@@ -13,13 +13,16 @@ IF(WITH_GPU) ...@@ -13,13 +13,16 @@ IF(WITH_GPU)
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm) nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
if(WITH_PSCORE) if(WITH_PSCORE)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table) nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table)
nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps) nv_library(graph_sampler SRCS graph_sampler_inl.h DEPS graph_gpu_ps)
nv_test(test_cpu_graph_sample SRCS test_cpu_graph_sample.cu DEPS graph_gpu_ps) #nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps)
#nv_test(test_sample_rate SRCS test_sample_rate.cu DEPS graph_gpu_ps) #nv_test(test_cpu_graph_sample SRCS test_cpu_graph_sample.cu DEPS graph_gpu_ps)
# ADD_EXECUTABLE(test_sample_rate test_sample_rate.cu) #nv_test(test_sample_rate SRCS test_sample_rate.cu DEPS graph_gpu_ps)
# target_link_libraries(test_sample_rate graph_gpu_ps) # ADD_EXECUTABLE(test_sample_rate test_sample_rate.cu)
# target_link_libraries(test_sample_rate graph_gpu_ps graph_sampler)
# nv_test(test_graph_xx SRCS test_xx.cu DEPS graph_gpu_ps graph_sampler)
endif() endif()
ENDIF() ENDIF()
IF(WITH_ROCM) IF(WITH_ROCM)
hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context)
......
...@@ -14,6 +14,11 @@ ...@@ -14,6 +14,11 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
#include <iostream>
#include <memory>
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct GpuPsGraphNode { struct GpuPsGraphNode {
...@@ -94,16 +99,24 @@ struct NeighborSampleResult { ...@@ -94,16 +99,24 @@ struct NeighborSampleResult {
int64_t *val; int64_t *val;
int *actual_sample_size, sample_size, key_size; int *actual_sample_size, sample_size, key_size;
int *offset; int *offset;
NeighborSampleResult(int _sample_size, int _key_size) std::shared_ptr<memory::Allocation> val_mem, actual_sample_size_mem;
NeighborSampleResult(int _sample_size, int _key_size, int dev_id)
: sample_size(_sample_size), key_size(_key_size) { : sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL; platform::CUDADeviceGuard guard(dev_id);
val = NULL; platform::CUDAPlace place = platform::CUDAPlace(dev_id);
val_mem =
memory::AllocShared(place, _sample_size * _key_size * sizeof(int64_t));
val = (int64_t *)val_mem->ptr();
actual_sample_size_mem =
memory::AllocShared(place, _key_size * sizeof(int));
actual_sample_size = (int *)actual_sample_size_mem->ptr();
offset = NULL; offset = NULL;
}; };
~NeighborSampleResult() { ~NeighborSampleResult() {
if (val != NULL) cudaFree(val); // if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size); // if (actual_sample_size != NULL) cudaFree(actual_sample_size);
if (offset != NULL) cudaFree(offset); // if (offset != NULL) cudaFree(offset);
} }
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <chrono>
#include "heter_comm.h" #include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
...@@ -21,19 +22,64 @@ ...@@ -21,19 +22,64 @@
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class GpuPsGraphTable : public HeterComm<int64_t, int, int> { class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
public: public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource) GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
: HeterComm<int64_t, int, int>(1, resource) { : HeterComm<int64_t, int, int>(1, resource) {
load_factor_ = 0.25; load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t()); rw_lock.reset(new pthread_rwlock_t());
gpu_num = resource_->total_gpu();
cpu_table_status = -1; cpu_table_status = -1;
if (topo_aware) {
int total_gpu = resource_->total_gpu();
std::map<int, int> device_map;
for (int i = 0; i < total_gpu; i++) {
device_map[resource_->dev_id(i)] = i;
VLOG(1) << " device " << resource_->dev_id(i) << " is stored on " << i;
}
path_.clear();
path_.resize(total_gpu);
VLOG(1) << "topo aware overide";
for (int i = 0; i < total_gpu; ++i) {
path_[i].resize(total_gpu);
for (int j = 0; j < total_gpu; ++j) {
auto &nodes = path_[i][j].nodes_;
nodes.clear();
int from = resource_->dev_id(i);
int to = resource_->dev_id(j);
int transfer_id = i;
if (need_transfer(from, to) &&
(device_map.find((from + 4) % 8) != device_map.end() ||
device_map.find((to + 4) % 8) != device_map.end())) {
transfer_id = (device_map.find((from + 4) % 8) != device_map.end())
? ((from + 4) % 8)
: ((to + 4) % 8);
transfer_id = device_map[transfer_id];
nodes.push_back(Node());
Node &node = nodes.back();
node.in_stream = resource_->comm_stream(i, transfer_id);
node.out_stream = resource_->comm_stream(transfer_id, i);
node.key_storage = NULL;
node.val_storage = NULL;
node.sync = 0;
node.gpu_num = transfer_id;
}
nodes.push_back(Node());
Node &node = nodes.back();
node.in_stream = resource_->comm_stream(i, transfer_id);
node.out_stream = resource_->comm_stream(transfer_id, i);
node.key_storage = NULL;
node.val_storage = NULL;
node.sync = 0;
node.gpu_num = j;
}
}
}
} }
~GpuPsGraphTable() { ~GpuPsGraphTable() {
if (cpu_table_status != -1) { // if (cpu_table_status != -1) {
end_graph_sampling(); // end_graph_sampling();
} // }
} }
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list); void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
NodeQueryResult *graph_node_sample(int gpu_id, int sample_size); NodeQueryResult *graph_node_sample(int gpu_id, int sample_size);
...@@ -41,21 +87,28 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> { ...@@ -41,21 +87,28 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int sample_size, int len); int sample_size, int len);
NodeQueryResult *query_node_list(int gpu_id, int start, int query_size); NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
void clear_graph_info(); void clear_graph_info();
void move_neighbor_sample_result_to_source_gpu( void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num,
int gpu_id, int gpu_num, int *h_left, int *h_right, int sample_size, int *h_left,
int64_t *src_sample_res, thrust::host_vector<int> &total_sample_size); int *h_right,
void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num, int64_t *src_sample_res,
int *h_left, int *h_right, int *actual_sample_size);
int *actual_sample_size, // void move_neighbor_sample_result_to_source_gpu(
int *total_sample_size); // int gpu_id, int gpu_num, int *h_left, int *h_right,
// int64_t *src_sample_res, thrust::host_vector<int> &total_sample_size);
// void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num,
// int *h_left, int *h_right,
// int *actual_sample_size,
// int *total_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph); int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int load(const std::string &path, const std::string &param); // int load(const std::string &path, const std::string &param);
virtual int32_t end_graph_sampling() { // virtual int32_t end_graph_sampling() {
return cpu_graph_table->end_graph_sampling(); // return cpu_graph_table->end_graph_sampling();
} // }
int gpu_num;
private:
std::vector<GpuPsCommGraph> gpu_graph_list; std::vector<GpuPsCommGraph> gpu_graph_list;
std::vector<int *> sample_status;
const int parallel_sample_size = 1;
const int dim_y = 256;
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table; std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table;
std::shared_ptr<pthread_rwlock_t> rw_lock; std::shared_ptr<pthread_rwlock_t> rw_lock;
mutable std::mutex mutex_; mutable std::mutex mutex_;
......
...@@ -13,23 +13,10 @@ ...@@ -13,23 +13,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" //#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
constexpr int WARP_SIZE = 32;
/* /*
comment 0 comment 0
this kernel just serves as an example of how to sample nodes' neighbors. this kernel just serves as an example of how to sample nodes' neighbors.
...@@ -42,116 +29,113 @@ sample_size; ...@@ -42,116 +29,113 @@ sample_size;
*/ */
struct MaxFunctor { __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* node_index,
int sample_size; int* actual_size, int64_t* res,
HOSTDEVICE explicit inline MaxFunctor(int sample_size) { int sample_len, int* sample_status,
this->sample_size = sample_size; int n, int from) {
} // printf("%d %d %d\n",blockIdx.x,threadIdx.x,threadIdx.y);
HOSTDEVICE inline int operator()(int x) const { int id = blockIdx.x * blockDim.y + threadIdx.y;
if (x > sample_size) { if (id < n) {
return sample_size; curandState rng;
curand_init(blockIdx.x, threadIdx.x, threadIdx.y, &rng);
int index = threadIdx.x;
int offset = id * sample_len;
int64_t* data = graph.neighbor_list;
int data_offset = graph.node_list[node_index[id]].neighbor_offset;
int neighbor_len = graph.node_list[node_index[id]].neighbor_size;
int ac_len;
if (sample_len > neighbor_len)
ac_len = neighbor_len;
else {
ac_len = sample_len;
} }
return x; if (4 * ac_len >= 3 * neighbor_len) {
} if (index == 0) {
}; res[offset] = curand(&rng) % (neighbor_len - ac_len + 1);
struct DegreeFunctor {
GpuPsCommGraph graph;
HOSTDEVICE explicit inline DegreeFunctor(GpuPsCommGraph graph) {
this->graph = graph;
}
HOSTDEVICE inline int operator()(int i) const {
return graph.node_list[i].neighbor_size;
}
};
template <int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample(const uint64_t rand_seed, GpuPsCommGraph graph,
int sample_size, int* index, int len,
int64_t* sample_result, int* output_idx,
int* output_offset) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, len);
curandState rng;
curand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng);
while (i < last_idx) {
auto node_index = index[i];
int degree = graph.node_list[node_index].neighbor_size;
const int offset = graph.node_list[node_index].neighbor_offset;
int output_start = output_offset[i];
if (degree <= sample_size) {
// Just copy
for (int j = threadIdx.x; j < degree; j += WARP_SIZE) {
sample_result[output_start + j] = graph.neighbor_list[offset + j];
}
} else {
for (int j = threadIdx.x; j < degree; j += WARP_SIZE) {
output_idx[output_start + j] = j;
} }
__syncwarp(); __syncwarp();
int start = res[offset];
for (int j = sample_size + threadIdx.x; j < degree; j += WARP_SIZE) { while (index < ac_len) {
const int num = curand(&rng) % (j + 1); res[offset + index] = data[data_offset + start + index];
if (num < sample_size) { index += blockDim.x;
atomicMax( }
reinterpret_cast<unsigned int*>(output_idx + output_start + num), actual_size[id] = ac_len;
static_cast<unsigned int>(j)); } else {
while (index < ac_len) {
int num = curand(&rng) % neighbor_len;
int* addr = sample_status + data_offset + num;
int expected = *addr;
if (!(expected & (1 << from))) {
int old = atomicCAS(addr, expected, expected | (1 << from));
if (old == expected) {
res[offset + index] = num;
index += blockDim.x;
}
} }
} }
__syncwarp(); __syncwarp();
index = threadIdx.x;
for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) { while (index < ac_len) {
const int perm_idx = output_idx[output_start + j] + offset; int* addr = sample_status + data_offset + res[offset + index];
sample_result[output_start + j] = graph.neighbor_list[perm_idx]; int expected, old = *addr;
do {
expected = old;
old = atomicCAS(addr, expected, expected & (~(1 << from)));
} while (old != expected);
res[offset + index] = data[data_offset + res[offset + index]];
index += blockDim.x;
} }
actual_size[id] = ac_len;
} }
i += BLOCK_WARPS;
} }
// const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
// if (i < n) {
// auto node_index = index[i];
// actual_size[i] = graph.node_list[node_index].neighbor_size < sample_size
// ? graph.node_list[node_index].neighbor_size
// : sample_size;
// int offset = graph.node_list[node_index].neighbor_offset;
// for (int j = 0; j < actual_size[i]; j++) {
// sample_result[sample_size * i + j] = graph.neighbor_list[offset + j];
// }
// }
} }
int GpuPsGraphTable::init_cpu_table( int GpuPsGraphTable::init_cpu_table(
const paddle::distributed::GraphParameter& graph) { const paddle::distributed::GraphParameter& graph) {
cpu_graph_table.reset(new paddle::distributed::GraphTable); cpu_graph_table.reset(new paddle::distributed::GraphTable);
cpu_table_status = cpu_graph_table->initialize(graph); cpu_table_status = cpu_graph_table->Initialize(graph);
if (cpu_table_status != 0) return cpu_table_status; // if (cpu_table_status != 0) return cpu_table_status;
std::function<void(std::vector<GpuPsCommGraph>&)> callback = // std::function<void(std::vector<GpuPsCommGraph>&)> callback =
[this](std::vector<GpuPsCommGraph>& res) { // [this](std::vector<GpuPsCommGraph>& res) {
pthread_rwlock_wrlock(this->rw_lock.get()); // pthread_rwlock_wrlock(this->rw_lock.get());
this->clear_graph_info(); // this->clear_graph_info();
this->build_graph_from_cpu(res); // this->build_graph_from_cpu(res);
pthread_rwlock_unlock(this->rw_lock.get()); // pthread_rwlock_unlock(this->rw_lock.get());
cv_.notify_one(); // cv_.notify_one();
}; // };
cpu_graph_table->set_graph_sample_callback(callback); // cpu_graph_table->set_graph_sample_callback(callback);
return cpu_table_status; return cpu_table_status;
} }
int GpuPsGraphTable::load(const std::string& path, const std::string& param) { // int GpuPsGraphTable::load(const std::string& path, const std::string& param)
int status = cpu_graph_table->load(path, param); // {
if (status != 0) { // int status = cpu_graph_table->load(path, param);
return status; // if (status != 0) {
} // return status;
std::unique_lock<std::mutex> lock(mutex_); // }
cpu_graph_table->start_graph_sampling(); // std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock); // cpu_graph_table->start_graph_sampling();
return 0; // cv_.wait(lock);
} // return 0;
// }
/* /*
comment 1 comment 1
gpu i triggers a neighbor_sample task, gpu i triggers a neighbor_sample task,
when this task is done, when this task is done,
this function is called to move the sample result on other gpu back this function is called to move the sample result on other gpu back
to gpu i and aggragate the result. to gup i and aggragate the result.
the sample_result is saved on src_sample_res and the actual sample size for the sample_result is saved on src_sample_res and the actual sample size for
each node is saved on actual_sample_size. each node is saved on actual_sample_size.
the number of actual sample_result for the number of actual sample_result for
...@@ -168,106 +152,163 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) { ...@@ -168,106 +152,163 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
that's what fill_dvals does. that's what fill_dvals does.
*/ */
void GpuPsGraphTable::move_neighbor_sample_size_to_source_gpu(
int gpu_id, int gpu_num, int* h_left, int* h_right, int* actual_sample_size, void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int* total_sample_size) { int start_index, int gpu_num, int sample_size, int* h_left, int* h_right,
// This function copyed actual_sample_size to source_gpu, int64_t* src_sample_res, int* actual_sample_size) {
// and calculate total_sample_size of each gpu sample number. int shard_len[gpu_num];
for (int i = 0; i < gpu_num; i++) { for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
continue; continue;
} }
auto shard_len = h_right[i] - h_left[i] + 1; shard_len[i] = h_right[i] - h_left[i] + 1;
auto& node = path_[gpu_id][i].nodes_.front(); int cur_step = path_[start_index][i].nodes_.size() - 1;
for (int j = cur_step; j > 0; j--) {
cudaMemcpyAsync(path_[start_index][i].nodes_[j - 1].val_storage,
path_[start_index][i].nodes_[j].val_storage,
path_[start_index][i].nodes_[j - 1].val_bytes_len,
cudaMemcpyDefault,
path_[start_index][i].nodes_[j - 1].out_stream);
}
auto& node = path_[start_index][i].nodes_.front();
cudaMemcpyAsync(
reinterpret_cast<char*>(src_sample_res + h_left[i] * sample_size),
node.val_storage + sizeof(int64_t) * shard_len[i],
node.val_bytes_len - sizeof(int64_t) * shard_len[i], cudaMemcpyDefault,
node.out_stream);
// resource_->remote_stream(i, start_index));
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]), cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len, node.val_storage + sizeof(int) * shard_len[i],
sizeof(int) * shard_len, cudaMemcpyDefault, sizeof(int) * shard_len[i], cudaMemcpyDefault,
node.out_stream); node.out_stream);
} }
for (int i = 0; i < gpu_num; ++i) { for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
total_sample_size[i] = 0;
continue; continue;
} }
auto& node = path_[gpu_id][i].nodes_.front(); auto& node = path_[start_index][i].nodes_.front();
cudaStreamSynchronize(node.out_stream); cudaStreamSynchronize(node.out_stream);
// cudaStreamSynchronize(resource_->remote_stream(i, start_index));
auto shard_len = h_right[i] - h_left[i] + 1;
thrust::device_vector<int> t_actual_sample_size(shard_len);
thrust::copy(actual_sample_size + h_left[i],
actual_sample_size + h_left[i] + shard_len,
t_actual_sample_size.begin());
total_sample_size[i] = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
} }
}
void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int* h_left, int* h_right, int64_t* src_sample_res,
thrust::host_vector<int>& total_sample_size) {
/* /*
if total_sample_size is [4, 5, 1, 6], std::queue<CopyTask> que;
then cumsum_total_sample_size is [0, 4, 9, 10]; // auto& node = path_[gpu_id][i].nodes_.front();
*/ // cudaMemcpyAsync(
thrust::host_vector<int> cumsum_total_sample_size(gpu_num, 0); // reinterpret_cast<char*>(src_sample_res + h_left[i] * sample_size),
thrust::exclusive_scan(total_sample_size.begin(), total_sample_size.end(), // node.val_storage + sizeof(int64_t) * shard_len,
cumsum_total_sample_size.begin(), 0); // node.val_bytes_len - sizeof(int64_t) * shard_len, cudaMemcpyDefault,
for (int i = 0; i < gpu_num; i++) { // node.out_stream);
if (h_left[i] == -1 || h_right[i] == -1) { // cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
continue; // node.val_storage + sizeof(int) * shard_len,
// sizeof(int) * shard_len, cudaMemcpyDefault,
// node.out_stream);
int cur_step = path_[start_index][i].nodes_.size() - 1;
auto& node = path_[start_index][i].nodes_[cur_step];
if (cur_step == 0) {
// cudaMemcpyAsync(reinterpret_cast<char*>(src_val + h_left[i]),
// node.val_storage, node.val_bytes_len,
// cudaMemcpyDefault,
// node.out_stream);
// VLOG(0)<<"copy "<<node.gpu_num<<" to "<<start_index;
cudaMemcpyAsync(
reinterpret_cast<char*>(src_sample_res + h_left[i] * sample_size),
node.val_storage + sizeof(int64_t) * shard_len[i],
node.val_bytes_len - sizeof(int64_t) * shard_len[i],
cudaMemcpyDefault,
node.out_stream);
//resource_->remote_stream(i, start_index));
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len[i],
sizeof(int) * shard_len[i], cudaMemcpyDefault,
node.out_stream);
//resource_->remote_stream(i, start_index));
} else {
CopyTask t(&path_[start_index][i], cur_step - 1);
que.push(t);
// VLOG(0)<<"copy "<<node.gpu_num<<" to
"<<path_[start_index][i].nodes_[cur_step - 1].gpu_num;
cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage,
node.val_storage,
path_[start_index][i].nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
path_[start_index][i].nodes_[cur_step - 1].out_stream);
//resource_->remote_stream(i, start_index));
}
}
while (!que.empty()) {
CopyTask& cur_task = que.front();
que.pop();
int cur_step = cur_task.step;
if (cur_task.path->nodes_[cur_step].sync) {
cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream);
//cudaStreamSynchronize(resource_->remote_stream(cur_task.path->nodes_.back().gpu_num,
start_index));
}
if (cur_step > 0) {
CopyTask c(cur_task.path, cur_step - 1);
que.push(c);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step - 1].out_stream);
//resource_->remote_stream(cur_task.path->nodes_.back().gpu_num,
start_index));
} else if (cur_step == 0) {
int end_index = cur_task.path->nodes_.back().gpu_num;
// cudaMemcpyAsync(reinterpret_cast<char*>(src_val + h_left[end_index]),
// cur_task.path->nodes_[cur_step].val_storage,
// cur_task.path->nodes_[cur_step].val_bytes_len,
// cudaMemcpyDefault,
// cur_task.path->nodes_[cur_step].out_stream);
//VLOG(0)<<"copy "<<cur_task.path->nodes_[cur_step].gpu_num<< " to
"<<start_index;
cudaMemcpyAsync(reinterpret_cast<char*>(src_sample_res +
h_left[end_index] * sample_size),
cur_task.path->nodes_[cur_step].val_storage +
sizeof(int64_t) * shard_len[end_index],
cur_task.path->nodes_[cur_step].val_bytes_len -
sizeof(int64_t) * shard_len[end_index],
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step].out_stream);
//resource_->remote_stream(cur_task.path->nodes_.back().gpu_num,
start_index));
cudaMemcpyAsync(
reinterpret_cast<char*>(actual_sample_size + h_left[end_index]),
cur_task.path->nodes_[cur_step].val_storage +
sizeof(int) * shard_len[end_index],
sizeof(int) * shard_len[end_index], cudaMemcpyDefault,
cur_task.path->nodes_[cur_step].out_stream);
//resource_->remote_stream(cur_task.path->nodes_.back().gpu_num,
start_index));
} }
auto shard_len = h_right[i] - h_left[i] + 1;
// int cur_step = path_[gpu_id][i].nodes_.size() - 1;
// auto& node = path_[gpu_id][i].nodes_[cur_step];
auto& node = path_[gpu_id][i].nodes_.front();
cudaMemcpyAsync(
reinterpret_cast<char*>(src_sample_res + cumsum_total_sample_size[i]),
node.val_storage + sizeof(int64_t) * shard_len,
sizeof(int64_t) * total_sample_size[i], cudaMemcpyDefault,
node.out_stream);
} }
for (int i = 0; i < gpu_num; ++i) { for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
continue; continue;
} }
auto& node = path_[gpu_id][i].nodes_.front(); auto& node = path_[start_index][i].nodes_.front();
cudaStreamSynchronize(node.out_stream); cudaStreamSynchronize(node.out_stream);
//cudaStreamSynchronize(resource_->remote_stream(i, start_index));
} }
*/
} }
/* /*
TODO: TODO:
how to optimize it to eliminate the for loop how to optimize it to eliminate the for loop
*/ */
__global__ void fill_dvalues_actual_sample_size(int* d_shard_actual_sample_size, __global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals,
int* d_actual_sample_size, int* d_shard_actual_sample_size,
int* idx, int len) { int* d_actual_sample_size, int* idx,
int sample_size, int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) { if (i < len) {
d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i]; d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i];
} // d_vals[idx[i]] = d_shard_vals[i];
} for (int j = 0; j < sample_size; j++) {
d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j];
template <int BLOCK_WARPS, int TILE_SIZE>
__global__ void fill_dvalues_sample_result(int64_t* d_shard_vals,
int64_t* d_vals,
int* d_actual_sample_size, int* idx,
int* offset, int* d_offset,
int len) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, len);
while (i < last_idx) {
const int sample_size = d_actual_sample_size[idx[i]];
for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) {
d_vals[offset[idx[i]] + j] = d_shard_vals[d_offset[i] + j];
} }
#ifdef PADDLE_WITH_CUDA
__syncwarp();
#endif
i += BLOCK_WARPS;
} }
} }
...@@ -307,6 +348,8 @@ gpu i saves the ith graph from cpu_graph_list ...@@ -307,6 +348,8 @@ gpu i saves the ith graph from cpu_graph_list
void GpuPsGraphTable::build_graph_from_cpu( void GpuPsGraphTable::build_graph_from_cpu(
std::vector<GpuPsCommGraph>& cpu_graph_list) { std::vector<GpuPsCommGraph>& cpu_graph_list) {
VLOG(0) << "in build_graph_from_cpu cpu_graph_list size = "
<< cpu_graph_list.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
cpu_graph_list.size(), resource_->total_gpu(), cpu_graph_list.size(), resource_->total_gpu(),
platform::errors::InvalidArgument("the cpu node list size doesn't match " platform::errors::InvalidArgument("the cpu node list size doesn't match "
...@@ -314,7 +357,9 @@ void GpuPsGraphTable::build_graph_from_cpu( ...@@ -314,7 +357,9 @@ void GpuPsGraphTable::build_graph_from_cpu(
clear_graph_info(); clear_graph_info();
for (int i = 0; i < cpu_graph_list.size(); i++) { for (int i = 0; i < cpu_graph_list.size(); i++) {
platform::CUDADeviceGuard guard(resource_->dev_id(i)); platform::CUDADeviceGuard guard(resource_->dev_id(i));
// platform::CUDADeviceGuard guard(i);
gpu_graph_list.push_back(GpuPsCommGraph()); gpu_graph_list.push_back(GpuPsCommGraph());
sample_status.push_back(NULL);
auto table = auto table =
new Table(std::max(1, cpu_graph_list[i].node_size) / load_factor_); new Table(std::max(1, cpu_graph_list[i].node_size) / load_factor_);
tables_.push_back(table); tables_.push_back(table);
...@@ -337,6 +382,10 @@ void GpuPsGraphTable::build_graph_from_cpu( ...@@ -337,6 +382,10 @@ void GpuPsGraphTable::build_graph_from_cpu(
gpu_graph_list[i].node_size = 0; gpu_graph_list[i].node_size = 0;
} }
if (cpu_graph_list[i].neighbor_size) { if (cpu_graph_list[i].neighbor_size) {
int* addr;
cudaMalloc((void**)&addr, cpu_graph_list[i].neighbor_size * sizeof(int));
cudaMemset(addr, 0, cpu_graph_list[i].neighbor_size * sizeof(int));
sample_status[i] = addr;
cudaMalloc((void**)&gpu_graph_list[i].neighbor_list, cudaMalloc((void**)&gpu_graph_list[i].neighbor_list,
cpu_graph_list[i].neighbor_size * sizeof(int64_t)); cpu_graph_list[i].neighbor_size * sizeof(int64_t));
cudaMemcpy(gpu_graph_list[i].neighbor_list, cudaMemcpy(gpu_graph_list[i].neighbor_list,
...@@ -382,15 +431,19 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -382,15 +431,19 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
*/ */
NeighborSampleResult* result = new NeighborSampleResult(sample_size, len); NeighborSampleResult* result =
new NeighborSampleResult(sample_size, len, resource_->dev_id(gpu_id));
if (len == 0) { if (len == 0) {
return result; return result;
} }
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
// cudaMalloc((void**)&result->val, len * sample_size * sizeof(int64_t));
// cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
int* actual_sample_size = result->actual_sample_size;
int64_t* val = result->val;
int total_gpu = resource_->total_gpu(); int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_id); // int dev_id = resource_->dev_id(gpu_id);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_id, 0); auto stream = resource_->local_stream(gpu_id, 0);
int grid_size = (len - 1) / block_size_ + 1; int grid_size = (len - 1) / block_size_ + 1;
...@@ -411,6 +464,11 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -411,6 +464,11 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t)); auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr()); int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
...@@ -423,7 +481,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -423,7 +481,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
// auto start1 = std::chrono::steady_clock::now();
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) { if (shard_len == 0) {
...@@ -450,138 +508,107 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -450,138 +508,107 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
of alloc_mem_i, actual_sample_size_of_x equals ((int of alloc_mem_i, actual_sample_size_of_x equals ((int
*)alloc_mem_i)[shard_len + x] *)alloc_mem_i)[shard_len + x]
*/ */
create_storage(gpu_id, i, shard_len * sizeof(int64_t), create_storage(gpu_id, i, shard_len * sizeof(int64_t),
shard_len * (1 + sample_size) * sizeof(int64_t)); shard_len * (1 + sample_size) * sizeof(int64_t));
} }
// auto end1 = std::chrono::steady_clock::now();
// auto tt = std::chrono::duration_cast<std::chrono::microseconds>(end1 -
// start1);
// VLOG(0)<< "create storage time " << tt.count() << " us";
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) { if (h_left[i] == -1) {
continue; continue;
} }
// auto& node = path_[gpu_id][i].nodes_.back(); auto& node = path_[gpu_id][i].nodes_.back();
auto& node = path_[gpu_id][i].nodes_.front();
cudaStreamSynchronize(node.in_stream); cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i)); platform::CUDADeviceGuard guard(resource_->dev_id(i));
// platform::CUDADeviceGuard guard(i);
// use the key-value map to update alloc_mem_i[0,shard_len) // use the key-value map to update alloc_mem_i[0,shard_len)
tables_[i]->rwlock_->RDLock(); // tables_[i]->rwlock_->RDLock();
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage), tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage),
reinterpret_cast<int*>(node.val_storage), reinterpret_cast<int*>(node.val_storage),
h_right[i] - h_left[i] + 1, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id)); resource_->remote_stream(i, gpu_id));
// node.in_stream);
auto shard_len = h_right[i] - h_left[i] + 1;
auto graph = gpu_graph_list[i];
int* id_array = reinterpret_cast<int*>(node.val_storage);
int* actual_size_array = id_array + shard_len;
int64_t* sample_array = (int64_t*)(id_array + shard_len * 2);
int sample_grid_size = (shard_len - 1) / dim_y + 1;
dim3 block(parallel_sample_size, dim_y);
dim3 grid(sample_grid_size);
// int sample_grid_size = shard_len / block_size_ + 1;
// VLOG(0)<<"in sample grid_size = "<<sample_grid_size<<" block_size
// ="<<block_size_<<" device = "<<resource_->dev_id(i)<<"len = "<<len;;
// neighbor_sample_example<<<sample_grid_size, block_size_, 0,
// resource_->remote_stream(i, gpu_id)>>>(
// graph, res_array, actual_size_array, sample_array, sample_size,
// shard_len);
neighbor_sample_example<<<grid, block, 0,
resource_->remote_stream(i, gpu_id)>>>(
graph, id_array, actual_size_array, sample_array, sample_size,
sample_status[i], shard_len, gpu_id);
} }
/*
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) { if (h_left[i] == -1) {
continue; continue;
} }
// cudaStreamSynchronize(resource_->remote_stream(i, num)); // cudaStreamSynchronize(resource_->remote_stream(i, num));
// tables_[i]->rwlock_->UNLock(); // tables_[i]->rwlock_->UNLock();
platform::CUDADeviceGuard guard(resource_->dev_id(i)); platform::CUDADeviceGuard guard(i);
auto& node = path_[gpu_id][i].nodes_.front(); //platform::CUDADeviceGuard guard(resource_->dev_id(i));
auto& node = path_[gpu_id][i].nodes_.back();
auto shard_len = h_right[i] - h_left[i] + 1; auto shard_len = h_right[i] - h_left[i] + 1;
auto graph = gpu_graph_list[i]; auto graph = gpu_graph_list[i];
int* res_array = reinterpret_cast<int*>(node.val_storage); int* id_array = reinterpret_cast<int*>(node.val_storage);
int* actual_size_array = res_array + shard_len; int* actual_size_array = id_array + shard_len;
int64_t* sample_array = (int64_t*)(res_array + shard_len * 2); int64_t* sample_array = (int64_t*)(id_array + shard_len * 2);
int sample_grid_size = (shard_len - 1) / dim_y + 1;
// 1. get actual_size_array. dim3 block(parallel_sample_size, dim_y);
// 2. get sum of actual_size. dim3 grid(sample_grid_size);
// 3. get offset ptr // int sample_grid_size = shard_len / block_size_ + 1;
thrust::device_vector<int> t_res_array(shard_len); // VLOG(0)<<"in sample grid_size = "<<sample_grid_size<<" block_size
thrust::copy(res_array, res_array + shard_len, t_res_array.begin()); // ="<<block_size_<<" device = "<<resource_->dev_id(i)<<"len = "<<len;;
thrust::device_vector<int> t_actual_size_array(shard_len); // neighbor_sample_example<<<sample_grid_size, block_size_, 0,
thrust::transform(t_res_array.begin(), t_res_array.end(), // resource_->remote_stream(i, gpu_id)>>>(
t_actual_size_array.begin(), DegreeFunctor(graph)); // graph, res_array, actual_size_array, sample_array, sample_size,
// shard_len);
if (sample_size >= 0) { neighbor_sample_example<<<grid, block, 0,
thrust::transform(t_actual_size_array.begin(), t_actual_size_array.end(), resource_->remote_stream(i, gpu_id)>>>(
t_actual_size_array.begin(), MaxFunctor(sample_size)); graph, id_array, actual_size_array, sample_array, sample_size,
} sample_status[i], shard_len, gpu_id);
// neighbor_sample_example<<<grid, block, 0,
thrust::copy(t_actual_size_array.begin(), t_actual_size_array.end(), // node.in_stream>>>(
actual_size_array); // graph, id_array, actual_size_array, sample_array, sample_size,
// sample_status[i], shard_len, gpu_id);
int total_sample_sum =
thrust::reduce(t_actual_size_array.begin(), t_actual_size_array.end());
thrust::device_vector<int> output_idx(total_sample_sum);
thrust::device_vector<int> output_offset(shard_len);
thrust::exclusive_scan(t_actual_size_array.begin(),
t_actual_size_array.end(), output_offset.begin(), 0);
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block_(WARP_SIZE, BLOCK_WARPS);
const dim3 grid_((shard_len + TILE_SIZE - 1) / TILE_SIZE);
neighbor_sample<
BLOCK_WARPS,
TILE_SIZE><<<grid_, block_, 0, resource_->remote_stream(i, gpu_id)>>>(
0, graph, sample_size, res_array, shard_len, sample_array,
thrust::raw_pointer_cast(output_idx.data()),
thrust::raw_pointer_cast(output_offset.data()));
} }
*/
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) { if (h_left[i] == -1) {
continue; continue;
} }
// auto& node = path_[gpu_id][i].nodes_.back();
// cudaStreamSynchronize(node.in_stream);
cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); cudaStreamSynchronize(resource_->remote_stream(i, gpu_id));
tables_[i]->rwlock_->UNLock(); // tables_[i]->rwlock_->UNLock();
} }
// walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr); // walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size,
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int)); h_left, h_right, d_shard_vals_ptr,
int* d_shard_actual_sample_size_ptr = d_shard_actual_sample_size_ptr);
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
// Store total sample number of each gpu. fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
thrust::host_vector<int> d_shard_total_sample_size(total_gpu, 0); d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size,
move_neighbor_sample_size_to_source_gpu( d_idx_ptr, sample_size, len);
gpu_id, total_gpu, h_left, h_right, d_shard_actual_sample_size_ptr, // cudaStreamSynchronize(stream);
thrust::raw_pointer_cast(d_shard_total_sample_size.data())); // auto end2 = std::chrono::steady_clock::now();
int allocate_sample_num = 0; // tt = std::chrono::duration_cast<std::chrono::microseconds>(end2 - end1);
for (int i = 0; i < total_gpu; ++i) { // VLOG(0)<< "sample graph time " << tt.count() << " us";
allocate_sample_num += d_shard_total_sample_size[i];
}
auto d_shard_vals =
memory::Alloc(place, allocate_sample_num * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, h_left, h_right,
d_shard_vals_ptr,
d_shard_total_sample_size);
cudaMalloc((void**)&result->val, allocate_sample_num * sizeof(int64_t));
cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
cudaMalloc((void**)&result->offset, len * sizeof(int));
int64_t* val = result->val;
int* actual_sample_size = result->actual_sample_size;
int* offset = result->offset;
fill_dvalues_actual_sample_size<<<grid_size, block_size_, 0, stream>>>(
d_shard_actual_sample_size_ptr, actual_sample_size, d_idx_ptr, len);
thrust::device_vector<int> t_actual_sample_size(len);
thrust::copy(actual_sample_size, actual_sample_size + len,
t_actual_sample_size.begin());
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(), offset, 0);
int* d_offset;
cudaMalloc(&d_offset, len * sizeof(int));
thrust::copy(d_shard_actual_sample_size_ptr,
d_shard_actual_sample_size_ptr + len,
t_actual_sample_size.begin());
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(), d_offset, 0);
constexpr int BLOCK_WARPS_ = 128 / WARP_SIZE;
constexpr int TILE_SIZE_ = BLOCK_WARPS_ * 16;
const dim3 block__(WARP_SIZE, BLOCK_WARPS_);
const dim3 grid__((len + TILE_SIZE_ - 1) / TILE_SIZE_);
fill_dvalues_sample_result<BLOCK_WARPS_,
TILE_SIZE_><<<grid__, block__, 0, stream>>>(
d_shard_vals_ptr, val, actual_sample_size, d_idx_ptr, offset, d_offset,
len);
cudaStreamSynchronize(stream);
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) { if (shard_len == 0) {
...@@ -589,7 +616,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -589,7 +616,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
} }
destroy_storage(gpu_id, i); destroy_storage(gpu_id, i);
} }
cudaFree(d_offset); cudaStreamSynchronize(stream);
return result; return result;
} }
...@@ -604,8 +631,9 @@ NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start, ...@@ -604,8 +631,9 @@ NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start,
actual_size = 0; actual_size = 0;
cudaMalloc((void**)&result->val, query_size * sizeof(int64_t)); cudaMalloc((void**)&result->val, query_size * sizeof(int64_t));
int64_t* val = result->val; int64_t* val = result->val;
int dev_id = resource_->dev_id(gpu_id); // int dev_id = resource_->dev_id(gpu_id);
platform::CUDADeviceGuard guard(dev_id); // platform::CUDADeviceGuard guard(dev_id);
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
std::vector<int> idx, gpu_begin_pos, local_begin_pos, sample_size; std::vector<int> idx, gpu_begin_pos, local_begin_pos, sample_size;
int size = 0; int size = 0;
/* /*
...@@ -647,6 +675,7 @@ NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start, ...@@ -647,6 +675,7 @@ NodeQueryResult* GpuPsGraphTable::query_node_list(int gpu_id, int start,
for (int i = 0; i < idx.size(); i++) { for (int i = 0; i < idx.size(); i++) {
int dev_id_i = resource_->dev_id(idx[i]); int dev_id_i = resource_->dev_id(idx[i]);
platform::CUDADeviceGuard guard(dev_id_i); platform::CUDADeviceGuard guard(dev_id_i);
// platform::CUDADeviceGuard guard(i);
auto& node = path_[gpu_id][idx[i]].nodes_.front(); auto& node = path_[gpu_id][idx[i]].nodes_.front();
int grid_size = (sample_size[i] - 1) / block_size_ + 1; int grid_size = (sample_size[i] - 1) / block_size_ + 1;
node_query_example<<<grid_size, block_size_, 0, node_query_example<<<grid_size, block_size_, 0,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
}
virtual int start_service(std::string path) {
load_from_ssd(path);
VLOG(0) << "load from ssd over";
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
VLOG(0) << " promise set ";
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
return 0;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
~GraphSampler() { end_graph_sampling(); }
virtual int load_from_ssd(std::string path) = 0;
;
virtual int run_graph_sampling() = 0;
;
virtual void init(GpuPsGraphTable *gpu_table,
std::vector<std::string> args_) = 0;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
};
class CommonGraphSampler : public GraphSampler {
public:
CommonGraphSampler() {}
virtual ~CommonGraphSampler() {}
GpuPsGraphTable *g_table;
virtual int load_from_ssd(std::string path);
virtual int run_graph_sampling();
virtual void init(GpuPsGraphTable *g, std::vector<std::string> args);
GpuPsGraphTable *gpu_table;
paddle::distributed::GraphTable *table;
std::vector<int64_t> gpu_edges_count;
int64_t cpu_edges_count;
int64_t gpu_edges_limit, cpu_edges_limit, gpu_edges_each_limit;
std::vector<std::unordered_set<int64_t>> gpu_set;
int gpu_num;
};
class AllInGpuGraphSampler : public GraphSampler {
public:
AllInGpuGraphSampler() {}
virtual ~AllInGpuGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual int load_from_ssd(std::string path);
virtual void init(GpuPsGraphTable *g, std::vector<std::string> args_);
protected:
paddle::distributed::GraphTable *graph_table;
GpuPsGraphTable *gpu_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
}
};
#include "paddle/fluid/framework/fleet/heter_ps/graph_sampler_inl.h"
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
int CommonGraphSampler::load_from_ssd(std::string path) {
std::ifstream file(path);
auto _db = table->_db;
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
std::cout << values.size();
if (values.size() < 2) continue;
auto neighbors = paddle::string::split_string<std::string>(values[1], ";");
std::vector<int64_t> neighbor_data;
for (auto x : neighbors) {
neighbor_data.push_back(std::stoll(x));
}
auto src_id = std::stoll(values[0]);
_db->put(0, (char *)&src_id, sizeof(uint64_t), (char *)neighbor_data.data(),
sizeof(int64_t) * neighbor_data.size());
int gpu_shard = src_id % gpu_num;
if (gpu_edges_count[gpu_shard] + neighbor_data.size() <=
gpu_edges_each_limit) {
gpu_edges_count[gpu_shard] += neighbor_data.size();
gpu_set[gpu_shard].insert(src_id);
}
if (cpu_edges_count + neighbor_data.size() <= cpu_edges_limit) {
cpu_edges_count += neighbor_data.size();
for (auto x : neighbor_data) {
// table->add_neighbor(src_id, x);
table->shards[src_id % table->shard_num]
->add_graph_node(src_id)
->build_edges(false);
table->shards[src_id % table->shard_num]->add_neighbor(src_id, x, 1.0);
}
}
std::vector<paddle::framework::GpuPsCommGraph> graph_list;
for (int i = 0; i < gpu_num; i++) {
std::vector<int64_t> ids(gpu_set[i].begin(), gpu_set[i].end());
graph_list.push_back(table->make_gpu_ps_graph(ids));
}
gpu_table->build_graph_from_cpu(graph_list);
for (int i = 0; i < graph_list.size(); i++) {
delete[] graph_list[i].node_list;
delete[] graph_list[i].neighbor_list;
}
}
}
int CommonGraphSampler::run_graph_sampling() { return 0; }
void CommonGraphSampler::init(GpuPsGraphTable *g,
std::vector<std::string> args) {
this->gpu_table = g;
gpu_num = g->gpu_num;
gpu_edges_limit = args.size() > 0 ? std::stoll(args[0]) : 1000000000LL;
cpu_edges_limit = args.size() > 1 ? std::stoll(args[1]) : 1000000000LL;
gpu_edges_each_limit = gpu_edges_limit / gpu_num;
if (gpu_edges_each_limit > INT_MAX) gpu_edges_each_limit = INT_MAX;
table = g->cpu_graph_table.get();
gpu_edges_count = std::vector<int64_t>(gpu_num, 0);
cpu_edges_count = 0;
gpu_set = std::vector<std::unordered_set<int64_t>>(gpu_num);
}
int AllInGpuGraphSampler::run_graph_sampling() { return 0; }
int AllInGpuGraphSampler::load_from_ssd(std::string path) {
graph_table->load_edges(path, false);
sample_nodes.clear();
sample_neighbors.clear();
sample_res.clear();
sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_);
for (int i = 0; i < graph_table->task_pool_size_; i++) {
sample_nodes_ex[i].resize(gpu_num);
sample_neighbors_ex[i].resize(gpu_num);
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < graph_table->shards.size(); ++i) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
paddle::framework::GpuPsGraphNode node;
std::vector<paddle::distributed::Node *> &v =
this->graph_table->shards[i]->get_bucket();
size_t ind = i % this->graph_table->task_pool_size_;
for (size_t j = 0; j < v.size(); j++) {
size_t location = v[j]->get_id() % this->gpu_num;
node.node_id = v[j]->get_id();
node.neighbor_size = v[j]->get_neighbor_size();
node.neighbor_offset =
(int)sample_neighbors_ex[ind][location].size();
sample_nodes_ex[ind][location].emplace_back(node);
for (int k = 0; k < node.neighbor_size; k++)
sample_neighbors_ex[ind][location].push_back(
v[j]->get_neighbor_id(k));
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
tasks.clear();
for (size_t i = 0; i < gpu_num; i++) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
int total_offset = 0;
size_t ind = i;
for (int j = 0; j < this->graph_table->task_pool_size_; j++) {
for (size_t k = 0; k < sample_nodes_ex[j][ind].size(); k++) {
sample_nodes[ind].push_back(sample_nodes_ex[j][ind][k]);
sample_nodes[ind].back().neighbor_offset += total_offset;
}
size_t neighbor_size = sample_neighbors_ex[j][ind].size();
total_offset += neighbor_size;
for (size_t k = 0; k < neighbor_size; k++) {
sample_neighbors[ind].push_back(
sample_neighbors_ex[j][ind][k]);
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
for (size_t i = 0; i < gpu_num; i++) {
sample_res[i].node_list = sample_nodes[i].data();
sample_res[i].neighbor_list = sample_neighbors[i].data();
sample_res[i].node_size = sample_nodes[i].size();
sample_res[i].neighbor_size = sample_neighbors[i].size();
}
gpu_table->build_graph_from_cpu(sample_res);
return 0;
}
void AllInGpuGraphSampler::init(GpuPsGraphTable *g,
std::vector<std::string> args_) {
this->gpu_table = g;
this->gpu_num = g->gpu_num;
graph_table = g->cpu_graph_table.get();
}
}
};
#endif
...@@ -210,11 +210,11 @@ class HeterComm { ...@@ -210,11 +210,11 @@ class HeterComm {
std::vector<std::vector<Path>> path_; std::vector<std::vector<Path>> path_;
float load_factor_{0.75}; float load_factor_{0.75};
int block_size_{256}; int block_size_{256};
int topo_aware_{0};
private: private:
std::unique_ptr<HeterCommKernel> heter_comm_kernel_; std::unique_ptr<HeterCommKernel> heter_comm_kernel_;
std::vector<LocalStorage> storage_; std::vector<LocalStorage> storage_;
int topo_aware_{0};
int feanum_{1800 * 2048}; int feanum_{1800 * 2048};
int multi_node_{0}; int multi_node_{0};
int node_size_; int node_size_;
......
...@@ -66,7 +66,6 @@ TEST(TEST_FLEET, graph_sample) { ...@@ -66,7 +66,6 @@ TEST(TEST_FLEET, graph_sample) {
1,4,7 1,4,7
gpu 2: gpu 2:
2,5,8 2,5,8
query(2,6) returns nodes [6,9,1,4,7,2] query(2,6) returns nodes [6,9,1,4,7,2]
*/ */
::paddle::distributed::GraphParameter table_proto; ::paddle::distributed::GraphParameter table_proto;
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_sampler.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
...@@ -52,9 +53,13 @@ namespace memory = paddle::memory; ...@@ -52,9 +53,13 @@ namespace memory = paddle::memory;
namespace distributed = paddle::distributed; namespace distributed = paddle::distributed;
std::string input_file; std::string input_file;
int fixed_key_size = 100, sample_size = 100, int exe_count = 100;
int use_nv = 1;
int fixed_key_size = 50000, sample_size = 32,
bfs_sample_nodes_in_each_shard = 10000, init_search_size = 1, bfs_sample_nodes_in_each_shard = 10000, init_search_size = 1,
bfs_sample_edges = 20; bfs_sample_edges = 20, gpu_num1 = 8, gpu_num = 8;
std::string gpu_str = "0,1,2,3,4,5,6,7";
int64_t *key[8];
std::vector<std::string> edges = { std::vector<std::string> edges = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"), std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"), std::string("37\t112\t0.21"), std::string("96\t48\t1.4"),
...@@ -83,7 +88,7 @@ void testSampleRate() { ...@@ -83,7 +88,7 @@ void testSampleRate() {
pthread_rwlock_init(&rwlock, NULL); pthread_rwlock_init(&rwlock, NULL);
{ {
::paddle::distributed::GraphParameter table_proto; ::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(false); // table_proto.set_gpups_mode(false);
table_proto.set_shard_num(127); table_proto.set_shard_num(127);
table_proto.set_task_pool_size(24); table_proto.set_task_pool_size(24);
std::cerr << "initializing begin"; std::cerr << "initializing begin";
...@@ -163,25 +168,48 @@ void testSampleRate() { ...@@ -163,25 +168,48 @@ void testSampleRate() {
std::chrono::duration_cast<std::chrono::microseconds>(end1 - start1); std::chrono::duration_cast<std::chrono::microseconds>(end1 - start1);
std::cerr << "total time cost without cache is " << tt.count() << " us" std::cerr << "total time cost without cache is " << tt.count() << " us"
<< std::endl; << std::endl;
int64_t tot = 0;
for (int i = 0; i < 10; i++) {
for (auto x : sample_id[i]) tot += x;
}
VLOG(0) << "sum = " << tot;
} }
const int gpu_num = 8; gpu_num = 0;
::paddle::distributed::GraphParameter table_proto; int st = 0, u = 0;
table_proto.set_gpups_mode(true); std::vector<int> device_id_mapping;
table_proto.set_shard_num(127); while (u < gpu_str.size()) {
table_proto.set_gpu_num(gpu_num); VLOG(0) << u << " " << gpu_str[u];
table_proto.set_gpups_graph_sample_class("BasicBfsGraphSampler"); if (gpu_str[u] == ',') {
table_proto.set_gpups_graph_sample_args(std::to_string(init_search_size) + auto p = gpu_str.substr(st, u - st);
",100000000,10000000,1,1"); int id = std::stoi(p);
std::vector<int> dev_ids; VLOG(0) << "got a new device id" << id;
for (int i = 0; i < gpu_num; i++) { device_id_mapping.push_back(id);
dev_ids.push_back(i); st = u + 1;
}
u++;
} }
auto p = gpu_str.substr(st, gpu_str.size() - st);
int id = std::stoi(p);
VLOG(0) << "got a new device id" << id;
device_id_mapping.push_back(id);
gpu_num = device_id_mapping.size();
::paddle::distributed::GraphParameter table_proto;
table_proto.set_shard_num(24);
// table_proto.set_gpups_graph_sample_class("CompleteGraphSampler");
std::shared_ptr<HeterPsResource> resource = std::shared_ptr<HeterPsResource> resource =
std::make_shared<HeterPsResource>(dev_ids); std::make_shared<HeterPsResource>(device_id_mapping);
resource->enable_p2p(); resource->enable_p2p();
GpuPsGraphTable g(resource); GpuPsGraphTable g(resource, use_nv);
g.init_cpu_table(table_proto); g.init_cpu_table(table_proto);
g.load(std::string(input_file), std::string("e>")); std::vector<std::string> arg;
AllInGpuGraphSampler sampler;
sampler.init(&g, arg);
// g.load(std::string(input_file), std::string("e>"));
// sampler.start(std::string(input_file));
// sampler.load_from_ssd(std::string(input_file));
sampler.start_service(input_file);
/*
NodeQueryResult *query_node_res; NodeQueryResult *query_node_res;
query_node_res = g.query_node_list(0, 0, ids.size() + 10000); query_node_res = g.query_node_list(0, 0, ids.size() + 10000);
...@@ -209,52 +237,65 @@ void testSampleRate() { ...@@ -209,52 +237,65 @@ void testSampleRate() {
auto q = g.query_node_list(0, st, ids.size() / 20); auto q = g.query_node_list(0, st, ids.size() / 20);
VLOG(0) << " the " << i << "th iteration size = " << q->actual_sample_size; VLOG(0) << " the " << i << "th iteration size = " << q->actual_sample_size;
} }
// NodeQueryResult *query_node_list(int gpu_id, int start, int query_size); // NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
*/
/* for (int i = 0; i < gpu_num1; i++) {
void *key; platform::CUDADeviceGuard guard(device_id_mapping[i]);
cudaMalloc((void **)&key[i], ids.size() * sizeof(int64_t));
cudaMemcpy(key[i], ids.data(), ids.size() * sizeof(int64_t),
cudaMemcpyHostToDevice);
}
/*
cudaMalloc((void **)&key, ids.size() * sizeof(int64_t)); cudaMalloc((void **)&key, ids.size() * sizeof(int64_t));
cudaMemcpy(key, ids.data(), ids.size() * sizeof(int64_t), cudaMemcpy(key, ids.data(), ids.size() * sizeof(int64_t),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
std::vector<NeighborSampleResult *> res[gpu_num]; */
/*
std::vector<std::vector<NeighborSampleResult *>> res(gpu_num1);
for (int i = 0; i < gpu_num1; i++) {
int st = 0;
int size = ids.size();
NeighborSampleResult *result = new NeighborSampleResult(sample_size, size);
platform::CUDAPlace place = platform::CUDAPlace(device_id_mapping[i]);
platform::CUDADeviceGuard guard(device_id_mapping[i]);
cudaMalloc((void **)&result->val, size * sample_size * sizeof(int64_t));
cudaMalloc((void **)&result->actual_sample_size, size * sizeof(int));
res[i].push_back(result);
}
*/
start = 0; start = 0;
auto func = [&rwlock, &g, &res, &start, auto func = [&rwlock, &g, &start, &ids](int i) {
&gpu_num, &ids, &key](int i) { int st = 0;
while (true) { int size = ids.size();
int s, sn; for (int k = 0; k < exe_count; k++) {
bool exit = false; st = 0;
pthread_rwlock_wrlock(&rwlock); while (st < size) {
if (start < ids.size()) { int len = std::min(fixed_key_size, (int)ids.size() - st);
s = start; auto r = g.graph_neighbor_sample(i, (int64_t *)(key[i] + st),
sn = ids.size() - start; sample_size, len);
sn = min(sn, fixed_key_size); st += len;
start += sn; delete r;
} else {
exit = true;
} }
pthread_rwlock_unlock(&rwlock);
if (exit) break;
auto r =
g.graph_neighbor_sample(i, (int64_t *)(key + s), sample_size, sn);
res[i].push_back(r);
} }
}; };
auto start1 = std::chrono::steady_clock::now(); auto start1 = std::chrono::steady_clock::now();
std::thread thr[gpu_num]; std::thread thr[gpu_num1];
for (int i = 0; i < gpu_num; i++) { for (int i = 0; i < gpu_num1; i++) {
thr[i] = std::thread(func, i); thr[i] = std::thread(func, i);
} }
for (int i = 0; i < gpu_num; i++) thr[i].join(); for (int i = 0; i < gpu_num1; i++) thr[i].join();
auto end1 = std::chrono::steady_clock::now(); auto end1 = std::chrono::steady_clock::now();
auto tt = auto tt =
std::chrono::duration_cast<std::chrono::microseconds>(end1 - start1); std::chrono::duration_cast<std::chrono::microseconds>(end1 - start1);
std::cerr << "total time cost without cache is " << tt.count() << " us" std::cerr << "total time cost without cache is "
<< std::endl; << tt.count() / exe_count / gpu_num1 << " us" << std::endl;
*/ for (int i = 0; i < gpu_num1; i++) {
cudaFree(key[i]);
}
#endif #endif
} }
// TEST(testSampleRate, Run) { testSampleRate(); } TEST(TEST_FLEET, sample_rate) { testSampleRate(); }
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
for (int i = 0; i < argc; i++) for (int i = 0; i < argc; i++)
...@@ -276,5 +317,14 @@ int main(int argc, char *argv[]) { ...@@ -276,5 +317,14 @@ int main(int argc, char *argv[]) {
VLOG(0) << "sample_size neighbor_size is " << sample_size; VLOG(0) << "sample_size neighbor_size is " << sample_size;
if (argc > 4) init_search_size = std::stoi(argv[4]); if (argc > 4) init_search_size = std::stoi(argv[4]);
VLOG(0) << " init_search_size " << init_search_size; VLOG(0) << " init_search_size " << init_search_size;
if (argc > 5) {
gpu_str = argv[5];
}
VLOG(0) << " gpu_str= " << gpu_str;
gpu_num = 0;
if (argc > 6) gpu_num1 = std::stoi(argv[6]);
VLOG(0) << " gpu_thread_num= " << gpu_num1;
if (argc > 7) use_nv = std::stoi(argv[7]);
VLOG(0) << " use_nv " << use_nv;
testSampleRate(); testSampleRate();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册