未验证 提交 b3270adf 编写于 作者: Z zhaocaibei123 提交者: GitHub

统一ps refine (#41234)

* update name

* update name

* fix test

* fix fleet bind

* update name

* update name

* fix test

* fix gpups wrapper

* remove Push/Pull/Load/Save with context in client and wrapper base class

* fix

* fix
Co-authored-by: Nesythan <esythan@126.com>
上级 cb124156
...@@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService { ...@@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService {
DownpourPsClientService() {} DownpourPsClientService() {}
virtual ~DownpourPsClientService() {} virtual ~DownpourPsClientService() {}
virtual int32_t configure(PSClient *client, size_t rank_id) { virtual int32_t Configure(PSClient *client, size_t rank_id) {
_client = client; _client = client;
_rank = rank_id; _rank = rank_id;
return 0; return 0;
...@@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient { ...@@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient {
BrpcPsClient() {} BrpcPsClient() {}
virtual ~BrpcPsClient() { virtual ~BrpcPsClient() {
if (_running) { if (_running) {
flush(); Flush();
_running = false; _running = false;
} }
if (_async_push_dense_thread.joinable()) { if (_async_push_dense_thread.joinable()) {
...@@ -154,109 +154,98 @@ class BrpcPsClient : public PSClient { ...@@ -154,109 +154,98 @@ class BrpcPsClient : public PSClient {
_server_started = false; _server_started = false;
} }
} }
virtual int32_t create_client2client_connection( virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); int pserver_connect_timeout_ms,
std::future<int32_t> shrink(uint32_t table_id, int max_retry);
std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override; const std::string threshold) override;
std::future<int32_t> load(const std::string &epoch, std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> load(uint32_t table_id, const std::string &epoch, std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Load(const LoadSaveContext &load_context) override; std::future<int32_t> Save(const std::string &epoch,
std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch, std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> Save( std::future<int32_t> Clear() override;
const LoadSaveContext &save_context) override;
std::future<int32_t> clear() override;
std::future<int32_t> clear(uint32_t table_id) override;
std::future<int32_t> stop_server() override; std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> start_profiler() override; std::future<int32_t> StopServer() override;
std::future<int32_t> stop_profiler() override;
void finalize_worker() override; std::future<int32_t> StartProfiler() override;
std::future<int32_t> StopProfiler() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num, void FinalizeWorker() override;
size_t table_id);
virtual std::future<int32_t> push_dense_param(const Region *regions, virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t region_num, size_t table_id);
size_t table_id);
virtual std::future<int32_t> push_dense(const Region *regions, virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num, size_t table_id); size_t region_num,
void push_dense_task_consume(); size_t table_id);
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> Pull(RequestContext &pull_context) override; virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num, size_t table_id);
void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> Push(RequestContext &push_context) override; virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual std::future<int32_t> print_table_stat(uint32_t table_id); virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type); virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> Flush();
virtual std::future<int32_t> pull_geo_param(size_t table_id, std::future<int32_t> SendClient2ClientMsg(int msg_type, int to_client_id,
std::vector<float> *values, const std::string &msg) override;
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> flush();
std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
const std::string &msg) override;
// for local save sparse // for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id, virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path); const std::string &path);
void print_queue_size(); void PrintQueueSize();
void print_queue_size_thread(); void PrintQueueSizeThread();
protected: protected:
virtual size_t get_server_nums() { return _server_channels.size(); } virtual size_t GetServerNums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) { inline brpc::Channel *GetSparseChannel(size_t server_id) {
return _server_channels[server_id][0].get(); return _server_channels[server_id][0].get();
} }
inline brpc::Channel *get_dense_channel(size_t server_id) { inline brpc::Channel *GetDenseChannel(size_t server_id) {
return _server_channels[server_id][1].get(); return _server_channels[server_id][1].get();
} }
inline brpc::Channel *get_cmd_channel(size_t server_id) { inline brpc::Channel *GetCmdChannel(size_t server_id) {
return _server_channels[server_id][2].get(); return _server_channels[server_id][2].get();
} }
int32_t initialize() override; int32_t Initialize() override;
private: private:
// virtual int32_t initialize() override; inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1; return dense_dim_total / shard_num + 1;
} }
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id, std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param); const std::vector<std::string> &param);
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id, std::future<int32_t> SendSaveCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param); const std::vector<std::string> &param);
bool _running = false; bool _running = false;
bool _flushing = false; bool _flushing = false;
...@@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient { ...@@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient {
std::thread _print_thread; std::thread _print_thread;
int push_sparse_async_shard_merge( int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
ValueAccessor *accessor); ValueAccessor *accessor);
int push_sparse_async_shard_push( int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor); DownpourBrpcClosure *closure, ValueAccessor *accessor);
...@@ -292,36 +281,36 @@ class BrpcPsClient : public PSClient { ...@@ -292,36 +281,36 @@ class BrpcPsClient : public PSClient {
_client_channels; // client2client _client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>> std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server _server_channels; // client2server
std::future<int32_t> push_dense_raw_gradient(int table_id, std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data, float *total_send_data,
size_t total_send_data_size, size_t total_send_data_size,
void *done) override; void *done) override;
std::future<int32_t> push_sparse_raw_gradient(size_t table_id, std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num, size_t num, void *done) override;
void *done) override;
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
std::future<int32_t> push_sparse_raw_gradient_partial( const uint64_t *keys,
size_t table_id, const uint64_t *keys, const float **update_values, const float **update_values,
uint32_t num, void *done, int pserver_idx) override; uint32_t num, void *done,
int pserver_idx) override;
std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
const float **update_values, std::future<int32_t> PushSparseParam(size_t table_id, const uint64_t *keys,
size_t num, void *done) override; const float **update_values, size_t num,
std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys, void *done) override;
const float **update_values, std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
size_t num) override; const float **update_values,
void push_sparse_task_consume(); size_t num) override;
void PushSparseTaskConsume();
private: private:
int32_t start_client_service(); int32_t StartClientService();
void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data, float *total_send_data, size_t total_send_data_size,
size_t total_send_data_size, DownpourBrpcClosure *closure);
DownpourBrpcClosure *closure);
float _mae = 0; float _mae = 0;
float _mse = 0; float _mse = 0;
uint16_t _push_times = 0; uint16_t _push_times = 0;
......
...@@ -41,8 +41,8 @@ class BrpcPsServer : public PSServer { ...@@ -41,8 +41,8 @@ class BrpcPsServer : public PSServer {
public: public:
BrpcPsServer() {} BrpcPsServer() {}
virtual ~BrpcPsServer() {} virtual ~BrpcPsServer() {}
virtual uint64_t start(const std::string &ip, uint32_t port); virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t stop() { virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true; stoped_ = true;
cv_.notify_all(); cv_.notify_all();
...@@ -51,10 +51,10 @@ class BrpcPsServer : public PSServer { ...@@ -51,10 +51,10 @@ class BrpcPsServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
int32_t port(); int32_t Port();
private: private:
virtual int32_t initialize(); virtual int32_t Initialize();
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
bool stoped_ = false; bool stoped_ = false;
...@@ -71,7 +71,7 @@ typedef int32_t (BrpcPsService::*serviceHandlerFunc)( ...@@ -71,7 +71,7 @@ typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
class BrpcPsService : public PsBaseService { class BrpcPsService : public PsBaseService {
public: public:
virtual int32_t initialize() override; virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, const PsRequestMessage *request,
...@@ -79,50 +79,49 @@ class BrpcPsService : public PsBaseService { ...@@ -79,50 +79,49 @@ class BrpcPsService : public PsBaseService {
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
private: private:
int32_t initialize_shard_info(); int32_t InitializeShardInfo();
int32_t pull_dense(Table *table, const PsRequestMessage &request, int32_t PullDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense(Table *table, const PsRequestMessage &request, int32_t PushDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense_param(Table *table, const PsRequestMessage &request, int32_t PushDenseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request, int32_t PushSparseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullGeoParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse(Table *table, const PsRequestMessage &request, int32_t PushSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request, int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t shrink_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_one_table(Table *table, const PsRequestMessage &request, int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_all_table(Table *table, const PsRequestMessage &request, int32_t SaveOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request, int32_t SaveAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t ShrinkTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request, int32_t ClearOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t ClearAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request, int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_global_step(Table *table, const PsRequestMessage &request, int32_t PushGlobalStep(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
bool _is_initialize_shard_info; bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex; std::mutex _initialize_shard_mutex;
......
...@@ -39,7 +39,7 @@ inline double GetCurrentUS() { ...@@ -39,7 +39,7 @@ inline double GetCurrentUS() {
Communicator::Communicator() {} Communicator::Communicator() {}
void Communicator::init_gflag(const std::string &gflags) { void Communicator::InitGFlag(const std::string &gflags) {
VLOG(3) << "Init With Gflags:" << gflags; VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags); std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) { if (flags.size() < 1) {
...@@ -73,7 +73,7 @@ void Communicator::InitBrpcClient( ...@@ -73,7 +73,7 @@ void Communicator::InitBrpcClient(
} }
std::vector<uint64_t> Communicator::GetClientInfo() { std::vector<uint64_t> Communicator::GetClientInfo() {
std::vector<uint64_t> res = _ps_env.get_client_info(); std::vector<uint64_t> res = _ps_env.GetClientInfo();
for (auto rr : res) { for (auto rr : res) {
VLOG(2) << "Communicator::GetClientInfo " << rr; VLOG(2) << "Communicator::GetClientInfo " << rr;
} }
...@@ -82,7 +82,7 @@ std::vector<uint64_t> Communicator::GetClientInfo() { ...@@ -82,7 +82,7 @@ std::vector<uint64_t> Communicator::GetClientInfo() {
int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) { int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
int node = host_sign_list.size(); int node = host_sign_list.size();
return _ps_env.set_ps_clients(host_sign_list.data(), node); return _ps_env.SetPsClients(host_sign_list.data(), node);
} }
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
...@@ -114,7 +114,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, ...@@ -114,7 +114,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
} }
} }
auto status = auto status =
_worker_ptr->pull_dense(regions.data(), regions.size(), table_id); _worker_ptr->PullDense(regions.data(), regions.size(), table_id);
status.wait(); status.wait();
for (auto &t : varnames) { for (auto &t : varnames) {
...@@ -177,7 +177,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames, ...@@ -177,7 +177,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
} }
} }
auto status = auto status =
_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); _worker_ptr->PushDenseParam(regions.data(), regions.size(), table_id);
status.wait(); status.wait();
VLOG(4) << "RPC Send Dense Param " << table_id << " done!"; VLOG(4) << "RPC Send Dense Param " << table_id << " done!";
return; return;
...@@ -190,9 +190,9 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { ...@@ -190,9 +190,9 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
auto &var_names = ctx.origin_varnames; auto &var_names = ctx.origin_varnames;
auto &table_id = ctx.table_id; auto &table_id = ctx.table_id;
auto dense_data = std::make_shared<std::vector<float>>(); auto dense_data = std::make_shared<std::vector<float>>();
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->GetServerNums();
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(ctx.height_sections[0], request_call_num); DenseDimPerShard(ctx.height_sections[0], request_call_num);
dense_data->resize(num_per_shard * dense_data->resize(num_per_shard *
request_call_num); // accessor->update_dim() = 1 request_call_num); // accessor->update_dim() = 1
float *data = dense_data->data(); float *data = dense_data->data();
...@@ -222,8 +222,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { ...@@ -222,8 +222,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
closure->set_promise_value(ret); closure->set_promise_value(ret);
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->push_dense_raw_gradient( auto status = _worker_ptr->PushDenseRawGradient(table_id, data,
table_id, data, dense_data->size(), closure); dense_data->size(), closure);
status.wait(); status.wait();
return; return;
} }
...@@ -233,7 +233,7 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, ...@@ -233,7 +233,7 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
platform::RecordEvent record_event("Communicator->RpcSendSparseParam", platform::RecordEvent record_event("Communicator->RpcSendSparseParam",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<float *> push_g_vec; std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(varname); auto *send_var = scope.FindVar(varname);
...@@ -260,9 +260,9 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, ...@@ -260,9 +260,9 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto status = _worker_ptr->push_sparse_param( auto status = _worker_ptr->PushSparseParam(table_id, sparse_push_keys.data(),
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure); sparse_push_keys.size(), closure);
status.wait(); status.wait();
return; return;
} }
...@@ -272,7 +272,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -272,7 +272,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
platform::RecordEvent record_event("Communicator->RpcSendSparse", platform::RecordEvent record_event("Communicator->RpcSendSparse",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<uint64_t> sparse_push_keys; std::vector<uint64_t> sparse_push_keys;
std::vector<float *> push_g_vec; std::vector<float *> push_g_vec;
...@@ -313,7 +313,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -313,7 +313,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
closure->set_promise_value(ret); closure->set_promise_value(ret);
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->push_sparse_raw_gradient( auto status = _worker_ptr->PushSparseRawGradient(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure); sparse_push_keys.size(), closure);
status.wait(); status.wait();
...@@ -340,7 +340,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id, ...@@ -340,7 +340,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
bool training = true; bool training = true;
auto status = _worker_ptr->pull_sparse_param( auto status = _worker_ptr->PullSparseParam(
(float **)push_g_vec.data(), table_id, // NOLINT (float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training); sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait(); status.wait();
...@@ -376,11 +376,11 @@ void Communicator::RpcProfilerControl() { ...@@ -376,11 +376,11 @@ void Communicator::RpcProfilerControl() {
if (!do_server_profiler_ && platform::IsProfileEnabled()) { if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag // send profiler start flag
do_server_profiler_ = true; do_server_profiler_ = true;
auto start_status = _worker_ptr->start_profiler(); auto start_status = _worker_ptr->StartProfiler();
start_status.wait(); start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) { } else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag // send profiler end flag
auto stop_status = _worker_ptr->stop_profiler(); auto stop_status = _worker_ptr->StopProfiler();
stop_status.wait(); stop_status.wait();
do_server_profiler_ = false; do_server_profiler_ = false;
} }
...@@ -396,7 +396,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, ...@@ -396,7 +396,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
auto &table_id = ctx.table_id; auto &table_id = ctx.table_id;
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->GetServerNums();
auto &var_name = STEP_COUNTER; auto &var_name = STEP_COUNTER;
auto *out_var = send_scope->Var(var_name); auto *out_var = send_scope->Var(var_name);
...@@ -416,7 +416,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, ...@@ -416,7 +416,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto status = _worker_ptr->push_global_step(table_id, data, closure); auto status = _worker_ptr->PushGlobalStep(table_id, data, closure);
status.wait(); status.wait();
return; return;
} }
...@@ -605,8 +605,8 @@ void AsyncCommunicator::PullSparseToTensorSync( ...@@ -605,8 +605,8 @@ void AsyncCommunicator::PullSparseToTensorSync(
} }
} }
auto status = auto status =
_worker_ptr->pull_sparse(pull_result_ptr.data(), table_id, _worker_ptr->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.data(), fea_keys.size(), is_training); fea_keys.size(), is_training);
status.wait(); status.wait();
auto ret = status.get(); auto ret = status.get();
if (ret != 0) { if (ret != 0) {
...@@ -738,9 +738,9 @@ void AsyncCommunicator::PushSparseFromTensorAsync( ...@@ -738,9 +738,9 @@ void AsyncCommunicator::PushSparseFromTensorAsync(
this->Check(table_id), true, this->Check(table_id), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id)); "can not find table: %s, please check your config", table_id));
auto status = _worker_ptr->push_sparse(table_id, push_keys.data(), auto status = _worker_ptr->PushSparse(table_id, push_keys.data(),
(const float **)push_g_vec.data(), (const float **)push_g_vec.data(),
push_keys.size()); push_keys.size());
} }
void HalfAsyncCommunicator::MainThread() { void HalfAsyncCommunicator::MainThread() {
...@@ -813,7 +813,7 @@ void AsyncCommunicator::Stop() { ...@@ -813,7 +813,7 @@ void AsyncCommunicator::Stop() {
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
} else { } else {
// _worker_ptr->finalize_worker(); // _worker_ptr->FinalizeWorker();
VLOG(1) << "client finalize_worker done"; VLOG(1) << "client finalize_worker done";
if (recv_thread_) { if (recv_thread_) {
VLOG(1) << "stop recv thread"; VLOG(1) << "stop recv thread";
...@@ -1327,7 +1327,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, ...@@ -1327,7 +1327,7 @@ void GeoCommunicator::SendSparse(const std::string &varname,
closure->set_promise_value(ret); closure->set_promise_value(ret);
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->push_sparse_raw_gradient_partial( auto status = _worker_ptr->PushSparseRawGradientPartial(
table_id, (const uint64_t *)sparse_ids.data(), table_id, (const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx); (const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx);
status.wait(); status.wait();
...@@ -1345,7 +1345,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, ...@@ -1345,7 +1345,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
// 1. recv from pserver // 1. recv from pserver
std::vector<uint64_t> keys; std::vector<uint64_t> keys;
std::vector<float> values; std::vector<float> values;
auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx); auto status = _worker_ptr->PullGeoParam(table_id, &values, &keys, ep_idx);
status.wait(); status.wait();
std::string param = SplitedGradToParam(varname); std::string param = SplitedGradToParam(varname);
......
...@@ -299,7 +299,7 @@ class Communicator { ...@@ -299,7 +299,7 @@ class Communicator {
virtual void Barrier() {} virtual void Barrier() {}
virtual void BarrierWithTable(uint32_t barrier_type) { virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
rets.wait(); rets.wait();
int status = rets.get(); int status = rets.get();
PADDLE_ENFORCE_EQ(status, 0, PADDLE_ENFORCE_EQ(status, 0,
...@@ -310,7 +310,7 @@ class Communicator { ...@@ -310,7 +310,7 @@ class Communicator {
virtual void CreateC2CConnection(int pserver_timeout_ms, virtual void CreateC2CConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) { int max_retry) {
_worker_ptr->create_client2client_connection( _worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
} }
...@@ -379,12 +379,12 @@ class Communicator { ...@@ -379,12 +379,12 @@ class Communicator {
std::unordered_map<std::string, std::string> envs; std::unordered_map<std::string, std::string> envs;
// 计算每个shard 对 dense的存储量 // 计算每个shard 对 dense的存储量
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) { uint32_t shard_num) {
return dense_dim_total / shard_num + 1; return dense_dim_total / shard_num + 1;
} }
void init_gflag(const std::string &gflags); void InitGFlag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param; paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0; int servers_ = 0;
......
...@@ -40,7 +40,7 @@ struct PSHost { ...@@ -40,7 +40,7 @@ struct PSHost {
// |---ip---|---port---|--rank--| // |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-| // |-32bit--|--20bit---|--12bit-|
uint64_t serialize_to_uint64() { uint64_t SerializeToUint64() {
uint64_t host_label = 0; uint64_t host_label = 0;
host_label = inet_addr(ip.c_str()); host_label = inet_addr(ip.c_str());
host_label = host_label << 32; host_label = host_label << 32;
...@@ -49,7 +49,7 @@ struct PSHost { ...@@ -49,7 +49,7 @@ struct PSHost {
return host_label; return host_label;
} }
void parse_from_uint64(uint64_t host_label) { void ParseFromUint64(uint64_t host_label) {
static uint64_t rank_label_mask = (1L << 12) - 1; static uint64_t rank_label_mask = (1L << 12) - 1;
static uint64_t port_label_mask = (1L << 20) - 1; static uint64_t port_label_mask = (1L << 20) - 1;
rank = host_label & rank_label_mask; rank = host_label & rank_label_mask;
...@@ -58,17 +58,17 @@ struct PSHost { ...@@ -58,17 +58,17 @@ struct PSHost {
ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT
} }
std::string to_string() { std::string ToString() {
std::stringstream s; std::stringstream s;
s << "host: " << ip; s << "host: " << ip;
s << " port: " << port; s << " port: " << port;
s << " rank: " << rank; s << " rank: " << rank;
s << " uint: " << serialize_to_uint64(); s << " uint: " << SerializeToUint64();
return s.str(); return s.str();
} }
// for open source parameter server // for open source parameter server
std::string serialize_to_string() { std::string SerializeToString() {
std::stringstream s; std::stringstream s;
s << ip << ":"; s << ip << ":";
s << port << ":"; s << port << ":";
...@@ -76,16 +76,16 @@ struct PSHost { ...@@ -76,16 +76,16 @@ struct PSHost {
return s.str(); return s.str();
} }
void parse_from_string(std::string endpoint) { void ParseFromString(std::string endpoint) {
std::vector<std::string> endpoint_info; std::vector<std::string> endpoint_info;
string_split(endpoint, ':', &endpoint_info); StringSplit(endpoint, ':', &endpoint_info);
ip = endpoint_info[0]; ip = endpoint_info[0];
port = std::stoi(endpoint_info[1]); port = std::stoi(endpoint_info[1]);
rank = std::stoi(endpoint_info[2]); rank = std::stoi(endpoint_info[2]);
} }
void string_split(const std::string &str, char sep, void StringSplit(const std::string &str, char sep,
std::vector<std::string> *pieces, bool ignore_null = true) { std::vector<std::string> *pieces, bool ignore_null = true) {
pieces->clear(); pieces->clear();
if (str.empty()) { if (str.empty()) {
if (!ignore_null) { if (!ignore_null) {
...@@ -111,63 +111,60 @@ class PSEnvironment { ...@@ -111,63 +111,60 @@ class PSEnvironment {
explicit PSEnvironment() {} // NOLINT explicit PSEnvironment() {} // NOLINT
virtual ~PSEnvironment() {} virtual ~PSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
return 0; return 0;
} }
virtual int32_t set_ps_servers( virtual int32_t SetPsServers(
const std::vector<std::string> *host_endpoint_list, int node_num) { const std::vector<std::string> *host_endpoint_list, int node_num) {
return 0; return 0;
} }
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
return 0; return 0;
} }
virtual int32_t set_ps_clients(std::string *host_endpoint_list, virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
int node_num) {
return 0; return 0;
} }
virtual uint64_t get_local_host_sign() { return 0; } virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> get_ps_servers() const { return _ps_server_list; } virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t registe_ps_server(const std::string &ip, uint32_t port, virtual int32_t RegistePsServer(const std::string &ip, uint32_t port,
int32_t rank) { int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_server_list, return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
_ps_server_sign_set);
} }
virtual std::vector<PSHost> get_ps_clients() const { return _ps_client_list; } virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
virtual int32_t registe_ps_client(const std::string &ip, uint32_t port, virtual int32_t RegistePsClient(const std::string &ip, uint32_t port,
int32_t rank) { int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_client_list, return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
_ps_client_sign_set);
} }
virtual std::vector<uint64_t> get_client_info() { virtual std::vector<uint64_t> GetClientInfo() {
std::vector<uint64_t> client_info; std::vector<uint64_t> client_info;
for (auto &i : _ps_client_list) { for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_uint64()); client_info.push_back(i.SerializeToUint64());
} }
return client_info; return client_info;
} }
virtual std::vector<std::string> get_client_info(bool use_string_endpoint) { virtual std::vector<std::string> GetClientInfo(bool use_string_endpoint) {
if (use_string_endpoint) { if (use_string_endpoint) {
std::vector<std::string> client_info; std::vector<std::string> client_info;
for (auto &i : _ps_client_list) { for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_string()); client_info.push_back(i.SerializeToString());
} }
return client_info; return client_info;
} }
return {}; return {};
} }
virtual void set_trainers(int trainers) { trainers_ = trainers; } virtual void SetTrainers(int trainers) { trainers_ = trainers; }
virtual int get_trainers() { return trainers_; } virtual int GetTrainers() { return trainers_; }
protected: protected:
//注册一个host // NOLINT //注册一个host // NOLINT
virtual int32_t registe_ps_host( virtual int32_t RegistePsHost(
const std::string &ip, uint32_t port, int32_t rank, const std::string &ip, uint32_t port, int32_t rank,
std::vector<PSHost> &host_list, // NOLINT std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT std::unordered_set<uint64_t> &sign_set) { // NOLINT
...@@ -198,15 +195,15 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -198,15 +195,15 @@ class PaddlePSEnvironment : public PSEnvironment {
explicit PaddlePSEnvironment() {} // NOLINT explicit PaddlePSEnvironment() {} // NOLINT
virtual ~PaddlePSEnvironment() {} virtual ~PaddlePSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
_ps_server_list.clear(); _ps_server_list.clear();
_ps_server_sign_set.clear(); _ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) { for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) { if (host_sign_list[i] > 0) {
PSHost host; PSHost host;
host.parse_from_uint64(host_sign_list[i]); host.ParseFromUint64(host_sign_list[i]);
_ps_server_list.push_back(host); _ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.serialize_to_uint64()); _ps_server_sign_set.insert(host.SerializeToUint64());
} }
} }
std::sort( std::sort(
...@@ -215,14 +212,14 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -215,14 +212,14 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0; return 0;
} }
virtual int32_t set_ps_servers(const std::vector<std::string> *host_sign_list, virtual int32_t SetPsServers(const std::vector<std::string> *host_sign_list,
int node_num) { int node_num) {
_ps_server_list.clear(); _ps_server_list.clear();
_ps_server_sign_set.clear(); _ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) { for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") { if (host_sign_list->at(i) != "") {
PSHost host; PSHost host;
host.parse_from_string(host_sign_list->at(i)); host.ParseFromString(host_sign_list->at(i));
_ps_server_list.push_back(host); _ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.rank); _ps_server_sign_set.insert(host.rank);
} }
...@@ -233,15 +230,15 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -233,15 +230,15 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0; return 0;
} }
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
_ps_client_list.clear(); _ps_client_list.clear();
_ps_client_sign_set.clear(); _ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) { for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) { if (host_sign_list[i] > 0) {
PSHost host; PSHost host;
host.parse_from_uint64(host_sign_list[i]); host.ParseFromUint64(host_sign_list[i]);
_ps_client_list.push_back(host); _ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.serialize_to_uint64()); _ps_client_sign_set.insert(host.SerializeToUint64());
} }
} }
std::sort( std::sort(
...@@ -250,14 +247,14 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -250,14 +247,14 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0; return 0;
} }
virtual int32_t set_ps_clients(const std::vector<std::string> *host_sign_list, virtual int32_t SetPsClients(const std::vector<std::string> *host_sign_list,
int node_num) { int node_num) {
_ps_client_list.clear(); _ps_client_list.clear();
_ps_client_sign_set.clear(); _ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) { for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") { if (host_sign_list->at(i) != "") {
PSHost host; PSHost host;
host.parse_from_string(host_sign_list->at(i)); host.ParseFromString(host_sign_list->at(i));
_ps_client_list.push_back(host); _ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.rank); _ps_client_sign_set.insert(host.rank);
} }
...@@ -269,9 +266,9 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -269,9 +266,9 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0; return 0;
} }
virtual uint64_t get_local_host_sign() { virtual uint64_t GetLocalHostSign() {
if (_ps_client_list.size() > 0) { if (_ps_client_list.size() > 0) {
return _ps_client_list[0].serialize_to_uint64(); return _ps_client_list[0].SerializeToUint64();
} else { } else {
return 0; return 0;
} }
......
...@@ -135,8 +135,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -135,8 +135,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size()); ->add_params(joint_feature_name.c_str(), joint_feature_name.size());
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -169,8 +168,7 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) { ...@@ -169,8 +168,7 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
closure->request(server_index)->set_table_id(table_id); closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id); closure->request(server_index)->set_client_id(_client_id);
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index), rpc_stub.service(closure->cntl(server_index),
closure->request(server_index), closure->request(server_index),
...@@ -238,9 +236,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -238,9 +236,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
->add_params((char *)weighted, ->add_params((char *)weighted,
sizeof(bool) * is_weighted_bucket[request_idx].size()); sizeof(bool) * is_weighted_bucket[request_idx].size());
} }
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -292,9 +289,8 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -292,9 +289,8 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(), ->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num); sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -362,9 +358,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -362,9 +358,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool)); closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
; ;
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), rpc_stub.service(closure->cntl(0), closure->request(0),
closure->response(0), closure); closure->response(0), closure);
...@@ -464,9 +459,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -464,9 +459,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool)); ->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -506,8 +500,8 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes( ...@@ -506,8 +500,8 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure->request(0)->set_client_id(_client_id); closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&sample_size, sizeof(int));
; ;
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure); closure);
...@@ -541,8 +535,7 @@ std::future<int32_t> GraphBrpcClient::load_graph_split_config( ...@@ -541,8 +535,7 @@ std::future<int32_t> GraphBrpcClient::load_graph_split_config(
closure->request(server_index)->set_table_id(table_id); closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id); closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params(path); closure->request(server_index)->add_params(path);
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index), rpc_stub.service(closure->cntl(server_index),
closure->request(server_index), closure->request(server_index),
...@@ -581,8 +574,7 @@ std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache( ...@@ -581,8 +574,7 @@ std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
closure->request(server_index) closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t)); ->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t)); closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index), rpc_stub.service(closure->cntl(server_index),
closure->request(server_index), closure->request(server_index),
...@@ -624,8 +616,8 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list( ...@@ -624,8 +616,8 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure->request(0)->add_params((char *)&start, sizeof(int)); closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int)); closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int)); closure->request(0)->add_params((char *)&step, sizeof(int));
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure); closure);
...@@ -717,8 +709,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -717,8 +709,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size()); ->add_params(set_feature.c_str(), set_feature.size());
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -727,10 +718,10 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -727,10 +718,10 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
return fut; return fut;
} }
int32_t GraphBrpcClient::initialize() { int32_t GraphBrpcClient::Initialize() {
// set_shard_num(_config.shard_num()); // set_shard_num(_config.shard_num());
BrpcPsClient::initialize(); BrpcPsClient::Initialize();
server_size = get_server_nums(); server_size = GetServerNums();
graph_service = NULL; graph_service = NULL;
local_channel = NULL; local_channel = NULL;
return 0; return 0;
......
...@@ -97,12 +97,12 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -97,12 +97,12 @@ class GraphBrpcClient : public BrpcPsClient {
std::string path); std::string path);
virtual std::future<int32_t> remove_graph_node( virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<int64_t>& node_id_list); uint32_t table_id, std::vector<int64_t>& node_id_list);
virtual int32_t initialize(); virtual int32_t Initialize();
int get_shard_num() { return shard_num; } int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(int64_t id); int get_server_index_by_id(int64_t id);
void set_local_channel(int index) { void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index); this->local_channel = GetCmdChannel(index);
} }
void set_local_graph_service(GraphBrpcService* graph_service) { void set_local_graph_service(GraphBrpcService* graph_service) {
this->graph_service = graph_service; this->graph_service = graph_service;
......
...@@ -33,7 +33,7 @@ namespace distributed { ...@@ -33,7 +33,7 @@ namespace distributed {
return -1; \ return -1; \
} }
int32_t GraphBrpcServer::initialize() { int32_t GraphBrpcServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param(); auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) { if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter"; LOG(ERROR) << "miss service_class in ServerServiceParameter";
...@@ -48,7 +48,7 @@ int32_t GraphBrpcServer::initialize() { ...@@ -48,7 +48,7 @@ int32_t GraphBrpcServer::initialize() {
} }
_service.reset(service); _service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) { if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:" LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class(); << service_config.service_class();
return -1; return -1;
...@@ -61,11 +61,11 @@ int32_t GraphBrpcServer::initialize() { ...@@ -61,11 +61,11 @@ int32_t GraphBrpcServer::initialize() {
return 0; return 0;
} }
brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) { brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
return _pserver_channels[server_index].get(); return _pserver_channels[server_index].get();
} }
uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port); std::string ip_port = ip + ":" + std::to_string(port);
...@@ -73,20 +73,20 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { ...@@ -73,20 +73,20 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
brpc::ServerOptions options; brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency(); int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers(); auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads; options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) { if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port; LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
return 0; return 0;
} }
_environment->registe_ps_server(ip, port, _rank); _environment->RegistePsServer(ip, port, _rank);
return 0; return 0;
} }
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
this->rank = rank; this->rank = rank;
auto _env = environment(); auto _env = Environment();
brpc::ChannelOptions options; brpc::ChannelOptions options;
options.protocol = "baidu_std"; options.protocol = "baidu_std";
options.timeout_ms = 500000; options.timeout_ms = 500000;
...@@ -94,7 +94,7 @@ int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { ...@@ -94,7 +94,7 @@ int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
options.connect_timeout_ms = 10000; options.connect_timeout_ms = 10000;
options.max_retry = 3; options.max_retry = 3;
std::vector<PSHost> server_list = _env->get_ps_servers(); std::vector<PSHost> server_list = _env->GetPsServers();
_pserver_channels.resize(server_list.size()); _pserver_channels.resize(server_list.size());
std::ostringstream os; std::ostringstream os;
std::string server_ip_port; std::string server_ip_port;
...@@ -172,19 +172,18 @@ int32_t GraphBrpcService::remove_graph_node(Table *table, ...@@ -172,19 +172,18 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
((GraphTable *)table)->remove_graph_node(node_ids); ((GraphTable *)table)->remove_graph_node(node_ids);
return 0; return 0;
} }
int32_t GraphBrpcServer::port() { return _server.listen_address().port; } int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
int32_t GraphBrpcService::initialize() { int32_t GraphBrpcService::Initialize() {
_is_initialize_shard_info = false; _is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server; _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table; _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table; _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
_service_handler_map[PS_PRINT_TABLE_STAT] = _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
&GraphBrpcService::print_table_stat; _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier; _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler; _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list; _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] = _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
...@@ -207,21 +206,21 @@ int32_t GraphBrpcService::initialize() { ...@@ -207,21 +206,21 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] = _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
&GraphBrpcService::load_graph_split_config; &GraphBrpcService::load_graph_split_config;
// shard初始化,server启动后才可从env获取到server_list的shard信息 // shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info(); InitializeShardInfo();
return 0; return 0;
} }
int32_t GraphBrpcService::initialize_shard_info() { int32_t GraphBrpcService::InitializeShardInfo() {
if (!_is_initialize_shard_info) { if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex); std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) { if (_is_initialize_shard_info) {
return 0; return 0;
} }
server_size = _server->environment()->get_ps_servers().size(); server_size = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->table()); auto &table_map = *(_server->GetTable());
for (auto itr : table_map) { for (auto itr : table_map) {
itr.second->set_shard(_rank, server_size); itr.second->SetShard(_rank, server_size);
} }
_is_initialize_shard_info = true; _is_initialize_shard_info = true;
} }
...@@ -241,7 +240,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, ...@@ -241,7 +240,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
response->set_err_code(0); response->set_err_code(0);
response->set_err_msg(""); response->set_err_msg("");
auto *table = _server->table(request->table_id()); auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base); brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id()); auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) { if (itr == _service_handler_map.end()) {
...@@ -261,7 +260,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, ...@@ -261,7 +260,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
} }
} }
int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
...@@ -275,16 +274,16 @@ int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, ...@@ -275,16 +274,16 @@ int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
auto trainer_id = request.client_id(); auto trainer_id = request.client_id();
auto barrier_type = request.params(0); auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type); table->Barrier(trainer_id, barrier_type);
return 0; return 0;
} }
int32_t GraphBrpcService::print_table_stat(Table *table, int32_t GraphBrpcService::PrintTableStat(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat(); std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar; paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second; ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length()); std::string table_info(ar.Buffer(), ar.Length());
...@@ -293,10 +292,10 @@ int32_t GraphBrpcService::print_table_stat(Table *table, ...@@ -293,10 +292,10 @@ int32_t GraphBrpcService::print_table_stat(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::load_one_table(Table *table, int32_t GraphBrpcService::LoadOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
...@@ -304,20 +303,20 @@ int32_t GraphBrpcService::load_one_table(Table *table, ...@@ -304,20 +303,20 @@ int32_t GraphBrpcService::load_one_table(Table *table,
"PsRequestMessage.datas is requeired at least 2 for path & load_param"); "PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1; return -1;
} }
if (table->load(request.params(0), request.params(1)) != 0) { if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed"); set_response_code(response, -1, "table load failed");
return -1; return -1;
} }
return 0; return 0;
} }
int32_t GraphBrpcService::load_all_table(Table *table, int32_t GraphBrpcService::LoadAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
auto &table_map = *(_server->table()); auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) { for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) { if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed"; LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1; return -1;
} }
...@@ -325,13 +324,13 @@ int32_t GraphBrpcService::load_all_table(Table *table, ...@@ -325,13 +324,13 @@ int32_t GraphBrpcService::load_all_table(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::stop_server(Table *table, int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server; GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
std::thread t_stop([p_server]() { std::thread t_stop([p_server]() {
p_server->stop(); p_server->Stop();
LOG(INFO) << "Server Stoped"; LOG(INFO) << "Server Stoped";
}); });
p_server->export_cv()->notify_all(); p_server->export_cv()->notify_all();
...@@ -339,19 +338,19 @@ int32_t GraphBrpcService::stop_server(Table *table, ...@@ -339,19 +338,19 @@ int32_t GraphBrpcService::stop_server(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::stop_profiler(Table *table, int32_t GraphBrpcService::StopProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault, platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank)); string::Sprintf("server_%s_profile", _rank));
return 0; return 0;
} }
int32_t GraphBrpcService::start_profiler(Table *table, int32_t GraphBrpcService::StartProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU); platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0; return 0;
} }
...@@ -475,7 +474,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -475,7 +474,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
std::vector<int64_t> local_id; std::vector<int64_t> local_id;
std::vector<int> local_query_idx; std::vector<int> local_query_idx;
size_t rank = get_rank(); size_t rank = GetRank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
...@@ -589,9 +588,9 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -589,9 +588,9 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool)); ->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub( PsService_Stub rpc_stub(
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index)); ((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
// GraphPsService_Stub rpc_stub = // GraphPsService_Stub rpc_stub =
// getServiceStub(get_cmd_channel(server_index)); // getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
......
...@@ -31,10 +31,10 @@ class GraphBrpcServer : public PSServer { ...@@ -31,10 +31,10 @@ class GraphBrpcServer : public PSServer {
GraphBrpcServer() {} GraphBrpcServer() {}
virtual ~GraphBrpcServer() {} virtual ~GraphBrpcServer() {}
PsBaseService *get_service() { return _service.get(); } PsBaseService *get_service() { return _service.get(); }
virtual uint64_t start(const std::string &ip, uint32_t port); virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t build_peer2peer_connection(int rank); virtual int32_t build_peer2peer_connection(int rank);
virtual brpc::Channel *get_cmd_channel(size_t server_index); virtual brpc::Channel *GetCmdChannel(size_t server_index);
virtual int32_t stop() { virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0; if (stoped_) return 0;
stoped_ = true; stoped_ = true;
...@@ -43,12 +43,12 @@ class GraphBrpcServer : public PSServer { ...@@ -43,12 +43,12 @@ class GraphBrpcServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
int32_t port(); int32_t Port();
std::condition_variable *export_cv() { return &cv_; } std::condition_variable *export_cv() { return &cv_; }
private: private:
virtual int32_t initialize(); virtual int32_t Initialize();
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
bool stoped_ = false; bool stoped_ = false;
...@@ -66,7 +66,7 @@ typedef int32_t (GraphBrpcService::*serviceFunc)( ...@@ -66,7 +66,7 @@ typedef int32_t (GraphBrpcService::*serviceFunc)(
class GraphBrpcService : public PsBaseService { class GraphBrpcService : public PsBaseService {
public: public:
virtual int32_t initialize() override; virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, const PsRequestMessage *request,
...@@ -75,7 +75,7 @@ class GraphBrpcService : public PsBaseService { ...@@ -75,7 +75,7 @@ class GraphBrpcService : public PsBaseService {
protected: protected:
std::unordered_map<int32_t, serviceFunc> _service_handler_map; std::unordered_map<int32_t, serviceFunc> _service_handler_map;
int32_t initialize_shard_info(); int32_t InitializeShardInfo();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request, int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample_neighbors(Table *table, int32_t graph_random_sample_neighbors(Table *table,
...@@ -100,21 +100,21 @@ class GraphBrpcService : public PsBaseService { ...@@ -100,21 +100,21 @@ class GraphBrpcService : public PsBaseService {
int32_t remove_graph_node(Table *table, const PsRequestMessage &request, int32_t remove_graph_node(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request, int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request, int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request, int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request, int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request, int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request, int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighbors_across_multi_servers(Table *table, int32_t sample_neighbors_across_multi_servers(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
......
...@@ -25,7 +25,7 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient); ...@@ -25,7 +25,7 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient); REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient); REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
int32_t PSClient::configure( int32_t PSClient::Configure(
const PSParameter &config, const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions, const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env, size_t client_id) { PSEnvironment &env, size_t client_id) {
...@@ -51,10 +51,10 @@ int32_t PSClient::configure( ...@@ -51,10 +51,10 @@ int32_t PSClient::configure(
_table_accessors[work_param.downpour_table_param(i).table_id()].reset( _table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor); accessor);
} }
return initialize(); return Initialize();
} }
PSClient *PSClientFactory::create(const PSParameter &ps_config) { PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param(); const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) { if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter"; LOG(ERROR) << "miss downpour_server_param in ServerParameter";
...@@ -81,7 +81,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) { ...@@ -81,7 +81,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
return NULL; return NULL;
} }
TableManager::instance().initialize(); TableManager::Instance().Initialize();
VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success"; VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success";
return client; return client;
} }
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
...@@ -60,41 +59,6 @@ class PSClientClosure : public google::protobuf::Closure { ...@@ -60,41 +59,6 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises; std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
}; };
struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};
enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };
enum TrainingPhase { Init = 0, Train = 1, Save = 2 };
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};
struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
uint64_t *keys;
float **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};
class PSClient { class PSClient {
public: public:
PSClient() {} PSClient() {}
...@@ -102,41 +66,37 @@ class PSClient { ...@@ -102,41 +66,37 @@ class PSClient {
PSClient(PSClient &&) = delete; PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete; PSClient(const PSClient &) = delete;
virtual int32_t configure( // NOLINT virtual int32_t Configure( // NOLINT
const PSParameter &config, const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions, &regions,
PSEnvironment &_env, size_t client_id) final; // NOLINT PSEnvironment &_env, size_t client_id) final; // NOLINT
virtual int32_t create_client2client_connection( virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_timeout_ms, int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) = 0; int max_retry) = 0;
// 触发table数据退场 // 触发table数据退场
virtual std::future<int32_t> shrink(uint32_t table_id, virtual std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) = 0; const std::string threshold) = 0;
// 全量table进行数据load // 全量table进行数据load
virtual std::future<int32_t> load(const std::string &epoch, virtual std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// 指定table数据load // 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件 // 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch, virtual std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件 // 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;
// 清空table数据 // 清空table数据
virtual std::future<int32_t> clear() = 0; virtual std::future<int32_t> Clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0; virtual std::future<int32_t> Clear(uint32_t table_id) = 0;
// pull dense的参数部分,并分块填充到本地网络参数中 // pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数 // start和num用于拉取部分参数
...@@ -145,23 +105,19 @@ class PSClient { ...@@ -145,23 +105,19 @@ class PSClient {
// sender聚集同一区块的请求,累计多个填充buffer // sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回 // server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中 // 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num, virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留 size_t table_id) = 0; // 保留
virtual std::future<int32_t> Push(RequestContext &push_context) = 0;
// firstly push dense param for parameter server // firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold // this is neccessary because dense weight initialized in trainer on cold
// start // start
virtual std::future<int32_t> push_dense_param(const Region *regions, virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) = 0; size_t table_id) = 0;
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0; virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values // 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间 // keys和values的个数均为num个,每个value占用select_size空间
...@@ -169,15 +125,14 @@ class PSClient { ...@@ -169,15 +125,14 @@ class PSClient {
// 整合多个线程请求的keys,聚集并分散发送到server // 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值 // 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理. // is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual std::future<int32_t> pull_sparse(float **select_values, virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, size_t table_id, const uint64_t *keys,
const uint64_t *keys, size_t num, size_t num, bool is_training) = 0;
bool is_training) = 0;
virtual std::future<int32_t> PullSparseParam(float **select_values,
virtual std::future<int32_t> pull_sparse_param(float **select_values, size_t table_id,
size_t table_id, const uint64_t *keys, size_t num,
const uint64_t *keys, bool is_training) {
size_t num, bool is_training) {
VLOG(0) << "Did not implement"; VLOG(0) << "Did not implement";
std::promise<int32_t> promise; std::promise<int32_t> promise;
std::future<int> fut = promise.get_future(); std::future<int> fut = promise.get_future();
...@@ -185,10 +140,10 @@ class PSClient { ...@@ -185,10 +140,10 @@ class PSClient {
return fut; return fut;
} }
virtual ::std::future<int32_t> pull_sparse_ptr(char **select_values, virtual ::std::future<int32_t> PullSparsePtr(char **select_values,
size_t table_id, size_t table_id,
const uint64_t *keys, const uint64_t *keys,
size_t num) { size_t num) {
VLOG(0) << "Did not implement"; VLOG(0) << "Did not implement";
std::promise<int32_t> promise; std::promise<int32_t> promise;
std::future<int> fut = promise.get_future(); std::future<int> fut = promise.get_future();
...@@ -196,38 +151,38 @@ class PSClient { ...@@ -196,38 +151,38 @@ class PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> print_table_stat(uint32_t table_id) = 0; virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送 // 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> flush() = 0; virtual std::future<int32_t> Flush() = 0;
// server优雅退出 // server优雅退出
virtual std::future<int32_t> stop_server() = 0; virtual std::future<int32_t> StopServer() = 0;
// server profilera // server profilera
virtual std::future<int32_t> start_profiler() = 0; virtual std::future<int32_t> StartProfiler() = 0;
virtual std::future<int32_t> stop_profiler() = 0; virtual std::future<int32_t> StopProfiler() = 0;
virtual std::future<int32_t> barrier(size_t table_id, virtual std::future<int32_t> Barrier(size_t table_id,
uint32_t barrier_type) = 0; uint32_t barrier_type) = 0;
virtual std::future<int32_t> pull_geo_param(size_t table_id, virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values, std::vector<float> *values,
std::vector<uint64_t> *keys, std::vector<uint64_t> *keys,
int pserver_idx) = 0; int pserver_idx) = 0;
virtual std::future<int32_t> push_global_step(int table_id, virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data, int64_t *total_send_data,
void *done) = 0; void *done) = 0;
// recv table from server and save it in LodTensor // recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id, virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path) = 0; const std::string &path) = 0;
virtual void finalize_worker() = 0; virtual void FinalizeWorker() = 0;
// client to client, 消息发送 // client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type, virtual std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id, int to_client_id,
const std::string &msg) { const std::string &msg) {
VLOG(0) << "Did not implement"; VLOG(0) << "Did not implement";
std::promise<int32_t> promise; std::promise<int32_t> promise;
std::future<int> fut = promise.get_future(); std::future<int> fut = promise.get_future();
...@@ -238,13 +193,13 @@ class PSClient { ...@@ -238,13 +193,13 @@ class PSClient {
// client2client消息处理,std::function<int32_t (int, int, const std::string&) // client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg) // -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc; typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_client2client_msg_handler(int msg_type, virtual int RegisteClient2ClientMsgHandler(int msg_type,
MsgHandlerFunc handler) { MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler; _msg_handler_map[msg_type] = handler;
return 0; return 0;
} }
virtual int handle_client2client_msg(int msg_type, int from_client_id, virtual int HandleClient2ClientMsg(int msg_type, int from_client_id,
const std::string &msg) { const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type); auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) { if (itr == _msg_handler_map.end()) {
LOG(WARNING) << "unknown client2client_msg type:" << msg_type; LOG(WARNING) << "unknown client2client_msg type:" << msg_type;
...@@ -253,7 +208,7 @@ class PSClient { ...@@ -253,7 +208,7 @@ class PSClient {
return itr->second(msg_type, from_client_id, msg); return itr->second(msg_type, from_client_id, msg);
} }
virtual ValueAccessor *table_accessor(size_t table_id) { virtual ValueAccessor *GetTableAccessor(size_t table_id) {
auto itr = _table_accessors.find(table_id); auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) { if (itr == _table_accessors.end()) {
return NULL; return NULL;
...@@ -261,31 +216,31 @@ class PSClient { ...@@ -261,31 +216,31 @@ class PSClient {
return itr->second.get(); return itr->second.get();
} }
virtual size_t get_server_nums() = 0; virtual size_t GetServerNums() = 0;
virtual std::future<int32_t> push_dense_raw_gradient( virtual std::future<int32_t> PushDenseRawGradient(int table_id,
int table_id, float *total_send_data, size_t total_send_data_size, float *total_send_data,
void *done) = 0; size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient( virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) = 0; size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient_partial( virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) = 0; uint32_t num, void *done, int pserver_idx) = 0;
virtual std::future<int32_t> push_sparse_param(size_t table_id, virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num, void *done) = 0; size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse(size_t table_id, virtual std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
const uint64_t *keys, const float **update_values,
const float **update_values, size_t num) = 0;
size_t num) = 0;
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t Initialize() = 0;
size_t _client_id; size_t _client_id;
PSParameter _config; PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>> std::map<uint64_t, std::vector<paddle::distributed::Region>>
...@@ -333,7 +288,7 @@ REGISTER_PSCORE_REGISTERER(PSClient); ...@@ -333,7 +288,7 @@ REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory { class PSClientFactory {
public: public:
static PSClient *create(const PSParameter &config); static PSClient *Create(const PSParameter &config);
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -19,166 +19,91 @@ ...@@ -19,166 +19,91 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t PsLocalClient::initialize() { int32_t PsLocalClient::Initialize() {
const auto& downpour_param = _config.server_param().downpour_server_param(); const auto& downpour_param = _config.server_param().downpour_server_param();
TableManager::instance().initialize(); TableManager::Instance().Initialize();
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto* table = CREATE_PSCORE_CLASS( auto* table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class()); Table, downpour_param.downpour_table_param(i).table_class());
table->set_shard(0, 1); table->SetShard(0, 1);
table->initialize(downpour_param.downpour_table_param(i), table->Initialize(downpour_param.downpour_table_param(i),
_config.fs_client_param()); _config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
} }
return 0; return 0;
} }
::std::future<int32_t> PsLocalClient::shrink(uint32_t table_id, ::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
const std::string threshold) { const std::string threshold) {
// TODO // TODO
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::load(const std::string& epoch, ::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
for (auto& it : _table_map) { for (auto& it : _table_map) {
load(it.first, epoch, mode); Load(it.first, epoch, mode);
} }
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::load(uint32_t table_id, ::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->load(epoch, mode); table_ptr->Load(epoch, mode);
return done(); return done();
} }
std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
for (auto& it : _table_map) { for (auto& it : _table_map) {
save(it.first, epoch, mode); Save(it.first, epoch, mode);
} }
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::save(uint32_t table_id, ::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->flush(); table_ptr->Flush();
table_ptr->save(epoch, mode); table_ptr->Save(epoch, mode);
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Save( ::std::future<int32_t> PsLocalClient::Clear() {
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::clear() {
// TODO // TODO
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::clear(uint32_t table_id) { ::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
// TODO // TODO
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::flush() { ::std::future<int32_t> PsLocalClient::Flush() {
// no need // no need
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::stop_server() { ::std::future<int32_t> PsLocalClient::StopServer() {
// no need // no need
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) { ::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
if (pull_context.value_type == Dense) { // pull dense size_t region_num,
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values); size_t table_id) {
pull_dense(dense_region, pull_context.num, pull_context.table); auto* accessor = GetTableAccessor(table_id);
} else { // pull sparse auto* table_ptr = GetTable(table_id);
// uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
// char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(reinterpret_cast<char**>(pull_context.sparse_values),
table_id, pull_context.keys, num);
}
}
::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) { uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1);
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}
::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(num_per_shard); region_buffer.resize(num_per_shard);
table_ptr->pull_dense(region_buffer.data(), region_buffer.size()); table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t region_idx = 0; size_t region_idx = 0;
size_t region_data_idx = 0; size_t region_data_idx = 0;
...@@ -213,48 +138,49 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -213,48 +138,49 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_dense_param(const Region* regions, ::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto* accessor = table_accessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1), region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0);
0);
for (size_t i = 0, offset = 0; i < region_num; ++i) { for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float); uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size); memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num; offset += data_num;
} }
// table_ptr->push_dense_param(region_buffer.data(), region_buffer.size()); // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_dense_raw_gradient( ::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
int table_id, float* total_send_data, size_t total_send_data_size, int table_id, float* total_send_data, size_t total_send_data_size,
void* callback) { void* callback) {
VLOG(1) << "wxx push_dense_raw_gradient"; VLOG(1) << "wxx push_dense_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback); PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->push_dense(total_send_data, total_send_data_size); table_ptr->PushDense(total_send_data, total_send_data_size);
delete closure; delete closure;
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_dense(const Region* regions, ::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto* accessor = table_accessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1)); region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1));
size_t data_size = region_buffer.size(); size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) { for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float); uint32_t data_num = regions[i].size / sizeof(float);
...@@ -267,12 +193,12 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -267,12 +193,12 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
offset += data_num; offset += data_num;
} }
table_ptr->push_dense(region_buffer.data(), region_buffer.size()); table_ptr->PushDense(region_buffer.data(), region_buffer.size());
return done(); return done();
} }
//::std::future<int32_t> PsLocalClient::pull_sparse(float** select_values, //::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id, // size_t table_id,
// const uint64_t* keys, // const uint64_t* keys,
// size_t num) { // size_t num) {
...@@ -282,14 +208,14 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -282,14 +208,14 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// // auto local_timer = // // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local"); // // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针 // //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = table_accessor(table_id); // auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = table(table_id); // auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size(); // size_t value_size = accessor->select_size();
// //
// // table_ptr->pull_sparse(keys, num); // // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data; // std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float)); // res_data.resize(num * value_size / sizeof(float));
// table_ptr->pull_sparse(res_data.data(), keys, num); // table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() * // // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float)); // // sizeof(float));
// size_t offset = 0; // size_t offset = 0;
...@@ -302,43 +228,43 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -302,43 +228,43 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// return done(); // return done();
//} //}
::std::future<int32_t> PsLocalClient::pull_sparse_ptr(char** select_values, ::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
size_t table_id, size_t table_id,
const uint64_t* keys, const uint64_t* keys,
size_t num) { size_t num) {
// FIXME // FIXME
// auto timer = // auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse"); // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer = // auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local"); // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针 //将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->pull_sparse_ptr(select_values, keys, num); table_ptr->PullSparsePtr(select_values, keys, num);
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_sparse_raw_gradient( ::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
size_t table_id, const uint64_t* keys, const float** update_values, size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) { size_t num, void* callback) {
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback); PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* accessor = table_accessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->push_sparse(keys, update_values, num); table_ptr->PushSparse(keys, update_values, num);
delete closure; delete closure;
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_sparse(size_t table_id, ::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
const uint64_t* keys, const uint64_t* keys,
const float** update_values, const float** update_values,
size_t num) { size_t num) {
auto* accessor = table_accessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->push_sparse(keys, update_values, num); table_ptr->PushSparse(keys, update_values, num);
return done(); return done();
} }
} }
......
...@@ -26,54 +26,46 @@ class PsLocalClient : public PSClient { ...@@ -26,54 +26,46 @@ class PsLocalClient : public PSClient {
public: public:
PsLocalClient() {} PsLocalClient() {}
virtual ~PsLocalClient() { _running = false; } virtual ~PsLocalClient() { _running = false; }
virtual int32_t create_client2client_connection(int pslib_timeout_ms, virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
int pslib_connect_timeout_ms, int pslib_connect_timeout_ms,
int max_retry) { int max_retry) {
return 0; return 0;
} }
virtual ::std::future<int32_t> shrink(uint32_t table_id, virtual ::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override; const std::string threshold) override;
virtual ::std::future<int32_t> load(const std::string& epoch, virtual ::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual ::std::future<int32_t> load(uint32_t table_id, virtual ::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;
virtual ::std::future<int32_t> save(const std::string& epoch, virtual ::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id, virtual ::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;
virtual ::std::future<int32_t> clear() override; virtual ::std::future<int32_t> Clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override; virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
virtual ::std::future<int32_t> stop_server() override; virtual ::std::future<int32_t> StopServer() override;
virtual void finalize_worker() override {} virtual void FinalizeWorker() override {}
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num, virtual ::std::future<int32_t> PullDense(Region* regions, size_t region_num,
size_t table_id); size_t table_id);
virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override; virtual ::std::future<int32_t> PushDense(const Region* regions,
size_t region_num, size_t table_id);
virtual ::std::future<int32_t> Push(RequestContext& push_context) override; virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> push_dense(const Region* regions, virtual ::std::future<int32_t> PullSparse(float** select_values,
size_t region_num, size_t table_id); size_t table_id,
const uint64_t* keys, size_t num,
virtual ::std::future<int32_t> push_dense_param(const Region* regions, bool is_training) {
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> pull_sparse(float** select_values,
size_t table_id,
const uint64_t* keys, size_t num,
bool is_training) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -81,26 +73,26 @@ class PsLocalClient : public PSClient { ...@@ -81,26 +73,26 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual ::std::future<int32_t> pull_sparse_ptr(char** select_values, virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
size_t table_id, size_t table_id,
const uint64_t* keys, const uint64_t* keys,
size_t num); size_t num);
virtual ::std::future<int32_t> print_table_stat(uint32_t table_id) { virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
return fut; return fut;
} }
virtual ::std::future<int32_t> push_sparse(size_t table_id, virtual ::std::future<int32_t> PushSparse(size_t table_id,
const uint64_t* keys, const uint64_t* keys,
const float** update_values, const float** update_values,
size_t num); size_t num);
virtual ::std::future<int32_t> flush(); virtual ::std::future<int32_t> Flush();
// server profilera // server profilera
virtual std::future<int32_t> start_profiler() { virtual std::future<int32_t> StartProfiler() {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -108,7 +100,7 @@ class PsLocalClient : public PSClient { ...@@ -108,7 +100,7 @@ class PsLocalClient : public PSClient {
return fut; return fut;
}; };
virtual std::future<int32_t> stop_profiler() { virtual std::future<int32_t> StopProfiler() {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -116,7 +108,7 @@ class PsLocalClient : public PSClient { ...@@ -116,7 +108,7 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type) { virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -124,10 +116,10 @@ class PsLocalClient : public PSClient { ...@@ -124,10 +116,10 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> pull_geo_param(size_t table_id, virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float>* values, std::vector<float>* values,
std::vector<uint64_t>* keys, std::vector<uint64_t>* keys,
int pserver_idx) { int pserver_idx) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -135,9 +127,9 @@ class PsLocalClient : public PSClient { ...@@ -135,9 +127,9 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> push_global_step(int table_id, virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t* total_send_data, int64_t* total_send_data,
void* done) { void* done) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -146,12 +138,12 @@ class PsLocalClient : public PSClient { ...@@ -146,12 +138,12 @@ class PsLocalClient : public PSClient {
} }
// recv table from server and save it in LodTensor // recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id, virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string& path) { const std::string& path) {
return 0; return 0;
} }
virtual ::std::future<int32_t> send_client2client_msg( virtual ::std::future<int32_t> SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string& msg) override { int msg_type, int to_client_id, const std::string& msg) override {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
...@@ -159,17 +151,18 @@ class PsLocalClient : public PSClient { ...@@ -159,17 +151,18 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual size_t get_server_nums() { return 1; } virtual size_t GetServerNums() { return 1; }
virtual std::future<int32_t> push_dense_raw_gradient( virtual std::future<int32_t> PushDenseRawGradient(int table_id,
int table_id, float* total_send_data, size_t total_send_data_size, float* total_send_data,
void* callback) override; size_t total_send_data_size,
void* callback) override;
virtual std::future<int32_t> push_sparse_raw_gradient( virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id, const uint64_t* keys, const float** update_values, size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) override; size_t num, void* callback) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial( virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id, const uint64_t* keys, const float** update_values, size_t table_id, const uint64_t* keys, const float** update_values,
uint32_t num, void* done, int pserver_idx) override { uint32_t num, void* done, int pserver_idx) override {
std::promise<int32_t> prom; std::promise<int32_t> prom;
...@@ -179,11 +172,11 @@ class PsLocalClient : public PSClient { ...@@ -179,11 +172,11 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> push_sparse_param(size_t table_id, virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys, const uint64_t* keys,
const float** update_values, const float** update_values,
size_t num, size_t num,
void* done) override { void* done) override {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -192,7 +185,7 @@ class PsLocalClient : public PSClient { ...@@ -192,7 +185,7 @@ class PsLocalClient : public PSClient {
} }
private: private:
virtual int32_t initialize() override; virtual int32_t Initialize() override;
std::future<int32_t> done() { std::future<int32_t> done() {
std::shared_ptr<std::promise<int32_t>> prom = std::shared_ptr<std::promise<int32_t>> prom =
...@@ -202,16 +195,16 @@ class PsLocalClient : public PSClient { ...@@ -202,16 +195,16 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) { uint32_t shard_num) {
return dense_dim_total / shard_num + 1; return dense_dim_total / shard_num + 1;
} }
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* table() { inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
return &_table_map; return &_table_map;
} }
inline Table* table(size_t table_id) { inline Table* GetTable(size_t table_id) {
auto itr = _table_map.find(table_id); auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) { if (itr != _table_map.end()) {
return itr->second.get(); return itr->second.get();
......
...@@ -25,17 +25,17 @@ class PsLocalServer : public PSServer { ...@@ -25,17 +25,17 @@ class PsLocalServer : public PSServer {
public: public:
PsLocalServer() {} PsLocalServer() {}
virtual ~PsLocalServer() {} virtual ~PsLocalServer() {}
virtual uint64_t start() { return 0; } virtual uint64_t Start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; } virtual int32_t Stop() { return 0; }
virtual int32_t configure( virtual int32_t Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) { const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0; return 0;
} }
private: private:
virtual int32_t initialize() { return 0; } virtual int32_t Initialize() { return 0; }
}; };
} }
} }
...@@ -70,7 +70,7 @@ void GraphPyService::set_up(std::string ips_str, int shard_num, ...@@ -70,7 +70,7 @@ void GraphPyService::set_up(std::string ips_str, int shard_num,
port_list.push_back(ip_and_port[1]); port_list.push_back(ip_and_port[1]);
uint32_t port = stoul(ip_and_port[1]); uint32_t port = stoul(ip_and_port[1]);
auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index); auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index);
host_sign_list.push_back(ph_host.serialize_to_string()); host_sign_list.push_back(ph_host.SerializeToString());
index++; index++;
} }
} }
...@@ -83,11 +83,11 @@ void GraphPyClient::start_client() { ...@@ -83,11 +83,11 @@ void GraphPyClient::start_client() {
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list.size(); auto servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list, servers_); _ps_env.SetPsServers(&host_sign_list, servers_);
worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>( worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*) (paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto)); paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id); worker_ptr->Configure(worker_proto, dense_regions, _ps_env, client_id);
worker_ptr->set_shard_num(get_shard_num()); worker_ptr->set_shard_num(get_shard_num());
} }
void GraphPyServer::start_server(bool block) { void GraphPyServer::start_server(bool block) {
...@@ -96,17 +96,17 @@ void GraphPyServer::start_server(bool block) { ...@@ -96,17 +96,17 @@ void GraphPyServer::start_server(bool block) {
::paddle::distributed::PSParameter server_proto = this->GetServerProto(); ::paddle::distributed::PSParameter server_proto = this->GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment(); auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&this->host_sign_list, _ps_env.SetPsServers(&this->host_sign_list,
this->host_sign_list.size()); // test this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>( pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*) (paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto)); paddle::distributed::PSServerFactory::Create(server_proto));
VLOG(0) << "pserver-ptr created "; VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec; std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog; framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog); empty_vec.push_back(empty_prog);
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec); pserver_ptr->Configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->start(ip, port); pserver_ptr->Start(ip, port);
pserver_ptr->build_peer2peer_connection(rank); pserver_ptr->build_peer2peer_connection(rank);
std::condition_variable* cv_ = pserver_ptr->export_cv(); std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) { if (block) {
...@@ -246,7 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath, ...@@ -246,7 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath,
VLOG(0) << "loadding data with type " << name << " from " << filepath; VLOG(0) << "loadding data with type " << name << " from " << filepath;
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status =
get_ps_client()->load(table_id, std::string(filepath), params); get_ps_client()->Load(table_id, std::string(filepath), params);
status.wait(); status.wait();
} }
} }
...@@ -285,7 +285,7 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { ...@@ -285,7 +285,7 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status =
get_ps_client()->load(table_id, std::string(filepath), params); get_ps_client()->Load(table_id, std::string(filepath), params);
status.wait(); status.wait();
} }
} }
...@@ -396,13 +396,13 @@ std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name, ...@@ -396,13 +396,13 @@ std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
return res; return res;
} }
void GraphPyClient::stop_server() { void GraphPyClient::StopServer() {
VLOG(0) << "going to stop server"; VLOG(0) << "going to stop server";
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return; if (stoped_) return;
auto status = this->worker_ptr->stop_server(); auto status = this->worker_ptr->StopServer();
if (status.get() == 0) stoped_ = true; if (status.get() == 0) stoped_ = true;
} }
void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); } void GraphPyClient::FinalizeWorker() { this->worker_ptr->FinalizeWorker(); }
} }
} }
...@@ -123,7 +123,7 @@ class GraphPyServer : public GraphPyService { ...@@ -123,7 +123,7 @@ class GraphPyServer : public GraphPyService {
set_rank(rank); set_rank(rank);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
} }
int get_rank() { return rank; } int GetRank() { return rank; }
void set_rank(int rank) { this->rank = rank; } void set_rank(int rank) { this->rank = rank; }
void start_server(bool block = true); void start_server(bool block = true);
...@@ -154,8 +154,8 @@ class GraphPyClient : public GraphPyService { ...@@ -154,8 +154,8 @@ class GraphPyClient : public GraphPyService {
(paddle::distributed::GraphBrpcService*)server.get_ps_server() (paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service()); ->get_service());
} }
void stop_server(); void StopServer();
void finalize_worker(); void FinalizeWorker();
void load_edge_file(std::string name, std::string filepath, bool reverse); void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath); void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name); void clear_nodes(std::string name);
......
...@@ -46,7 +46,7 @@ paddle::distributed::PSParameter load_from_prototxt( ...@@ -46,7 +46,7 @@ paddle::distributed::PSParameter load_from_prototxt(
return param; return param;
} }
void PSCore::init_gflag(const std::string& gflags) { void PSCore::InitGFlag(const std::string& gflags) {
VLOG(3) << "Init With Gflags:" << gflags; VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags); std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) { if (flags.size() < 1) {
...@@ -65,67 +65,67 @@ void PSCore::init_gflag(const std::string& gflags) { ...@@ -65,67 +65,67 @@ void PSCore::init_gflag(const std::string& gflags) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true); ::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
} }
int PSCore::init_server( int PSCore::InitServer(
const std::string& dist_desc, const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index, const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers, int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) { const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags()); InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num); _ps_env.SetPsServers(host_sign_list, node_num);
_ps_env.set_trainers(trainers); _ps_env.SetTrainers(trainers);
int ret = 0; int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>( _server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param)); paddle::distributed::PSServerFactory::Create(_ps_param));
ret = _server_ptr->configure(_ps_param, _ps_env, index, server_sub_program); ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program);
CHECK(ret == 0) << "failed to configure server"; CHECK(ret == 0) << "failed to configure server";
return ret; return ret;
} }
int PSCore::init_worker( int PSCore::InitWorker(
const std::string& dist_desc, const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions, const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
const std::vector<std::string>* host_sign_list, int node_num, int index) { const std::vector<std::string>* host_sign_list, int node_num, int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags()); InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num); _ps_env.SetPsServers(host_sign_list, node_num);
int ret = 0; int ret = 0;
VLOG(1) << "PSCore::init_worker"; VLOG(1) << "PSCore::InitWorker";
auto* communicator = Communicator::GetInstance(); auto* communicator = Communicator::GetInstance();
ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env, ret = communicator->GetPsClient()->Configure(_ps_param, regions, _ps_env,
index); index);
communicator->Start(); communicator->Start();
return ret; return ret;
} }
std::vector<uint64_t> PSCore::get_client_info() { std::vector<uint64_t> PSCore::GetClientInfo() {
return _ps_env.get_client_info(); return _ps_env.GetClientInfo();
} }
int PSCore::create_client2client_connection(int pserver_timeout_ms, int PSCore::CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) { int max_retry) {
int ret = _worker_ptr->create_client2client_connection( int ret = _worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
return ret; return ret;
} }
uint64_t PSCore::run_server(const std::string& ip, uint32_t port) { uint64_t PSCore::RunServer(const std::string& ip, uint32_t port) {
return _server_ptr->start(ip, port); return _server_ptr->Start(ip, port);
} }
int PSCore::finalize_worker() { int PSCore::FinalizeWorker() {
_worker_ptr->finalize_worker(); _worker_ptr->FinalizeWorker();
return 0; return 0;
} }
int PSCore::stop_server() { int PSCore::StopServer() {
auto stop_status = _worker_ptr->stop_server(); auto stop_status = _worker_ptr->StopServer();
stop_status.wait(); stop_status.wait();
return 0; return 0;
} }
paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; } paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -42,31 +42,31 @@ class PSCore { ...@@ -42,31 +42,31 @@ class PSCore {
explicit PSCore() {} explicit PSCore() {}
virtual ~PSCore() {} virtual ~PSCore() {}
virtual int init_server( virtual int InitServer(
const std::string& dist_desc, const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index, const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers, int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {}); const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int init_worker( virtual int InitWorker(
const std::string& dist_desc, const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
regions, regions,
const std::vector<std::string>* host_sign_list, int node_num, int index); const std::vector<std::string>* host_sign_list, int node_num, int index);
virtual uint64_t run_server(const std::string& ip, uint32_t port); virtual uint64_t RunServer(const std::string& ip, uint32_t port);
virtual int stop_server(); virtual int StopServer();
virtual int finalize_worker(); virtual int FinalizeWorker();
virtual std::vector<uint64_t> get_client_info(); virtual std::vector<uint64_t> GetClientInfo();
virtual int create_client2client_connection(int pserver_timeout_ms, virtual int CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry); int max_retry);
std::shared_ptr<paddle::distributed::PSServer> std::shared_ptr<paddle::distributed::PSServer>
_server_ptr; // pointer to server _server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient> std::shared_ptr<paddle::distributed::PSClient>
_worker_ptr; // pointer to worker _worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* get_param(); virtual paddle::distributed::PSParameter* GetParam();
private: private:
void init_gflag(const std::string& gflags); void InitGFlag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param; paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
}; };
......
...@@ -29,7 +29,7 @@ REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService); ...@@ -29,7 +29,7 @@ REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer); REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService); REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) { PSServer *PSServerFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param(); const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) { if (!config.has_downpour_server_param()) {
...@@ -56,18 +56,18 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) { ...@@ -56,18 +56,18 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
<< service_param.server_class(); << service_param.server_class();
return NULL; return NULL;
} }
TableManager::instance().initialize(); TableManager::Instance().Initialize();
return server; return server;
} }
int32_t PSServer::configure( int32_t PSServer::Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program) { const std::vector<framework::ProgramDesc> &server_sub_program) {
scope_.reset(new framework::Scope()); scope_.reset(new framework::Scope());
_config = config.server_param(); _config = config.server_param();
_rank = server_rank; _rank = server_rank;
_environment = &env; _environment = &env;
size_t shard_num = env.get_ps_servers().size(); size_t shard_num = env.GetPsServers().size();
const auto &downpour_param = _config.downpour_server_param(); const auto &downpour_param = _config.downpour_server_param();
...@@ -87,21 +87,21 @@ int32_t PSServer::configure( ...@@ -87,21 +87,21 @@ int32_t PSServer::configure(
global_step_table = downpour_param.downpour_table_param(i).table_id(); global_step_table = downpour_param.downpour_table_param(i).table_id();
} }
table->set_program_env(scope_.get(), place_, &server_sub_program); table->SetProgramEnv(scope_.get(), place_, &server_sub_program);
table->set_shard(_rank, shard_num); table->SetShard(_rank, shard_num);
table->initialize(downpour_param.downpour_table_param(i), table->Initialize(downpour_param.downpour_table_param(i),
config.fs_client_param()); config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
} }
if (barrier_table != UINT32_MAX) { if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->set_table_map(&_table_map); _table_map[barrier_table]->SetTableMap(&_table_map);
} }
if (global_step_table != UINT32_MAX) { if (global_step_table != UINT32_MAX) {
_table_map[global_step_table]->set_table_map(&_table_map); _table_map[global_step_table]->SetTableMap(&_table_map);
} }
return initialize(); return Initialize();
} }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -65,19 +65,19 @@ class PSServer { ...@@ -65,19 +65,19 @@ class PSServer {
PSServer(PSServer &&) = delete; PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete; PSServer(const PSServer &) = delete;
virtual int32_t configure( virtual int32_t Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}); const std::vector<framework::ProgramDesc> &server_sub_program = {});
virtual uint64_t start(const std::string &ip, uint32_t port) = 0; virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0; virtual int32_t Stop() = 0;
inline size_t rank() const { return _rank; } inline size_t Rank() const { return _rank; }
inline PSEnvironment *environment() { return _environment; } inline PSEnvironment *Environment() { return _environment; }
inline const ServerParameter *config() const { return &_config; } inline const ServerParameter *Config() const { return &_config; }
inline Table *table(size_t table_id) { inline Table *GetTable(size_t table_id) {
auto itr = _table_map.find(table_id); auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) { if (itr != _table_map.end()) {
return itr->second.get(); return itr->second.get();
...@@ -85,12 +85,12 @@ class PSServer { ...@@ -85,12 +85,12 @@ class PSServer {
return NULL; return NULL;
} }
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *table() { inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
return &_table_map; return &_table_map;
} }
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t Initialize() = 0;
protected: protected:
size_t _rank; size_t _rank;
...@@ -129,11 +129,11 @@ class PsBaseService : public PsService { ...@@ -129,11 +129,11 @@ class PsBaseService : public PsService {
public: public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {} PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {} virtual ~PsBaseService() {}
virtual size_t get_rank() { return _rank; } virtual size_t GetRank() { return _rank; }
virtual int32_t configure(PSServer *server) { virtual int32_t Configure(PSServer *server) {
_server = server; _server = server;
_rank = _server->rank(); _rank = _server->Rank();
_config = _server->config(); _config = _server->Config();
return 0; return 0;
} }
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
...@@ -148,8 +148,8 @@ class PsBaseService : public PsService { ...@@ -148,8 +148,8 @@ class PsBaseService : public PsService {
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg; LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
} }
virtual int32_t initialize() = 0; virtual int32_t Initialize() = 0;
PSServer *get_server() { return _server; } PSServer *GetServer() { return _server; }
protected: protected:
size_t _rank; size_t _rank;
...@@ -160,7 +160,7 @@ REGISTER_PSCORE_REGISTERER(PsBaseService); ...@@ -160,7 +160,7 @@ REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory { class PSServerFactory {
public: public:
static PSServer *create(const PSParameter &config); static PSServer *Create(const PSParameter &config);
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t BarrierTable::initialize() { int32_t BarrierTable::Initialize() {
auto trainers = _config.common().trainer_num(); auto trainers = _config.common().trainer_num();
trigger_.store(trainers); trigger_.store(trainers);
...@@ -29,7 +29,7 @@ int32_t BarrierTable::initialize() { ...@@ -29,7 +29,7 @@ int32_t BarrierTable::initialize() {
} }
// 0: send_barrier 1: recv_barrier 2: complete // 0: send_barrier 1: recv_barrier 2: complete
int32_t BarrierTable::barrier(const uint32_t trainer_id, int32_t BarrierTable::Barrier(const uint32_t trainer_id,
const std::string barrier_type) { const std::string barrier_type) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -56,7 +56,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id, ...@@ -56,7 +56,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id,
VLOG(1) << "barrier table optimize begin"; VLOG(1) << "barrier table optimize begin";
for (auto& x : *table_map_) { for (auto& x : *table_map_) {
auto table = x.second; auto table = x.second;
table->pour(); table->Pour();
} }
VLOG(1) << "barrier table optimize done"; VLOG(1) << "barrier table optimize done";
...@@ -66,7 +66,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id, ...@@ -66,7 +66,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id,
return 0; return 0;
} }
int32_t BarrierTable::set_table_map( int32_t BarrierTable::SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) { std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) {
table_map_ = table_map; table_map_ = table_map;
return 0; return 0;
......
...@@ -21,8 +21,8 @@ namespace distributed { ...@@ -21,8 +21,8 @@ namespace distributed {
int FLAGS_pslib_table_save_max_retry_dense = 3; int FLAGS_pslib_table_save_max_retry_dense = 3;
void CommonDenseTable::create_initializer(const std::string& attr, void CommonDenseTable::CreateInitializer(const std::string& attr,
const std::string& name) { const std::string& name) {
auto slices = string::split_string<std::string>(attr, "&"); auto slices = string::split_string<std::string>(attr, "&");
if (slices[0] == "gaussian_random") { if (slices[0] == "gaussian_random") {
...@@ -39,7 +39,7 @@ void CommonDenseTable::create_initializer(const std::string& attr, ...@@ -39,7 +39,7 @@ void CommonDenseTable::create_initializer(const std::string& attr,
} }
} }
int32_t CommonDenseTable::initialize() { int32_t CommonDenseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_); _shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
...@@ -49,12 +49,12 @@ int32_t CommonDenseTable::initialize() { ...@@ -49,12 +49,12 @@ int32_t CommonDenseTable::initialize() {
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync; VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
_global_lr = new float(1.0); _global_lr = new float(1.0);
initialize_value(); InitializeValue();
initialize_optimizer(); InitializeOptimizer();
return 0; return 0;
} }
int32_t CommonDenseTable::initialize_value() { int32_t CommonDenseTable::InitializeValue() {
auto common = _config.common(); auto common = _config.common();
int size = static_cast<int>(common.params().size()); int size = static_cast<int>(common.params().size());
values_.resize(size); values_.resize(size);
...@@ -70,7 +70,7 @@ int32_t CommonDenseTable::initialize_value() { ...@@ -70,7 +70,7 @@ int32_t CommonDenseTable::initialize_value() {
auto& initializer = common.initializers()[x]; auto& initializer = common.initializers()[x];
total_dim_ += dim; total_dim_ += dim;
create_initializer(initializer, varname); CreateInitializer(initializer, varname);
values_[x].resize(dim); values_[x].resize(dim);
names_index_[varname] = x; names_index_[varname] = x;
...@@ -92,27 +92,27 @@ int32_t CommonDenseTable::initialize_value() { ...@@ -92,27 +92,27 @@ int32_t CommonDenseTable::initialize_value() {
param_col_ids_.insert(param_col_ids_.begin() + 1, -1); param_col_ids_.insert(param_col_ids_.begin() + 1, -1);
} }
VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_ VLOG(1) << "CommonDenseTable::InitializeValue total dim: " << total_dim_
<< " fixed_len_params_dim: " << fixed_len_params_dim_; << " fixed_len_params_dim: " << fixed_len_params_dim_;
pull_reservoir_ = ReservoirValue<float>(param_dim_); pull_reservoir_ = ReservoirValue<float>(param_dim_);
return 0; return 0;
} }
int32_t CommonDenseTable::initialize_optimizer() { int32_t CommonDenseTable::InitializeOptimizer() {
auto common = _config.common(); auto common = _config.common();
auto name = common.name(); auto name = common.name();
auto attrs = common.attributes(); auto attrs = common.attributes();
if (name == "sgd") { if (name == "sgd") {
optimizer_ = std::make_shared<DSGD>(common, &values_); optimizer_ = std::make_shared<DSGD>(common, &values_);
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam") { } else if (name == "adam") {
optimizer_ = std::make_shared<DAdam>(common, &values_); optimizer_ = std::make_shared<DAdam>(common, &values_);
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam_d2sum") { } else if (name == "adam_d2sum") {
optimizer_ = std::make_shared<DAdamD2Sum>(common, &values_); optimizer_ = std::make_shared<DAdamD2Sum>(common, &values_);
// optimizer_->set_global_lr(_global_lr); //no use // optimizer_->SetGlobalLR(_global_lr); //no use
} else if (name == "sum") { } else if (name == "sum") {
optimizer_ = std::make_shared<DSUM>(common, &values_); optimizer_ = std::make_shared<DSUM>(common, &values_);
} else if (name == "summary") { } else if (name == "summary") {
...@@ -124,34 +124,34 @@ int32_t CommonDenseTable::initialize_optimizer() { ...@@ -124,34 +124,34 @@ int32_t CommonDenseTable::initialize_optimizer() {
return 0; return 0;
} }
int32_t CommonDenseTable::set_global_lr(float* lr) { int32_t CommonDenseTable::SetGlobalLR(float* lr) {
_global_lr = lr; _global_lr = lr;
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
return 0; return 0;
} }
int32_t CommonDenseTable::Pull(TableContext& context) { int32_t CommonDenseTable::Pull(TableContext& context) {
CHECK(context.value_type == Dense); CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values; float* pull_values = context.pull_context.values;
return pull_dense(pull_values, context.num); return PullDense(pull_values, context.num);
} }
int32_t CommonDenseTable::Push(TableContext& context) { int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense); CHECK(context.value_type == Dense);
if (context.push_context.values != nullptr) { if (context.push_context.values != nullptr) {
const float* values = context.push_context.values; const float* values = context.push_context.values;
return push_dense(values, context.num); return PushDense(values, context.num);
} }
return 0; return 0;
} }
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { int32_t CommonDenseTable::PullDense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values); pull_values);
return 0; return 0;
} }
int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { int32_t CommonDenseTable::PushDenseParam(const float* values, size_t num) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
num, param_dim_, num, param_dim_,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
...@@ -160,14 +160,14 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { ...@@ -160,14 +160,14 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
return 0; return 0;
} }
int32_t CommonDenseTable::pour() { int32_t CommonDenseTable::Pour() {
pull_reservoir_.avg(); pull_reservoir_.avg();
_push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); _PushDense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
pull_reservoir_.reset(); pull_reservoir_.reset();
return 0; return 0;
} }
int32_t CommonDenseTable::push_dense(const float* values, size_t num) { int32_t CommonDenseTable::PushDense(const float* values, size_t num) {
if (sync) { if (sync) {
std::future<int> task = std::future<int> task =
_shards_task_pool[0]->enqueue([this, &values]() -> int { _shards_task_pool[0]->enqueue([this, &values]() -> int {
...@@ -176,12 +176,12 @@ int32_t CommonDenseTable::push_dense(const float* values, size_t num) { ...@@ -176,12 +176,12 @@ int32_t CommonDenseTable::push_dense(const float* values, size_t num) {
}); });
task.wait(); task.wait();
} else { } else {
_push_dense(values, num); _PushDense(values, num);
} }
return 0; return 0;
} }
int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { int32_t CommonDenseTable::_PushDense(const float* values, size_t num) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
num, param_dim_, num, param_dim_,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
...@@ -195,7 +195,7 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { ...@@ -195,7 +195,7 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
[this, shard_id, &buckets, &values]() -> int { [this, shard_id, &buckets, &values]() -> int {
auto begin = buckets[shard_id]; auto begin = buckets[shard_id];
auto end = buckets[shard_id + 1]; auto end = buckets[shard_id + 1];
optimizer_->update(values, param_dim_, begin, end); optimizer_->Update(values, param_dim_, begin, end);
return 0; return 0;
}); });
} }
...@@ -207,12 +207,12 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { ...@@ -207,12 +207,12 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
return 0; return 0;
} }
int32_t CommonDenseTable::load(const std::string& path, int32_t CommonDenseTable::Load(const std::string& path,
const std::string& param) { const std::string& param) {
if (param_dim_ <= 0) { if (param_dim_ <= 0) {
return 0; return 0;
} }
std::string table_path = table_dir(path); std::string table_path = TableDir(path);
auto file_list = _afs_client.list(table_path); auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end()); std::sort(file_list.begin(), file_list.end());
for (auto ff : file_list) { for (auto ff : file_list) {
...@@ -314,7 +314,7 @@ int32_t CommonDenseTable::load(const std::string& path, ...@@ -314,7 +314,7 @@ int32_t CommonDenseTable::load(const std::string& path,
return 0; return 0;
} }
int32_t CommonDenseTable::save(const std::string& path, int32_t CommonDenseTable::Save(const std::string& path,
const std::string& param) { const std::string& param) {
int save_param = atoi(param.c_str()); int save_param = atoi(param.c_str());
uint32_t feasign_size; uint32_t feasign_size;
...@@ -323,10 +323,10 @@ int32_t CommonDenseTable::save(const std::string& path, ...@@ -323,10 +323,10 @@ int32_t CommonDenseTable::save(const std::string& path,
FsChannelConfig channel_config; FsChannelConfig channel_config;
if (_config.compress_in_save()) { if (_config.compress_in_save()) {
channel_config.path = paddle::string::format_string( channel_config.path = paddle::string::format_string(
"%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx); "%s/part-%03d.gz", TableDir(path).c_str(), _shard_idx);
} else { } else {
channel_config.path = paddle::string::format_string( channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_dir(path).c_str(), _shard_idx); "%s/part-%03d", TableDir(path).c_str(), _shard_idx);
} }
_afs_client.remove(channel_config.path); _afs_client.remove(channel_config.path);
channel_config.converter = _value_accesor->Converter(save_param).converter; channel_config.converter = _value_accesor->Converter(save_param).converter;
......
...@@ -34,29 +34,29 @@ class CommonDenseTable : public DenseTable { ...@@ -34,29 +34,29 @@ class CommonDenseTable : public DenseTable {
public: public:
CommonDenseTable() {} CommonDenseTable() {}
virtual ~CommonDenseTable() {} virtual ~CommonDenseTable() {}
int32_t initialize() override; int32_t Initialize() override;
int32_t initialize_shard() override { return 0; } int32_t InitializeShard() override { return 0; }
virtual void create_initializer(const std::string& attr, virtual void CreateInitializer(const std::string& attr,
const std::string& name); const std::string& name);
virtual int32_t initialize_value(); virtual int32_t InitializeValue();
virtual int32_t initialize_optimizer(); virtual int32_t InitializeOptimizer();
virtual int32_t Pull(TableContext& context); virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context); virtual int32_t Push(TableContext& context);
int32_t pull_dense(float* pull_values, size_t num) override; int32_t PullDense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override; int32_t PushDenseParam(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override; int32_t PushDense(const float* values, size_t num) override;
int32_t pour() override; int32_t Pour() override;
int32_t set_global_lr(float* lr) override; int32_t SetGlobalLR(float* lr) override;
int32_t load(const std::string& path, const std::string& param) override; int32_t Load(const std::string& path, const std::string& param) override;
int32_t save(const std::string& path, const std::string& param) override; int32_t Save(const std::string& path, const std::string& param) override;
int32_t flush() override { return 0; } int32_t Flush() override { return 0; }
int32_t shrink(const std::string& param) override { return 0; } int32_t Shrink(const std::string& param) override { return 0; }
void clear() override { return; } void Clear() override { return; }
protected: protected:
int32_t _push_dense(const float* values, size_t num); int32_t _PushDense(const float* values, size_t num);
private: private:
const int task_pool_size_ = 10; const int task_pool_size_ = 10;
......
...@@ -448,7 +448,7 @@ int32_t GraphTable::load_graph_split_config(const std::string &path) { ...@@ -448,7 +448,7 @@ int32_t GraphTable::load_graph_split_config(const std::string &path) {
return 0; return 0;
} }
int32_t GraphTable::load(const std::string &path, const std::string &param) { int32_t GraphTable::Load(const std::string &path, const std::string &param) {
bool load_edge = (param[0] == 'e'); bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n'); bool load_node = (param[0] == 'n');
if (load_edge) { if (load_edge) {
...@@ -1066,11 +1066,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, ...@@ -1066,11 +1066,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int32_t GraphTable::get_server_index_by_id(int64_t id) { int32_t GraphTable::get_server_index_by_id(int64_t id) {
return id % shard_num / shard_num_per_server; return id % shard_num / shard_num_per_server;
} }
int32_t GraphTable::initialize(const TableParameter &config, int32_t GraphTable::Initialize(const TableParameter &config,
const FsClientParameter &fs_config) { const FsClientParameter &fs_config) {
LOG(INFO) << "in graphTable initialize"; LOG(INFO) << "in graphTable initialize";
_config = config; _config = config;
if (initialize_accessor() != 0) { if (InitializeAccessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed"; LOG(WARNING) << "Table accessor initialize failed";
return -1; return -1;
} }
...@@ -1082,9 +1082,9 @@ int32_t GraphTable::initialize(const TableParameter &config, ...@@ -1082,9 +1082,9 @@ int32_t GraphTable::initialize(const TableParameter &config,
auto graph = config.graph_parameter(); auto graph = config.graph_parameter();
shard_num = _config.shard_num(); shard_num = _config.shard_num();
LOG(INFO) << "in graphTable initialize over"; LOG(INFO) << "in graphTable initialize over";
return initialize(graph); return Initialize(graph);
} }
int32_t GraphTable::initialize(const GraphParameter &graph) { int32_t GraphTable::Initialize(const GraphParameter &graph) {
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
if (graph.gpups_mode()) { if (graph.gpups_mode()) {
gpups_mode = true; gpups_mode = true;
......
...@@ -280,7 +280,7 @@ class ScaledLRU { ...@@ -280,7 +280,7 @@ class ScaledLRU {
} }
} }
auto status = auto status =
thread_pool->enqueue([this]() -> int { return shrink(); }); thread_pool->enqueue([this]() -> int { return Shrink(); });
status.wait(); status.wait();
} }
}); });
...@@ -298,7 +298,7 @@ class ScaledLRU { ...@@ -298,7 +298,7 @@ class ScaledLRU {
LRUResponse insert(size_t index, K *keys, V *data, size_t length) { LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
return lru_pool[index].insert(keys, data, length); return lru_pool[index].insert(keys, data, length);
} }
int shrink() { int Shrink() {
int node_size = 0; int node_size = 0;
for (size_t i = 0; i < lru_pool.size(); i++) { for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count; node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
...@@ -329,7 +329,7 @@ class ScaledLRU { ...@@ -329,7 +329,7 @@ class ScaledLRU {
if (diff != 0) { if (diff != 0) {
__sync_fetch_and_add(&global_count, diff); __sync_fetch_and_add(&global_count, diff);
if (global_count > int(1.25 * size_limit)) { if (global_count > int(1.25 * size_limit)) {
thread_pool->enqueue([this]() -> int { return shrink(); }); thread_pool->enqueue([this]() -> int { return Shrink(); });
} }
} }
} }
...@@ -430,11 +430,11 @@ class GraphTable : public SparseTable { ...@@ -430,11 +430,11 @@ class GraphTable : public SparseTable {
virtual int32_t get_nodes_ids_by_ranges( virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res); std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res);
virtual int32_t initialize() { return 0; } virtual int32_t Initialize() { return 0; }
virtual int32_t initialize(const TableParameter &config, virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config); const FsClientParameter &fs_config);
virtual int32_t initialize(const GraphParameter &config); virtual int32_t Initialize(const GraphParameter &config);
int32_t load(const std::string &path, const std::string &param); int32_t Load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path); int32_t load_graph_split_config(const std::string &path);
int32_t load_edges(const std::string &path, bool reverse); int32_t load_edges(const std::string &path, bool reverse);
...@@ -452,26 +452,25 @@ class GraphTable : public SparseTable { ...@@ -452,26 +452,25 @@ class GraphTable : public SparseTable {
virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t pull_sparse(float *values, virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) {
const PullSparseValue &pull_value) {
return 0; return 0;
} }
virtual int32_t push_sparse(const uint64_t *keys, const float *values, virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) { size_t num) {
return 0; return 0;
} }
virtual int32_t clear_nodes(); virtual int32_t clear_nodes();
virtual void clear() {} virtual void Clear() {}
virtual int32_t flush() { return 0; } virtual int32_t Flush() { return 0; }
virtual int32_t shrink(const std::string &param) { return 0; } virtual int32_t Shrink(const std::string &param) { return 0; }
//指定保存路径 //指定保存路径
virtual int32_t save(const std::string &path, const std::string &converter) { virtual int32_t Save(const std::string &path, const std::string &converter) {
return 0; return 0;
} }
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t set_shard(size_t shard_idx, size_t server_num) { virtual int32_t SetShard(size_t shard_idx, size_t server_num) {
_shard_idx = shard_idx; _shard_idx = shard_idx;
/* /*
_shard_num is not used in graph_table, this following operation is for the _shard_num is not used in graph_table, this following operation is for the
......
...@@ -167,7 +167,7 @@ int64_t CommonSparseTable::LoadFromText( ...@@ -167,7 +167,7 @@ int64_t CommonSparseTable::LoadFromText(
return 0; return 0;
} }
int32_t CommonSparseTable::initialize() { int32_t CommonSparseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_); _shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
...@@ -200,15 +200,15 @@ int32_t CommonSparseTable::initialize() { ...@@ -200,15 +200,15 @@ int32_t CommonSparseTable::initialize() {
offset += dim; offset += dim;
} }
initialize_value(); InitializeValue();
initialize_optimizer(); InitializeOptimizer();
initialize_recorder(); InitializeRecorder();
return 0; return 0;
} }
int32_t CommonSparseTable::initialize_recorder() { return 0; } int32_t CommonSparseTable::InitializeRecorder() { return 0; }
int32_t CommonSparseTable::initialize_value() { int32_t CommonSparseTable::InitializeValue() {
auto common = _config.common(); auto common = _config.common();
shard_values_.reserve(task_pool_size_); shard_values_.reserve(task_pool_size_);
...@@ -223,18 +223,18 @@ int32_t CommonSparseTable::initialize_value() { ...@@ -223,18 +223,18 @@ int32_t CommonSparseTable::initialize_value() {
return 0; return 0;
} }
int32_t CommonSparseTable::initialize_optimizer() { int32_t CommonSparseTable::InitializeOptimizer() {
auto common = _config.common(); auto common = _config.common();
auto name = common.name(); auto name = common.name();
if (name == "sgd") { if (name == "sgd") {
optimizer_ = std::make_shared<SSGD>(value_names_, value_dims_, optimizer_ = std::make_shared<SSGD>(value_names_, value_dims_,
value_offsets_, value_idx_); value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam") { } else if (name == "adam") {
optimizer_ = std::make_shared<SAdam>(value_names_, value_dims_, optimizer_ = std::make_shared<SAdam>(value_names_, value_dims_,
value_offsets_, value_idx_); value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
} else if (name == "sum") { } else if (name == "sum") {
optimizer_ = std::make_shared<SSUM>(value_names_, value_dims_, optimizer_ = std::make_shared<SSUM>(value_names_, value_dims_,
value_offsets_, value_idx_); value_offsets_, value_idx_);
...@@ -246,13 +246,13 @@ int32_t CommonSparseTable::initialize_optimizer() { ...@@ -246,13 +246,13 @@ int32_t CommonSparseTable::initialize_optimizer() {
return 0; return 0;
} }
int32_t CommonSparseTable::set_global_lr(float* lr) { int32_t CommonSparseTable::SetGlobalLR(float* lr) {
_global_lr = lr; _global_lr = lr;
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
return 0; return 0;
} }
int32_t CommonSparseTable::load(const std::string& dirname, int32_t CommonSparseTable::Load(const std::string& dirname,
const std::string& param) { const std::string& param) {
auto begin = GetCurrentUS(); auto begin = GetCurrentUS();
rwlock_->WRLock(); rwlock_->WRLock();
...@@ -276,7 +276,7 @@ int32_t CommonSparseTable::load(const std::string& dirname, ...@@ -276,7 +276,7 @@ int32_t CommonSparseTable::load(const std::string& dirname,
return 0; return 0;
} }
int32_t CommonSparseTable::save(const std::string& dirname, int32_t CommonSparseTable::Save(const std::string& dirname,
const std::string& param) { const std::string& param) {
auto begin = GetCurrentUS(); auto begin = GetCurrentUS();
rwlock_->WRLock(); rwlock_->WRLock();
...@@ -322,7 +322,7 @@ int32_t CommonSparseTable::save(const std::string& dirname, ...@@ -322,7 +322,7 @@ int32_t CommonSparseTable::save(const std::string& dirname,
return 0; return 0;
} }
std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() { std::pair<int64_t, int64_t> CommonSparseTable::PrintTableStat() {
int64_t feasign_size = 0; int64_t feasign_size = 0;
int64_t mf_size = 0; int64_t mf_size = 0;
...@@ -335,7 +335,7 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() { ...@@ -335,7 +335,7 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
return {feasign_size, mf_size}; return {feasign_size, mf_size};
} }
int32_t CommonSparseTable::pour() { int32_t CommonSparseTable::Pour() {
std::vector<float> values; std::vector<float> values;
std::vector<uint64_t> keys; std::vector<uint64_t> keys;
...@@ -349,7 +349,7 @@ int32_t CommonSparseTable::pour() { ...@@ -349,7 +349,7 @@ int32_t CommonSparseTable::pour() {
std::copy(reservoir.values.begin(), reservoir.values.end(), std::copy(reservoir.values.begin(), reservoir.values.end(),
std::back_inserter(values)); std::back_inserter(values));
} }
_push_sparse(keys.data(), values.data(), pull_reservoir_.size()); _PushSparse(keys.data(), values.data(), pull_reservoir_.size());
pull_reservoir_.clear(); pull_reservoir_.clear();
return 0; return 0;
...@@ -360,11 +360,11 @@ int32_t CommonSparseTable::Pull(TableContext& context) { ...@@ -360,11 +360,11 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
if (context.use_ptr) { if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values; char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys; const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num); return PullSparsePtr(pull_values, keys, context.num);
} else { } else {
float* pull_values = context.pull_context.values; float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value; const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value); return PullSparse(pull_values, pull_value);
} }
} }
...@@ -373,16 +373,16 @@ int32_t CommonSparseTable::Push(TableContext& context) { ...@@ -373,16 +373,16 @@ int32_t CommonSparseTable::Push(TableContext& context) {
if (context.push_context.values != nullptr) { if (context.push_context.values != nullptr) {
const float* values = context.push_context.values; const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys; const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num); return PushSparse(keys, values, context.num);
} else { } else {
const float** values = context.push_context.ptr_values; const float** values = context.push_context.ptr_values;
const uint64_t* keys = context.push_context.keys; const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num); return PushSparse(keys, values, context.num);
} }
} }
int32_t CommonSparseTable::pull_sparse(float* pull_values, int32_t CommonSparseTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
...@@ -421,8 +421,8 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, ...@@ -421,8 +421,8 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values,
return 0; return 0;
} }
int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, int32_t CommonSparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys, size_t num) { const uint64_t* keys, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -458,8 +458,8 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, ...@@ -458,8 +458,8 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
return 0; return 0;
} }
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -474,7 +474,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -474,7 +474,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &values, num, &offset_bucket]() -> int { [this, shard_id, &keys, &values, num, &offset_bucket]() -> int {
auto& offsets = offset_bucket[shard_id]; auto& offsets = offset_bucket[shard_id];
optimizer_->update(keys, values, num, offsets, optimizer_->Update(keys, values, num, offsets,
shard_values_[shard_id].get()); shard_values_[shard_id].get());
return 0; return 0;
}); });
...@@ -486,8 +486,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -486,8 +486,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
return 0; return 0;
} }
int32_t CommonSparseTable::push_sparse(const uint64_t* keys, int32_t CommonSparseTable::PushSparse(const uint64_t* keys, const float* values,
const float* values, size_t num) { size_t num) {
if (sync) { if (sync) {
std::future<int> task = std::future<int> task =
_shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int { _shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int {
...@@ -506,20 +506,20 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys, ...@@ -506,20 +506,20 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
}); });
task.wait(); task.wait();
} else { } else {
_push_sparse(keys, values, num); _PushSparse(keys, values, num);
} }
return 0; return 0;
} }
int32_t CommonSparseTable::push_sparse(const uint64_t* keys, int32_t CommonSparseTable::PushSparse(const uint64_t* keys,
const float** values, size_t num) { const float** values, size_t num) {
_push_sparse(keys, values, num); _PushSparse(keys, values, num);
return 0; return 0;
} }
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
const float** values, size_t num) { const float** values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -536,7 +536,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -536,7 +536,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
auto& offsets = offset_bucket[shard_id]; auto& offsets = offset_bucket[shard_id];
for (size_t i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
std::vector<uint64_t> tmp_off = {0}; std::vector<uint64_t> tmp_off = {0};
optimizer_->update(keys + offsets[i], values[offsets[i]], num, optimizer_->Update(keys + offsets[i], values[offsets[i]], num,
tmp_off, shard_values_[shard_id].get()); tmp_off, shard_values_[shard_id].get());
} }
return 0; return 0;
...@@ -549,8 +549,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -549,8 +549,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
return 0; return 0;
} }
int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, int32_t CommonSparseTable::PushSparseParam(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -585,21 +585,21 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, ...@@ -585,21 +585,21 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
return 0; return 0;
} }
int32_t CommonSparseTable::flush() { return 0; } int32_t CommonSparseTable::Flush() { return 0; }
int32_t CommonSparseTable::shrink(const std::string& param) { int32_t CommonSparseTable::Shrink(const std::string& param) {
int threshold = std::stoi(param); int threshold = std::stoi(param);
VLOG(3) << "sparse table shrink: " << threshold; VLOG(3) << "sparse table Shrink: " << threshold;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// shrink // Shrink
VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink"; VLOG(4) << shard_id << " " << task_pool_size_ << " begin Shrink";
shard_values_[shard_id]->Shrink(threshold); shard_values_[shard_id]->Shrink(threshold);
} }
return 0; return 0;
} }
void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; } void CommonSparseTable::Clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -114,25 +114,23 @@ class CommonSparseTable : public SparseTable { ...@@ -114,25 +114,23 @@ class CommonSparseTable : public SparseTable {
virtual ~CommonSparseTable() {} virtual ~CommonSparseTable() {}
// unused method begin // unused method begin
virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t push_dense_param(const float* values, size_t num) { virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; }
return 0; virtual int32_t PushDense(const float* values, size_t num) { return 0; }
}
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context); virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context); virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t Initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t InitializeValue();
virtual int32_t initialize_optimizer(); virtual int32_t InitializeOptimizer();
virtual int32_t initialize_recorder(); virtual int32_t InitializeRecorder();
virtual int32_t load(const std::string& path, const std::string& param); virtual int32_t Load(const std::string& path, const std::string& param);
virtual int32_t save(const std::string& path, const std::string& param); virtual int32_t Save(const std::string& path, const std::string& param);
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total); const size_t shard_idx, const int64_t total);
...@@ -150,34 +148,34 @@ class CommonSparseTable : public SparseTable { ...@@ -150,34 +148,34 @@ class CommonSparseTable : public SparseTable {
const int pserver_id, const int pserver_num, const int local_shard_num, const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks); std::vector<std::shared_ptr<ValueBlock>>* blocks);
virtual std::pair<int64_t, int64_t> print_table_stat(); virtual std::pair<int64_t, int64_t> PrintTableStat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num); size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float* values, virtual int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num); size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float** values, virtual int32_t PushSparse(const uint64_t* keys, const float** values,
size_t num); size_t num);
// only for sparse geo table // only for sparse geo table
virtual int32_t push_sparse_param(const uint64_t* keys, const float* values, virtual int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num); size_t num);
virtual int32_t set_global_lr(float* lr) override; virtual int32_t SetGlobalLR(float* lr) override;
virtual int32_t pour(); virtual int32_t Pour();
virtual int32_t flush(); virtual int32_t Flush();
virtual int32_t shrink(const std::string& param); virtual int32_t Shrink(const std::string& param);
virtual void clear(); virtual void Clear();
protected: protected:
virtual int32_t _push_sparse(const uint64_t* keys, const float* values, virtual int32_t _PushSparse(const uint64_t* keys, const float* values,
size_t num); size_t num);
virtual int32_t _push_sparse(const uint64_t* keys, const float** values, virtual int32_t _PushSparse(const uint64_t* keys, const float** values,
size_t num); size_t num);
protected: protected:
const int task_pool_size_ = 11; const int task_pool_size_ = 11;
......
...@@ -71,11 +71,11 @@ class SparseTable : public Table { ...@@ -71,11 +71,11 @@ class SparseTable : public Table {
SparseTable() {} SparseTable() {}
virtual ~SparseTable() {} virtual ~SparseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t PushDense(const float *values, size_t num) override { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num, static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) { uint32_t server_num) {
...@@ -97,19 +97,17 @@ class DenseTable : public Table { ...@@ -97,19 +97,17 @@ class DenseTable : public Table {
DenseTable() {} DenseTable() {}
virtual ~DenseTable() {} virtual ~DenseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t pull_sparse(float *values, int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override { const PullSparseValue &pull_value) override {
return 0; return 0;
} }
int32_t push_sparse(const uint64_t *keys, const float *values, int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override { size_t num) override {
return 0; return 0;
} }
int32_t push_dense_param(const float *values, size_t num) override { int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
return 0; int32_t Shrink(const std::string &param) override { return 0; }
}
int32_t shrink(const std::string &param) override { return 0; }
}; };
class BarrierTable : public Table { class BarrierTable : public Table {
...@@ -117,44 +115,42 @@ class BarrierTable : public Table { ...@@ -117,44 +115,42 @@ class BarrierTable : public Table {
BarrierTable() {} BarrierTable() {}
virtual ~BarrierTable() {} virtual ~BarrierTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values, int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override { const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0; return 0;
} }
int32_t push_dense_param(const float *values, size_t num) override { int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0; return 0;
} }
int32_t shrink(const std::string &param) override { return 0; } int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
virtual void clear() {} int32_t Shrink(const std::string &param) override { return 0; }
virtual int32_t flush() { return 0; } virtual void Clear() {}
virtual int32_t load(const std::string &path, const std::string &param) { virtual int32_t Flush() { return 0; }
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t save(const std::string &path, const std::string &param) { virtual int32_t Save(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t initialize() override; virtual int32_t Initialize() override;
// only for barrier // only for barrier
// 0: send_barrier 1: recv_barrier 2: complete // 0: send_barrier 1: recv_barrier 2: complete
virtual int32_t barrier(const uint32_t trainer_id, virtual int32_t Barrier(const uint32_t trainer_id,
const std::string barrier_type) override; const std::string barrier_type) override;
virtual int32_t set_table_map( virtual int32_t SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override; std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;
private: private:
......
...@@ -34,9 +34,9 @@ class DenseOptimizer { ...@@ -34,9 +34,9 @@ class DenseOptimizer {
DenseOptimizer() {} DenseOptimizer() {}
explicit DenseOptimizer(const CommonAccessorParameter& accessor, explicit DenseOptimizer(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {} std::vector<std::vector<float>>* values) {}
virtual void update(const float* update_values, size_t num, int begin, virtual void Update(const float* update_values, size_t num, int begin,
int end) = 0; int end) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; } virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
protected: protected:
float* global_learning_rate_; float* global_learning_rate_;
...@@ -55,7 +55,7 @@ class DSUM : public DenseOptimizer { ...@@ -55,7 +55,7 @@ class DSUM : public DenseOptimizer {
} }
} }
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
GetBlas<float>().VADD(update_numel, update_values + begin, param + begin, GetBlas<float>().VADD(update_numel, update_values + begin, param + begin,
...@@ -81,7 +81,7 @@ class DSGD : public DenseOptimizer { ...@@ -81,7 +81,7 @@ class DSGD : public DenseOptimizer {
} }
} }
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
std::vector<float> grads; std::vector<float> grads;
...@@ -134,7 +134,7 @@ class DAdam : public DenseOptimizer { ...@@ -134,7 +134,7 @@ class DAdam : public DenseOptimizer {
// make sure common_dense_table.task_pool_size_ == 1; // make sure common_dense_table.task_pool_size_ == 1;
// otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication // otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
std::vector<float> grad, grad2, tmp; std::vector<float> grad, grad2, tmp;
...@@ -214,7 +214,7 @@ class DAdamD2Sum : public DenseOptimizer { ...@@ -214,7 +214,7 @@ class DAdamD2Sum : public DenseOptimizer {
} }
} }
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_ada_g2sum(ada_g2sum + begin, 1, Eigen::Map<Eigen::MatrixXf> mat_ada_g2sum(ada_g2sum + begin, 1,
...@@ -276,7 +276,7 @@ class DSummary : public DenseOptimizer { ...@@ -276,7 +276,7 @@ class DSummary : public DenseOptimizer {
} }
} }
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel); Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel);
......
...@@ -40,11 +40,11 @@ class SparseOptimizer { ...@@ -40,11 +40,11 @@ class SparseOptimizer {
value_offsets_(value_offsets), value_offsets_(value_offsets),
value_idx_(value_idx) {} value_idx_(value_idx) {}
virtual void update(const uint64_t* keys, const float* update_values, virtual void Update(const uint64_t* keys, const float* update_values,
size_t num, const std::vector<uint64_t>& offsets, size_t num, const std::vector<uint64_t>& offsets,
ValueBlock* block) = 0; ValueBlock* block) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; } virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
const std::vector<std::string>& value_names_; const std::vector<std::string>& value_names_;
const std::vector<int>& value_dims_; const std::vector<int>& value_dims_;
...@@ -70,7 +70,7 @@ class SSUM : public SparseOptimizer { ...@@ -70,7 +70,7 @@ class SSUM : public SparseOptimizer {
update_numel = value_dims.at(idx); update_numel = value_dims.at(idx);
} }
void update(const uint64_t* keys, const float* update_values, size_t num, void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets, const std::vector<uint64_t>& offsets,
ValueBlock* block) override { ValueBlock* block) override {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
...@@ -100,7 +100,7 @@ class SSGD : public SparseOptimizer { ...@@ -100,7 +100,7 @@ class SSGD : public SparseOptimizer {
lr_offset = value_offsets.at(idx); lr_offset = value_offsets.at(idx);
} }
void update(const uint64_t* keys, const float* update_values, size_t num, void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets, const std::vector<uint64_t>& offsets,
ValueBlock* block) override { ValueBlock* block) override {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
...@@ -156,7 +156,7 @@ class SAdam : public SparseOptimizer { ...@@ -156,7 +156,7 @@ class SAdam : public SparseOptimizer {
epsilon = 1.0e-8; epsilon = 1.0e-8;
} }
void update(const uint64_t* keys, const float* update_values, size_t num, void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets, const std::vector<uint64_t>& offsets,
ValueBlock* block) override { ValueBlock* block) override {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
......
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys,
const float* values, const float* values, size_t num) {
size_t num) { VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin "
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param begin " "PushSparseParam "
"push_sparse_param "
<< num; << num;
auto shard_num = _task_pool_size; auto shard_num = _task_pool_size;
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
...@@ -31,8 +30,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, ...@@ -31,8 +30,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
auto y = keys[x] % shard_num; auto y = keys[x] % shard_num;
offset_bucket[y].push_back(x); offset_bucket[y].push_back(x);
if (x < 10) { if (x < 10) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param key: " VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam key: " << keys[x]
<< keys[x] << " shard: " << y; << " shard: " << y;
} }
} }
...@@ -51,8 +50,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, ...@@ -51,8 +50,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
feature_value.resize(_dim); feature_value.resize(_dim);
std::copy_n(values + _dim * offset, _dim, feature_value.data()); std::copy_n(values + _dim * offset, _dim, feature_value.data());
if (i < 10) { if (i < 10) {
VLOG(5) << "MemorySparseGeoTable::push_sparse_param " VLOG(5) << "MemorySparseGeoTable::PushSparseParam "
"push_sparse_param key " "PushSparseParam key "
<< id << " value[0]: " << (values + _dim * offset)[0] << id << " value[0]: " << (values + _dim * offset)[0]
<< " data: " << feature_value.data()[0] << " data: " << feature_value.data()[0]
<< " value[-1]: " << (values + _dim * offset)[_dim - 1] << " value[-1]: " << (values + _dim * offset)[_dim - 1]
...@@ -69,9 +68,9 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, ...@@ -69,9 +68,9 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
return 0; return 0;
} }
int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, int32_t MemorySparseGeoTable::PullGeoParam(const uint32_t trainer_id,
std::vector<float>* values, std::vector<float>* values,
std::vector<uint64_t>* ids) { std::vector<uint64_t>* ids) {
_geo_recorder->GetAndClear(trainer_id, ids); _geo_recorder->GetAndClear(trainer_id, ids);
VLOG(5) VLOG(5)
<< "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id " << "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id "
...@@ -86,23 +85,23 @@ int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, ...@@ -86,23 +85,23 @@ int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id,
pull_value.frequencies_ = frequencies.data(); pull_value.frequencies_ = frequencies.data();
values->resize(ids->size() * _dim); values->resize(ids->size() * _dim);
pull_sparse(values->data(), pull_value); PullSparse(values->data(), pull_value);
return 0; return 0;
} }
int32_t MemorySparseGeoTable::push_sparse(const uint64_t* keys, int32_t MemorySparseGeoTable::PushSparse(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse keys[0]" << keys[0] VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparse keys[0]" << keys[0]
<< " key_num: " << num; << " key_num: " << num;
std::vector<uint64_t> ids; std::vector<uint64_t> ids;
ids.resize(num); ids.resize(num);
std::copy_n(keys, num, ids.begin()); std::copy_n(keys, num, ids.begin());
_geo_recorder->Update(ids); _geo_recorder->Update(ids);
_push_sparse(keys, values, num); _PushSparse(keys, values, num);
return 0; return 0;
} }
int32_t MemorySparseGeoTable::initialize() { int32_t MemorySparseGeoTable::Initialize() {
if (!_geo_recorder) { if (!_geo_recorder) {
auto trainers = _config.common().trainer_num(); auto trainers = _config.common().trainer_num();
_geo_recorder = std::make_shared<GeoRecorder>(trainers); _geo_recorder = std::make_shared<GeoRecorder>(trainers);
...@@ -118,8 +117,8 @@ int32_t MemorySparseGeoTable::initialize() { ...@@ -118,8 +117,8 @@ int32_t MemorySparseGeoTable::initialize() {
return 0; return 0;
} }
int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, int32_t MemorySparseGeoTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = _task_pool_size; auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
...@@ -146,13 +145,13 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, ...@@ -146,13 +145,13 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
auto& feature_value = local_shard[key]; auto& feature_value = local_shard[key];
feature_value.resize(_dim); feature_value.resize(_dim);
memset(feature_value.data(), 0, sizeof(float) * _dim); memset(feature_value.data(), 0, sizeof(float) * _dim);
VLOG(0) << "MemorySparseGeoTable pull_sparse key not found!!! " VLOG(0) << "MemorySparseGeoTable PullSparse key not found!!! "
<< key; << key;
itr = local_shard.find(key); itr = local_shard.find(key);
} }
memcpy(select_data, itr.value().data(), _dim * sizeof(float)); memcpy(select_data, itr.value().data(), _dim * sizeof(float));
VLOG(5) << "DEBUG MemorySparseGeoTable::pull_sparse key: " << key VLOG(5) << "DEBUG MemorySparseGeoTable::PullSparse key: " << key
<< " select_data[0] " << select_data[0] << " select_data[0] " << select_data[0]
<< " value[0]: " << itr.value().data()[0]; << " value[0]: " << itr.value().data()[0];
} }
...@@ -167,8 +166,8 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, ...@@ -167,8 +166,8 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
return 0; return 0;
} }
int32_t MemorySparseGeoTable::_push_sparse(const uint64_t* keys, int32_t MemorySparseGeoTable::_PushSparse(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
auto shard_num = _task_pool_size; auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(shard_num); std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(shard_num);
......
...@@ -40,31 +40,31 @@ class MemorySparseGeoTable : public SparseTable { ...@@ -40,31 +40,31 @@ class MemorySparseGeoTable : public SparseTable {
MemorySparseGeoTable() { _geo_recorder = nullptr; } MemorySparseGeoTable() { _geo_recorder = nullptr; }
virtual ~MemorySparseGeoTable() {} virtual ~MemorySparseGeoTable() {}
virtual int32_t initialize(); virtual int32_t Initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t load(const std::string& path, const std::string& param) { virtual int32_t Load(const std::string& path, const std::string& param) {
return 0; return 0;
} }
virtual int32_t save(const std::string& path, const std::string& param) { virtual int32_t Save(const std::string& path, const std::string& param) {
return 0; return 0;
} }
virtual int32_t Pull(TableContext& context) { return 0; } virtual int32_t Pull(TableContext& context) { return 0; }
virtual int32_t Push(TableContext& context) { return 0; } virtual int32_t Push(TableContext& context) { return 0; }
virtual int32_t flush() { return 0; } virtual int32_t Flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; } virtual int32_t Shrink(const std::string& param) { return 0; }
virtual void clear() { return; } virtual void Clear() { return; }
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
int32_t push_sparse_param(const uint64_t* keys, const float* values, int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num); size_t num);
// TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse // TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values, int32_t PullGeoParam(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys); std::vector<uint64_t>* keys);
int32_t push_sparse(const uint64_t* keys, const float* values, int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num) override; size_t num) override;
int32_t _push_sparse(const uint64_t* keys, const float* values, size_t num); int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num);
// int32_t _pull_sparse(float* pull_values, const PullSparseValue& // int32_t _pull_sparse(float* pull_values, const PullSparseValue&
// pull_value); // pull_value);
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册