未验证 提交 b3270adf 编写于 作者: Z zhaocaibei123 提交者: GitHub

统一ps refine (#41234)

* update name

* update name

* fix test

* fix fleet bind

* update name

* update name

* fix test

* fix gpups wrapper

* remove Push/Pull/Load/Save with context in client and wrapper base class

* fix

* fix
Co-authored-by: Nesythan <esythan@126.com>
上级 cb124156
......@@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService {
DownpourPsClientService() {}
virtual ~DownpourPsClientService() {}
virtual int32_t configure(PSClient *client, size_t rank_id) {
virtual int32_t Configure(PSClient *client, size_t rank_id) {
_client = client;
_rank = rank_id;
return 0;
......@@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient {
BrpcPsClient() {}
virtual ~BrpcPsClient() {
if (_running) {
flush();
Flush();
_running = false;
}
if (_async_push_dense_thread.joinable()) {
......@@ -154,109 +154,98 @@ class BrpcPsClient : public PSClient {
_server_started = false;
}
}
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
std::future<int32_t> shrink(uint32_t table_id,
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
std::future<int32_t> load(const std::string &epoch,
std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Load(const LoadSaveContext &load_context) override;
std::future<int32_t> save(const std::string &epoch,
std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;
std::future<int32_t> clear() override;
std::future<int32_t> clear(uint32_t table_id) override;
std::future<int32_t> Clear() override;
std::future<int32_t> stop_server() override;
std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> start_profiler() override;
std::future<int32_t> stop_profiler() override;
std::future<int32_t> StopServer() override;
void finalize_worker() override;
std::future<int32_t> StartProfiler() override;
std::future<int32_t> StopProfiler() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id);
void FinalizeWorker() override;
virtual std::future<int32_t> push_dense_param(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t table_id);
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num, size_t table_id);
void push_dense_task_consume();
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> Pull(RequestContext &pull_context) override;
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num, size_t table_id);
void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> Push(RequestContext &push_context) override;
virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> Flush();
virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> flush();
std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
const std::string &msg) override;
std::future<int32_t> SendClient2ClientMsg(int msg_type, int to_client_id,
const std::string &msg) override;
// for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);
void print_queue_size();
void print_queue_size_thread();
void PrintQueueSize();
void PrintQueueSizeThread();
protected:
virtual size_t get_server_nums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) {
virtual size_t GetServerNums() { return _server_channels.size(); }
inline brpc::Channel *GetSparseChannel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *get_dense_channel(size_t server_id) {
inline brpc::Channel *GetDenseChannel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *get_cmd_channel(size_t server_id) {
inline brpc::Channel *GetCmdChannel(size_t server_id) {
return _server_channels[server_id][2].get();
}
int32_t initialize() override;
int32_t Initialize() override;
private:
// virtual int32_t initialize() override;
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> SendSaveCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
bool _running = false;
bool _flushing = false;
......@@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient {
std::thread _print_thread;
int push_sparse_async_shard_merge(
int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
ValueAccessor *accessor);
int push_sparse_async_shard_push(
int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor);
......@@ -292,36 +281,36 @@ class BrpcPsClient : public PSClient {
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
std::future<int32_t> push_dense_raw_gradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;
std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override;
std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num, void *done) override;
std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) override;
void push_sparse_task_consume();
std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;
std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) override;
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num, void *done,
int pserver_idx) override;
std::future<int32_t> PushSparseParam(size_t table_id, const uint64_t *keys,
const float **update_values, size_t num,
void *done) override;
std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) override;
void PushSparseTaskConsume();
private:
int32_t start_client_service();
int32_t StartClientService();
void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure);
void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data, size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
......
......@@ -41,8 +41,8 @@ class BrpcPsServer : public PSServer {
public:
BrpcPsServer() {}
virtual ~BrpcPsServer() {}
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual int32_t stop() {
virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true;
cv_.notify_all();
......@@ -51,10 +51,10 @@ class BrpcPsServer : public PSServer {
_server.Join();
return 0;
}
int32_t port();
int32_t Port();
private:
virtual int32_t initialize();
virtual int32_t Initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
......@@ -71,7 +71,7 @@ typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
class BrpcPsService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
......@@ -79,50 +79,49 @@ class BrpcPsService : public PsBaseService {
::google::protobuf::Closure *done) override;
private:
int32_t initialize_shard_info();
int32_t pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
int32_t InitializeShardInfo();
int32_t PullDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PushDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PushDenseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
int32_t PushSparseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullGeoParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t shrink_table(Table *table, const PsRequestMessage &request,
int32_t PushSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request,
int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t SaveOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t SaveAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t ShrinkTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
int32_t ClearOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t ClearAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_global_step(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
......
......@@ -39,7 +39,7 @@ inline double GetCurrentUS() {
Communicator::Communicator() {}
void Communicator::init_gflag(const std::string &gflags) {
void Communicator::InitGFlag(const std::string &gflags) {
VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
......@@ -73,7 +73,7 @@ void Communicator::InitBrpcClient(
}
std::vector<uint64_t> Communicator::GetClientInfo() {
std::vector<uint64_t> res = _ps_env.get_client_info();
std::vector<uint64_t> res = _ps_env.GetClientInfo();
for (auto rr : res) {
VLOG(2) << "Communicator::GetClientInfo " << rr;
}
......@@ -82,7 +82,7 @@ std::vector<uint64_t> Communicator::GetClientInfo() {
int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
int node = host_sign_list.size();
return _ps_env.set_ps_clients(host_sign_list.data(), node);
return _ps_env.SetPsClients(host_sign_list.data(), node);
}
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
......@@ -114,7 +114,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
}
}
auto status =
_worker_ptr->pull_dense(regions.data(), regions.size(), table_id);
_worker_ptr->PullDense(regions.data(), regions.size(), table_id);
status.wait();
for (auto &t : varnames) {
......@@ -177,7 +177,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
}
}
auto status =
_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id);
_worker_ptr->PushDenseParam(regions.data(), regions.size(), table_id);
status.wait();
VLOG(4) << "RPC Send Dense Param " << table_id << " done!";
return;
......@@ -190,9 +190,9 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
auto &var_names = ctx.origin_varnames;
auto &table_id = ctx.table_id;
auto dense_data = std::make_shared<std::vector<float>>();
size_t request_call_num = _worker_ptr->get_server_nums();
size_t request_call_num = _worker_ptr->GetServerNums();
uint32_t num_per_shard =
dense_dim_per_shard(ctx.height_sections[0], request_call_num);
DenseDimPerShard(ctx.height_sections[0], request_call_num);
dense_data->resize(num_per_shard *
request_call_num); // accessor->update_dim() = 1
float *data = dense_data->data();
......@@ -222,8 +222,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_dense_raw_gradient(
table_id, data, dense_data->size(), closure);
auto status = _worker_ptr->PushDenseRawGradient(table_id, data,
dense_data->size(), closure);
status.wait();
return;
}
......@@ -233,7 +233,7 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
platform::RecordEvent record_event("Communicator->RpcSendSparseParam",
platform::TracerEventType::Communication,
1);
size_t request_call_num = _worker_ptr->get_server_nums();
size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(varname);
......@@ -260,9 +260,9 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->push_sparse_param(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
auto status = _worker_ptr->PushSparseParam(table_id, sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
status.wait();
return;
}
......@@ -272,7 +272,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
platform::RecordEvent record_event("Communicator->RpcSendSparse",
platform::TracerEventType::Communication,
1);
size_t request_call_num = _worker_ptr->get_server_nums();
size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<uint64_t> sparse_push_keys;
std::vector<float *> push_g_vec;
......@@ -313,7 +313,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_sparse_raw_gradient(
auto status = _worker_ptr->PushSparseRawGradient(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
status.wait();
......@@ -340,7 +340,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
bool training = true;
auto status = _worker_ptr->pull_sparse_param(
auto status = _worker_ptr->PullSparseParam(
(float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait();
......@@ -376,11 +376,11 @@ void Communicator::RpcProfilerControl() {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = _worker_ptr->start_profiler();
auto start_status = _worker_ptr->StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = _worker_ptr->stop_profiler();
auto stop_status = _worker_ptr->StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
......@@ -396,7 +396,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
platform::TracerEventType::Communication,
1);
auto &table_id = ctx.table_id;
size_t request_call_num = _worker_ptr->get_server_nums();
size_t request_call_num = _worker_ptr->GetServerNums();
auto &var_name = STEP_COUNTER;
auto *out_var = send_scope->Var(var_name);
......@@ -416,7 +416,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->push_global_step(table_id, data, closure);
auto status = _worker_ptr->PushGlobalStep(table_id, data, closure);
status.wait();
return;
}
......@@ -605,8 +605,8 @@ void AsyncCommunicator::PullSparseToTensorSync(
}
}
auto status =
_worker_ptr->pull_sparse(pull_result_ptr.data(), table_id,
fea_keys.data(), fea_keys.size(), is_training);
_worker_ptr->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.size(), is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
......@@ -738,9 +738,9 @@ void AsyncCommunicator::PushSparseFromTensorAsync(
this->Check(table_id), true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
auto status = _worker_ptr->push_sparse(table_id, push_keys.data(),
(const float **)push_g_vec.data(),
push_keys.size());
auto status = _worker_ptr->PushSparse(table_id, push_keys.data(),
(const float **)push_g_vec.data(),
push_keys.size());
}
void HalfAsyncCommunicator::MainThread() {
......@@ -813,7 +813,7 @@ void AsyncCommunicator::Stop() {
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
// _worker_ptr->finalize_worker();
// _worker_ptr->FinalizeWorker();
VLOG(1) << "client finalize_worker done";
if (recv_thread_) {
VLOG(1) << "stop recv thread";
......@@ -1327,7 +1327,7 @@ void GeoCommunicator::SendSparse(const std::string &varname,
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_sparse_raw_gradient_partial(
auto status = _worker_ptr->PushSparseRawGradientPartial(
table_id, (const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx);
status.wait();
......@@ -1345,7 +1345,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
// 1. recv from pserver
std::vector<uint64_t> keys;
std::vector<float> values;
auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx);
auto status = _worker_ptr->PullGeoParam(table_id, &values, &keys, ep_idx);
status.wait();
std::string param = SplitedGradToParam(varname);
......
......@@ -299,7 +299,7 @@ class Communicator {
virtual void Barrier() {}
virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type);
auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
rets.wait();
int status = rets.get();
PADDLE_ENFORCE_EQ(status, 0,
......@@ -310,7 +310,7 @@ class Communicator {
virtual void CreateC2CConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
_worker_ptr->create_client2client_connection(
_worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
}
......@@ -379,12 +379,12 @@ class Communicator {
std::unordered_map<std::string, std::string> envs;
// 计算每个shard 对 dense的存储量
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
void init_gflag(const std::string &gflags);
void InitGFlag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0;
......
......@@ -40,7 +40,7 @@ struct PSHost {
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
uint64_t serialize_to_uint64() {
uint64_t SerializeToUint64() {
uint64_t host_label = 0;
host_label = inet_addr(ip.c_str());
host_label = host_label << 32;
......@@ -49,7 +49,7 @@ struct PSHost {
return host_label;
}
void parse_from_uint64(uint64_t host_label) {
void ParseFromUint64(uint64_t host_label) {
static uint64_t rank_label_mask = (1L << 12) - 1;
static uint64_t port_label_mask = (1L << 20) - 1;
rank = host_label & rank_label_mask;
......@@ -58,17 +58,17 @@ struct PSHost {
ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT
}
std::string to_string() {
std::string ToString() {
std::stringstream s;
s << "host: " << ip;
s << " port: " << port;
s << " rank: " << rank;
s << " uint: " << serialize_to_uint64();
s << " uint: " << SerializeToUint64();
return s.str();
}
// for open source parameter server
std::string serialize_to_string() {
std::string SerializeToString() {
std::stringstream s;
s << ip << ":";
s << port << ":";
......@@ -76,16 +76,16 @@ struct PSHost {
return s.str();
}
void parse_from_string(std::string endpoint) {
void ParseFromString(std::string endpoint) {
std::vector<std::string> endpoint_info;
string_split(endpoint, ':', &endpoint_info);
StringSplit(endpoint, ':', &endpoint_info);
ip = endpoint_info[0];
port = std::stoi(endpoint_info[1]);
rank = std::stoi(endpoint_info[2]);
}
void string_split(const std::string &str, char sep,
std::vector<std::string> *pieces, bool ignore_null = true) {
void StringSplit(const std::string &str, char sep,
std::vector<std::string> *pieces, bool ignore_null = true) {
pieces->clear();
if (str.empty()) {
if (!ignore_null) {
......@@ -111,63 +111,60 @@ class PSEnvironment {
explicit PSEnvironment() {} // NOLINT
virtual ~PSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t set_ps_servers(
virtual int32_t SetPsServers(
const std::vector<std::string> *host_endpoint_list, int node_num) {
return 0;
}
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) {
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t set_ps_clients(std::string *host_endpoint_list,
int node_num) {
virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
return 0;
}
virtual uint64_t get_local_host_sign() { return 0; }
virtual std::vector<PSHost> get_ps_servers() const { return _ps_server_list; }
virtual int32_t registe_ps_server(const std::string &ip, uint32_t port,
int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_server_list,
_ps_server_sign_set);
virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t RegistePsServer(const std::string &ip, uint32_t port,
int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
}
virtual std::vector<PSHost> get_ps_clients() const { return _ps_client_list; }
virtual int32_t registe_ps_client(const std::string &ip, uint32_t port,
int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_client_list,
_ps_client_sign_set);
virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
virtual int32_t RegistePsClient(const std::string &ip, uint32_t port,
int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
}
virtual std::vector<uint64_t> get_client_info() {
virtual std::vector<uint64_t> GetClientInfo() {
std::vector<uint64_t> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_uint64());
client_info.push_back(i.SerializeToUint64());
}
return client_info;
}
virtual std::vector<std::string> get_client_info(bool use_string_endpoint) {
virtual std::vector<std::string> GetClientInfo(bool use_string_endpoint) {
if (use_string_endpoint) {
std::vector<std::string> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_string());
client_info.push_back(i.SerializeToString());
}
return client_info;
}
return {};
}
virtual void set_trainers(int trainers) { trainers_ = trainers; }
virtual void SetTrainers(int trainers) { trainers_ = trainers; }
virtual int get_trainers() { return trainers_; }
virtual int GetTrainers() { return trainers_; }
protected:
//注册一个host // NOLINT
virtual int32_t registe_ps_host(
virtual int32_t RegistePsHost(
const std::string &ip, uint32_t port, int32_t rank,
std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT
......@@ -198,15 +195,15 @@ class PaddlePSEnvironment : public PSEnvironment {
explicit PaddlePSEnvironment() {} // NOLINT
virtual ~PaddlePSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.parse_from_uint64(host_sign_list[i]);
host.ParseFromUint64(host_sign_list[i]);
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.serialize_to_uint64());
_ps_server_sign_set.insert(host.SerializeToUint64());
}
}
std::sort(
......@@ -215,14 +212,14 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual int32_t set_ps_servers(const std::vector<std::string> *host_sign_list,
int node_num) {
virtual int32_t SetPsServers(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.parse_from_string(host_sign_list->at(i));
host.ParseFromString(host_sign_list->at(i));
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.rank);
}
......@@ -233,15 +230,15 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) {
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.parse_from_uint64(host_sign_list[i]);
host.ParseFromUint64(host_sign_list[i]);
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.serialize_to_uint64());
_ps_client_sign_set.insert(host.SerializeToUint64());
}
}
std::sort(
......@@ -250,14 +247,14 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual int32_t set_ps_clients(const std::vector<std::string> *host_sign_list,
int node_num) {
virtual int32_t SetPsClients(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.parse_from_string(host_sign_list->at(i));
host.ParseFromString(host_sign_list->at(i));
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.rank);
}
......@@ -269,9 +266,9 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual uint64_t get_local_host_sign() {
virtual uint64_t GetLocalHostSign() {
if (_ps_client_list.size() > 0) {
return _ps_client_list[0].serialize_to_uint64();
return _ps_client_list[0].SerializeToUint64();
} else {
return 0;
}
......
......@@ -135,8 +135,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -169,8 +168,7 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
......@@ -238,9 +236,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
->add_params((char *)weighted,
sizeof(bool) * is_weighted_bucket[request_idx].size());
}
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -292,9 +289,8 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -362,9 +358,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0),
closure->response(0), closure);
......@@ -464,9 +459,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -506,8 +500,8 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
......@@ -541,8 +535,7 @@ std::future<int32_t> GraphBrpcClient::load_graph_split_config(
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params(path);
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
......@@ -581,8 +574,7 @@ std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
......@@ -624,8 +616,8 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
......@@ -717,8 +709,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size());
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -727,10 +718,10 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
return fut;
}
int32_t GraphBrpcClient::initialize() {
int32_t GraphBrpcClient::Initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::initialize();
server_size = get_server_nums();
BrpcPsClient::Initialize();
server_size = GetServerNums();
graph_service = NULL;
local_channel = NULL;
return 0;
......
......@@ -97,12 +97,12 @@ class GraphBrpcClient : public BrpcPsClient {
std::string path);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<int64_t>& node_id_list);
virtual int32_t initialize();
virtual int32_t Initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(int64_t id);
void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index);
this->local_channel = GetCmdChannel(index);
}
void set_local_graph_service(GraphBrpcService* graph_service) {
this->graph_service = graph_service;
......
......@@ -33,7 +33,7 @@ namespace distributed {
return -1; \
}
int32_t GraphBrpcServer::initialize() {
int32_t GraphBrpcServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
......@@ -48,7 +48,7 @@ int32_t GraphBrpcServer::initialize() {
}
_service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) {
if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
......@@ -61,11 +61,11 @@ int32_t GraphBrpcServer::initialize() {
return 0;
}
brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) {
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
return _pserver_channels[server_index].get();
}
uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
......@@ -73,20 +73,20 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers();
auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
return 0;
}
_environment->registe_ps_server(ip, port, _rank);
_environment->RegistePsServer(ip, port, _rank);
return 0;
}
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
this->rank = rank;
auto _env = environment();
auto _env = Environment();
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = 500000;
......@@ -94,7 +94,7 @@ int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
options.connect_timeout_ms = 10000;
options.max_retry = 3;
std::vector<PSHost> server_list = _env->get_ps_servers();
std::vector<PSHost> server_list = _env->GetPsServers();
_pserver_channels.resize(server_list.size());
std::ostringstream os;
std::string server_ip_port;
......@@ -172,19 +172,18 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
((GraphTable *)table)->remove_graph_node(node_ids);
return 0;
}
int32_t GraphBrpcServer::port() { return _server.listen_address().port; }
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
int32_t GraphBrpcService::initialize() {
int32_t GraphBrpcService::Initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
_service_handler_map[PS_PRINT_TABLE_STAT] =
&GraphBrpcService::print_table_stat;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;
_service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
......@@ -207,21 +206,21 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
&GraphBrpcService::load_graph_split_config;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
InitializeShardInfo();
return 0;
}
int32_t GraphBrpcService::initialize_shard_info() {
int32_t GraphBrpcService::InitializeShardInfo() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
server_size = _server->environment()->get_ps_servers().size();
auto &table_map = *(_server->table());
server_size = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->GetTable());
for (auto itr : table_map) {
itr.second->set_shard(_rank, server_size);
itr.second->SetShard(_rank, server_size);
}
_is_initialize_shard_info = true;
}
......@@ -241,7 +240,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->table(request->table_id());
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
......@@ -261,7 +260,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
}
}
int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
......@@ -275,16 +274,16 @@ int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type);
table->Barrier(trainer_id, barrier_type);
return 0;
}
int32_t GraphBrpcService::print_table_stat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t GraphBrpcService::PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat();
std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
......@@ -293,10 +292,10 @@ int32_t GraphBrpcService::print_table_stat(Table *table,
return 0;
}
int32_t GraphBrpcService::load_one_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t GraphBrpcService::LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
......@@ -304,20 +303,20 @@ int32_t GraphBrpcService::load_one_table(Table *table,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->load(request.params(0), request.params(1)) != 0) {
if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t GraphBrpcService::load_all_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
int32_t GraphBrpcService::LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
......@@ -325,13 +324,13 @@ int32_t GraphBrpcService::load_all_table(Table *table,
return 0;
}
int32_t GraphBrpcService::stop_server(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
std::thread t_stop([p_server]() {
p_server->stop();
p_server->Stop();
LOG(INFO) << "Server Stoped";
});
p_server->export_cv()->notify_all();
......@@ -339,19 +338,19 @@ int32_t GraphBrpcService::stop_server(Table *table,
return 0;
}
int32_t GraphBrpcService::stop_profiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t GraphBrpcService::StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t GraphBrpcService::start_profiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t GraphBrpcService::StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
......@@ -475,7 +474,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<int> server2request(server_size, -1);
std::vector<int64_t> local_id;
std::vector<int> local_query_idx;
size_t rank = get_rank();
size_t rank = GetRank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
......@@ -589,9 +588,9 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub(
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index));
((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
// GraphPsService_Stub rpc_stub =
// getServiceStub(get_cmd_channel(server_index));
// getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......
......@@ -31,10 +31,10 @@ class GraphBrpcServer : public PSServer {
GraphBrpcServer() {}
virtual ~GraphBrpcServer() {}
PsBaseService *get_service() { return _service.get(); }
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t build_peer2peer_connection(int rank);
virtual brpc::Channel *get_cmd_channel(size_t server_index);
virtual int32_t stop() {
virtual brpc::Channel *GetCmdChannel(size_t server_index);
virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0;
stoped_ = true;
......@@ -43,12 +43,12 @@ class GraphBrpcServer : public PSServer {
_server.Join();
return 0;
}
int32_t port();
int32_t Port();
std::condition_variable *export_cv() { return &cv_; }
private:
virtual int32_t initialize();
virtual int32_t Initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
......@@ -66,7 +66,7 @@ typedef int32_t (GraphBrpcService::*serviceFunc)(
class GraphBrpcService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
......@@ -75,7 +75,7 @@ class GraphBrpcService : public PsBaseService {
protected:
std::unordered_map<int32_t, serviceFunc> _service_handler_map;
int32_t initialize_shard_info();
int32_t InitializeShardInfo();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample_neighbors(Table *table,
......@@ -100,21 +100,21 @@ class GraphBrpcService : public PsBaseService {
int32_t remove_graph_node(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighbors_across_multi_servers(Table *table,
const PsRequestMessage &request,
......
......@@ -25,7 +25,7 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
int32_t PSClient::configure(
int32_t PSClient::Configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env, size_t client_id) {
......@@ -51,10 +51,10 @@ int32_t PSClient::configure(
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor);
}
return initialize();
return Initialize();
}
PSClient *PSClientFactory::create(const PSParameter &ps_config) {
PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
......@@ -81,7 +81,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
return NULL;
}
TableManager::instance().initialize();
TableManager::Instance().Initialize();
VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success";
return client;
}
......
......@@ -26,7 +26,6 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
......@@ -60,41 +59,6 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};
enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };
enum TrainingPhase { Init = 0, Train = 1, Save = 2 };
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};
struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
uint64_t *keys;
float **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};
class PSClient {
public:
PSClient() {}
......@@ -102,41 +66,37 @@ class PSClient {
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t configure( // NOLINT
virtual int32_t Configure( // NOLINT
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, size_t client_id) final; // NOLINT
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) = 0;
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) = 0;
// 触发table数据退场
virtual std::future<int32_t> shrink(uint32_t table_id,
virtual std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) = 0;
// 全量table进行数据load
virtual std::future<int32_t> load(const std::string &epoch,
virtual std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
virtual std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch,
virtual std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
virtual std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;
// 清空table数据
virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0;
virtual std::future<int32_t> Clear() = 0;
virtual std::future<int32_t> Clear(uint32_t table_id) = 0;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
......@@ -145,23 +105,19 @@ class PSClient {
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留
virtual std::future<int32_t> Push(RequestContext &push_context) = 0;
virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// start
virtual std::future<int32_t> push_dense_param(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0;
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
......@@ -169,15 +125,14 @@ class PSClient {
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training) = 0;
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training) {
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys,
size_t num, bool is_training) = 0;
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
......@@ -185,10 +140,10 @@ class PSClient {
return fut;
}
virtual ::std::future<int32_t> pull_sparse_ptr(char **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) {
virtual ::std::future<int32_t> PullSparsePtr(char **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
......@@ -196,38 +151,38 @@ class PSClient {
return fut;
}
virtual std::future<int32_t> print_table_stat(uint32_t table_id) = 0;
virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> flush() = 0;
virtual std::future<int32_t> Flush() = 0;
// server优雅退出
virtual std::future<int32_t> stop_server() = 0;
virtual std::future<int32_t> StopServer() = 0;
// server profilera
virtual std::future<int32_t> start_profiler() = 0;
virtual std::future<int32_t> stop_profiler() = 0;
virtual std::future<int32_t> StartProfiler() = 0;
virtual std::future<int32_t> StopProfiler() = 0;
virtual std::future<int32_t> barrier(size_t table_id,
virtual std::future<int32_t> Barrier(size_t table_id,
uint32_t barrier_type) = 0;
virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data,
void *done) = 0;
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done) = 0;
// recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path) = 0;
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path) = 0;
virtual void finalize_worker() = 0;
virtual void FinalizeWorker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type,
int to_client_id,
const std::string &msg) {
virtual std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
......@@ -238,13 +193,13 @@ class PSClient {
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_client2client_msg_handler(int msg_type,
MsgHandlerFunc handler) {
virtual int RegisteClient2ClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int handle_client2client_msg(int msg_type, int from_client_id,
const std::string &msg) {
virtual int HandleClient2ClientMsg(int msg_type, int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
LOG(WARNING) << "unknown client2client_msg type:" << msg_type;
......@@ -253,7 +208,7 @@ class PSClient {
return itr->second(msg_type, from_client_id, msg);
}
virtual ValueAccessor *table_accessor(size_t table_id) {
virtual ValueAccessor *GetTableAccessor(size_t table_id) {
auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) {
return NULL;
......@@ -261,31 +216,31 @@ class PSClient {
return itr->second.get();
}
virtual size_t get_server_nums() = 0;
virtual size_t GetServerNums() = 0;
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient(
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) = 0;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) = 0;
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) = 0;
virtual std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) = 0;
protected:
virtual int32_t initialize() = 0;
virtual int32_t Initialize() = 0;
size_t _client_id;
PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>>
......@@ -333,7 +288,7 @@ REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory {
public:
static PSClient *create(const PSParameter &config);
static PSClient *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
......@@ -19,166 +19,91 @@
namespace paddle {
namespace distributed {
int32_t PsLocalClient::initialize() {
int32_t PsLocalClient::Initialize() {
const auto& downpour_param = _config.server_param().downpour_server_param();
TableManager::instance().initialize();
TableManager::Instance().Initialize();
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto* table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
table->set_shard(0, 1);
table->initialize(downpour_param.downpour_table_param(i),
table->SetShard(0, 1);
table->Initialize(downpour_param.downpour_table_param(i),
_config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
return 0;
}
::std::future<int32_t> PsLocalClient::shrink(uint32_t table_id,
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
const std::string threshold) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::load(const std::string& epoch,
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
load(it.first, epoch, mode);
Load(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::load(uint32_t table_id,
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = table(table_id);
table_ptr->load(epoch, mode);
auto* table_ptr = GetTable(table_id);
table_ptr->Load(epoch, mode);
return done();
}
std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
save(it.first, epoch, mode);
Save(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::save(uint32_t table_id,
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = table(table_id);
table_ptr->flush();
table_ptr->save(epoch, mode);
auto* table_ptr = GetTable(table_id);
table_ptr->Flush();
table_ptr->Save(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Save(
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::clear() {
::std::future<int32_t> PsLocalClient::Clear() {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::clear(uint32_t table_id) {
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::flush() {
::std::future<int32_t> PsLocalClient::Flush() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::stop_server() {
::std::future<int32_t> PsLocalClient::StopServer() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
// uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
// char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(reinterpret_cast<char**>(pull_context.sparse_values),
table_id, pull_context.keys, num);
}
}
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) {
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}
::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1);
std::vector<float> region_buffer;
region_buffer.resize(num_per_shard);
table_ptr->pull_dense(region_buffer.data(), region_buffer.size());
table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t region_idx = 0;
size_t region_data_idx = 0;
......@@ -213,48 +138,49 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
return done();
}
::std::future<int32_t> PsLocalClient::push_dense_param(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1),
0);
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0);
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num;
}
// table_ptr->push_dense_param(region_buffer.data(), region_buffer.size());
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return done();
}
::std::future<int32_t> PsLocalClient::push_dense_raw_gradient(
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
int table_id, float* total_send_data, size_t total_send_data_size,
void* callback) {
VLOG(1) << "wxx push_dense_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = table(table_id);
auto* table_ptr = GetTable(table_id);
table_ptr->push_dense(total_send_data, total_send_data_size);
table_ptr->PushDense(total_send_data, total_send_data_size);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::push_dense(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1));
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1));
size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
......@@ -267,12 +193,12 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
offset += data_num;
}
table_ptr->push_dense(region_buffer.data(), region_buffer.size());
table_ptr->PushDense(region_buffer.data(), region_buffer.size());
return done();
}
//::std::future<int32_t> PsLocalClient::pull_sparse(float** select_values,
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
......@@ -282,14 +208,14 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = table_accessor(table_id);
// auto* table_ptr = table(table_id);
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->pull_sparse(keys, num);
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->pull_sparse(res_data.data(), keys, num);
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
......@@ -302,43 +228,43 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// return done();
//}
::std::future<int32_t> PsLocalClient::pull_sparse_ptr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num) {
::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num) {
// FIXME
// auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = table(table_id);
auto* table_ptr = GetTable(table_id);
table_ptr->pull_sparse_ptr(select_values, keys, num);
table_ptr->PullSparsePtr(select_values, keys, num);
return done();
}
::std::future<int32_t> PsLocalClient::push_sparse_raw_gradient(
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) {
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
table_ptr->push_sparse(keys, update_values, num);
table_ptr->PushSparse(keys, update_values, num);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::push_sparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
table_ptr->push_sparse(keys, update_values, num);
table_ptr->PushSparse(keys, update_values, num);
return done();
}
}
......
......@@ -26,54 +26,46 @@ class PsLocalClient : public PSClient {
public:
PsLocalClient() {}
virtual ~PsLocalClient() { _running = false; }
virtual int32_t create_client2client_connection(int pslib_timeout_ms,
int pslib_connect_timeout_ms,
int max_retry) {
virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
int pslib_connect_timeout_ms,
int max_retry) {
return 0;
}
virtual ::std::future<int32_t> shrink(uint32_t table_id,
virtual ::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
virtual ::std::future<int32_t> load(const std::string& epoch,
virtual ::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> load(uint32_t table_id,
virtual ::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;
virtual ::std::future<int32_t> save(const std::string& epoch,
virtual ::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id,
virtual ::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;
virtual ::std::future<int32_t> clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override;
virtual ::std::future<int32_t> Clear() override;
virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
virtual ::std::future<int32_t> stop_server() override;
virtual ::std::future<int32_t> StopServer() override;
virtual void finalize_worker() override {}
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num,
size_t table_id);
virtual void FinalizeWorker() override {}
virtual ::std::future<int32_t> PullDense(Region* regions, size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override;
virtual ::std::future<int32_t> PushDense(const Region* regions,
size_t region_num, size_t table_id);
virtual ::std::future<int32_t> Push(RequestContext& push_context) override;
virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> push_dense(const Region* regions,
size_t region_num, size_t table_id);
virtual ::std::future<int32_t> push_dense_param(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> pull_sparse(float** select_values,
size_t table_id,
const uint64_t* keys, size_t num,
bool is_training) {
virtual ::std::future<int32_t> PullSparse(float** select_values,
size_t table_id,
const uint64_t* keys, size_t num,
bool is_training) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -81,26 +73,26 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual ::std::future<int32_t> pull_sparse_ptr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num);
virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num);
virtual ::std::future<int32_t> print_table_stat(uint32_t table_id) {
virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual ::std::future<int32_t> push_sparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num);
virtual ::std::future<int32_t> PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num);
virtual ::std::future<int32_t> flush();
virtual ::std::future<int32_t> Flush();
// server profilera
virtual std::future<int32_t> start_profiler() {
virtual std::future<int32_t> StartProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -108,7 +100,7 @@ class PsLocalClient : public PSClient {
return fut;
};
virtual std::future<int32_t> stop_profiler() {
virtual std::future<int32_t> StopProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -116,7 +108,7 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type) {
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -124,10 +116,10 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float>* values,
std::vector<uint64_t>* keys,
int pserver_idx) {
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float>* values,
std::vector<uint64_t>* keys,
int pserver_idx) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -135,9 +127,9 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual std::future<int32_t> push_global_step(int table_id,
int64_t* total_send_data,
void* done) {
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t* total_send_data,
void* done) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -146,12 +138,12 @@ class PsLocalClient : public PSClient {
}
// recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string& path) {
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
return 0;
}
virtual ::std::future<int32_t> send_client2client_msg(
virtual ::std::future<int32_t> SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string& msg) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
......@@ -159,17 +151,18 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual size_t get_server_nums() { return 1; }
virtual size_t GetServerNums() { return 1; }
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float* total_send_data, size_t total_send_data_size,
void* callback) override;
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) override;
virtual std::future<int32_t> push_sparse_raw_gradient(
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id, const uint64_t* keys, const float** update_values,
uint32_t num, void* done, int pserver_idx) override {
std::promise<int32_t> prom;
......@@ -179,11 +172,11 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* done) override {
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* done) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -192,7 +185,7 @@ class PsLocalClient : public PSClient {
}
private:
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
std::future<int32_t> done() {
std::shared_ptr<std::promise<int32_t>> prom =
......@@ -202,16 +195,16 @@ class PsLocalClient : public PSClient {
return fut;
}
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* table() {
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
return &_table_map;
}
inline Table* table(size_t table_id) {
inline Table* GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
......
......@@ -25,17 +25,17 @@ class PsLocalServer : public PSServer {
public:
PsLocalServer() {}
virtual ~PsLocalServer() {}
virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; }
virtual int32_t configure(
virtual uint64_t Start() { return 0; }
virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t Stop() { return 0; }
virtual int32_t Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0;
}
private:
virtual int32_t initialize() { return 0; }
virtual int32_t Initialize() { return 0; }
};
}
}
......@@ -70,7 +70,7 @@ void GraphPyService::set_up(std::string ips_str, int shard_num,
port_list.push_back(ip_and_port[1]);
uint32_t port = stoul(ip_and_port[1]);
auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index);
host_sign_list.push_back(ph_host.serialize_to_string());
host_sign_list.push_back(ph_host.SerializeToString());
index++;
}
}
......@@ -83,11 +83,11 @@ void GraphPyClient::start_client() {
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list, servers_);
_ps_env.SetPsServers(&host_sign_list, servers_);
worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto));
worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id);
paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr->Configure(worker_proto, dense_regions, _ps_env, client_id);
worker_ptr->set_shard_num(get_shard_num());
}
void GraphPyServer::start_server(bool block) {
......@@ -96,17 +96,17 @@ void GraphPyServer::start_server(bool block) {
::paddle::distributed::PSParameter server_proto = this->GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&this->host_sign_list,
this->host_sign_list.size()); // test
_ps_env.SetPsServers(&this->host_sign_list,
this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto));
paddle::distributed::PSServerFactory::Create(server_proto));
VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->start(ip, port);
pserver_ptr->Configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->Start(ip, port);
pserver_ptr->build_peer2peer_connection(rank);
std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) {
......@@ -246,7 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath,
VLOG(0) << "loadding data with type " << name << " from " << filepath;
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->load(table_id, std::string(filepath), params);
get_ps_client()->Load(table_id, std::string(filepath), params);
status.wait();
}
}
......@@ -285,7 +285,7 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->load(table_id, std::string(filepath), params);
get_ps_client()->Load(table_id, std::string(filepath), params);
status.wait();
}
}
......@@ -396,13 +396,13 @@ std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
return res;
}
void GraphPyClient::stop_server() {
void GraphPyClient::StopServer() {
VLOG(0) << "going to stop server";
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return;
auto status = this->worker_ptr->stop_server();
auto status = this->worker_ptr->StopServer();
if (status.get() == 0) stoped_ = true;
}
void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); }
void GraphPyClient::FinalizeWorker() { this->worker_ptr->FinalizeWorker(); }
}
}
......@@ -123,7 +123,7 @@ class GraphPyServer : public GraphPyService {
set_rank(rank);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
int get_rank() { return rank; }
int GetRank() { return rank; }
void set_rank(int rank) { this->rank = rank; }
void start_server(bool block = true);
......@@ -154,8 +154,8 @@ class GraphPyClient : public GraphPyService {
(paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service());
}
void stop_server();
void finalize_worker();
void StopServer();
void FinalizeWorker();
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name);
......
......@@ -46,7 +46,7 @@ paddle::distributed::PSParameter load_from_prototxt(
return param;
}
void PSCore::init_gflag(const std::string& gflags) {
void PSCore::InitGFlag(const std::string& gflags) {
VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
......@@ -65,67 +65,67 @@ void PSCore::init_gflag(const std::string& gflags) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
int PSCore::init_server(
int PSCore::InitServer(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
_ps_env.set_trainers(trainers);
_ps_env.SetPsServers(host_sign_list, node_num);
_ps_env.SetTrainers(trainers);
int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param));
ret = _server_ptr->configure(_ps_param, _ps_env, index, server_sub_program);
paddle::distributed::PSServerFactory::Create(_ps_param));
ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program);
CHECK(ret == 0) << "failed to configure server";
return ret;
}
int PSCore::init_worker(
int PSCore::InitWorker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
const std::vector<std::string>* host_sign_list, int node_num, int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
_ps_env.SetPsServers(host_sign_list, node_num);
int ret = 0;
VLOG(1) << "PSCore::init_worker";
VLOG(1) << "PSCore::InitWorker";
auto* communicator = Communicator::GetInstance();
ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env,
ret = communicator->GetPsClient()->Configure(_ps_param, regions, _ps_env,
index);
communicator->Start();
return ret;
}
std::vector<uint64_t> PSCore::get_client_info() {
return _ps_env.get_client_info();
std::vector<uint64_t> PSCore::GetClientInfo() {
return _ps_env.GetClientInfo();
}
int PSCore::create_client2client_connection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
int ret = _worker_ptr->create_client2client_connection(
int PSCore::CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
int ret = _worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
return ret;
}
uint64_t PSCore::run_server(const std::string& ip, uint32_t port) {
return _server_ptr->start(ip, port);
uint64_t PSCore::RunServer(const std::string& ip, uint32_t port) {
return _server_ptr->Start(ip, port);
}
int PSCore::finalize_worker() {
_worker_ptr->finalize_worker();
int PSCore::FinalizeWorker() {
_worker_ptr->FinalizeWorker();
return 0;
}
int PSCore::stop_server() {
auto stop_status = _worker_ptr->stop_server();
int PSCore::StopServer() {
auto stop_status = _worker_ptr->StopServer();
stop_status.wait();
return 0;
}
paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; }
paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; }
} // namespace distributed
} // namespace paddle
......@@ -42,31 +42,31 @@ class PSCore {
explicit PSCore() {}
virtual ~PSCore() {}
virtual int init_server(
virtual int InitServer(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int init_worker(
virtual int InitWorker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
regions,
const std::vector<std::string>* host_sign_list, int node_num, int index);
virtual uint64_t run_server(const std::string& ip, uint32_t port);
virtual int stop_server();
virtual int finalize_worker();
virtual std::vector<uint64_t> get_client_info();
virtual int create_client2client_connection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
virtual uint64_t RunServer(const std::string& ip, uint32_t port);
virtual int StopServer();
virtual int FinalizeWorker();
virtual std::vector<uint64_t> GetClientInfo();
virtual int CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::shared_ptr<paddle::distributed::PSServer>
_server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient>
_worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* get_param();
virtual paddle::distributed::PSParameter* GetParam();
private:
void init_gflag(const std::string& gflags);
void InitGFlag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
};
......
......@@ -29,7 +29,7 @@ REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
PSServer *PSServerFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
......@@ -56,18 +56,18 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
<< service_param.server_class();
return NULL;
}
TableManager::instance().initialize();
TableManager::Instance().Initialize();
return server;
}
int32_t PSServer::configure(
int32_t PSServer::Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program) {
scope_.reset(new framework::Scope());
_config = config.server_param();
_rank = server_rank;
_environment = &env;
size_t shard_num = env.get_ps_servers().size();
size_t shard_num = env.GetPsServers().size();
const auto &downpour_param = _config.downpour_server_param();
......@@ -87,21 +87,21 @@ int32_t PSServer::configure(
global_step_table = downpour_param.downpour_table_param(i).table_id();
}
table->set_program_env(scope_.get(), place_, &server_sub_program);
table->set_shard(_rank, shard_num);
table->initialize(downpour_param.downpour_table_param(i),
table->SetProgramEnv(scope_.get(), place_, &server_sub_program);
table->SetShard(_rank, shard_num);
table->Initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->set_table_map(&_table_map);
_table_map[barrier_table]->SetTableMap(&_table_map);
}
if (global_step_table != UINT32_MAX) {
_table_map[global_step_table]->set_table_map(&_table_map);
_table_map[global_step_table]->SetTableMap(&_table_map);
}
return initialize();
return Initialize();
}
} // namespace distributed
} // namespace paddle
......@@ -65,19 +65,19 @@ class PSServer {
PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete;
virtual int32_t configure(
virtual int32_t Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {});
virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0;
virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
virtual int32_t Stop() = 0;
inline size_t rank() const { return _rank; }
inline size_t Rank() const { return _rank; }
inline PSEnvironment *environment() { return _environment; }
inline PSEnvironment *Environment() { return _environment; }
inline const ServerParameter *config() const { return &_config; }
inline Table *table(size_t table_id) {
inline const ServerParameter *Config() const { return &_config; }
inline Table *GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
......@@ -85,12 +85,12 @@ class PSServer {
return NULL;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *table() {
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
return &_table_map;
}
protected:
virtual int32_t initialize() = 0;
virtual int32_t Initialize() = 0;
protected:
size_t _rank;
......@@ -129,11 +129,11 @@ class PsBaseService : public PsService {
public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {}
virtual size_t get_rank() { return _rank; }
virtual int32_t configure(PSServer *server) {
virtual size_t GetRank() { return _rank; }
virtual int32_t Configure(PSServer *server) {
_server = server;
_rank = _server->rank();
_config = _server->config();
_rank = _server->Rank();
_config = _server->Config();
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
......@@ -148,8 +148,8 @@ class PsBaseService : public PsService {
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
}
virtual int32_t initialize() = 0;
PSServer *get_server() { return _server; }
virtual int32_t Initialize() = 0;
PSServer *GetServer() { return _server; }
protected:
size_t _rank;
......@@ -160,7 +160,7 @@ REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory {
public:
static PSServer *create(const PSParameter &config);
static PSServer *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
......@@ -17,7 +17,7 @@
namespace paddle {
namespace distributed {
int32_t BarrierTable::initialize() {
int32_t BarrierTable::Initialize() {
auto trainers = _config.common().trainer_num();
trigger_.store(trainers);
......@@ -29,7 +29,7 @@ int32_t BarrierTable::initialize() {
}
// 0: send_barrier 1: recv_barrier 2: complete
int32_t BarrierTable::barrier(const uint32_t trainer_id,
int32_t BarrierTable::Barrier(const uint32_t trainer_id,
const std::string barrier_type) {
std::unique_lock<std::mutex> lock(mutex_);
......@@ -56,7 +56,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id,
VLOG(1) << "barrier table optimize begin";
for (auto& x : *table_map_) {
auto table = x.second;
table->pour();
table->Pour();
}
VLOG(1) << "barrier table optimize done";
......@@ -66,7 +66,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id,
return 0;
}
int32_t BarrierTable::set_table_map(
int32_t BarrierTable::SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) {
table_map_ = table_map;
return 0;
......
......@@ -21,8 +21,8 @@ namespace distributed {
int FLAGS_pslib_table_save_max_retry_dense = 3;
void CommonDenseTable::create_initializer(const std::string& attr,
const std::string& name) {
void CommonDenseTable::CreateInitializer(const std::string& attr,
const std::string& name) {
auto slices = string::split_string<std::string>(attr, "&");
if (slices[0] == "gaussian_random") {
......@@ -39,7 +39,7 @@ void CommonDenseTable::create_initializer(const std::string& attr,
}
}
int32_t CommonDenseTable::initialize() {
int32_t CommonDenseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
......@@ -49,12 +49,12 @@ int32_t CommonDenseTable::initialize() {
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
_global_lr = new float(1.0);
initialize_value();
initialize_optimizer();
InitializeValue();
InitializeOptimizer();
return 0;
}
int32_t CommonDenseTable::initialize_value() {
int32_t CommonDenseTable::InitializeValue() {
auto common = _config.common();
int size = static_cast<int>(common.params().size());
values_.resize(size);
......@@ -70,7 +70,7 @@ int32_t CommonDenseTable::initialize_value() {
auto& initializer = common.initializers()[x];
total_dim_ += dim;
create_initializer(initializer, varname);
CreateInitializer(initializer, varname);
values_[x].resize(dim);
names_index_[varname] = x;
......@@ -92,27 +92,27 @@ int32_t CommonDenseTable::initialize_value() {
param_col_ids_.insert(param_col_ids_.begin() + 1, -1);
}
VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_
VLOG(1) << "CommonDenseTable::InitializeValue total dim: " << total_dim_
<< " fixed_len_params_dim: " << fixed_len_params_dim_;
pull_reservoir_ = ReservoirValue<float>(param_dim_);
return 0;
}
int32_t CommonDenseTable::initialize_optimizer() {
int32_t CommonDenseTable::InitializeOptimizer() {
auto common = _config.common();
auto name = common.name();
auto attrs = common.attributes();
if (name == "sgd") {
optimizer_ = std::make_shared<DSGD>(common, &values_);
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam") {
optimizer_ = std::make_shared<DAdam>(common, &values_);
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam_d2sum") {
optimizer_ = std::make_shared<DAdamD2Sum>(common, &values_);
// optimizer_->set_global_lr(_global_lr); //no use
// optimizer_->SetGlobalLR(_global_lr); //no use
} else if (name == "sum") {
optimizer_ = std::make_shared<DSUM>(common, &values_);
} else if (name == "summary") {
......@@ -124,34 +124,34 @@ int32_t CommonDenseTable::initialize_optimizer() {
return 0;
}
int32_t CommonDenseTable::set_global_lr(float* lr) {
int32_t CommonDenseTable::SetGlobalLR(float* lr) {
_global_lr = lr;
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
return 0;
}
int32_t CommonDenseTable::Pull(TableContext& context) {
CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values;
return pull_dense(pull_values, context.num);
return PullDense(pull_values, context.num);
}
int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense);
if (context.push_context.values != nullptr) {
const float* values = context.push_context.values;
return push_dense(values, context.num);
return PushDense(values, context.num);
}
return 0;
}
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) {
int32_t CommonDenseTable::PullDense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values);
return 0;
}
int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
int32_t CommonDenseTable::PushDenseParam(const float* values, size_t num) {
PADDLE_ENFORCE_GE(
num, param_dim_,
paddle::platform::errors::InvalidArgument(
......@@ -160,14 +160,14 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
return 0;
}
int32_t CommonDenseTable::pour() {
int32_t CommonDenseTable::Pour() {
pull_reservoir_.avg();
_push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
_PushDense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
pull_reservoir_.reset();
return 0;
}
int32_t CommonDenseTable::push_dense(const float* values, size_t num) {
int32_t CommonDenseTable::PushDense(const float* values, size_t num) {
if (sync) {
std::future<int> task =
_shards_task_pool[0]->enqueue([this, &values]() -> int {
......@@ -176,12 +176,12 @@ int32_t CommonDenseTable::push_dense(const float* values, size_t num) {
});
task.wait();
} else {
_push_dense(values, num);
_PushDense(values, num);
}
return 0;
}
int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
int32_t CommonDenseTable::_PushDense(const float* values, size_t num) {
PADDLE_ENFORCE_GE(
num, param_dim_,
paddle::platform::errors::InvalidArgument(
......@@ -195,7 +195,7 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
[this, shard_id, &buckets, &values]() -> int {
auto begin = buckets[shard_id];
auto end = buckets[shard_id + 1];
optimizer_->update(values, param_dim_, begin, end);
optimizer_->Update(values, param_dim_, begin, end);
return 0;
});
}
......@@ -207,12 +207,12 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
return 0;
}
int32_t CommonDenseTable::load(const std::string& path,
int32_t CommonDenseTable::Load(const std::string& path,
const std::string& param) {
if (param_dim_ <= 0) {
return 0;
}
std::string table_path = table_dir(path);
std::string table_path = TableDir(path);
auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end());
for (auto ff : file_list) {
......@@ -314,7 +314,7 @@ int32_t CommonDenseTable::load(const std::string& path,
return 0;
}
int32_t CommonDenseTable::save(const std::string& path,
int32_t CommonDenseTable::Save(const std::string& path,
const std::string& param) {
int save_param = atoi(param.c_str());
uint32_t feasign_size;
......@@ -323,10 +323,10 @@ int32_t CommonDenseTable::save(const std::string& path,
FsChannelConfig channel_config;
if (_config.compress_in_save()) {
channel_config.path = paddle::string::format_string(
"%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx);
"%s/part-%03d.gz", TableDir(path).c_str(), _shard_idx);
} else {
channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_dir(path).c_str(), _shard_idx);
"%s/part-%03d", TableDir(path).c_str(), _shard_idx);
}
_afs_client.remove(channel_config.path);
channel_config.converter = _value_accesor->Converter(save_param).converter;
......
......@@ -34,29 +34,29 @@ class CommonDenseTable : public DenseTable {
public:
CommonDenseTable() {}
virtual ~CommonDenseTable() {}
int32_t initialize() override;
int32_t initialize_shard() override { return 0; }
virtual void create_initializer(const std::string& attr,
const std::string& name);
virtual int32_t initialize_value();
virtual int32_t initialize_optimizer();
int32_t Initialize() override;
int32_t InitializeShard() override { return 0; }
virtual void CreateInitializer(const std::string& attr,
const std::string& name);
virtual int32_t InitializeValue();
virtual int32_t InitializeOptimizer();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
int32_t pull_dense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override;
int32_t pour() override;
int32_t set_global_lr(float* lr) override;
int32_t PullDense(float* pull_values, size_t num) override;
int32_t PushDenseParam(const float* values, size_t num) override;
int32_t PushDense(const float* values, size_t num) override;
int32_t Pour() override;
int32_t SetGlobalLR(float* lr) override;
int32_t load(const std::string& path, const std::string& param) override;
int32_t save(const std::string& path, const std::string& param) override;
int32_t Load(const std::string& path, const std::string& param) override;
int32_t Save(const std::string& path, const std::string& param) override;
int32_t flush() override { return 0; }
int32_t shrink(const std::string& param) override { return 0; }
void clear() override { return; }
int32_t Flush() override { return 0; }
int32_t Shrink(const std::string& param) override { return 0; }
void Clear() override { return; }
protected:
int32_t _push_dense(const float* values, size_t num);
int32_t _PushDense(const float* values, size_t num);
private:
const int task_pool_size_ = 10;
......
......@@ -448,7 +448,7 @@ int32_t GraphTable::load_graph_split_config(const std::string &path) {
return 0;
}
int32_t GraphTable::load(const std::string &path, const std::string &param) {
int32_t GraphTable::Load(const std::string &path, const std::string &param) {
bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
......@@ -1066,11 +1066,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int32_t GraphTable::get_server_index_by_id(int64_t id) {
return id % shard_num / shard_num_per_server;
}
int32_t GraphTable::initialize(const TableParameter &config,
int32_t GraphTable::Initialize(const TableParameter &config,
const FsClientParameter &fs_config) {
LOG(INFO) << "in graphTable initialize";
_config = config;
if (initialize_accessor() != 0) {
if (InitializeAccessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed";
return -1;
}
......@@ -1082,9 +1082,9 @@ int32_t GraphTable::initialize(const TableParameter &config,
auto graph = config.graph_parameter();
shard_num = _config.shard_num();
LOG(INFO) << "in graphTable initialize over";
return initialize(graph);
return Initialize(graph);
}
int32_t GraphTable::initialize(const GraphParameter &graph) {
int32_t GraphTable::Initialize(const GraphParameter &graph) {
#ifdef PADDLE_WITH_HETERPS
if (graph.gpups_mode()) {
gpups_mode = true;
......
......@@ -280,7 +280,7 @@ class ScaledLRU {
}
}
auto status =
thread_pool->enqueue([this]() -> int { return shrink(); });
thread_pool->enqueue([this]() -> int { return Shrink(); });
status.wait();
}
});
......@@ -298,7 +298,7 @@ class ScaledLRU {
LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
return lru_pool[index].insert(keys, data, length);
}
int shrink() {
int Shrink() {
int node_size = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
......@@ -329,7 +329,7 @@ class ScaledLRU {
if (diff != 0) {
__sync_fetch_and_add(&global_count, diff);
if (global_count > int(1.25 * size_limit)) {
thread_pool->enqueue([this]() -> int { return shrink(); });
thread_pool->enqueue([this]() -> int { return Shrink(); });
}
}
}
......@@ -430,11 +430,11 @@ class GraphTable : public SparseTable {
virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res);
virtual int32_t initialize() { return 0; }
virtual int32_t initialize(const TableParameter &config,
virtual int32_t Initialize() { return 0; }
virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t initialize(const GraphParameter &config);
int32_t load(const std::string &path, const std::string &param);
virtual int32_t Initialize(const GraphParameter &config);
int32_t Load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path);
int32_t load_edges(const std::string &path, bool reverse);
......@@ -452,26 +452,25 @@ class GraphTable : public SparseTable {
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) {
virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) {
return 0;
}
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) {
virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}
virtual int32_t clear_nodes();
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string &param) { return 0; }
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Shrink(const std::string &param) { return 0; }
//指定保存路径
virtual int32_t save(const std::string &path, const std::string &converter) {
virtual int32_t Save(const std::string &path, const std::string &converter) {
return 0;
}
virtual int32_t initialize_shard() { return 0; }
virtual int32_t set_shard(size_t shard_idx, size_t server_num) {
virtual int32_t InitializeShard() { return 0; }
virtual int32_t SetShard(size_t shard_idx, size_t server_num) {
_shard_idx = shard_idx;
/*
_shard_num is not used in graph_table, this following operation is for the
......
......@@ -167,7 +167,7 @@ int64_t CommonSparseTable::LoadFromText(
return 0;
}
int32_t CommonSparseTable::initialize() {
int32_t CommonSparseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
......@@ -200,15 +200,15 @@ int32_t CommonSparseTable::initialize() {
offset += dim;
}
initialize_value();
initialize_optimizer();
initialize_recorder();
InitializeValue();
InitializeOptimizer();
InitializeRecorder();
return 0;
}
int32_t CommonSparseTable::initialize_recorder() { return 0; }
int32_t CommonSparseTable::InitializeRecorder() { return 0; }
int32_t CommonSparseTable::initialize_value() {
int32_t CommonSparseTable::InitializeValue() {
auto common = _config.common();
shard_values_.reserve(task_pool_size_);
......@@ -223,18 +223,18 @@ int32_t CommonSparseTable::initialize_value() {
return 0;
}
int32_t CommonSparseTable::initialize_optimizer() {
int32_t CommonSparseTable::InitializeOptimizer() {
auto common = _config.common();
auto name = common.name();
if (name == "sgd") {
optimizer_ = std::make_shared<SSGD>(value_names_, value_dims_,
value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam") {
optimizer_ = std::make_shared<SAdam>(value_names_, value_dims_,
value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
} else if (name == "sum") {
optimizer_ = std::make_shared<SSUM>(value_names_, value_dims_,
value_offsets_, value_idx_);
......@@ -246,13 +246,13 @@ int32_t CommonSparseTable::initialize_optimizer() {
return 0;
}
int32_t CommonSparseTable::set_global_lr(float* lr) {
int32_t CommonSparseTable::SetGlobalLR(float* lr) {
_global_lr = lr;
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
return 0;
}
int32_t CommonSparseTable::load(const std::string& dirname,
int32_t CommonSparseTable::Load(const std::string& dirname,
const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock();
......@@ -276,7 +276,7 @@ int32_t CommonSparseTable::load(const std::string& dirname,
return 0;
}
int32_t CommonSparseTable::save(const std::string& dirname,
int32_t CommonSparseTable::Save(const std::string& dirname,
const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock();
......@@ -322,7 +322,7 @@ int32_t CommonSparseTable::save(const std::string& dirname,
return 0;
}
std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
std::pair<int64_t, int64_t> CommonSparseTable::PrintTableStat() {
int64_t feasign_size = 0;
int64_t mf_size = 0;
......@@ -335,7 +335,7 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
return {feasign_size, mf_size};
}
int32_t CommonSparseTable::pour() {
int32_t CommonSparseTable::Pour() {
std::vector<float> values;
std::vector<uint64_t> keys;
......@@ -349,7 +349,7 @@ int32_t CommonSparseTable::pour() {
std::copy(reservoir.values.begin(), reservoir.values.end(),
std::back_inserter(values));
}
_push_sparse(keys.data(), values.data(), pull_reservoir_.size());
_PushSparse(keys.data(), values.data(), pull_reservoir_.size());
pull_reservoir_.clear();
return 0;
......@@ -360,11 +360,11 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
return PullSparsePtr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
return PullSparse(pull_values, pull_value);
}
}
......@@ -373,16 +373,16 @@ int32_t CommonSparseTable::Push(TableContext& context) {
if (context.push_context.values != nullptr) {
const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
return PushSparse(keys, values, context.num);
} else {
const float** values = context.push_context.ptr_values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
return PushSparse(keys, values, context.num);
}
}
int32_t CommonSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
int32_t CommonSparseTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num);
......@@ -421,8 +421,8 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values,
return 0;
}
int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
const uint64_t* keys, size_t num) {
int32_t CommonSparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -458,8 +458,8 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
return 0;
}
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
const float* values, size_t num) {
int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
const float* values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -474,7 +474,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &values, num, &offset_bucket]() -> int {
auto& offsets = offset_bucket[shard_id];
optimizer_->update(keys, values, num, offsets,
optimizer_->Update(keys, values, num, offsets,
shard_values_[shard_id].get());
return 0;
});
......@@ -486,8 +486,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
return 0;
}
int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
const float* values, size_t num) {
int32_t CommonSparseTable::PushSparse(const uint64_t* keys, const float* values,
size_t num) {
if (sync) {
std::future<int> task =
_shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int {
......@@ -506,20 +506,20 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
});
task.wait();
} else {
_push_sparse(keys, values, num);
_PushSparse(keys, values, num);
}
return 0;
}
int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
const float** values, size_t num) {
_push_sparse(keys, values, num);
int32_t CommonSparseTable::PushSparse(const uint64_t* keys,
const float** values, size_t num) {
_PushSparse(keys, values, num);
return 0;
}
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
const float** values, size_t num) {
int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
const float** values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -536,7 +536,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
auto& offsets = offset_bucket[shard_id];
for (size_t i = 0; i < offsets.size(); ++i) {
std::vector<uint64_t> tmp_off = {0};
optimizer_->update(keys + offsets[i], values[offsets[i]], num,
optimizer_->Update(keys + offsets[i], values[offsets[i]], num,
tmp_off, shard_values_[shard_id].get());
}
return 0;
......@@ -549,8 +549,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
return 0;
}
int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
const float* values, size_t num) {
int32_t CommonSparseTable::PushSparseParam(const uint64_t* keys,
const float* values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -585,21 +585,21 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
return 0;
}
int32_t CommonSparseTable::flush() { return 0; }
int32_t CommonSparseTable::Flush() { return 0; }
int32_t CommonSparseTable::shrink(const std::string& param) {
int32_t CommonSparseTable::Shrink(const std::string& param) {
int threshold = std::stoi(param);
VLOG(3) << "sparse table shrink: " << threshold;
VLOG(3) << "sparse table Shrink: " << threshold;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// shrink
VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink";
// Shrink
VLOG(4) << shard_id << " " << task_pool_size_ << " begin Shrink";
shard_values_[shard_id]->Shrink(threshold);
}
return 0;
}
void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; }
void CommonSparseTable::Clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed
} // namespace paddle
......@@ -114,25 +114,23 @@ class CommonSparseTable : public SparseTable {
virtual ~CommonSparseTable() {}
// unused method begin
virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; }
virtual int32_t push_dense_param(const float* values, size_t num) {
return 0;
}
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; }
virtual int32_t PushDense(const float* values, size_t num) { return 0; }
// unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value();
virtual int32_t initialize_optimizer();
virtual int32_t initialize_recorder();
virtual int32_t Initialize();
virtual int32_t InitializeShard() { return 0; }
virtual int32_t InitializeValue();
virtual int32_t InitializeOptimizer();
virtual int32_t InitializeRecorder();
virtual int32_t load(const std::string& path, const std::string& param);
virtual int32_t Load(const std::string& path, const std::string& param);
virtual int32_t save(const std::string& path, const std::string& param);
virtual int32_t Save(const std::string& path, const std::string& param);
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total);
......@@ -150,34 +148,34 @@ class CommonSparseTable : public SparseTable {
const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks);
virtual std::pair<int64_t, int64_t> print_table_stat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual std::pair<int64_t, int64_t> PrintTableStat();
virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
size_t num);
virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float** values,
size_t num);
virtual int32_t PushSparse(const uint64_t* keys, const float** values,
size_t num);
// only for sparse geo table
virtual int32_t push_sparse_param(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t set_global_lr(float* lr) override;
virtual int32_t SetGlobalLR(float* lr) override;
virtual int32_t pour();
virtual int32_t flush();
virtual int32_t shrink(const std::string& param);
virtual void clear();
virtual int32_t Pour();
virtual int32_t Flush();
virtual int32_t Shrink(const std::string& param);
virtual void Clear();
protected:
virtual int32_t _push_sparse(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t _push_sparse(const uint64_t* keys, const float** values,
size_t num);
virtual int32_t _PushSparse(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t _PushSparse(const uint64_t* keys, const float** values,
size_t num);
protected:
const int task_pool_size_ = 11;
......
......@@ -71,11 +71,11 @@ class SparseTable : public Table {
SparseTable() {}
virtual ~SparseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
......@@ -97,19 +97,17 @@ class DenseTable : public Table {
DenseTable() {}
virtual ~DenseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) override {
virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) override {
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t push_dense_param(const float *values, size_t num) override {
return 0;
}
int32_t shrink(const std::string &param) override { return 0; }
int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
};
class BarrierTable : public Table {
......@@ -117,44 +115,42 @@ class BarrierTable : public Table {
BarrierTable() {}
virtual ~BarrierTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) override {
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_dense_param(const float *values, size_t num) override {
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t shrink(const std::string &param) override { return 0; }
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t load(const std::string &path, const std::string &param) {
int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t save(const std::string &path, const std::string &param) {
virtual int32_t Save(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t initialize_shard() { return 0; }
virtual int32_t InitializeShard() { return 0; }
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
// only for barrier
// 0: send_barrier 1: recv_barrier 2: complete
virtual int32_t barrier(const uint32_t trainer_id,
virtual int32_t Barrier(const uint32_t trainer_id,
const std::string barrier_type) override;
virtual int32_t set_table_map(
virtual int32_t SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;
private:
......
......@@ -34,9 +34,9 @@ class DenseOptimizer {
DenseOptimizer() {}
explicit DenseOptimizer(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {}
virtual void update(const float* update_values, size_t num, int begin,
virtual void Update(const float* update_values, size_t num, int begin,
int end) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; }
virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
protected:
float* global_learning_rate_;
......@@ -55,7 +55,7 @@ class DSUM : public DenseOptimizer {
}
}
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
GetBlas<float>().VADD(update_numel, update_values + begin, param + begin,
......@@ -81,7 +81,7 @@ class DSGD : public DenseOptimizer {
}
}
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
std::vector<float> grads;
......@@ -134,7 +134,7 @@ class DAdam : public DenseOptimizer {
// make sure common_dense_table.task_pool_size_ == 1;
// otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
std::vector<float> grad, grad2, tmp;
......@@ -214,7 +214,7 @@ class DAdamD2Sum : public DenseOptimizer {
}
}
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_ada_g2sum(ada_g2sum + begin, 1,
......@@ -276,7 +276,7 @@ class DSummary : public DenseOptimizer {
}
}
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel);
......
......@@ -40,11 +40,11 @@ class SparseOptimizer {
value_offsets_(value_offsets),
value_idx_(value_idx) {}
virtual void update(const uint64_t* keys, const float* update_values,
virtual void Update(const uint64_t* keys, const float* update_values,
size_t num, const std::vector<uint64_t>& offsets,
ValueBlock* block) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; }
virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
const std::vector<std::string>& value_names_;
const std::vector<int>& value_dims_;
......@@ -70,7 +70,7 @@ class SSUM : public SparseOptimizer {
update_numel = value_dims.at(idx);
}
void update(const uint64_t* keys, const float* update_values, size_t num,
void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
......@@ -100,7 +100,7 @@ class SSGD : public SparseOptimizer {
lr_offset = value_offsets.at(idx);
}
void update(const uint64_t* keys, const float* update_values, size_t num,
void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
......@@ -156,7 +156,7 @@ class SAdam : public SparseOptimizer {
epsilon = 1.0e-8;
}
void update(const uint64_t* keys, const float* update_values, size_t num,
void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
......
......@@ -17,11 +17,10 @@
namespace paddle {
namespace distributed {
int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
const float* values,
size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param begin "
"push_sparse_param "
int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys,
const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin "
"PushSparseParam "
<< num;
auto shard_num = _task_pool_size;
std::vector<std::vector<uint64_t>> offset_bucket;
......@@ -31,8 +30,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
auto y = keys[x] % shard_num;
offset_bucket[y].push_back(x);
if (x < 10) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param key: "
<< keys[x] << " shard: " << y;
VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam key: " << keys[x]
<< " shard: " << y;
}
}
......@@ -51,8 +50,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
feature_value.resize(_dim);
std::copy_n(values + _dim * offset, _dim, feature_value.data());
if (i < 10) {
VLOG(5) << "MemorySparseGeoTable::push_sparse_param "
"push_sparse_param key "
VLOG(5) << "MemorySparseGeoTable::PushSparseParam "
"PushSparseParam key "
<< id << " value[0]: " << (values + _dim * offset)[0]
<< " data: " << feature_value.data()[0]
<< " value[-1]: " << (values + _dim * offset)[_dim - 1]
......@@ -69,9 +68,9 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
return 0;
}
int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id,
std::vector<float>* values,
std::vector<uint64_t>* ids) {
int32_t MemorySparseGeoTable::PullGeoParam(const uint32_t trainer_id,
std::vector<float>* values,
std::vector<uint64_t>* ids) {
_geo_recorder->GetAndClear(trainer_id, ids);
VLOG(5)
<< "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id "
......@@ -86,23 +85,23 @@ int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id,
pull_value.frequencies_ = frequencies.data();
values->resize(ids->size() * _dim);
pull_sparse(values->data(), pull_value);
PullSparse(values->data(), pull_value);
return 0;
}
int32_t MemorySparseGeoTable::push_sparse(const uint64_t* keys,
const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse keys[0]" << keys[0]
int32_t MemorySparseGeoTable::PushSparse(const uint64_t* keys,
const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparse keys[0]" << keys[0]
<< " key_num: " << num;
std::vector<uint64_t> ids;
ids.resize(num);
std::copy_n(keys, num, ids.begin());
_geo_recorder->Update(ids);
_push_sparse(keys, values, num);
_PushSparse(keys, values, num);
return 0;
}
int32_t MemorySparseGeoTable::initialize() {
int32_t MemorySparseGeoTable::Initialize() {
if (!_geo_recorder) {
auto trainers = _config.common().trainer_num();
_geo_recorder = std::make_shared<GeoRecorder>(trainers);
......@@ -118,8 +117,8 @@ int32_t MemorySparseGeoTable::initialize() {
return 0;
}
int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
int32_t MemorySparseGeoTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num);
......@@ -146,13 +145,13 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
auto& feature_value = local_shard[key];
feature_value.resize(_dim);
memset(feature_value.data(), 0, sizeof(float) * _dim);
VLOG(0) << "MemorySparseGeoTable pull_sparse key not found!!! "
VLOG(0) << "MemorySparseGeoTable PullSparse key not found!!! "
<< key;
itr = local_shard.find(key);
}
memcpy(select_data, itr.value().data(), _dim * sizeof(float));
VLOG(5) << "DEBUG MemorySparseGeoTable::pull_sparse key: " << key
VLOG(5) << "DEBUG MemorySparseGeoTable::PullSparse key: " << key
<< " select_data[0] " << select_data[0]
<< " value[0]: " << itr.value().data()[0];
}
......@@ -167,8 +166,8 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
return 0;
}
int32_t MemorySparseGeoTable::_push_sparse(const uint64_t* keys,
const float* values, size_t num) {
int32_t MemorySparseGeoTable::_PushSparse(const uint64_t* keys,
const float* values, size_t num) {
auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(shard_num);
......
......@@ -40,31 +40,31 @@ class MemorySparseGeoTable : public SparseTable {
MemorySparseGeoTable() { _geo_recorder = nullptr; }
virtual ~MemorySparseGeoTable() {}
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t load(const std::string& path, const std::string& param) {
virtual int32_t Initialize();
virtual int32_t InitializeShard() { return 0; }
virtual int32_t Load(const std::string& path, const std::string& param) {
return 0;
}
virtual int32_t save(const std::string& path, const std::string& param) {
virtual int32_t Save(const std::string& path, const std::string& param) {
return 0;
}
virtual int32_t Pull(TableContext& context) { return 0; }
virtual int32_t Push(TableContext& context) { return 0; }
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; }
virtual void clear() { return; }
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t Flush() { return 0; }
virtual int32_t Shrink(const std::string& param) { return 0; }
virtual void Clear() { return; }
virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
int32_t push_sparse_param(const uint64_t* keys, const float* values,
size_t num);
int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num);
// TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys);
int32_t PullGeoParam(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys);
int32_t push_sparse(const uint64_t* keys, const float* values,
size_t num) override;
int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num) override;
int32_t _push_sparse(const uint64_t* keys, const float* values, size_t num);
int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num);
// int32_t _pull_sparse(float* pull_values, const PullSparseValue&
// pull_value);
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册