未验证 提交 caf2008b 编写于 作者: Z zmxdream 提交者: GitHub

【Pglbox】merge gpugraph to develop (#50091)

* add dump_walk_path  (#193)

* add dump_walk_path; test=develop

* add dump_walk_path; test=develop

* add dump_walk_path; test=develop

* Add multiple CPU communication, parameter query and merging functions, support batch alignment between multiple cards (#194)

* compatible with edge_type of src2dst and src2etype2dst (#195)

* do not merge_feature_shard when using metapath_split_opt (#198)

* support only load reverse_edge (#199)

* refactor GraphTable (#201)

* fix

* fix

* fix code style

* fix code style

* fix test_dataset

* fix hogwild worker

* fix code style

* fix code style

* fix code style

* fix code style

* fix code style.

* fix code style.

---------
Co-authored-by: Ndanleifeng <52735331+danleifeng@users.noreply.github.com>
Co-authored-by: Nqingshui <qshuihu@gmail.com>
Co-authored-by: NWebbley <liwb5@foxmail.com>
Co-authored-by: Nhuwei02 <53012141+huwei02@users.noreply.github.com>
上级 5a13280a
......@@ -49,6 +49,10 @@ brpc_library(
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
proto_library(simple_brpc_proto SRCS simple_brpc.proto)
set_source_files_properties(
simple_rpc/rpc_server.cc simple_rpc/baidu_rpc_server.cc
PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
communicator/communicator.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
......@@ -60,6 +64,8 @@ set_source_files_properties(
brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_local_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_graph_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......@@ -85,11 +91,17 @@ set_source_files_properties(
set_source_files_properties(
ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
brpc_utils
SRCS brpc_utils.cc
DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})
cc_library(
simple_rpc
SRCS simple_rpc/rpc_server.cc simple_rpc/baidu_rpc_server.cc
DEPS simple_brpc_proto ${RPC_DEPS})
cc_library(
ps_service
SRCS graph_brpc_server.cc
......@@ -98,6 +110,7 @@ cc_library(
graph_brpc_client.cc
brpc_ps_client.cc
ps_local_client.cc
ps_graph_client.cc
coordinator_client.cc
ps_client.cc
communicator/communicator.cc
......@@ -107,11 +120,42 @@ cc_library(
table
brpc_utils
simple_threadpool
simple_rpc
scope
math_function
selected_rows_functor
ps_gpu_wrapper
${RPC_DEPS})
#cc_library(
# downpour_server
# SRCS graph_brpc_server.cc brpc_ps_server.cc
# DEPS eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})
#cc_library(
# downpour_client
# SRCS graph_brpc_client.cc brpc_ps_client.cc ps_local_client.cc
# ps_graph_client.cc coordinator_client.cc
# DEPS eigen3 table brpc_utils simple_threadpool ps_gpu_wrapper simple_rpc ${RPC_DEPS})
#cc_library(
# client
# SRCS ps_client.cc
# DEPS downpour_client ${RPC_DEPS})
#cc_library(
# server
# SRCS server.cc
# DEPS downpour_server ${RPC_DEPS})
#cc_library(
# communicator
# SRCS communicator/communicator.cc
# DEPS scope client table math_function selected_rows_functor ${RPC_DEPS})
#cc_library(
# ps_service
# SRCS ps_service/service.cc
# DEPS communicator client server ${RPC_DEPS})
cc_library(
heter_client
SRCS heter_client.cc
......@@ -120,3 +164,8 @@ cc_library(
heter_server
SRCS heter_server.cc
DEPS heter_client brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
#cc_library(
# graph_py_service
# SRCS ps_service/graph_py_service.cc
# DEPS ps_service)
......@@ -126,9 +126,9 @@ int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = std::stoi(request.params(0).c_str());
int idx_ = std::stoi(request.params(1).c_str());
(reinterpret_cast<GraphTable *>(table))->clear_nodes(type_id, idx_);
GraphTableType type_id = *(GraphTableType *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
((GraphTable *)table)->clear_nodes(type_id, idx_);
return 0;
}
......@@ -380,11 +380,11 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
response, -1, "pull_graph_list request requires at least 5 arguments");
return 0;
}
int type_id = std::stoi(request.params(0).c_str());
int idx = std::stoi(request.params(1).c_str());
int start = std::stoi(request.params(2).c_str());
int size = std::stoi(request.params(3).c_str());
int step = std::stoi(request.params(4).c_str());
GraphTableType type_id = *(GraphTableType *)(request.params(0).c_str());
int idx = *(int *)(request.params(1).c_str());
int start = *(int *)(request.params(2).c_str());
int size = *(int *)(request.params(3).c_str());
int step = *(int *)(request.params(4).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
(reinterpret_cast<GraphTable *>(table))
......@@ -432,9 +432,9 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = std::stoi(request.params(0).c_str());
int idx_ = std::stoi(request.params(1).c_str());
size_t size = std::stoull(request.params(2).c_str());
GraphTableType type_id = *(GraphTableType *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
size_t size = *(uint64_t *)(request.params(2).c_str());
// size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
......
......@@ -20,6 +20,10 @@
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#include "paddle/fluid/distributed/ps/service/ps_graph_client.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace distributed {
......@@ -27,6 +31,9 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
REGISTER_PSCORE_CLASS(PSClient, CoordinatorClient);
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
REGISTER_PSCORE_CLASS(PSClient, PsGraphClient);
#endif
int32_t PSClient::Configure( // called in FleetWrapper::InitWorker
const PSParameter &config,
......@@ -77,8 +84,20 @@ PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
}
const auto &service_param = config.downpour_server_param().service_param();
PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
const auto &client_name = service_param.client_class();
PSClient *client = NULL;
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (client_name == "PsLocalClient" && gloo->Size() > 1) {
client = CREATE_PSCORE_CLASS(PSClient, "PsGraphClient");
LOG(WARNING) << "change PsLocalClient to PsGraphClient";
} else {
client = CREATE_PSCORE_CLASS(PSClient, client_name);
}
#else
client = CREATE_PSCORE_CLASS(PSClient, client_name);
#endif
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/service/sparse_shard_value.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
......@@ -72,7 +73,7 @@ class PSClient {
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, // NOLINT
size_t client_id) final;
size_t client_id);
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
......@@ -153,7 +154,8 @@ class PSClient {
size_t table_id,
const uint64_t *keys,
size_t num,
uint16_t pass_id) {
uint16_t pass_id,
const uint16_t &dim_id = 0) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
......@@ -329,6 +331,12 @@ class PSClient {
promise.set_value(-1);
return fut;
}
// add
virtual std::shared_ptr<SparseShardValues> TakePassSparseReferedValues(
const size_t &table_id, const uint16_t &pass_id, const uint16_t &dim_id) {
VLOG(0) << "Did not implement";
return nullptr;
}
protected:
virtual int32_t Initialize() = 0;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#include "paddle/fluid/distributed/ps/service/ps_graph_client.h"
#include "paddle/fluid/distributed/ps/service/simple_rpc/rpc_server.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
// #include "paddle/fluid/framework/threadpool.h"
namespace paddle {
namespace distributed {
PsGraphClient::PsGraphClient() {
simple::global_rpc_server().initialize();
auto gloo = paddle::framework::GlooWrapper::GetInstance();
_rank_id = gloo->Rank();
_rank_num = gloo->Size();
_service = simple::global_rpc_server().add_service(
[this](const simple::RpcMessageHead &head,
paddle::framework::BinaryArchive &iar) {
request_handler(head, iar);
});
}
PsGraphClient::~PsGraphClient() {}
int32_t PsGraphClient::Initialize() {
const auto &downpour_param = _config.server_param().downpour_server_param();
uint32_t max_shard_num = 0;
for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto &param = downpour_param.downpour_table_param(i);
uint32_t table_id = param.table_id();
uint32_t shard_num = param.shard_num();
_table_info[table_id] = std::make_shared<SparseTableInfo>();
_table_info[table_id]->shard_num = shard_num;
if (max_shard_num < shard_num) {
max_shard_num = shard_num;
}
}
for (uint32_t k = 0; k < max_shard_num; ++k) {
_thread_pools.push_back(std::make_shared<paddle::framework::ThreadPool>(1));
}
_local_shard_keys.resize(max_shard_num);
_shard_ars.resize(max_shard_num);
return PsLocalClient::Initialize();
}
void PsGraphClient::FinalizeWorker() {
if (_service != nullptr) {
simple::global_rpc_server().remove_service(_service);
_service = nullptr;
fprintf(stdout, "FinalizeWorker remove rpc service");
}
simple::global_rpc_server().finalize();
}
// add maco
#define DIM_PASS_ID(dim_id, pass_id) \
uint32_t((uint32_t(dim_id) << 16) | pass_id)
#define GET_PASS_ID(id) (id & 0xffff)
#define GET_DIM_ID(id) ((id >> 16) & 0xffff)
::std::future<int32_t> PsGraphClient::PullSparsePtr(int shard_id,
char **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
uint16_t pass_id,
const uint16_t &dim_id) {
platform::Timer timeline;
timeline.Start();
// ps_gpu_wrapper
auto ps_wrapper = paddle::framework::PSGPUWrapper::GetInstance();
std::vector<uint64_t> &local_keys = _local_shard_keys[shard_id];
local_keys.clear();
auto &ars = _shard_ars[shard_id];
ars.resize(_rank_num);
for (int rank = 0; rank < _rank_num; ++rank) {
ars[rank].Clear();
}
// split keys to rankid
for (size_t i = 0; i < num; ++i) {
auto &k = keys[i];
int rank = ps_wrapper->PartitionKeyForRank(k);
if (rank == _rank_id) {
local_keys.push_back(k);
} else {
ars[rank].PutRaw(k);
}
}
paddle::framework::WaitGroup wg;
wg.add(_rank_num);
uint32_t id = DIM_PASS_ID(dim_id, pass_id);
// send to remote
for (int rank = 0; rank < _rank_num; ++rank) {
if (rank == _rank_id) {
wg.done();
continue;
}
auto &ar = ars[rank];
size_t n = ar.Length() / sizeof(uint64_t);
ar.PutRaw(n);
ar.PutRaw(shard_id);
ar.PutRaw(id);
simple::global_rpc_server().send_request_consumer(
rank,
table_id,
_service,
ar,
[this, &wg](const simple::RpcMessageHead & /**head*/,
framework::BinaryArchive & /**ar*/) { wg.done(); });
}
// not empty
if (!local_keys.empty()) {
auto f = _thread_pools[shard_id]->Run(
[this, table_id, pass_id, shard_id, &local_keys, &select_values](void) {
// local pull values
Table *table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.keys = &local_keys[0];
table_context.pull_context.ptr_values = select_values;
table_context.use_ptr = true;
table_context.num = local_keys.size();
table_context.shard_id = shard_id;
table_context.pass_id = pass_id;
table_ptr->Pull(table_context);
});
f.get();
}
wg.wait();
timeline.Pause();
VLOG(3) << "PullSparsePtr local table id=" << table_id
<< ", pass id=" << pass_id << ", shard_id=" << shard_id
<< ", dim_id=" << dim_id << ", keys count=" << num
<< ", span=" << timeline.ElapsedSec();
return done();
}
// server pull remote keys values
void PsGraphClient::request_handler(const simple::RpcMessageHead &head,
paddle::framework::BinaryArchive &iar) {
size_t table_id = head.consumer_id;
uint32_t id = 0;
iar.ReadBack(&id, sizeof(uint32_t));
int shard_id = 0;
iar.ReadBack(&shard_id, sizeof(int));
size_t num = 0;
iar.ReadBack(&num, sizeof(size_t));
SparsePassValues *pass_refered = nullptr;
SparseTableInfo &info = get_table_info(table_id);
info.pass_mutex.lock();
auto it = info.refered_feas.find(id);
if (it == info.refered_feas.end()) {
pass_refered = new SparsePassValues;
pass_refered->wg.clear();
int total_ref = info.shard_num * (_rank_num - 1);
pass_refered->wg.add(total_ref);
pass_refered->values = new SparseShardValues;
pass_refered->values->resize(info.shard_num);
info.refered_feas[id].reset(pass_refered);
VLOG(0) << "add request_handler table id=" << table_id
<< ", pass id=" << GET_PASS_ID(id) << ", shard_id=" << shard_id
<< ", total_ref=" << total_ref;
} else {
pass_refered = it->second.get();
}
auto &shard_values = (*pass_refered->values)[shard_id];
size_t shard_size = shard_values.keys.size();
shard_values.offsets.push_back(shard_size);
if (num > 0) {
shard_values.keys.resize(num + shard_size);
iar.Read(&shard_values.keys[shard_size], num * sizeof(uint64_t));
shard_values.values.resize(num + shard_size);
}
info.pass_mutex.unlock();
if (num > 0) {
auto f = _thread_pools[shard_id]->Run(
[this, table_id, id, shard_id, num, shard_size, pass_refered](void) {
platform::Timer timeline;
timeline.Start();
auto &shard_values = (*pass_refered->values)[shard_id];
auto *table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.keys = &shard_values.keys[shard_size];
table_context.pull_context.ptr_values =
&shard_values.values[shard_size];
table_context.use_ptr = true;
table_context.num = num;
table_context.shard_id = shard_id;
table_context.pass_id = GET_PASS_ID(id);
table_ptr->Pull(table_context);
timeline.Pause();
VLOG(3) << "end pull remote table id=" << table_id
<< ", pass id=" << GET_PASS_ID(id)
<< ", shard_id=" << shard_id << ", keys count=" << num
<< ", span=" << timeline.ElapsedSec();
// notify done
pass_refered->wg.done();
});
} else {
// zero done
pass_refered->wg.done();
}
// send response
paddle::framework::BinaryArchive oar;
simple::global_rpc_server().send_response(head, oar);
}
// get shard num
PsGraphClient::SparseTableInfo &PsGraphClient::get_table_info(
const size_t &table_id) {
return (*_table_info[table_id].get());
}
// get pass keep keys values
std::shared_ptr<SparseShardValues> PsGraphClient::TakePassSparseReferedValues(
const size_t &table_id, const uint16_t &pass_id, const uint16_t &dim_id) {
SparseTableInfo &info = get_table_info(table_id);
uint32_t id = DIM_PASS_ID(dim_id, pass_id);
SparsePassValues *pass_refered = nullptr;
info.pass_mutex.lock();
auto it = info.refered_feas.find(id);
if (it == info.refered_feas.end()) {
info.pass_mutex.unlock();
VLOG(0) << "table_id=" << table_id
<< ", TakePassSparseReferedValues pass_id=" << pass_id
<< ", dim_id=" << dim_id << " is nullptr";
return nullptr;
}
pass_refered = it->second.get();
info.pass_mutex.unlock();
int cnt = pass_refered->wg.count();
VLOG(0) << "table_id=" << table_id
<< ", begin TakePassSparseReferedValues pass_id=" << pass_id
<< ", dim_id=" << dim_id << " wait count=" << cnt;
pass_refered->wg.wait();
std::shared_ptr<SparseShardValues> shard_ptr;
shard_ptr.reset(pass_refered->values);
pass_refered->values = nullptr;
info.pass_mutex.lock();
info.refered_feas.erase(id);
info.pass_mutex.unlock();
return shard_ptr;
}
} // namespace distributed
} // namespace paddle
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#pragma once
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/barrier.h"
#include "paddle/fluid/framework/threadpool.h"
namespace paddle {
// namespace framework {
// class ThreadPool;
// };
namespace distributed {
namespace simple {
struct RpcMessageHead;
};
struct SparsePassValues {
paddle::framework::WaitGroup wg;
SparseShardValues *values;
};
class PsGraphClient : public PsLocalClient {
typedef std::unordered_map<uint32_t, std::shared_ptr<SparsePassValues>>
SparseFeasReferedMap;
struct SparseTableInfo {
uint32_t shard_num;
std::mutex pass_mutex;
SparseFeasReferedMap refered_feas;
};
public:
PsGraphClient();
virtual ~PsGraphClient();
virtual int32_t Initialize();
virtual void FinalizeWorker();
virtual ::std::future<int32_t> PullSparsePtr(int shard_id,
char **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
uint16_t pass_id,
const uint16_t &dim_id = 0);
virtual std::shared_ptr<SparseShardValues> TakePassSparseReferedValues(
const size_t &table_id, const uint16_t &pass_id, const uint16_t &dim_id);
public:
void request_handler(const simple::RpcMessageHead &head,
paddle::framework::BinaryArchive &iar); // NOLINT
SparseTableInfo &get_table_info(const size_t &table_id);
private:
std::map<uint32_t, std::shared_ptr<SparseTableInfo>> _table_info;
void *_service = nullptr;
int _rank_id = 0;
int _rank_num = 0;
std::vector<std::shared_ptr<framework::ThreadPool>> _thread_pools;
std::vector<std::vector<uint64_t>> _local_shard_keys;
std::vector<std::vector<paddle::framework::BinaryArchive>> _shard_ars;
};
} // namespace distributed
} // namespace paddle
#endif
......@@ -13,11 +13,8 @@
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
// #define pslib_debug_dense_compress
namespace paddle {
namespace distributed {
int32_t PsLocalClient::Initialize() {
......@@ -36,13 +33,11 @@ int32_t PsLocalClient::Initialize() {
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
const std::string threshold) {
// TODO // NOLINT
return done();
}
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
const std::string& mode) {
// TODO // NOLINT
for (auto& it : _table_map) {
Load(it.first, epoch, mode);
}
......@@ -51,7 +46,6 @@ int32_t PsLocalClient::Initialize() {
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO // NOLINT
auto* table_ptr = GetTable(table_id);
table_ptr->Load(epoch, mode);
return done();
......@@ -59,7 +53,6 @@ int32_t PsLocalClient::Initialize() {
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
const std::string& mode) {
// TODO // NOLINT
for (auto& it : _table_map) {
Save(it.first, epoch, mode);
}
......@@ -68,19 +61,14 @@ int32_t PsLocalClient::Initialize() {
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO // NOLINT
auto* table_ptr = GetTable(table_id);
table_ptr->Flush();
table_ptr->Save(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Clear() {
// TODO // NOLINT
return done();
}
::std::future<int32_t> PsLocalClient::Clear() { return done(); }
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
// TODO // NOLINT
return done();
}
......@@ -234,42 +222,14 @@ int32_t PsLocalClient::Initialize() {
return done();
}
// ::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
// // FIXME
// // auto timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
// for (int i = 0; i < num; ++i) {
// memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
// offset += value_size;
// }
//
// // return fut;
// return done();
//}
::std::future<int32_t> PsLocalClient::PullSparsePtr(int shard_id,
char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num,
uint16_t pass_id) {
::std::future<int32_t> PsLocalClient::PullSparsePtr(
int shard_id,
char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num,
uint16_t pass_id,
const uint16_t& /**dim_id*/) {
// FIXME
// auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
......
......@@ -32,26 +32,26 @@ class PsLocalClient : public PSClient {
return 0;
}
::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode) override;
::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode) override;
::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
::std::future<int32_t> Clear() override;
::std::future<int32_t> Clear(uint32_t table_id) override;
::std::future<int32_t> StopServer() override;
void FinalizeWorker() override {}
virtual ::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold);
virtual ::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode);
virtual ::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode);
virtual ::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode);
virtual ::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode);
virtual ::std::future<int32_t> Clear();
virtual ::std::future<int32_t> Clear(uint32_t table_id);
virtual ::std::future<int32_t> StopServer();
virtual void FinalizeWorker() {}
virtual ::std::future<int32_t> PullDense(Region* regions,
size_t region_num,
size_t table_id);
......@@ -76,12 +76,13 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual ::std::future<int32_t> PullSparsePtr(int shard_id,
virtual ::std::future<int32_t> PullSparsePtr(const int shard_id,
char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num,
uint16_t pass_id);
uint16_t pass_id,
const uint16_t& dim_id = 0);
virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id);
......@@ -147,9 +148,9 @@ class PsLocalClient : public PSClient {
return 0;
}
::std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string& msg) override {
virtual ::std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string& msg) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -158,23 +159,25 @@ class PsLocalClient : public PSClient {
}
virtual size_t GetServerNums() { return 1; }
std::future<int32_t> PushDenseRawGradient(int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) override;
std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* callback) override;
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
const uint64_t* keys,
const float** update_values,
uint32_t num,
void* done,
int pserver_idx) override {
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback);
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* callback);
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id,
const uint64_t* keys,
const float** update_values,
uint32_t num,
void* done,
int pserver_idx) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -182,11 +185,11 @@ class PsLocalClient : public PSClient {
return fut;
}
std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* done) override {
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* done) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -194,8 +197,8 @@ class PsLocalClient : public PSClient {
return fut;
}
private:
int32_t Initialize() override;
protected:
virtual int32_t Initialize();
std::future<int32_t> done() {
std::shared_ptr<std::promise<int32_t>> prom =
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed.simple;
option cc_generic_services = true;
message SimpleRpcRequest {
required int64 archive_size = 1;
};
message SimpleRpcResponse {
required int64 archive_size = 1;
};
service SimpleRpcService {
rpc handle_request(SimpleRpcRequest) returns (SimpleRpcResponse);
};
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#include "paddle/fluid/distributed/ps/service/simple_rpc/baidu_rpc_server.h"
#include <brpc/channel.h>
#include <brpc/server.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/phi/core/enforce.h"
namespace brpc {
DECLARE_uint64(max_body_size);
DECLARE_int64(socket_max_unwritten_bytes);
} // namespace brpc
namespace paddle {
namespace distributed {
namespace simple {
static const int MIN_SERVER_LISTEN_PORT = 20000;
static const int MAX_SERVER_LISTEN_PORT = 65535;
static const int64_t MAX_RPC_BODY_SIZE = 10 * 1024 * 1024 * 1024L;
class BRpcReqService : public RpcService {
public:
BRpcReqService(RpcCallback callback, bool simplex)
: RpcService(callback), _simplex(simplex) {}
void set_handler(brpc::Controller *cntl,
google::protobuf::Closure *done,
SimpleRpcResponse *response) {
_cntl = cntl;
_response = response;
_done = done;
}
bool is_simplex(void) { return _simplex; }
butil::IOBuf &response_attachment(void) {
return _cntl->response_attachment();
}
void done(int64_t size) {
_response->set_archive_size(size);
_done->Run();
}
private:
bool _simplex = true;
brpc::Controller *_cntl = nullptr;
SimpleRpcResponse *_response = nullptr;
google::protobuf::Closure *_done = nullptr;
};
/**
* @Brief service 处理
*/
class BRpcServiceImpl : public SimpleRpcService {
public:
explicit BRpcServiceImpl(int rank_id) : _rank_id(rank_id) {}
virtual ~BRpcServiceImpl() {}
virtual void handle_request(google::protobuf::RpcController *cntl_base,
const SimpleRpcRequest *baidu_rpc_request,
SimpleRpcResponse *baidu_rpc_response,
google::protobuf::Closure *done) {
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
uint64_t size = baidu_rpc_request->archive_size();
butil::IOBuf &attach = cntl->request_attachment();
BinaryArchive iar;
iar.Reserve(size);
uint64_t attach_size = attach.cutn(iar.Buffer(), size);
PADDLE_ENFORCE_EQ(
(attach_size == size),
true,
phi::errors::PreconditionNotMet("Request size is wrong."));
iar.AdvanceFinish(size);
RpcMessageHead head;
iar.ReadBack(&head, sizeof(RpcMessageHead));
if (head.message_type == RpcMessageHead::REQUEST) {
PADDLE_ENFORCE_EQ(
(head.server_id == _rank_id),
true,
phi::errors::PreconditionNotMet(
"Server id %d not equal rank id %d.", head.server_id, _rank_id));
BRpcReqService *service =
reinterpret_cast<BRpcReqService *>(head.service);
service->set_handler(cntl, done, baidu_rpc_response);
service->callback()(head, iar);
// 如果只单向由client->server通信,就直接将应答为0
if (service->is_simplex()) {
baidu_rpc_response->set_archive_size(0);
done->Run();
}
return;
}
if (head.message_type == RpcMessageHead::RESPONSE) {
PADDLE_ENFORCE_EQ(
(head.client_id == _rank_id),
true,
phi::errors::PreconditionNotMet(
"Client id %d not equal rank id %d.", head.client_id, _rank_id));
head.request->callback()(head, iar);
delete head.request;
PADDLE_ENFORCE_NE(
head.service,
0,
phi::errors::PreconditionNotMet("Service should not be nullptr."));
head.service->decrease_request();
} else {
LOG(FATAL) << "Unknown message type";
}
baidu_rpc_response->set_archive_size(0);
done->Run();
}
private:
int _rank_id = 0;
};
BaiduRpcServer::BaiduRpcServer() : RpcServer(), _server(nullptr) {
/** 因为RPC这里主要用于pull sparse和data shuffle数据量比较大,
* 单个pass的key超过几亿,发送数据单包大小是存在超过1G以上的可能,
* 需要设baidu rpc最大可发送包的大小
*/
if (brpc::FLAGS_max_body_size < MAX_RPC_BODY_SIZE) {
brpc::FLAGS_max_body_size = MAX_RPC_BODY_SIZE;
}
if (brpc::FLAGS_socket_max_unwritten_bytes < MAX_RPC_BODY_SIZE) {
brpc::FLAGS_socket_max_unwritten_bytes = MAX_RPC_BODY_SIZE;
}
_server.reset(new brpc::Server);
_ref = 0;
}
BaiduRpcServer::~BaiduRpcServer() {}
/**
* @brief 初始化服务
*/
void BaiduRpcServer::initialize() {
if (++_ref > 1) {
LOG(WARNING) << "already initialize rpc server";
return;
}
PADDLE_ENFORCE_NE(
_gloo, NULL, phi::errors::PreconditionNotMet("Gloo not allow nullptr."));
_gloo->Barrier();
_server->set_version(google::VersionString());
brpc::ServerOptions option;
option.idle_timeout_sec = _connection_idle_timeout_sec;
option.auth = nullptr;
option.num_threads = _thread_num;
_service_impl = std::make_shared<BRpcServiceImpl>(_gloo->Rank());
int ret =
_server->AddService(_service_impl.get(), brpc::SERVER_DOESNT_OWN_SERVICE);
PADDLE_ENFORCE_EQ(
(ret == 0),
true,
phi::errors::PreconditionNotMet("Failed to add BRpcServiceImpl."));
brpc::PortRange range(MIN_SERVER_LISTEN_PORT, MAX_SERVER_LISTEN_PORT);
auto server_ip = butil::ip2str(butil::int2ip(_ips[_gloo->Rank()]));
ret = _server->Start(server_ip.c_str(), range, &option);
PADDLE_ENFORCE_EQ(
(ret == 0),
true,
phi::errors::PreconditionNotMet("Fail to start BaiduRpcServer."));
butil::EndPoint ep = _server->listen_address();
std::vector<int> ports = _gloo->AllGather(ep.port);
auto new_channel = [this, &ports](int i) {
brpc::Channel *channel_ptr = new brpc::Channel();
brpc::ChannelOptions option;
option.connection_type = _connection_type;
option.auth = nullptr;
option.timeout_ms = _client_timeout_ms;
option.connect_timeout_ms = _connect_timeout_ms;
option.max_retry = _max_retry;
butil::EndPoint cep;
cep.ip = butil::int2ip(_ips[i]);
cep.port = ports[i];
if (channel_ptr->Init(cep, &option) != 0) {
LOG(FATAL) << "Failed to initialize channel";
}
LOG(INFO) << "connected to " << butil::endpoint2str(cep).c_str();
return channel_ptr;
};
for (int i = 0; i < _gloo->Size(); i++) {
_senders.emplace_back(new SimpleRpcService_Stub(
new_channel(i), google::protobuf::Service::STUB_OWNS_CHANNEL));
}
_gloo->Barrier();
LOG(WARNING) << "initialize rpc server : " << butil::endpoint2str(ep).c_str();
}
/**
* @brief 停止服务
*/
void BaiduRpcServer::finalize() {
if (--_ref > 0) {
LOG(WARNING) << "finalize running by other";
return;
}
_gloo->Barrier();
_server->Stop(60000);
_server->Join();
_gloo->Barrier();
LOG(INFO) << "finalize rpc server";
}
/**
* @brief 客户端发送回的应答
*/
static void handle_baidu_rpc_response(brpc::Controller *cntl,
SimpleRpcResponse *baidu_rpc_response) {
size_t size = baidu_rpc_response->archive_size();
if (size > 0) {
BinaryArchive iar;
iar.Reserve(size);
size_t attach_size = cntl->response_attachment().cutn(iar.Buffer(), size);
PADDLE_ENFORCE_EQ(
(attach_size == size),
true,
phi::errors::PreconditionNotMet("Request size is wrong."));
iar.AdvanceFinish(size);
RpcMessageHead head;
iar.ReadBack(&head, sizeof(RpcMessageHead));
if (head.message_type == RpcMessageHead::RESPONSE) {
head.request->callback()(head, iar);
delete head.request;
PADDLE_ENFORCE_NE(
head.service,
0,
phi::errors::PreconditionNotMet("Service should not be nullptr."));
head.service->decrease_request();
} else {
LOG(FATAL) << "Unknown message type";
}
}
delete baidu_rpc_response;
delete cntl;
}
void BaiduRpcServer::send_request(int server_id,
void *service_,
const size_t n,
BinaryArchive *oars,
RpcCallback callback) {
send_request_ex(server_id, 0, service_, n, oars, callback);
}
void BaiduRpcServer::send_request_ex(int server_id,
int consumer_id,
void *service_,
const size_t n,
BinaryArchive *oars,
RpcCallback callback) {
RpcService *service = reinterpret_cast<RpcService *>(service_);
service->increase_request();
RpcMessageHead head;
head.service = service->remote_pointer(server_id);
head.request = new RpcRequest(callback);
head.client_id = _gloo->Rank();
head.server_id = server_id;
head.message_type = RpcMessageHead::REQUEST;
head.consumer_id = consumer_id;
send_message(server_id, head, n, oars);
}
void BaiduRpcServer::send_response(RpcMessageHead head,
const size_t n,
BinaryArchive *oars) {
PADDLE_ENFORCE_EQ(
(head.server_id == _gloo->Rank()),
true,
phi::errors::PreconditionNotMet("Server_id not equal rank id."));
PADDLE_ENFORCE_EQ((head.client_id >= 0 && head.client_id < _gloo->Size()),
true,
phi::errors::PreconditionNotMet("The client id is error."));
BRpcReqService *service = reinterpret_cast<BRpcReqService *>(head.service);
head.service = head.service->remote_pointer(head.client_id);
head.message_type = RpcMessageHead::RESPONSE;
// 如果只单向由client->server通信,就统一走数据发送接口
if (service->is_simplex()) {
send_message(head.client_id, head, n, oars);
} else {
// 这种情况只适合在callback里面直接调用send_response方式
auto &ar = service->response_attachment();
for (size_t i = 0; i < n; i++) {
auto &oar = oars[i];
if (oar.Length() == 0) {
continue;
}
ar.append(oar.Buffer(), oar.Length());
}
ar.append(&head, sizeof(head));
service->done(ar.length());
}
}
void BaiduRpcServer::send_message(int send_id,
const RpcMessageHead &head,
const size_t n,
BinaryArchive *oars) {
brpc::Controller *cntl = new brpc::Controller();
cntl->ignore_eovercrowded();
auto &ar = cntl->request_attachment();
for (size_t i = 0; i < n; i++) {
auto &oar = oars[i];
if (oar.Length() == 0) {
continue;
}
ar.append(oar.Buffer(), oar.Length());
}
ar.append(&head, sizeof(head));
SimpleRpcRequest baidu_rpc_request;
baidu_rpc_request.set_archive_size(ar.length());
cntl->set_log_id(_gloo->Rank());
SimpleRpcResponse *baidu_rpc_response = new SimpleRpcResponse();
google::protobuf::Closure *done = google::protobuf::NewCallback(
&handle_baidu_rpc_response, cntl, baidu_rpc_response);
_senders[send_id]->handle_request(
cntl, &baidu_rpc_request, baidu_rpc_response, done);
}
/**
* @Brief 主要处理baidu-rpc异步响应
*/
void *BaiduRpcServer::add_service(RpcCallback callback, bool simplex) {
return new BRpcReqService(std::move(callback), simplex);
}
} // namespace simple
} // namespace distributed
} // namespace paddle
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#pragma once
#include <memory> // std::unique_ptr
#include <string> // std::string
#include <vector> // std::vector
#include "paddle/fluid/distributed/ps/service/simple_brpc.pb.h" // RpcRequest
#include "paddle/fluid/distributed/ps/service/simple_rpc/rpc_server.h" // RpcServerCallBack
namespace brpc {
class Channel;
class Controller;
class Server;
} // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace distributed {
namespace simple {
/**
* @Brief service 处理
*/
class BRpcServiceImpl;
/**
* @brief baidu rpc
*/
class BaiduRpcServer : public RpcServer {
public:
BaiduRpcServer();
~BaiduRpcServer();
void initialize();
void finalize();
void send_request(int server_id,
void *service_,
const size_t n,
BinaryArchive *oars,
RpcCallback callback);
void send_response(RpcMessageHead head, const size_t n, BinaryArchive *oars);
void send_request_ex(int server_id,
int consumer_id,
void *service_,
const size_t n,
BinaryArchive *oars,
RpcCallback callback);
public:
/**
* @Brief 主要处理baidu-rpc异步响应
*/
virtual void *add_service(RpcCallback callback, bool simplex = true);
private:
void send_message(int send_id,
const RpcMessageHead &head,
const size_t n,
BinaryArchive *oars);
private:
std::shared_ptr<BRpcServiceImpl> _service_impl;
std::shared_ptr<brpc::Server> _server;
std::vector<std::unique_ptr<SimpleRpcService_Stub>> _senders;
std::atomic<int> _ref;
};
} // namespace simple
} // namespace distributed
} // namespace paddle
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#include "paddle/fluid/distributed/ps/service/simple_rpc/rpc_server.h"
#include <arpa/inet.h>
#include <net/if.h>
#include <netinet/in.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include "paddle/fluid/distributed/ps/service/simple_rpc/baidu_rpc_server.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace distributed {
namespace simple {
RpcService::RpcService(RpcCallback callback) : _callback(std::move(callback)) {
auto gloo = paddle::framework::GlooWrapper::GetInstance();
void* my_ptr = reinterpret_cast<void*>(this);
std::vector<void*> ids = gloo->AllGather(my_ptr);
_remote_ptrs.assign(gloo->Size(), NULL);
for (int i = 0; i < gloo->Size(); ++i) {
_remote_ptrs[i] = reinterpret_cast<RpcService*>(ids[i]);
}
gloo->Barrier();
}
RpcService::~RpcService() {
paddle::framework::GlooWrapper::GetInstance()->Barrier();
if (_request_counter != 0) {
fprintf(stderr, "check request counter is not zero");
}
}
inline uint32_t get_broadcast_ip(char* ethname) {
struct ifreq ifr;
int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
strncpy(ifr.ifr_name, ethname, IFNAMSIZ - 1);
if (ioctl(sockfd, SIOCGIFBRDADDR, &ifr) == -1) {
return 0;
}
close(sockfd);
return ((struct sockaddr_in*)&ifr.ifr_addr)->sin_addr.s_addr;
}
inline std::string get_local_ip_internal() {
int sockfd = -1;
char buf[512];
struct ifconf ifconf;
struct ifreq* ifreq;
ifconf.ifc_len = 512;
ifconf.ifc_buf = buf;
sockfd = socket(AF_INET, SOCK_DGRAM, 0);
PADDLE_ENFORCE_EQ((sockfd >= 0),
true,
phi::errors::PreconditionNotMet("Socket should be >= 0."));
int ret = ioctl(sockfd, SIOCGIFCONF, &ifconf);
PADDLE_ENFORCE_EQ(
(ret >= 0),
true,
phi::errors::PreconditionNotMet("Ioctl ret should be >= 0."));
ret = close(sockfd);
PADDLE_ENFORCE_EQ(
(0 == ret),
true,
phi::errors::PreconditionNotMet("Close call should return 0."));
ifreq = (struct ifreq*)buf;
for (int i = 0; i < static_cast<int>(ifconf.ifc_len / sizeof(struct ifreq));
i++) {
std::string ip =
inet_ntoa(((struct sockaddr_in*)&ifreq->ifr_addr)->sin_addr);
if (strncmp(ifreq->ifr_name, "lo", 2) == 0 ||
strncmp(ifreq->ifr_name, "docker", 6) == 0) {
fprintf(stdout,
"skip interface: [%s], ip: %s\n",
ifreq->ifr_name,
ip.c_str());
ifreq++;
continue;
}
if (get_broadcast_ip(ifreq->ifr_name) == 0) {
fprintf(stdout,
"skip interface: [%s], ip: %s\n",
ifreq->ifr_name,
ip.c_str());
ifreq++;
continue;
}
if (ip != "127.0.0.1") {
fprintf(stdout,
"used interface: [%s], ip: %s\n",
ifreq->ifr_name,
ip.c_str());
return ip;
}
ifreq++;
}
fprintf(stdout, "not found, use ip: 127.0.0.1\n");
return "127.0.0.1";
}
RpcServer::RpcServer() {
_gloo = paddle::framework::GlooWrapper::GetInstance().get();
std::string ip = get_local_ip_internal();
uint32_t int_ip = inet_addr(ip.c_str());
_ips = _gloo->AllGather(int_ip);
}
RpcServer::~RpcServer() {
if (_gloo != NULL) {
_gloo = NULL;
}
}
void RpcServer::set_connection_num(int n) {
_gloo->Barrier();
if (n < _gloo->Size()) {
n = _gloo->Size();
}
PADDLE_ENFORCE_EQ(
(n >= 1),
true,
phi::errors::InvalidArgument("Connect num need more than 1."));
_conn_num = n;
}
void RpcServer::set_thread_num(int n) {
if (n < _gloo->Size()) {
n = _gloo->Size();
}
PADDLE_ENFORCE_EQ(
(n >= 1),
true,
phi::errors::InvalidArgument("Thread num need more than 1."));
_thread_num = n;
}
void* RpcServer::add_service(RpcCallback callback, bool simplex) {
return new RpcService(std::move(callback));
}
void RpcServer::remove_service(void* service) {
delete reinterpret_cast<RpcService*>(service);
}
RpcServer& global_rpc_server() {
static BaiduRpcServer server;
return server;
}
} // namespace simple
} // namespace distributed
} // namespace paddle
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
#pragma once
#include <glog/logging.h>
#include <atomic>
#include <functional>
#include <mutex>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/archive.h"
namespace paddle {
namespace framework {
class GlooWrapper;
}
namespace distributed {
namespace simple {
using BinaryArchive = paddle::framework::BinaryArchive;
class RpcService;
class RpcRequest;
struct RpcMessageHead {
RpcService *service;
RpcRequest *request;
int client_id;
int server_id;
enum { REQUEST, RESPONSE } message_type;
int consumer_id;
};
typedef std::function<void(const RpcMessageHead &, BinaryArchive &)>
RpcCallback; // NOLINT
class RpcService {
public:
RpcService() {}
explicit RpcService(RpcCallback callback);
~RpcService();
RpcService *remote_pointer(int rank) { return _remote_ptrs[rank]; }
RpcCallback &callback() { return _callback; }
void increase_request() { ++_request_counter; }
void decrease_request() { --_request_counter; }
protected:
std::vector<RpcService *> _remote_ptrs;
RpcCallback _callback;
std::atomic<int> _request_counter{0};
};
class RpcRequest {
public:
explicit RpcRequest(RpcCallback callback) : _callback(std::move(callback)) {}
RpcCallback &callback() { return _callback; }
protected:
RpcCallback _callback;
};
class RpcServer {
public:
RpcServer();
virtual ~RpcServer();
public:
void set_connection_num(int n);
void set_thread_num(int n);
void set_connection_idle_timeout_sec(int timeout_sec) {
_connection_idle_timeout_sec = timeout_sec;
}
void set_max_retry(int retry_cnt) { _max_retry = retry_cnt; }
void set_connect_timeout_ms(int timeout_ms) {
_connect_timeout_ms = timeout_ms;
}
void set_connection_type(const std::string &conn_type) {
_connection_type = conn_type;
}
void set_client_timeout_ms(int timeout_ms) {
_client_timeout_ms = timeout_ms;
}
public:
virtual void initialize() = 0;
virtual void finalize() = 0;
virtual void send_request(int server_id,
void *service_,
const size_t n,
BinaryArchive *oars,
RpcCallback callback) = 0;
virtual void send_response(RpcMessageHead head,
const size_t n,
BinaryArchive *oars) = 0;
virtual void send_request_ex(int server_id,
int consumer_id,
void *service_,
const size_t n,
BinaryArchive *oars,
RpcCallback callback) = 0;
public:
virtual void *add_service(RpcCallback callback, bool simplex = true);
virtual void remove_service(void *service);
public:
void send_request_wrapper(int server_id,
void *service,
BinaryArchive &oar, // NOLINT
RpcCallback callback) {
send_request(server_id, service, 1, &oar, std::move(callback));
}
void send_request_consumer(int server_id,
int consumer_id,
void *service,
BinaryArchive &oar, // NOLINT
RpcCallback callback) {
send_request_ex(
server_id, consumer_id, service, 1, &oar, std::move(callback));
}
void send_response(RpcMessageHead head, BinaryArchive &oar) { // NOLINT
send_response(head, 1, &oar);
}
protected:
int _conn_num = 1;
int _thread_num = 10;
std::vector<uint32_t> _ips;
paddle::framework::GlooWrapper *_gloo = NULL;
// configure for rpc
int _connection_idle_timeout_sec = 3600;
int _max_retry = 1000;
int _connect_timeout_ms = -1;
std::string _connection_type = "pooled";
int _client_timeout_ms = -1;
};
extern RpcServer &global_rpc_server();
} // namespace simple
} // namespace distributed
} // namespace paddle
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
namespace paddle {
namespace distributed {
struct GraphPsShardValues {
std::vector<size_t> offsets;
std::vector<uint64_t> keys;
std::vector<char*> values;
void clear() {
offsets.clear();
keys.clear();
values.clear();
offsets.shrink_to_fit();
keys.shrink_to_fit();
values.shrink_to_fit();
}
};
typedef std::vector<GraphPsShardValues> SparseShardValues;
} // namespace distributed
} // namespace paddle
......@@ -36,6 +36,7 @@ DECLARE_bool(graph_load_in_parallel);
DECLARE_bool(graph_get_neighbor_id);
DECLARE_int32(gpugraph_storage_mode);
DECLARE_uint64(gpugraph_slot_feasign_max_num);
DECLARE_bool(graph_metapath_split_opt);
namespace paddle {
namespace distributed {
......@@ -94,8 +95,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
paddle::framework::GpuPsFeaInfo x;
std::vector<uint64_t> feature_ids;
for (size_t j = 0; j < bags[i].size(); j++) {
// TODO(danleifeng): use FEATURE_TABLE instead
Node *v = find_node(1, bags[i][j]);
Node *v = find_node(GraphTableType::FEATURE_TABLE, bags[i][j]);
node_id = bags[i][j];
if (v == NULL) {
x.feature_size = 0;
......@@ -192,7 +192,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
for (size_t j = 0; j < bags[i].size(); j++) {
auto node_id = bags[i][j];
node_array[i][j] = node_id;
Node *v = find_node(0, idx, node_id);
Node *v = find_node(GraphTableType::EDGE_TABLE, idx, node_id);
if (v != nullptr) {
info_array[i][j].neighbor_offset = edge_array[i].size();
info_array[i][j].neighbor_size = v->get_neighbor_size();
......@@ -540,14 +540,18 @@ void GraphTable::release_graph_edge() {
void GraphTable::release_graph_node() {
build_graph_type_keys();
if (FLAGS_gpugraph_storage_mode != paddle::framework::GpuGraphStorageMode::
MEM_EMB_FEATURE_AND_GPU_GRAPH &&
FLAGS_gpugraph_storage_mode != paddle::framework::GpuGraphStorageMode::
SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) {
if (FLAGS_graph_metapath_split_opt) {
clear_feature_shard();
} else {
merge_feature_shard();
feature_shrink_to_fit();
if (FLAGS_gpugraph_storage_mode != paddle::framework::GpuGraphStorageMode::
MEM_EMB_FEATURE_AND_GPU_GRAPH &&
FLAGS_gpugraph_storage_mode != paddle::framework::GpuGraphStorageMode::
SSD_EMB_AND_MEM_FEATURE_GPU_GRAPH) {
clear_feature_shard();
} else {
merge_feature_shard();
feature_shrink_to_fit();
}
}
}
#endif
......@@ -1264,10 +1268,12 @@ int32_t GraphTable::parse_type_to_typepath(
return 0;
}
int32_t GraphTable::parse_edge_and_load(std::string etype2files,
std::string graph_data_local_path,
int part_num,
bool reverse) {
int32_t GraphTable::parse_edge_and_load(
std::string etype2files,
std::string graph_data_local_path,
int part_num,
bool reverse,
const std::vector<bool> &is_reverse_edge_map) {
std::vector<std::string> etypes;
std::unordered_map<std::string, std::string> edge_to_edgedir;
int res = parse_type_to_typepath(
......@@ -1287,6 +1293,17 @@ int32_t GraphTable::parse_edge_and_load(std::string etype2files,
tasks.push_back(
_shards_task_pool[i % task_pool_size_]->enqueue([&, i, this]() -> int {
std::string etype_path = edge_to_edgedir[etypes[i]];
bool only_load_reverse_edge = false;
if (!reverse) {
only_load_reverse_edge = is_reverse_edge_map[i];
}
if (only_load_reverse_edge) {
VLOG(1) << "only_load_reverse_edge is True, etype[" << etypes[i]
<< "], file_path[" << etype_path << "]";
} else {
VLOG(1) << "only_load_reverse_edge is False, etype[" << etypes[i]
<< "], file_path[" << etype_path << "]";
}
auto etype_path_list = paddle::framework::localfs_list(etype_path);
std::string etype_path_str;
if (part_num > 0 &&
......@@ -1299,10 +1316,14 @@ int32_t GraphTable::parse_edge_and_load(std::string etype2files,
etype_path_str =
paddle::string::join_strings(etype_path_list, delim);
}
this->load_edges(etype_path_str, false, etypes[i]);
if (reverse) {
std::string r_etype = get_inverse_etype(etypes[i]);
this->load_edges(etype_path_str, true, r_etype);
if (!only_load_reverse_edge) {
this->load_edges(etype_path_str, false, etypes[i]);
if (reverse) {
std::string r_etype = get_inverse_etype(etypes[i]);
this->load_edges(etype_path_str, true, r_etype);
}
} else {
this->load_edges(etype_path_str, true, etypes[i]);
}
return 0;
}));
......@@ -1357,11 +1378,13 @@ int32_t GraphTable::parse_node_and_load(std::string ntype2files,
return 0;
}
int32_t GraphTable::load_node_and_edge_file(std::string etype2files,
std::string ntype2files,
std::string graph_data_local_path,
int part_num,
bool reverse) {
int32_t GraphTable::load_node_and_edge_file(
std::string etype2files,
std::string ntype2files,
std::string graph_data_local_path,
int part_num,
bool reverse,
const std::vector<bool> &is_reverse_edge_map) {
std::vector<std::string> etypes;
std::unordered_map<std::string, std::string> edge_to_edgedir;
int res = parse_type_to_typepath(
......@@ -1391,6 +1414,17 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype2files,
_shards_task_pool[i % task_pool_size_]->enqueue([&, i, this]() -> int {
if (i < etypes.size()) {
std::string etype_path = edge_to_edgedir[etypes[i]];
bool only_load_reverse_edge = false;
if (!reverse) {
only_load_reverse_edge = is_reverse_edge_map[i];
}
if (only_load_reverse_edge) {
VLOG(1) << "only_load_reverse_edge is True, etype[" << etypes[i]
<< "], file_path[" << etype_path << "]";
} else {
VLOG(1) << "only_load_reverse_edge is False, etype[" << etypes[i]
<< "], file_path[" << etype_path << "]";
}
auto etype_path_list = paddle::framework::localfs_list(etype_path);
std::string etype_path_str;
if (part_num > 0 &&
......@@ -1403,10 +1437,14 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype2files,
etype_path_str =
paddle::string::join_strings(etype_path_list, delim);
}
this->load_edges(etype_path_str, false, etypes[i]);
if (reverse) {
std::string r_etype = get_inverse_etype(etypes[i]);
this->load_edges(etype_path_str, true, r_etype);
if (!only_load_reverse_edge) {
this->load_edges(etype_path_str, false, etypes[i]);
if (reverse) {
std::string r_etype = get_inverse_etype(etypes[i]);
this->load_edges(etype_path_str, true, r_etype);
}
} else {
this->load_edges(etype_path_str, true, etypes[i]);
}
} else {
std::string npath = node_to_nodedir[ntypes[0]];
......@@ -1454,14 +1492,15 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype2files,
}
int32_t GraphTable::get_nodes_ids_by_ranges(
int type_id,
GraphTableType table_type,
int idx,
std::vector<std::pair<int, int>> ranges,
std::vector<uint64_t> &res) {
std::mutex mutex;
int start = 0, end, index = 0, total_size = 0;
res.clear();
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx]
: feature_shards[idx];
std::vector<std::future<size_t>> tasks;
for (size_t i = 0;
i < shards.size() && index < static_cast<int>(ranges.size());
......@@ -1730,7 +1769,8 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_edge_file(
local_valid_count++;
}
VLOG(2) << local_count << " edges are loaded from filepath->" << path;
VLOG(2) << local_valid_count << "/" << local_count
<< " edges are loaded from filepath->" << path;
return {local_count, local_valid_count};
}
......@@ -1814,14 +1854,15 @@ int32_t GraphTable::load_edges(const std::string &path,
return 0;
}
Node *GraphTable::find_node(int type_id, uint64_t id) {
Node *GraphTable::find_node(GraphTableType table_type, uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
Node *node = nullptr;
size_t index = shard_id - shard_start;
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
auto &search_shards =
table_type == GraphTableType::EDGE_TABLE ? edge_shards : feature_shards;
for (auto &search_shard : search_shards) {
PADDLE_ENFORCE_NOT_NULL(search_shard[index],
paddle::platform::errors::InvalidArgument(
......@@ -1834,13 +1875,15 @@ Node *GraphTable::find_node(int type_id, uint64_t id) {
return node;
}
Node *GraphTable::find_node(int type_id, int idx, uint64_t id) {
Node *GraphTable::find_node(GraphTableType table_type, int idx, uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
size_t index = shard_id - shard_start;
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE
? edge_shards[idx]
: feature_shards[idx];
PADDLE_ENFORCE_NOT_NULL(search_shards[index],
paddle::platform::errors::InvalidArgument(
"search_shard[%d] should not be null.", index));
......@@ -1856,22 +1899,25 @@ uint32_t GraphTable::get_thread_pool_index_by_shard_index(
return shard_index % shard_num_per_server % task_pool_size_;
}
int32_t GraphTable::clear_nodes(int type_id, int idx) {
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
int32_t GraphTable::clear_nodes(GraphTableType table_type, int idx) {
auto &search_shards = table_type == GraphTableType::EDGE_TABLE
? edge_shards[idx]
: feature_shards[idx];
for (size_t i = 0; i < search_shards.size(); i++) {
search_shards[i]->clear();
}
return 0;
}
int32_t GraphTable::random_sample_nodes(int type_id,
int32_t GraphTable::random_sample_nodes(GraphTableType table_type,
int idx,
int sample_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
int total_size = 0;
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
for (size_t i = 0; i < shards.size(); i++) {
auto &shards = table_type == GraphTableType::EDGE_TABLE ? edge_shards[idx]
: feature_shards[idx];
for (int i = 0; i < (int)shards.size(); i++) {
total_size += shards[i]->get_size();
}
if (sample_size > total_size) sample_size = total_size;
......@@ -1926,7 +1972,7 @@ int32_t GraphTable::random_sample_nodes(int type_id,
}
for (auto &pair : first_half) second_half.push_back(pair);
std::vector<uint64_t> res;
get_nodes_ids_by_ranges(type_id, idx, second_half, res);
get_nodes_ids_by_ranges(table_type, idx, second_half, res);
actual_size = res.size() * sizeof(uint64_t);
buffer.reset(new char[actual_size]);
char *pointer = buffer.get();
......@@ -1975,7 +2021,7 @@ int32_t GraphTable::random_sample_neighbors(
index++;
} else {
node_id = id_list[i][k].node_key;
Node *node = find_node(0, idx, node_id);
Node *node = find_node(GraphTableType::EDGE_TABLE, idx, node_id);
int idy = seq_id[i][k];
int &actual_size = actual_sizes[idy];
if (node == nullptr) {
......@@ -2046,7 +2092,7 @@ int32_t GraphTable::get_node_feat(int idx,
uint64_t node_id = node_ids[idy];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, idy, node_id]() -> int {
Node *node = find_node(1, idx, node_id);
Node *node = find_node(GraphTableType::FEATURE_TABLE, idx, node_id);
if (node == nullptr) {
return 0;
......@@ -2205,7 +2251,7 @@ int GraphTable::parse_feature(int idx,
return 0;
}
} else {
VLOG(2) << "feature_name[" << name << "] is not in feat_id_map, ntype_id["
VLOG(4) << "feature_name[" << name << "] is not in feat_id_map, ntype_id["
<< idx << "] feat_id_map_size[" << feat_id_map.size() << "]";
}
......@@ -2245,11 +2291,12 @@ class MergeShardVector {
std::vector<std::vector<uint64_t>> *_shard_keys;
};
int GraphTable::get_all_id(int type_id,
int GraphTable::get_all_id(GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
auto &search_shards =
table_type == GraphTableType::EDGE_TABLE ? edge_shards : feature_shards;
std::vector<std::future<size_t>> tasks;
for (size_t idx = 0; idx < search_shards.size(); idx++) {
for (size_t j = 0; j < search_shards[idx].size(); j++) {
......@@ -2271,9 +2318,12 @@ int GraphTable::get_all_id(int type_id,
}
int GraphTable::get_all_neighbor_id(
int type_id, int slice_num, std::vector<std::vector<uint64_t>> *output) {
GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
auto &search_shards =
table_type == GraphTableType::EDGE_TABLE ? edge_shards : feature_shards;
std::vector<std::future<size_t>> tasks;
for (size_t idx = 0; idx < search_shards.size(); idx++) {
for (size_t j = 0; j < search_shards[idx].size(); j++) {
......@@ -2294,12 +2344,14 @@ int GraphTable::get_all_neighbor_id(
return 0;
}
int GraphTable::get_all_id(int type_id,
int GraphTable::get_all_id(GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE
? edge_shards[idx]
: feature_shards[idx];
std::vector<std::future<size_t>> tasks;
VLOG(3) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
for (size_t i = 0; i < search_shards.size(); i++) {
......@@ -2320,12 +2372,14 @@ int GraphTable::get_all_id(int type_id,
}
int GraphTable::get_all_neighbor_id(
int type_id,
GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE
? edge_shards[idx]
: feature_shards[idx];
std::vector<std::future<size_t>> tasks;
VLOG(3) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
for (size_t i = 0; i < search_shards.size(); i++) {
......@@ -2347,12 +2401,14 @@ int GraphTable::get_all_neighbor_id(
}
int GraphTable::get_all_feature_ids(
int type_id,
GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE
? edge_shards[idx]
: feature_shards[idx];
std::vector<std::future<size_t>> tasks;
for (size_t i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
......@@ -2373,15 +2429,15 @@ int GraphTable::get_all_feature_ids(
int GraphTable::get_node_embedding_ids(
int slice_num, std::vector<std::vector<uint64_t>> *output) {
if (is_load_reverse_edge && !FLAGS_graph_get_neighbor_id) {
return get_all_id(0, slice_num, output);
if (is_load_reverse_edge and !FLAGS_graph_get_neighbor_id) {
return get_all_id(GraphTableType::EDGE_TABLE, slice_num, output);
} else {
get_all_id(0, slice_num, output);
return get_all_neighbor_id(0, slice_num, output);
get_all_id(GraphTableType::EDGE_TABLE, slice_num, output);
return get_all_neighbor_id(GraphTableType::EDGE_TABLE, slice_num, output);
}
}
int32_t GraphTable::pull_graph_list(int type_id,
int32_t GraphTable::pull_graph_list(GraphTableType table_type,
int idx,
int start,
int total_size,
......@@ -2391,7 +2447,9 @@ int32_t GraphTable::pull_graph_list(int type_id,
int step) {
if (start < 0) start = 0;
int size = 0, cur_size;
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
auto &search_shards = table_type == GraphTableType::EDGE_TABLE
? edge_shards[idx]
: feature_shards[idx];
std::vector<std::future<std::vector<Node *>>> tasks;
for (size_t i = 0; i < search_shards.size() && total_size > 0; i++) {
cur_size = search_shards[i]->get_size();
......@@ -2523,7 +2581,7 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
auto graph_feature = graph.graph_feature();
auto node_types = graph.node_types();
auto edge_types = graph.edge_types();
VLOG(0) << "got " << edge_types.size() << "edge types in total";
VLOG(0) << "got " << edge_types.size() << " edge types in total";
feat_id_map.resize(node_types.size());
for (int k = 0; k < edge_types.size(); k++) {
VLOG(0) << "in initialize: get a edge_type " << edge_types[k];
......@@ -2620,7 +2678,7 @@ void GraphTable::build_graph_type_keys() {
for (auto &it : this->feature_to_id) {
auto node_idx = it.second;
std::vector<std::vector<uint64_t>> keys;
this->get_all_id(1, node_idx, 1, &keys);
this->get_all_id(GraphTableType::FEATURE_TABLE, node_idx, 1, &keys);
type_to_index_[node_idx] = cnt;
graph_type_keys_[cnt++] = std::move(keys[0]);
}
......@@ -2631,7 +2689,8 @@ void GraphTable::build_graph_type_keys() {
for (auto &it : this->feature_to_id) {
auto node_idx = it.second;
std::vector<std::vector<uint64_t>> keys;
this->get_all_feature_ids(1, node_idx, 1, &keys);
this->get_all_feature_ids(
GraphTableType::FEATURE_TABLE, node_idx, 1, &keys);
graph_total_keys_.insert(
graph_total_keys_.end(), keys[0].begin(), keys[0].end());
}
......
......@@ -496,6 +496,8 @@ class GraphSampler {
#endif
*/
enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };
class GraphTable : public Table {
public:
GraphTable() {
......@@ -526,7 +528,7 @@ class GraphTable : public Table {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}
virtual int32_t pull_graph_list(int type_id,
virtual int32_t pull_graph_list(GraphTableType table_type,
int idx,
int start,
int size,
......@@ -543,14 +545,14 @@ class GraphTable : public Table {
std::vector<int> &actual_sizes, // NOLINT
bool need_weight);
int32_t random_sample_nodes(int type_id,
int32_t random_sample_nodes(GraphTableType table_type,
int idx,
int sample_size,
std::unique_ptr<char[]> &buffers, // NOLINT
int &actual_sizes); // NOLINT
virtual int32_t get_nodes_ids_by_ranges(
int type_id,
GraphTableType table_type,
int idx,
std::vector<std::pair<int, int>> ranges,
std::vector<uint64_t> &res); // NOLINT
......@@ -564,11 +566,13 @@ class GraphTable : public Table {
std::string ntype2files,
std::string graph_data_local_path,
int part_num,
bool reverse);
bool reverse,
const std::vector<bool> &is_reverse_edge_map);
int32_t parse_edge_and_load(std::string etype2files,
std::string graph_data_local_path,
int part_num,
bool reverse);
bool reverse,
const std::vector<bool> &is_reverse_edge_map);
int32_t parse_node_and_load(std::string ntype2files,
std::string graph_data_local_path,
int part_num);
......@@ -581,21 +585,21 @@ class GraphTable : public Table {
int32_t load_edges(const std::string &path,
bool reverse,
const std::string &edge_type);
int get_all_id(int type,
int get_all_id(GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_neighbor_id(int type,
int get_all_neighbor_id(GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_id(int type,
int get_all_id(GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_neighbor_id(int type_id,
int get_all_neighbor_id(GraphTableType table_type,
int id,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_feature_ids(int type,
int get_all_feature_ids(GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
......@@ -617,13 +621,13 @@ class GraphTable : public Table {
int32_t remove_graph_node(int idx, std::vector<uint64_t> &id_list); // NOLINT
int32_t get_server_index_by_id(uint64_t id);
Node *find_node(int type_id, int idx, uint64_t id);
Node *find_node(int type_id, uint64_t id);
Node *find_node(GraphTableType table_type, int idx, uint64_t id);
Node *find_node(GraphTableType table_type, uint64_t id);
virtual int32_t Pull(TableContext &context) { return 0; } // NOLINT
virtual int32_t Push(TableContext &context) { return 0; } // NOLINT
virtual int32_t clear_nodes(int type, int idx);
virtual int32_t clear_nodes(GraphTableType table_type, int idx);
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Shrink(const std::string &param) { return 0; }
......
......@@ -16,13 +16,15 @@
#if defined _WIN32 || defined __APPLE__
#else
#define __LINUX__
#define _LINUX
#endif
#ifdef __LINUX__
#ifdef _LINUX
#include <pthread.h>
#include <semaphore.h>
#endif
#include <condition_variable>
#include <mutex>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -30,37 +32,38 @@ namespace framework {
class Barrier {
public:
explicit Barrier(int count = 1) {
#ifdef __LINUX__
#ifdef _LINUX
CHECK_GE(count, 1);
CHECK_EQ(pthread_barrier_init(&_barrier, NULL, count), 0);
int ret = pthread_barrier_init(&_barrier, NULL, count);
CHECK_EQ(0, ret);
#endif
}
~Barrier() {
#ifdef __LINUX__
CHECK_EQ(pthread_barrier_destroy(&_barrier), 0);
#ifdef _LINUX
int ret = pthread_barrier_destroy(&_barrier);
CHECK_EQ(0, ret);
#endif
}
void reset(int count) {
#ifdef __LINUX__
#ifdef _LINUX
CHECK_GE(count, 1);
CHECK_EQ(pthread_barrier_destroy(&_barrier), 0);
CHECK_EQ(pthread_barrier_init(&_barrier, NULL, count), 0);
int ret = pthread_barrier_destroy(&_barrier);
CHECK_EQ(0, ret);
ret = pthread_barrier_init(&_barrier, NULL, count);
CHECK_EQ(0, ret);
#endif
}
void wait() {
#ifdef __LINUX__
#ifdef _LINUX
int err = pthread_barrier_wait(&_barrier);
if (err != 0 && err != PTHREAD_BARRIER_SERIAL_THREAD) {
CHECK_EQ(1, 0);
}
err = pthread_barrier_wait(&_barrier);
CHECK_EQ(true, (err == 0 || err == PTHREAD_BARRIER_SERIAL_THREAD));
#endif
}
private:
#ifdef __LINUX__
#ifdef _LINUX
pthread_barrier_t _barrier;
#endif
};
......@@ -81,38 +84,79 @@ auto ignore_signal_call(FUNC &&func, ARGS &&...args) ->
class Semaphore {
public:
Semaphore() {
#ifdef __LINUX__
CHECK_EQ(sem_init(&_sem, 0, 0), 0);
#ifdef _LINUX
int ret = sem_init(&_sem, 0, 0);
CHECK_EQ(0, ret);
#endif
}
~Semaphore() {
#ifdef __LINUX__
CHECK_EQ(sem_destroy(&_sem), 0);
#ifdef _LINUX
int ret = sem_destroy(&_sem);
CHECK_EQ(0, ret);
#endif
}
void post() {
#ifdef __LINUX__
CHECK_EQ(sem_post(&_sem), 0);
#ifdef _LINUX
int ret = sem_post(&_sem);
CHECK_EQ(0, ret);
#endif
}
void wait() {
#ifdef __LINUX__
CHECK_EQ(ignore_signal_call(sem_wait, &_sem), 0);
#ifdef _LINUX
int ret = ignore_signal_call(sem_wait, &_sem);
CHECK_EQ(0, ret);
#endif
}
bool try_wait() {
int err = 0;
#ifdef __LINUX__
CHECK((err = ignore_signal_call(sem_trywait, &_sem),
err == 0 || errno == EAGAIN));
#ifdef _LINUX
err = ignore_signal_call(sem_trywait, &_sem);
CHECK_EQ(true, (err == 0 || errno == EAGAIN));
#endif
return err == 0;
}
private:
#ifdef __LINUX__
#ifdef _LINUX
sem_t _sem;
#endif
};
class WaitGroup {
public:
WaitGroup() {}
void clear() {
std::lock_guard<std::mutex> lock(mutex_);
counter_ = 0;
cond_.notify_all();
}
void add(int delta) {
if (delta == 0) {
return;
}
std::lock_guard<std::mutex> lock(mutex_);
counter_ += delta;
if (counter_ == 0) {
cond_.notify_all();
}
}
void done() { add(-1); }
void wait() {
std::unique_lock<std::mutex> lock(mutex_);
while (counter_ != 0) {
cond_.wait(lock);
}
}
int count(void) {
std::unique_lock<std::mutex> lock(mutex_);
return counter_;
}
private:
std::mutex mutex_;
std::condition_variable cond_;
int counter_ = 0;
};
} // namespace framework
} // namespace paddle
......@@ -2717,6 +2717,16 @@ void SlotRecordInMemoryDataFeed::DoWalkandSage() {
}
#endif
void SlotRecordInMemoryDataFeed::DumpWalkPath(std::string dump_path,
size_t dump_rate) {
VLOG(3) << "INTO SlotRecordInMemoryDataFeed::DumpWalkPath";
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
std::string path =
string::format_string("%s/part-%03d", dump_path.c_str(), thread_id_);
gpu_graph_data_generator_.DumpWalkPath(path, dump_rate);
#endif
}
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
int offset_cols_size = (ins_num + 1);
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/phi/kernels/gpu/graph_reindex_funcs.h"
#include "paddle/phi/kernels/graph_reindex_kernel.h"
......@@ -2620,12 +2621,12 @@ int GraphDataGenerator::FillWalkBufMultiPath() {
if (!sage_mode_) {
uint64_t h_uniq_node_num = CopyUniqueNodes();
VLOG(0) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_
VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_
<< ", d_walk_offset:" << i << ", total_rows:" << total_row_
<< ", h_uniq_node_num:" << h_uniq_node_num
<< ", total_samples:" << total_samples;
} else {
VLOG(0) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_
VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_
<< ", d_walk_offset:" << i << ", total_rows:" << total_row_
<< ", total_samples:" << total_samples;
}
......@@ -2938,6 +2939,42 @@ void GraphDataGenerator::SetConfig(
}
}
void GraphDataGenerator::DumpWalkPath(std::string dump_path, size_t dump_rate) {
#ifdef _LINUX
PADDLE_ENFORCE_LT(
dump_rate,
10000000,
platform::errors::InvalidArgument(
"dump_rate can't be large than 10000000. Please check the dump "
"rate[1, 10000000]"));
PADDLE_ENFORCE_GT(dump_rate,
1,
platform::errors::InvalidArgument(
"dump_rate can't be less than 1. Please check "
"the dump rate[1, 10000000]"));
int err_no = 0;
std::shared_ptr<FILE> fp = fs_open_append_write(dump_path, &err_no, "");
uint64_t *h_walk = new uint64_t[buf_size_];
uint64_t *walk = reinterpret_cast<uint64_t *>(d_walk_->ptr());
cudaMemcpy(
h_walk, walk, buf_size_ * sizeof(uint64_t), cudaMemcpyDeviceToHost);
VLOG(1) << "DumpWalkPath all buf_size_:" << buf_size_;
std::string ss = "";
size_t write_count = 0;
for (int xx = 0; xx < buf_size_ / dump_rate; xx += walk_len_) {
ss = "";
for (int yy = 0; yy < walk_len_; yy++) {
ss += std::to_string(h_walk[xx + yy]) + "-";
}
write_count = fwrite_unlocked(ss.data(), 1, ss.length(), fp.get());
if (write_count != ss.length()) {
VLOG(1) << "dump walk path" << ss << " failed";
}
write_count = fwrite_unlocked("\n", 1, 1, fp.get());
}
#endif
}
} // namespace framework
} // namespace paddle
#endif
......@@ -940,6 +940,7 @@ class GraphDataGenerator {
void ResetPathNum() { total_row_ = 0; }
void ResetEpochFinish() { epoch_finish_ = false; }
void ClearSampleState();
void DumpWalkPath(std::string dump_path, size_t dump_rate);
void SetDeviceKeys(std::vector<uint64_t>* device_keys, int type) {
// type_to_index_[type] = h_device_keys_.size();
// h_device_keys_.push_back(device_keys);
......@@ -1211,6 +1212,11 @@ class DataFeed {
}
virtual const paddle::platform::Place& GetPlace() const { return place_; }
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) {
PADDLE_THROW(platform::errors::Unimplemented(
"This function(DumpWalkPath) is not implemented."));
}
protected:
// The following three functions are used to check if it is executed in this
// order:
......@@ -1820,6 +1826,7 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
virtual void InitGraphTrainResource(void);
virtual void DoWalkandSage();
#endif
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate);
float sample_rate_ = 1.0f;
int use_slot_size_ = 0;
......
......@@ -657,6 +657,26 @@ void DatasetImpl<T>::LocalShuffle() {
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::DumpWalkPath(std::string dump_path, size_t dump_rate) {
VLOG(3) << "DatasetImpl<T>::DumpWalkPath() begin";
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
std::vector<std::thread> dump_threads;
if (gpu_graph_mode_) {
for (int64_t i = 0; i < thread_num_; ++i) {
dump_threads.push_back(
std::thread(&paddle::framework::DataFeed::DumpWalkPath,
readers_[i].get(),
dump_path,
dump_rate));
}
for (std::thread& t : dump_threads) {
t.join();
}
}
#endif
}
// do tdm sample
void MultiSlotDataset::TDMSample(const std::string tree_name,
const std::string tree_path,
......
......@@ -172,6 +172,8 @@ class Dataset {
virtual void SetPassId(uint32_t pass_id) = 0;
virtual uint32_t GetPassID() = 0;
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) = 0;
protected:
virtual int ReceiveFromClient(int msg_type,
int client_id,
......@@ -265,6 +267,7 @@ class DatasetImpl : public Dataset {
virtual void SetFleetSendSleepSeconds(int seconds);
virtual std::vector<std::string> GetSlots();
virtual bool GetEpochFinish();
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate);
std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
return multi_output_channel_;
......
......@@ -285,6 +285,8 @@ class HogwildWorker : public CPUWorkerBase {
protected:
void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program);
// check batch num
bool CheckBatchNum(int flag);
std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_;
......@@ -294,7 +296,7 @@ class HogwildWorker : public CPUWorkerBase {
std::vector<std::string> skip_ops_;
std::map<std::string, int> stat_var_name_map_;
static std::atomic<bool> quit_flag_;
// static bool quit_flag_2;
phi::DenseTensor sync_stat_;
};
class DownpourWorker : public HogwildWorker {
......
......@@ -29,7 +29,7 @@ if(WITH_HETERPS)
nv_library(
ps_gpu_wrapper
SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc
DEPS heter_ps gloo_wrapper ps_framework_proto graph_gpu_wrapper
DEPS heter_ps gloo_wrapper ps_framework_proto graph_gpu_wrapper fleet
${BRPC_DEPS})
else()
nv_library(
......
......@@ -35,6 +35,14 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_PSLIB
#define CONV2FEATURE_PTR(ptr) \
reinterpret_cast<paddle::ps::DownpourFixedFeatureValue**>(ptr)
#else
#define CONV2FEATURE_PTR(ptr) \
reinterpret_cast<paddle::distributed::FixedFeatureValue*>(ptr)
#endif
namespace paddle {
namespace framework {
......
......@@ -46,7 +46,8 @@ if(WITH_GPU)
hashtable_kernel
heter_ps
${HETERPS_DEPS}
graph_gpu_ps)
graph_gpu_ps
fleet_wrapper)
nv_test(
test_cpu_query
SRCS test_cpu_query.cu
......
......@@ -28,15 +28,24 @@ DECLARE_double(gpugraph_hbm_table_load_factor);
namespace paddle {
namespace framework {
enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };
typedef paddle::distributed::GraphTableType GraphTableType;
class GpuPsGraphTable
: public HeterComm<uint64_t, uint64_t, int, CommonFeatureValueAccessor> {
public:
int get_table_offset(int gpu_id, GraphTableType type, int idx) const {
inline int get_table_offset(int gpu_id, GraphTableType type, int idx) const {
int type_id = type;
return gpu_id * (graph_table_num_ + feature_table_num_) +
type_id * graph_table_num_ + idx;
}
inline int get_graph_list_offset(int gpu_id, int edge_idx) const {
return gpu_id * graph_table_num_ + edge_idx;
}
inline int get_graph_fea_list_offset(int gpu_id) const {
return gpu_id * feature_table_num_;
}
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource,
int graph_table_num)
: HeterComm<uint64_t, uint64_t, int, CommonFeatureValueAccessor>(
......@@ -83,8 +92,6 @@ class GpuPsGraphTable
void clear_feature_info(int index);
void build_graph_from_cpu(const std::vector<GpuPsCommGraph> &cpu_node_list,
int idx);
void build_graph_fea_from_cpu(
const std::vector<GpuPsCommGraphFea> &cpu_node_list, int idx);
NodeQueryResult graph_node_sample(int gpu_id, int sample_size);
NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q,
bool cpu_switch,
......
......@@ -686,17 +686,13 @@ __global__ void node_query_example(GpuPsCommGraph graph,
void GpuPsGraphTable::clear_feature_info(int gpu_id) {
int idx = 0;
if (idx >= feature_table_num_) return;
int offset = get_table_offset(gpu_id, GraphTableType::FEATURE_TABLE, idx);
if (offset < tables_.size()) {
delete tables_[offset];
tables_[offset] = NULL;
}
int graph_fea_idx = gpu_id * feature_table_num_ + idx;
if (graph_fea_idx >= gpu_graph_fea_list_.size()) {
return;
}
int graph_fea_idx = get_graph_fea_list_offset(gpu_id);
auto& graph = gpu_graph_fea_list_[graph_fea_idx];
if (graph.feature_list != NULL) {
cudaFree(graph.feature_list);
......@@ -714,16 +710,12 @@ void GpuPsGraphTable::reset_feature_info(int gpu_id,
size_t capacity,
size_t feature_size) {
int idx = 0;
if (idx >= feature_table_num_) return;
int offset = get_table_offset(gpu_id, GraphTableType::FEATURE_TABLE, idx);
if (offset < tables_.size()) {
delete tables_[offset];
tables_[offset] = new Table(capacity);
}
int graph_fea_idx = gpu_id * feature_table_num_ + idx;
if (graph_fea_idx >= gpu_graph_fea_list_.size()) {
return;
}
int graph_fea_idx = get_graph_fea_list_offset(gpu_id);
auto& graph = gpu_graph_fea_list_[graph_fea_idx];
graph.node_list = NULL;
if (graph.feature_list == NULL) {
......@@ -753,7 +745,7 @@ void GpuPsGraphTable::clear_graph_info(int gpu_id, int idx) {
delete tables_[offset];
tables_[offset] = NULL;
}
auto& graph = gpu_graph_list_[gpu_id * graph_table_num_ + idx];
auto& graph = gpu_graph_list_[get_graph_list_offset(gpu_id, idx)];
if (graph.neighbor_list != NULL) {
cudaFree(graph.neighbor_list);
graph.neighbor_list = nullptr;
......@@ -780,7 +772,7 @@ void GpuPsGraphTable::build_graph_fea_on_single_gpu(const GpuPsCommGraphFea& g,
size_t capacity = std::max((uint64_t)1, g.node_size) / load_factor_;
reset_feature_info(gpu_id, capacity, g.feature_size);
int ntype_id = 0;
int offset = gpu_id * feature_table_num_ + ntype_id;
int offset = get_graph_fea_list_offset(gpu_id);
int table_offset =
get_table_offset(gpu_id, GraphTableType::FEATURE_TABLE, ntype_id);
if (g.node_size > 0) {
......@@ -828,7 +820,7 @@ GpuPsGraphTable::get_edge_type_graph(int gpu_id, int edge_type_len) {
GpuPsCommGraph graphs[edge_type_len]; // NOLINT
for (int idx = 0; idx < edge_type_len; idx++) {
int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx);
int offset = i * graph_table_num_ + idx;
int offset = get_graph_list_offset(i, idx);
graphs[idx] = gpu_graph_list_[offset];
}
auto d_commgraph_mem = memory::AllocShared(
......@@ -856,13 +848,14 @@ In this function, memory is allocated on each gpu to save the graphs,
gpu i saves the ith graph from cpu_graph_list
*/
void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g,
int i,
int idx) {
clear_graph_info(i, idx);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
int offset = i * graph_table_num_ + idx;
int gpu_id,
int edge_idx) {
clear_graph_info(gpu_id, edge_idx);
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
int offset = get_graph_list_offset(gpu_id, edge_idx);
gpu_graph_list_[offset] = GpuPsCommGraph();
int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx);
int table_offset =
get_table_offset(gpu_id, GraphTableType::EDGE_TABLE, edge_idx);
size_t capacity = std::max((uint64_t)1, (uint64_t)g.node_size) / load_factor_;
tables_[table_offset] = new Table(capacity);
if (g.node_size > 0) {
......@@ -875,7 +868,7 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g,
cudaMemcpyHostToDevice));
}
build_ps(i,
build_ps(gpu_id,
g.node_list,
reinterpret_cast<uint64_t*>(g.node_info_list),
g.node_size,
......@@ -884,7 +877,7 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g,
table_offset);
gpu_graph_list_[offset].node_size = g.node_size;
} else {
build_ps(i, NULL, NULL, 0, 1024, 8, table_offset);
build_ps(gpu_id, NULL, NULL, 0, 1024, 8, table_offset);
gpu_graph_list_[offset].node_list = NULL;
gpu_graph_list_[offset].node_size = 0;
}
......@@ -897,7 +890,7 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g,
"ailed to allocate memory for graph on gpu "));
VLOG(0) << "sucessfully allocate " << g.neighbor_size * sizeof(uint64_t)
<< " bytes of memory for graph-edges on gpu "
<< resource_->dev_id(i);
<< resource_->dev_id(gpu_id);
CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].neighbor_list,
g.neighbor_list,
g.neighbor_size * sizeof(uint64_t),
......@@ -907,78 +900,13 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g,
gpu_graph_list_[offset].neighbor_list = NULL;
gpu_graph_list_[offset].neighbor_size = 0;
}
VLOG(0) << " gpu node_neighbor info card: " << i << " ,node_size is "
VLOG(0) << " gpu node_neighbor info card: " << gpu_id << " ,node_size is "
<< gpu_graph_list_[offset].node_size << ", neighbor_size is "
<< gpu_graph_list_[offset].neighbor_size;
}
void GpuPsGraphTable::build_graph_fea_from_cpu(
const std::vector<GpuPsCommGraphFea>& cpu_graph_fea_list, int ntype_id) {
PADDLE_ENFORCE_EQ(
cpu_graph_fea_list.size(),
resource_->total_device(),
platform::errors::InvalidArgument("the cpu node list size doesn't match "
"the number of gpu on your machine."));
clear_feature_info(ntype_id);
for (int i = 0; i < cpu_graph_fea_list.size(); i++) {
int table_offset =
get_table_offset(i, GraphTableType::FEATURE_TABLE, ntype_id);
int offset = i * feature_table_num_ + ntype_id;
platform::CUDADeviceGuard guard(resource_->dev_id(i));
gpu_graph_fea_list_[offset] = GpuPsCommGraphFea();
tables_[table_offset] = new Table(
std::max((uint64_t)1, (uint64_t)cpu_graph_fea_list[i].node_size) /
load_factor_);
if (cpu_graph_fea_list[i].node_size > 0) {
build_ps(i,
cpu_graph_fea_list[i].node_list,
reinterpret_cast<uint64_t*>(cpu_graph_fea_list[i].fea_info_list),
cpu_graph_fea_list[i].node_size,
1024,
8,
table_offset);
gpu_graph_fea_list_[offset].node_size = cpu_graph_fea_list[i].node_size;
} else {
build_ps(i, NULL, NULL, 0, 1024, 8, table_offset);
gpu_graph_fea_list_[offset].node_list = NULL;
gpu_graph_fea_list_[offset].node_size = 0;
}
if (cpu_graph_fea_list[i].feature_size) {
// TODO
CUDA_CHECK(
cudaMalloc(&gpu_graph_fea_list_[offset].feature_list,
cpu_graph_fea_list[i].feature_size * sizeof(uint64_t)));
CUDA_CHECK(
cudaMemcpy(gpu_graph_fea_list_[offset].feature_list,
cpu_graph_fea_list[i].feature_list,
cpu_graph_fea_list[i].feature_size * sizeof(uint64_t),
cudaMemcpyHostToDevice));
// TODO
CUDA_CHECK(
cudaMalloc(&gpu_graph_fea_list_[offset].slot_id_list,
cpu_graph_fea_list[i].feature_size * sizeof(uint8_t)));
CUDA_CHECK(
cudaMemcpy(gpu_graph_fea_list_[offset].slot_id_list,
cpu_graph_fea_list[i].slot_id_list,
cpu_graph_fea_list[i].feature_size * sizeof(uint8_t),
cudaMemcpyHostToDevice));
gpu_graph_fea_list_[offset].feature_size =
cpu_graph_fea_list[i].feature_size;
} else {
gpu_graph_fea_list_[offset].feature_list = NULL;
gpu_graph_fea_list_[offset].slot_id_list = NULL;
gpu_graph_fea_list_[offset].feature_size = 0;
}
}
cudaDeviceSynchronize();
}
void GpuPsGraphTable::build_graph_from_cpu(
const std::vector<GpuPsCommGraph>& cpu_graph_list, int idx) {
const std::vector<GpuPsCommGraph>& cpu_graph_list, int edge_idx) {
VLOG(0) << "in build_graph_from_cpu cpu_graph_list size = "
<< cpu_graph_list.size();
PADDLE_ENFORCE_EQ(
......@@ -986,10 +914,11 @@ void GpuPsGraphTable::build_graph_from_cpu(
resource_->total_device(),
platform::errors::InvalidArgument("the cpu node list size doesn't match "
"the number of gpu on your machine."));
clear_graph_info(idx);
clear_graph_info(edge_idx);
for (int i = 0; i < cpu_graph_list.size(); i++) {
int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx);
int offset = i * graph_table_num_ + idx;
int table_offset =
get_table_offset(i, GraphTableType::EDGE_TABLE, edge_idx);
int offset = get_graph_list_offset(i, edge_idx);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
gpu_graph_list_[offset] = GpuPsCommGraph();
tables_[table_offset] =
......@@ -1178,7 +1107,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
platform::CUDADeviceGuard guard(resource_->dev_id(i));
// If not found, val is -1.
int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx);
int offset = i * graph_table_num_ + idx;
int offset = get_graph_list_offset(i, idx);
tables_[table_offset]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<uint64_t*>(node.val_storage),
static_cast<size_t>(h_right[i] - h_left[i] + 1),
......@@ -1520,7 +1449,10 @@ NeighborSampleResultV2 GpuPsGraphTable::graph_neighbor_sample_all_edge_type(
reinterpret_cast<GpuPsNodeInfo*>(node.val_storage);
for (int idx = 0; idx < edge_type_len; idx++) {
int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx);
int offset = i * graph_table_num_ + idx;
int offset = get_graph_list_offset(i, idx);
if (tables_[table_offset] == NULL) {
continue;
}
tables_[table_offset]->get(
reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<uint64_t*>(node_info_base + idx * shard_len),
......@@ -1732,7 +1664,7 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id,
return y2 - x2;
};
int offset = gpu_id * graph_table_num_ + idx;
int offset = get_graph_list_offset(gpu_id, idx);
const auto& graph = gpu_graph_list_[offset];
if (graph.node_size == 0) {
return result;
......@@ -1932,7 +1864,7 @@ int GpuPsGraphTable::get_feature_info_of_nodes(
sizeof(uint32_t) * shard_len[i],
cudaMemcpyDeviceToDevice,
resource_->remote_stream(i, gpu_id)));
int offset = i * feature_table_num_;
int offset = get_graph_fea_list_offset(i);
auto graph = gpu_graph_fea_list_[offset];
uint64_t* feature_array = reinterpret_cast<uint64_t*>(
......@@ -2185,7 +2117,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id,
static_cast<size_t>(h_right[i] - h_left[i] + 1),
resource_->remote_stream(i, gpu_id));
int offset = i * feature_table_num_;
int offset = get_graph_fea_list_offset(i);
auto graph = gpu_graph_fea_list_[offset];
GpuPsFeaInfo* val_array = reinterpret_cast<GpuPsFeaInfo*>(node.val_storage);
......
......@@ -69,7 +69,7 @@ void GraphGpuWrapper::init_conf(const std::string &first_node_type,
VLOG(2) << "edge_to_id[" << edge << "] = " << iter->second;
meta_path_[i].push_back(iter->second);
if (edge_to_node_map_.find(iter->second) == edge_to_node_map_.end()) {
auto nodes = paddle::string::split_string<std::string>(edge, "2");
auto nodes = get_ntype_from_etype(edge);
uint64_t src_node_id = node_to_id.find(nodes[0])->second;
uint64_t dst_node_id = node_to_id.find(nodes[1])->second;
edge_to_node_map_[iter->second] = src_node_id << 32 | dst_node_id;
......@@ -81,7 +81,7 @@ void GraphGpuWrapper::init_conf(const std::string &first_node_type,
paddle::string::split_string<std::string>(excluded_train_pair, ";");
VLOG(2) << "excluded_train_pair[" << excluded_train_pair << "]";
for (auto &path : paths) {
auto nodes = paddle::string::split_string<std::string>(path, "2");
auto nodes = get_ntype_from_etype(path);
for (auto &node : nodes) {
auto iter = node_to_id.find(node);
PADDLE_ENFORCE_NE(iter,
......@@ -189,8 +189,7 @@ void GraphGpuWrapper::init_metapath(std::string cur_metapath,
edge_to_id.end(),
platform::errors::NotFound("(%s) is not found in edge_to_id.", node));
cur_parse_metapath_.push_back(iter->second);
auto etype_split = paddle::string::split_string<std::string>(node, "2");
std::string reverse_type = etype_split[1] + "2" + etype_split[0];
std::string reverse_type = get_reverse_etype(node);
iter = edge_to_id.find(reverse_type);
PADDLE_ENFORCE_NE(iter,
edge_to_id.end(),
......@@ -210,8 +209,7 @@ void GraphGpuWrapper::init_metapath(std::string cur_metapath,
std::vector<std::vector<uint64_t>> tmp_keys;
tmp_keys.resize(thread_num);
int first_node_idx;
std::string first_node =
paddle::string::split_string<std::string>(cur_metapath_, "2")[0];
std::string first_node = get_ntype_from_etype(nodes[0])[0];
auto it = node_to_id.find(first_node);
first_node_idx = it->second;
d_graph_train_total_keys_.resize(thread_num);
......@@ -285,53 +283,102 @@ void GraphGpuWrapper::clear_metapath_state() {
}
}
int GraphGpuWrapper::get_all_id(int type,
int GraphGpuWrapper::get_all_id(int table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
return reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->get_all_id(type, slice_num, output);
->cpu_graph_table_->get_all_id(
(GraphTableType)table_type, slice_num, output);
}
int GraphGpuWrapper::get_all_neighbor_id(
int type, int slice_num, std::vector<std::vector<uint64_t>> *output) {
GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
return reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->get_all_neighbor_id(type, slice_num, output);
->cpu_graph_table_->get_all_neighbor_id(table_type, slice_num, output);
}
int GraphGpuWrapper::get_all_id(int type,
int GraphGpuWrapper::get_all_id(int table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
return reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->get_all_id(type, idx, slice_num, output);
->cpu_graph_table_->get_all_id(
(GraphTableType)table_type, idx, slice_num, output);
}
int GraphGpuWrapper::get_all_neighbor_id(
int type,
GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
return reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->get_all_neighbor_id(type, idx, slice_num, output);
->cpu_graph_table_->get_all_neighbor_id(
table_type, idx, slice_num, output);
}
int GraphGpuWrapper::get_all_feature_ids(
int type,
GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
return reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->get_all_feature_ids(type, idx, slice_num, output);
->cpu_graph_table_->get_all_feature_ids(
table_type, idx, slice_num, output);
}
int GraphGpuWrapper::get_node_embedding_ids(
int slice_num, std::vector<std::vector<uint64_t>> *output) {
return (reinterpret_cast<GpuPsGraphTable *>(graph_table))
->cpu_graph_table_->get_node_embedding_ids(slice_num, output);
}
std::string GraphGpuWrapper::get_reverse_etype(std::string etype) {
auto etype_split = paddle::string::split_string<std::string>(etype, "2");
if (etype_split.size() == 2) {
std::string reverse_type = etype_split[1] + "2" + etype_split[0];
return reverse_type;
} else if (etype_split.size() == 3) {
std::string reverse_type =
etype_split[2] + "2" + etype_split[1] + "2" + etype_split[0];
return reverse_type;
} else {
PADDLE_THROW(platform::errors::Fatal(
"The format of edge type should be [src2dst] or [src2etype2dst], "
"but got [%s].",
etype));
}
}
std::vector<std::string> GraphGpuWrapper::get_ntype_from_etype(
std::string etype) {
std::vector<std::string> etype_split =
paddle::string::split_string<std::string>(etype, "2");
if (etype_split.size() == 2) {
return etype_split;
} else if (etype_split.size() == 3) {
auto iter = etype_split.erase(etype_split.begin() + 1);
return etype_split;
} else {
PADDLE_THROW(platform::errors::Fatal(
"The format of edge type should be [src2dst] or [src2etype2dst], "
"but got [%s].",
etype));
}
}
void GraphGpuWrapper::set_up_types(const std::vector<std::string> &edge_types,
const std::vector<std::string> &node_types) {
id_to_edge = edge_types;
edge_to_id.clear();
for (size_t table_id = 0; table_id < edge_types.size(); table_id++) {
int res = edge_to_id.size();
edge_to_id[edge_types[table_id]] = res;
}
id_to_feature = node_types;
node_to_id.clear();
for (size_t table_id = 0; table_id < node_types.size(); table_id++) {
int res = node_to_id.size();
node_to_id[node_types[table_id]] = res;
......@@ -404,13 +451,18 @@ void GraphGpuWrapper::load_edge_file(std::string name,
}
}
void GraphGpuWrapper::load_edge_file(std::string etype2files,
std::string graph_data_local_path,
int part_num,
bool reverse) {
void GraphGpuWrapper::load_edge_file(
std::string etype2files,
std::string graph_data_local_path,
int part_num,
bool reverse,
const std::vector<bool> &is_reverse_edge_map) {
reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->parse_edge_and_load(
etype2files, graph_data_local_path, part_num, reverse);
->cpu_graph_table_->parse_edge_and_load(etype2files,
graph_data_local_path,
part_num,
reverse,
is_reverse_edge_map);
}
int GraphGpuWrapper::load_node_file(std::string name, std::string filepath) {
......@@ -433,14 +485,20 @@ int GraphGpuWrapper::load_node_file(std::string ntype2files,
ntype2files, graph_data_local_path, part_num);
}
void GraphGpuWrapper::load_node_and_edge(std::string etype2files,
std::string ntype2files,
std::string graph_data_local_path,
int part_num,
bool reverse) {
void GraphGpuWrapper::load_node_and_edge(
std::string etype2files,
std::string ntype2files,
std::string graph_data_local_path,
int part_num,
bool reverse,
const std::vector<bool> &is_reverse_edge_map) {
reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->load_node_and_edge_file(
etype2files, ntype2files, graph_data_local_path, part_num, reverse);
->cpu_graph_table_->load_node_and_edge_file(etype2files,
ntype2files,
graph_data_local_path,
part_num,
reverse,
is_reverse_edge_map);
}
void GraphGpuWrapper::add_table_feat_conf(std::string table_name,
......@@ -514,28 +572,29 @@ void GraphGpuWrapper::finalize() {
reinterpret_cast<GpuPsGraphTable *>(graph_table)->show_table_collisions();
}
void GraphGpuWrapper::upload_batch(int type,
int idx,
// edge table
void GraphGpuWrapper::upload_batch(int table_type,
int slice_num,
const std::string &edge_type) {
VLOG(0) << "begin upload edge, type[" << edge_type << "]";
VLOG(0) << "begin upload edge, etype[" << edge_type << "]";
auto iter = edge_to_id.find(edge_type);
idx = iter->second;
VLOG(2) << "cur edge: " << edge_type << ",idx: " << idx;
int edge_idx = iter->second;
VLOG(2) << "cur edge: " << edge_type << ", edge_idx: " << edge_idx;
std::vector<std::vector<uint64_t>> ids;
reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->get_all_id(type, idx, slice_num, &ids);
->cpu_graph_table_->get_all_id(
(GraphTableType)table_type, edge_idx, slice_num, &ids);
debug_gpu_memory_info("upload_batch node start");
GpuPsGraphTable *g = reinterpret_cast<GpuPsGraphTable *>(graph_table);
std::vector<std::future<int>> tasks;
for (int i = 0; i < ids.size(); i++) {
tasks.push_back(upload_task_pool->enqueue([&, i, idx, this]() -> int {
for (int i = 0; i < slice_num; i++) {
tasks.push_back(upload_task_pool->enqueue([&, i, edge_idx, this]() -> int {
VLOG(0) << "begin make_gpu_ps_graph, node_id[" << i << "]_size["
<< ids[i].size() << "]";
GpuPsCommGraph sub_graph =
g->cpu_graph_table_->make_gpu_ps_graph(idx, ids[i]);
g->build_graph_on_single_gpu(sub_graph, i, idx);
g->cpu_graph_table_->make_gpu_ps_graph(edge_idx, ids[i]);
g->build_graph_on_single_gpu(sub_graph, i, edge_idx);
sub_graph.release_on_cpu();
VLOG(1) << "sub graph on gpu " << i << " is built";
return 0;
......@@ -546,8 +605,10 @@ void GraphGpuWrapper::upload_batch(int type,
}
// feature table
void GraphGpuWrapper::upload_batch(int type, int slice_num, int slot_num) {
if (type == 1 &&
void GraphGpuWrapper::upload_batch(int table_type,
int slice_num,
int slot_num) {
if (table_type == GraphTableType::FEATURE_TABLE &&
(FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::
MEM_EMB_FEATURE_AND_GPU_GRAPH ||
FLAGS_gpugraph_storage_mode == paddle::framework::GpuGraphStorageMode::
......@@ -556,11 +617,12 @@ void GraphGpuWrapper::upload_batch(int type, int slice_num, int slot_num) {
}
std::vector<std::vector<uint64_t>> node_ids;
reinterpret_cast<GpuPsGraphTable *>(graph_table)
->cpu_graph_table_->get_all_id(type, slice_num, &node_ids);
->cpu_graph_table_->get_all_id(
(GraphTableType)table_type, slice_num, &node_ids);
debug_gpu_memory_info("upload_batch feature start");
GpuPsGraphTable *g = reinterpret_cast<GpuPsGraphTable *>(graph_table);
std::vector<std::future<int>> tasks;
for (int i = 0; i < node_ids.size(); i++) {
for (int i = 0; i < slice_num; i++) {
tasks.push_back(upload_task_pool->enqueue([&, i, this]() -> int {
VLOG(0) << "begin make_gpu_ps_graph_fea, node_ids[" << i << "]_size["
<< node_ids[i].size() << "]";
......@@ -638,7 +700,7 @@ void GraphGpuWrapper::get_node_degree(
uint64_t *key,
int len,
std::shared_ptr<phi::Allocation> node_degree) {
return ((GpuPsGraphTable *)graph_table)
return (reinterpret_cast<GpuPsGraphTable *>(graph_table))
->get_node_degree(gpu_id, edge_idx, key, len, node_degree);
}
......@@ -830,7 +892,6 @@ std::string &GraphGpuWrapper::get_edge_type_size() {
->cpu_graph_table_->edge_type_size;
std::string delim = ";";
edge_type_size_str_ = paddle::string::join_strings(edge_type_size, delim);
std::cout << "edge_type_size_str: " << edge_type_size_str_ << std::endl;
return edge_type_size_str_;
}
......
......@@ -22,8 +22,11 @@
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
namespace paddle {
namespace framework {
#ifdef PADDLE_WITH_HETERPS
typedef paddle::distributed::GraphTableType GraphTableType;
enum GpuGraphStorageMode {
WHOLE_HBM = 1,
MEM_EMB_AND_GPU_GRAPH,
......@@ -47,13 +50,14 @@ class GraphGpuWrapper {
void finalize();
void set_device(std::vector<int> ids);
void init_service();
std::string get_reverse_etype(std::string etype);
std::vector<std::string> get_ntype_from_etype(std::string etype);
void set_up_types(const std::vector<std::string>& edge_type,
const std::vector<std::string>& node_type);
void upload_batch(int type,
int idx,
void upload_batch(int table_type,
int slice_num,
const std::string& edge_type);
void upload_batch(int type, int slice_num, int slot_num);
void upload_batch(int table_type, int slice_num, int slot_num);
std::vector<GpuPsCommGraphFea> get_sub_graph_fea(
std::vector<std::vector<uint64_t>>& node_ids, int slot_num); // NOLINT
void build_gpu_graph_fea(GpuPsCommGraphFea& sub_graph_fea, int i); // NOLINT
......@@ -65,7 +69,8 @@ class GraphGpuWrapper {
void load_edge_file(std::string etype2files,
std::string graph_data_local_path,
int part_num,
bool reverse);
bool reverse,
const std::vector<bool>& is_reverse_edge_map);
int load_node_file(std::string name, std::string filepath);
int load_node_file(std::string ntype2files,
......@@ -75,7 +80,8 @@ class GraphGpuWrapper {
std::string ntype2files,
std::string graph_data_local_path,
int part_num,
bool reverse);
bool reverse,
const std::vector<bool>& is_reverse_edge_map);
int32_t load_next_partition(int idx);
int32_t get_partition_num(int idx);
void load_node_weight(int type_id, int idx, std::string path);
......@@ -85,24 +91,26 @@ class GraphGpuWrapper {
void make_complementary_graph(int idx, int64_t byte_size);
void set_search_level(int level);
void init_search_level(int level);
int get_all_id(int type,
int get_all_id(int table_type,
int slice_num,
std::vector<std::vector<uint64_t>>* output);
int get_all_neighbor_id(int type,
int get_all_neighbor_id(GraphTableType table_type,
int slice_num,
std::vector<std::vector<uint64_t>>* output);
int get_all_id(int type,
int get_all_id(int table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>>* output);
int get_all_neighbor_id(int type,
int get_all_neighbor_id(GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>>* output);
int get_all_feature_ids(int type,
int get_all_feature_ids(GraphTableType table_type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>>* output);
int get_node_embedding_ids(int slice_num,
std::vector<std::vector<uint64_t>>* output);
NodeQueryResult query_node_list(int gpu_id,
int idx,
int start,
......
......@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thread>
#include <memory>
#include <vector>
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#if defined(PADDLE_WITH_CUDA)
......@@ -60,7 +59,7 @@ class HeterComm {
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
HeterComm(size_t capacity,
std::shared_ptr<HeterPsResource> resource,
const GPUAccessor& gpu_accessor);
GPUAccessor& gpu_accessor); // NOLINT
virtual ~HeterComm();
HeterComm(const HeterComm&) = delete;
HeterComm& operator=(const HeterComm&) = delete;
......@@ -299,10 +298,11 @@ class HeterComm {
struct LocalStorage {
LocalStorage() { sem_wait = std::make_unique<Semaphore>(); }
void init(int device_num, int dev_id) {
void init(int device_num, int dev_id, phi::Stream stream) {
place_ = platform::CUDAPlace(dev_id);
h_recv_offsets.resize(device_num);
h_fea_sizes.resize(device_num);
stream_ = stream;
}
template <typename T>
T* alloc_cache(const size_t& len,
......@@ -310,20 +310,31 @@ class HeterComm {
bool need_copy = false) {
size_t need_mem = len * sizeof(T);
if (alloc.get() == nullptr) {
alloc = memory::Alloc(place_, need_mem);
alloc = memory::Alloc(place_, need_mem, stream_);
} else if (need_mem > alloc->size()) {
if (need_copy) {
std::shared_ptr<memory::Allocation> tmp =
memory::Alloc(place_, need_mem);
cudaMemcpy(tmp->ptr(),
alloc->ptr(),
alloc->size(),
cudaMemcpyDeviceToDevice);
memory::Alloc(place_, need_mem, stream_);
#if defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpyAsync(tmp->ptr(), // output
alloc->ptr(),
alloc->size(),
cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_.id())));
#else
memory::Copy(place_,
tmp->ptr(),
place_,
alloc->ptr(),
alloc->size(),
reinterpret_cast<void*>(stream_.id()));
#endif
alloc.reset();
alloc = tmp;
} else {
alloc.reset();
alloc = memory::Alloc(place_, need_mem);
alloc = memory::Alloc(place_, need_mem, stream_);
}
}
return reinterpret_cast<T*>(alloc->ptr());
......@@ -344,6 +355,11 @@ class HeterComm {
d_merged_vals = all_grads;
d_merged_push_vals = local_grads;
}
void check(const size_t& len,
const size_t& value_bytes = sizeof(GradType)) {
CHECK_GE(all_keys_mem->size(), len);
CHECK_GE(all_grads_mem->size(), len * value_bytes);
}
void init_pull(const size_t& len) {
pull_res.h_recv_fea_num = len;
pull_res.d_restore_keys_idx = alloc_cache<uint32_t>(len, local_pull_idx);
......@@ -375,6 +391,7 @@ class HeterComm {
#elif defined(PADDLE_WITH_XPU_KP)
platform::XPUPlace place_;
#endif
phi::Stream stream_;
std::shared_ptr<memory::Allocation> all_keys_mem = nullptr;
std::shared_ptr<memory::Allocation> all_grads_mem = nullptr;
......@@ -554,8 +571,6 @@ class HeterComm {
size_t gather_sparse_keys_by_all2all(const int& gpu_id,
const size_t& fea_size,
const KeyType* d_in_keys,
KeyType* d_out_keys,
KeyType* d_tmp_keys,
const cudaStream_t& stream);
void scatter_sparse_vals_by_all2all(const int& gpu_id,
const size_t& fea_size,
......@@ -642,6 +657,19 @@ class HeterComm {
const cudaStream_t& stream);
// debug time
void print_debug_time(const int& gpu_id, bool force = false);
// alloc temp memory
template <typename T, typename TPlace, typename StreamType>
T* AllocCache(std::shared_ptr<memory::Allocation>* alloc,
const TPlace& place,
const size_t& byte_len,
const StreamType& stream) {
if (alloc->get() == nullptr || byte_len > (*alloc)->size()) {
alloc->reset();
auto id = phi::Stream(reinterpret_cast<phi::StreamId>(stream));
*alloc = memory::Alloc(place, byte_len, id);
}
return reinterpret_cast<T*>((*alloc)->ptr());
}
using Table = HashTable<KeyType, ValType>;
using PtrTable = HashTable<KeyType, float*>;
......
......@@ -54,7 +54,7 @@ template <typename GPUAccessor, template <typename T> class GPUOptimizer>
HeterPs<GPUAccessor, GPUOptimizer>::HeterPs(
size_t capacity,
std::shared_ptr<HeterPsResource> resource,
const GPUAccessor& gpu_accessor) {
GPUAccessor& gpu_accessor) { // NOLINT
comm_ = std::make_shared<HeterComm<FeatureKey, float*, float*, GPUAccessor>>(
capacity, resource, gpu_accessor);
opt_ = GPUOptimizer<GPUAccessor>(gpu_accessor);
......
......@@ -32,7 +32,7 @@ class HeterPs : public HeterPsBase {
HeterPs() {}
HeterPs(size_t capacity,
std::shared_ptr<HeterPsResource> resource,
const GPUAccessor& gpu_accessor);
GPUAccessor& gpu_accessor); // NOLINT
virtual ~HeterPs();
HeterPs(const HeterPs&) = delete;
HeterPs& operator=(const HeterPs&) = delete;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -377,6 +377,9 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("get_pass_id",
&framework::Dataset::GetPassID,
py::call_guard<py::gil_scoped_release>())
.def("dump_walk_path",
&framework::Dataset::DumpWalkPath,
py::call_guard<py::gil_scoped_release>());
py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册