未验证 提交 b3270adf 编写于 作者: Z zhaocaibei123 提交者: GitHub

统一ps refine (#41234)

* update name

* update name

* fix test

* fix fleet bind

* update name

* update name

* fix test

* fix gpups wrapper

* remove Push/Pull/Load/Save with context in client and wrapper base class

* fix

* fix
Co-authored-by: Nesythan <esythan@126.com>
上级 cb124156
...@@ -80,7 +80,7 @@ void DownpourPsClientService::service( ...@@ -80,7 +80,7 @@ void DownpourPsClientService::service(
const PsRequestMessage *request, PsResponseMessage *response, const PsRequestMessage *request, PsResponseMessage *response,
::google::protobuf::Closure *done) { ::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
int ret = _client->handle_client2client_msg( int ret = _client->HandleClient2ClientMsg(
request->cmd_id(), request->client_id(), request->data()); request->cmd_id(), request->client_id(), request->data());
response->set_err_code(0); response->set_err_code(0);
response->set_err_msg(""); response->set_err_msg("");
...@@ -91,8 +91,8 @@ void DownpourPsClientService::service( ...@@ -91,8 +91,8 @@ void DownpourPsClientService::service(
} }
// 启动client端RpcService 用于数据互发等操作 // 启动client端RpcService 用于数据互发等操作
int32_t BrpcPsClient::start_client_service() { int32_t BrpcPsClient::StartClientService() {
if (_service.configure(this, _client_id) != 0) { if (_service.Configure(this, _client_id) != 0) {
LOG(ERROR) LOG(ERROR)
<< "service initialize failed, service_name:DownpourPsClientService"; << "service initialize failed, service_name:DownpourPsClientService";
return -1; return -1;
...@@ -108,12 +108,12 @@ int32_t BrpcPsClient::start_client_service() { ...@@ -108,12 +108,12 @@ int32_t BrpcPsClient::start_client_service() {
return -1; return -1;
} }
_server_started = true; _server_started = true;
_env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, _env->RegistePsClient(butil::my_ip_cstr(), _server.listen_address().port,
_client_id); _client_id);
return 0; return 0;
} }
int32_t BrpcPsClient::create_client2client_connection( int32_t BrpcPsClient::CreateClient2ClientConnection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) {
brpc::ChannelOptions options; brpc::ChannelOptions options;
options.protocol = "baidu_std"; options.protocol = "baidu_std";
...@@ -122,12 +122,12 @@ int32_t BrpcPsClient::create_client2client_connection( ...@@ -122,12 +122,12 @@ int32_t BrpcPsClient::create_client2client_connection(
options.connect_timeout_ms = pserver_connect_timeout_ms; options.connect_timeout_ms = pserver_connect_timeout_ms;
options.max_retry = max_retry; options.max_retry = max_retry;
std::vector<PSHost> client_list = _env->get_ps_clients(); std::vector<PSHost> client_list = _env->GetPsClients();
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: " VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: "
<< client_list.size(); << client_list.size();
for (auto cc : client_list) { for (auto cc : client_list) {
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: " VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: "
<< cc.to_string(); << cc.ToString();
} }
_client_channels.resize(client_list.size()); _client_channels.resize(client_list.size());
std::ostringstream os; std::ostringstream os;
...@@ -154,7 +154,7 @@ int32_t BrpcPsClient::create_client2client_connection( ...@@ -154,7 +154,7 @@ int32_t BrpcPsClient::create_client2client_connection(
return 0; return 0;
} }
int32_t BrpcPsClient::initialize() { int32_t BrpcPsClient::Initialize() {
_async_call_num = 0; _async_call_num = 0;
brpc::ChannelOptions options; brpc::ChannelOptions options;
...@@ -169,7 +169,7 @@ int32_t BrpcPsClient::initialize() { ...@@ -169,7 +169,7 @@ int32_t BrpcPsClient::initialize() {
std::string client_ip(butil::my_ip_cstr()); std::string client_ip(butil::my_ip_cstr());
// 获取server列表,并连接 // 获取server列表,并连接
std::vector<PSHost> server_list = _env->get_ps_servers(); std::vector<PSHost> server_list = _env->GetPsServers();
_server_channels.resize(server_list.size()); _server_channels.resize(server_list.size());
for (size_t i = 0; i < server_list.size(); ++i) { for (size_t i = 0; i < server_list.size(); ++i) {
server_ip_port.assign(server_list[i].ip.c_str()); server_ip_port.assign(server_list[i].ip.c_str());
...@@ -194,7 +194,7 @@ int32_t BrpcPsClient::initialize() { ...@@ -194,7 +194,7 @@ int32_t BrpcPsClient::initialize() {
os << server_ip_port << ","; os << server_ip_port << ",";
} }
// 启动client探听接口, 并相互建立连接 // 启动client探听接口, 并相互建立连接
start_client_service(); StartClientService();
// 异步push 请求队列初始化 // 异步push 请求队列初始化
const auto &worker_param = _config.worker_param().downpour_worker_param(); const auto &worker_param = _config.worker_param().downpour_worker_param();
...@@ -234,13 +234,13 @@ int32_t BrpcPsClient::initialize() { ...@@ -234,13 +234,13 @@ int32_t BrpcPsClient::initialize() {
_flushing = false; _flushing = false;
// 启动异步push线程 // 启动异步push线程
_async_push_sparse_thread = _async_push_sparse_thread =
std::thread(std::bind(&BrpcPsClient::push_sparse_task_consume, this)); std::thread(std::bind(&BrpcPsClient::PushSparseTaskConsume, this));
// _async_push_sparse_thread.detach(); // _async_push_sparse_thread.detach();
_async_push_dense_thread = _async_push_dense_thread =
std::thread(std::bind(&BrpcPsClient::push_dense_task_consume, this)); std::thread(std::bind(&BrpcPsClient::PushDenseTaskConsume, this));
// for debug // for debug
// _print_thread = // _print_thread =
// std::thread(std::bind(&BrpcPsClient::print_queue_size_thread, this)); // std::thread(std::bind(&BrpcPsClient::PrintQueueSizeThread, this));
return 0; return 0;
} }
...@@ -286,7 +286,7 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { ...@@ -286,7 +286,7 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) {
return data; return data;
} }
std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) { std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, table_id](void *done) { request_call_num, [request_call_num, table_id](void *done) {
...@@ -319,7 +319,7 @@ std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) { ...@@ -319,7 +319,7 @@ std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) {
closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT); closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT);
closure->request(i)->set_table_id(table_id); closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id); closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(get_cmd_channel(i)); PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms( closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load 10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
...@@ -327,7 +327,7 @@ std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) { ...@@ -327,7 +327,7 @@ std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) {
} }
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::send_cmd( std::future<int32_t> BrpcPsClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) { uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
...@@ -352,7 +352,7 @@ std::future<int32_t> BrpcPsClient::send_cmd( ...@@ -352,7 +352,7 @@ std::future<int32_t> BrpcPsClient::send_cmd(
for (const auto &param : params) { for (const auto &param : params) {
closure->request(i)->add_params(param); closure->request(i)->add_params(param);
} }
PsService_Stub rpc_stub(get_cmd_channel(i)); PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms( closure->cntl(i)->set_timeout_ms(
10800000 * 2); // cmd msg don't limit timeout for save/load 10800000 * 2); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
...@@ -361,7 +361,7 @@ std::future<int32_t> BrpcPsClient::send_cmd( ...@@ -361,7 +361,7 @@ std::future<int32_t> BrpcPsClient::send_cmd(
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::send_save_cmd( std::future<int32_t> BrpcPsClient::SendSaveCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) { uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
...@@ -392,7 +392,7 @@ std::future<int32_t> BrpcPsClient::send_save_cmd( ...@@ -392,7 +392,7 @@ std::future<int32_t> BrpcPsClient::send_save_cmd(
for (const auto &param : params) { for (const auto &param : params) {
closure->request(i)->add_params(param); closure->request(i)->add_params(param);
} }
PsService_Stub rpc_stub(get_cmd_channel(i)); PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms( closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load 10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
...@@ -401,65 +401,42 @@ std::future<int32_t> BrpcPsClient::send_save_cmd( ...@@ -401,65 +401,42 @@ std::future<int32_t> BrpcPsClient::send_save_cmd(
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id, std::future<int32_t> BrpcPsClient::Shrink(uint32_t table_id,
const std::string threshold) { const std::string threshold) {
return send_cmd(table_id, PS_SHRINK_TABLE, {threshold}); return SendCmd(table_id, PS_SHRINK_TABLE, {threshold});
} }
std::future<int32_t> BrpcPsClient::load(const std::string &epoch, std::future<int32_t> BrpcPsClient::Load(const std::string &epoch,
const std::string &mode) { const std::string &mode) {
return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); return SendCmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::load(uint32_t table_id, std::future<int32_t> BrpcPsClient::Load(uint32_t table_id,
const std::string &epoch, const std::string &epoch,
const std::string &mode) { const std::string &mode) {
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); return SendCmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::Load(const LoadSaveContext &load_context) { std::future<int32_t> BrpcPsClient::Save(const std::string &epoch,
if (load_context.table_id < 0) {
return send_cmd(-1, PS_LOAD_ALL_TABLE,
{load_context.epoch, load_context.mode});
} else {
return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE,
{load_context.epoch, load_context.mode});
}
}
std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) { const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch; VLOG(1) << "BrpcPsClient::save path " << epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); return SendSaveCmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::save(uint32_t table_id, std::future<int32_t> BrpcPsClient::Save(uint32_t table_id,
const std::string &epoch, const std::string &epoch,
const std::string &mode) { const std::string &mode) {
VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id " VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id "
<< table_id; << table_id;
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::Save(const LoadSaveContext &save_context) { std::future<int32_t> BrpcPsClient::Clear() {
if (save_context.table_id < 0) { return SendCmd(-1, PS_CLEAR_ALL_TABLE, {});
VLOG(1) << "BrpcPsClient::save path " << save_context.epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE,
{save_context.epoch, save_context.mode});
} else {
VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch
<< " table_id " << save_context.table_id;
return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE,
{save_context.epoch, save_context.mode});
}
}
std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
} }
std::future<int32_t> BrpcPsClient::clear(uint32_t table_id) { std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) {
return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {}); return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {});
} }
std::future<int32_t> BrpcPsClient::flush() { std::future<int32_t> BrpcPsClient::Flush() {
VLOG(0) << "BrpcPsClient::flush begin"; VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true; _flushing = true;
std::promise<int> promise; std::promise<int> promise;
...@@ -472,106 +449,69 @@ std::future<int32_t> BrpcPsClient::flush() { ...@@ -472,106 +449,69 @@ std::future<int32_t> BrpcPsClient::flush() {
promise.set_value(0); promise.set_value(0);
_flushing = false; _flushing = false;
VLOG(0) << "BrpcPsClient::flush done"; VLOG(0) << "BrpcPsClient::flush done";
print_queue_size(); PrintQueueSize();
return fut; return fut;
} }
void BrpcPsClient::print_queue_size() { void BrpcPsClient::PrintQueueSize() {
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first; auto table_id = push_sparse_task_itr.first;
auto queue_size = push_sparse_task_itr.second->Size(); auto queue_size = push_sparse_task_itr.second->Size();
VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id
<< " size: " << queue_size; << " size: " << queue_size;
} }
for (auto &task_queue_itr : _push_dense_task_queue_map) { for (auto &task_queue_itr : _push_dense_task_queue_map) {
auto table_id = task_queue_itr.first; auto table_id = task_queue_itr.first;
auto queue_size = task_queue_itr.second->Size(); auto queue_size = task_queue_itr.second->Size();
VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id
<< " size: " << queue_size; << " size: " << queue_size;
} }
} }
void BrpcPsClient::print_queue_size_thread() { void BrpcPsClient::PrintQueueSizeThread() {
while (_running) { while (_running) {
usleep(1000000 * 60 * 2); usleep(1000000 * 60 * 2);
print_queue_size(); PrintQueueSize();
} }
} }
void BrpcPsClient::finalize_worker() { void BrpcPsClient::FinalizeWorker() {
flush(); Flush();
VLOG(0) << "BrpcPsClient::finalize_worker begin join thread"; VLOG(0) << "BrpcPsClient::FinalizeWorker begin join thread";
_running = false; _running = false;
_async_push_dense_thread.join(); _async_push_dense_thread.join();
_async_push_sparse_thread.join(); _async_push_sparse_thread.join();
// _print_thread.join(); // _print_thread.join();
VLOG(0) << "BrpcPsClient::finalize_worker begin join server"; VLOG(0) << "BrpcPsClient::FinalizeWorker begin join server";
_server.Stop(1000); _server.Stop(1000);
_server.Join(); _server.Join();
_server_started = false; _server_started = false;
VLOG(0) << "BrpcPsClient::finalize_worker done"; VLOG(0) << "BrpcPsClient::FinalizeWorker done";
} }
std::future<int32_t> BrpcPsClient::stop_server() { std::future<int32_t> BrpcPsClient::StopServer() {
return send_cmd(-1, PS_STOP_SERVER, {}); return SendCmd(-1, PS_STOP_SERVER, {});
} }
std::future<int32_t> BrpcPsClient::start_profiler() { std::future<int32_t> BrpcPsClient::StartProfiler() {
return send_cmd(-1, PS_START_PROFILER, {}); return SendCmd(-1, PS_START_PROFILER, {});
} }
std::future<int32_t> BrpcPsClient::stop_profiler() { std::future<int32_t> BrpcPsClient::StopProfiler() {
return send_cmd(-1, PS_STOP_PROFILER, {}); return SendCmd(-1, PS_STOP_PROFILER, {});
} }
std::future<int32_t> BrpcPsClient::barrier(size_t table_id, std::future<int32_t> BrpcPsClient::Barrier(size_t table_id,
uint32_t barrier_type) { uint32_t barrier_type) {
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); return SendCmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}
std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
return pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
return pull_sparse_param(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
return pull_sparse(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
}
}
} }
std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) { std::future<int32_t> BrpcPsClient::PullGeoParam(size_t table_id,
if (push_context.value_type == Dense) { // push dense std::vector<float> *values,
const Region *dense_region = push_context.push_context.push_dense_values; std::vector<uint64_t> *keys,
return push_dense(dense_region, push_context.num, push_context.table); int pserver_idx) {
} else { // push sparse auto *accessor = GetTableAccessor(table_id);
size_t table_id = push_context.table;
size_t num = push_context.num;
bool is_training = push_context.is_training;
if (push_context.training_mode == Geo) { // for geo
// TODO(zhaocaibei)
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
return push_sparse(table_id, keys, update_values, num);
}
}
}
std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) {
auto *accessor = table_accessor(table_id);
DownpourBrpcClosure *closure = DownpourBrpcClosure *closure =
new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { new DownpourBrpcClosure(1, [keys, values, accessor](void *done) {
int ret = 0; int ret = 0;
...@@ -600,7 +540,7 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id, ...@@ -600,7 +540,7 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM); closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM);
closure->request(0)->set_table_id(table_id); closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id); closure->request(0)->set_client_id(_client_id);
PsService_Stub rpc_stub(get_cmd_channel(pserver_idx)); PsService_Stub rpc_stub(GetCmdChannel(pserver_idx));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure); closure);
...@@ -608,10 +548,11 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id, ...@@ -608,10 +548,11 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
} }
// for GEO // for GEO
std::future<int32_t> BrpcPsClient::push_sparse_param( std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
size_t table_id, const uint64_t *keys, const float **update_values, const uint64_t *keys,
size_t num, void *done) { const float **update_values,
auto *accessor = table_accessor(table_id); size_t num, void *done) {
auto *accessor = GetTableAccessor(table_id);
// 发送RPC请求 // 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done); DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
...@@ -649,7 +590,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param( ...@@ -649,7 +590,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
} }
PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type); (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
...@@ -658,16 +599,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_param( ...@@ -658,16 +599,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::pull_dense(Region *regions, std::future<int32_t> BrpcPsClient::PullDense(Region *regions, size_t region_num,
size_t region_num, size_t table_id) {
size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense"); auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
auto fea_dim = accessor->GetTableInfo(FEA_DIM); auto fea_dim = accessor->GetTableInfo(FEA_DIM);
auto select_size = accessor->GetTableInfo(SELECT_SIZE); auto select_size = accessor->GetTableInfo(SELECT_SIZE);
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// callback 将各shard结果,顺序填入region // callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, num_per_shard, regions, region_num, request_call_num, [request_call_num, num_per_shard, regions, region_num,
...@@ -730,22 +670,22 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions, ...@@ -730,22 +670,22 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
closure->request(i)->set_client_id(_client_id); closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&num_per_shard, // NOLINT closure->request(i)->add_params((char *)&num_per_shard, // NOLINT
sizeof(num_per_shard)); sizeof(num_per_shard));
PsService_Stub rpc_stub(get_dense_channel(i)); PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure); closure->response(i), closure);
} }
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions, std::future<int32_t> BrpcPsClient::PushDenseParam(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据 // 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num); std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE); size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE);
size_t current_region_idx = 0; size_t current_region_idx = 0;
size_t current_region_data_idx = 0; size_t current_region_data_idx = 0;
...@@ -809,17 +749,17 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions, ...@@ -809,17 +749,17 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
fill_num); fill_num);
fill_remain_size -= fill_num; fill_remain_size -= fill_num;
} }
PsService_Stub rpc_stub(get_dense_channel(i)); PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure); closure->response(i), closure);
} }
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient( std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) { size_t num, void *done) {
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
// 发送RPC请求 // 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done); DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
...@@ -872,7 +812,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient( ...@@ -872,7 +812,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
} }
PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type); (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
...@@ -881,7 +821,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient( ...@@ -881,7 +821,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::push_dense_raw_gradient( std::future<int32_t> BrpcPsClient::PushDenseRawGradient(
int table_id, float *total_send_data, size_t total_send_data_size, int table_id, float *total_send_data, size_t total_send_data_size,
void *done) { void *done) {
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
...@@ -889,9 +829,9 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient( ...@@ -889,9 +829,9 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise); closure->add_promise(promise);
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id); closure->request(i)->set_table_id(table_id);
...@@ -905,16 +845,16 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient( ...@@ -905,16 +845,16 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); total_send_data + i * num_per_shard, num_per_shard * sizeof(float));
// closure->cntl(i)->set_request_compress_type( // closure->cntl(i)->set_request_compress_type(
// (brpc::CompressType)FLAGS_pserver_communicate_compress_type); // (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i)); PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure); closure->response(i), closure);
} }
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::push_global_step(int table_id, std::future<int32_t> BrpcPsClient::PushGlobalStep(int table_id,
int64_t *total_send_data, int64_t *total_send_data,
void *done) { void *done) {
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done); DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
...@@ -933,17 +873,17 @@ std::future<int32_t> BrpcPsClient::push_global_step(int table_id, ...@@ -933,17 +873,17 @@ std::future<int32_t> BrpcPsClient::push_global_step(int table_id,
memcpy(push_data_ptr + sizeof(uint32_t), total_send_data, memcpy(push_data_ptr + sizeof(uint32_t), total_send_data,
num_per_shard * sizeof(int64_t)); num_per_shard * sizeof(int64_t));
PsService_Stub rpc_stub(get_dense_channel(i)); PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure); closure->response(i), closure);
} }
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values, std::future<int32_t> BrpcPsClient::PullSparse(float **select_values,
size_t table_id, size_t table_id,
const uint64_t *keys, size_t num, const uint64_t *keys, size_t num,
bool is_training) { bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse"); auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse");
auto local_timer = auto local_timer =
std::make_shared<CostTimer>("pserver_client_pull_sparse_local"); std::make_shared<CostTimer>("pserver_client_pull_sparse_local");
...@@ -968,7 +908,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values, ...@@ -968,7 +908,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
} }
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(SELECT_SIZE); size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
...@@ -1055,7 +995,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values, ...@@ -1055,7 +995,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
closure->request(i)->set_client_id(_client_id); closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t)); sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i)); PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure); closure->response(i), closure);
...@@ -1065,11 +1005,11 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values, ...@@ -1065,11 +1005,11 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
} }
// for GEO // for GEO
std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values, std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
size_t table_id, size_t table_id,
const uint64_t *keys, const uint64_t *keys,
size_t num, size_t num,
bool is_training) { bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse_param"); auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse_param");
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
...@@ -1082,7 +1022,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values, ...@@ -1082,7 +1022,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
} }
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(SELECT_SIZE); size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
...@@ -1169,7 +1109,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values, ...@@ -1169,7 +1109,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
closure->request(i)->set_client_id(_client_id); closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t)); sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i)); PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure); closure->response(i), closure);
...@@ -1178,7 +1118,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values, ...@@ -1178,7 +1118,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::send_client2client_msg( std::future<int32_t> BrpcPsClient::SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string &msg) { int msg_type, int to_client_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
...@@ -1203,10 +1143,10 @@ std::future<int32_t> BrpcPsClient::send_client2client_msg( ...@@ -1203,10 +1143,10 @@ std::future<int32_t> BrpcPsClient::send_client2client_msg(
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial( std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) { uint32_t num, void *done, int pserver_idx) {
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(UPDATE_SIZE); size_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done); DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
...@@ -1228,7 +1168,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial( ...@@ -1228,7 +1168,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
memcpy(push_data_ptr, update_values[i], value_size); memcpy(push_data_ptr, update_values[i], value_size);
push_data_ptr += value_size; push_data_ptr += value_size;
} }
PsService_Stub rpc_stub(get_sparse_channel(pserver_idx)); PsService_Stub rpc_stub(GetSparseChannel(pserver_idx));
closure->cntl(0)->set_request_compress_type( closure->cntl(0)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type); (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
...@@ -1236,8 +1176,8 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial( ...@@ -1236,8 +1176,8 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
return fut; return fut;
} }
int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, int32_t BrpcPsClient::RecvAndSaveTable(const uint64_t table_id,
const std::string &path) { const std::string &path) {
// get var information // get var information
std::string var_name = ""; std::string var_name = "";
int64_t var_num = 0; int64_t var_num = 0;
...@@ -1271,17 +1211,17 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, ...@@ -1271,17 +1211,17 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
save_vec.push_back(save_huge_vec.data() + i * var_shape); save_vec.push_back(save_huge_vec.data() + i * var_shape);
} }
VLOG(2) << "recv_and_save_table: table_class: " << table_class; VLOG(2) << "RecvAndSaveTable: table_class: " << table_class;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its // TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// recv_and_save_table // RecvAndSaveTable
if (table_class == "MemorySparseGeoTable") { if (table_class == "MemorySparseGeoTable") {
auto status = auto status =
pull_sparse_param(reinterpret_cast<float **>(save_vec.data()), table_id, PullSparseParam(reinterpret_cast<float **>(save_vec.data()), table_id,
save_key.data(), save_key.size(), true); save_key.data(), save_key.size(), true);
status.wait(); status.wait();
} else { } else {
auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()), auto status = PullSparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true); table_id, save_key.data(), save_key.size(), true);
status.wait(); status.wait();
} }
...@@ -1315,15 +1255,15 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, ...@@ -1315,15 +1255,15 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
return 0; return 0;
} }
std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id, std::future<int32_t> BrpcPsClient::PushSparse(size_t table_id,
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num) { size_t num) {
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_sparse"); auto push_timer = std::make_shared<CostTimer>("pserver_client_push_sparse");
CostTimer parse_timer("pserver_client_push_sparse_parse"); CostTimer parse_timer("pserver_client_push_sparse_parse");
int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) { while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "push_sparse Waiting for async_call_num comsume, // LOG(INFO) << "PushSparse Waiting for async_call_num comsume,
// task_num:" // task_num:"
// << push_sparse_async_num // << push_sparse_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num; // << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
...@@ -1333,7 +1273,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id, ...@@ -1333,7 +1273,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
auto put_timer = std::make_shared<CostTimer>("client_push_sparse_put"); auto put_timer = std::make_shared<CostTimer>("client_push_sparse_put");
thread_local std::vector<std::vector<std::pair<uint64_t, const float *>>> thread_local std::vector<std::vector<std::pair<uint64_t, const float *>>>
shard_sorted_kv_list; shard_sorted_kv_list;
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
shard_sorted_kv_list.resize(request_call_num); shard_sorted_kv_list.resize(request_call_num);
for (auto &x : shard_sorted_kv_list) { for (auto &x : shard_sorted_kv_list) {
...@@ -1381,7 +1321,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id, ...@@ -1381,7 +1321,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
return fut; return fut;
} }
void BrpcPsClient::push_sparse_task_consume() { void BrpcPsClient::PushSparseTaskConsume() {
uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit; uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit;
std::vector<std::shared_ptr<SparseAsyncTask>> task_list; std::vector<std::shared_ptr<SparseAsyncTask>> task_list;
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
...@@ -1392,7 +1332,7 @@ void BrpcPsClient::push_sparse_task_consume() { ...@@ -1392,7 +1332,7 @@ void BrpcPsClient::push_sparse_task_consume() {
// 所有sparseTable的pushTask 进行处理 // 所有sparseTable的pushTask 进行处理
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first; auto table_id = push_sparse_task_itr.first;
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
auto &task_queue = push_sparse_task_itr.second; auto &task_queue = push_sparse_task_itr.second;
auto queue_size = task_queue->Size(); auto queue_size = task_queue->Size();
if (queue_size == 0) { if (queue_size == 0) {
...@@ -1471,7 +1411,7 @@ void BrpcPsClient::push_sparse_task_consume() { ...@@ -1471,7 +1411,7 @@ void BrpcPsClient::push_sparse_task_consume() {
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] = merge_status[shard_idx] =
async_push_sparse_shard_threads.enqueue(std::bind( async_push_sparse_shard_threads.enqueue(std::bind(
&BrpcPsClient::push_sparse_async_shard_push, this, task_list, &BrpcPsClient::PushSparseAsyncShardPush, this, task_list,
request_kv_num, table_id, shard_idx, closure, accessor)); request_kv_num, table_id, shard_idx, closure, accessor));
} }
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
...@@ -1487,7 +1427,7 @@ void BrpcPsClient::push_sparse_task_consume() { ...@@ -1487,7 +1427,7 @@ void BrpcPsClient::push_sparse_task_consume() {
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] = merge_status[shard_idx] =
async_push_sparse_shard_threads.enqueue(std::bind( async_push_sparse_shard_threads.enqueue(std::bind(
&BrpcPsClient::push_sparse_async_shard_merge, this, task_list, &BrpcPsClient::PushSparseAsyncShardMerge, this, task_list,
request_kv_num, table_id, shard_idx, accessor)); request_kv_num, table_id, shard_idx, accessor));
} }
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
...@@ -1523,7 +1463,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data, ...@@ -1523,7 +1463,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
accessor->Merge(merge_data_shell, another_data_shell, 1); accessor->Merge(merge_data_shell, another_data_shell, 1);
} }
int BrpcPsClient::push_sparse_async_shard_merge( int BrpcPsClient::PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num, int table_id, int shard_idx, std::vector<int> &request_kv_num, int table_id, int shard_idx,
ValueAccessor *accessor) { ValueAccessor *accessor) {
...@@ -1615,12 +1555,12 @@ int BrpcPsClient::push_sparse_async_shard_merge( ...@@ -1615,12 +1555,12 @@ int BrpcPsClient::push_sparse_async_shard_merge(
return 0; return 0;
} }
int BrpcPsClient::push_sparse_async_shard_push( int BrpcPsClient::PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num, int table_id, int shard_idx, std::vector<int> &request_kv_num, int table_id, int shard_idx,
DownpourBrpcClosure *closure, ValueAccessor *accessor) { DownpourBrpcClosure *closure, ValueAccessor *accessor) {
push_sparse_async_shard_merge(task_list, request_kv_num, table_id, shard_idx, PushSparseAsyncShardMerge(task_list, request_kv_num, table_id, shard_idx,
accessor); accessor);
size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num; size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num;
auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list; auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list;
...@@ -1649,7 +1589,7 @@ int BrpcPsClient::push_sparse_async_shard_push( ...@@ -1649,7 +1589,7 @@ int BrpcPsClient::push_sparse_async_shard_push(
accessor->GetTableInfo(UPDATE_SIZE)); accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
} }
PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type); (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
...@@ -1658,10 +1598,10 @@ int BrpcPsClient::push_sparse_async_shard_push( ...@@ -1658,10 +1598,10 @@ int BrpcPsClient::push_sparse_async_shard_push(
return 0; return 0;
} }
std::future<int32_t> BrpcPsClient::push_dense(const Region *regions, std::future<int32_t> BrpcPsClient::PushDense(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto *accessor = table_accessor(table_id); auto *accessor = GetTableAccessor(table_id);
int fea_dim = accessor->GetTableInfo(FEA_DIM); int fea_dim = accessor->GetTableInfo(FEA_DIM);
int update_dim = accessor->GetTableInfo(UPDATE_DIM); int update_dim = accessor->GetTableInfo(UPDATE_DIM);
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense"); auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
...@@ -1669,7 +1609,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions, ...@@ -1669,7 +1609,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
std::make_shared<CostTimer>("pserver_client_push_dense_parse"); std::make_shared<CostTimer>("pserver_client_push_dense_parse");
int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size(); int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
while (push_dense_async_num > FLAGS_pserver_max_async_call_num) { while (push_dense_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "push_dense Waiting for async_call_num comsume, // LOG(INFO) << "PushDense Waiting for async_call_num comsume,
// task_num:" // task_num:"
// << push_dense_async_num // << push_dense_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num; // << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
...@@ -1683,7 +1623,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions, ...@@ -1683,7 +1623,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// 将region数据拷贝到转置矩阵中 // 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num * async_task->data()->resize(num_per_shard * request_call_num *
...@@ -1705,7 +1645,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions, ...@@ -1705,7 +1645,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
return fut; return fut;
} }
void BrpcPsClient::push_dense_task_consume() { void BrpcPsClient::PushDenseTaskConsume() {
uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit; uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit;
static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge; static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge;
::ThreadPool async_merge_dense_threads(10); ::ThreadPool async_merge_dense_threads(10);
...@@ -1723,7 +1663,7 @@ void BrpcPsClient::push_dense_task_consume() { ...@@ -1723,7 +1663,7 @@ void BrpcPsClient::push_dense_task_consume() {
++_async_call_num; ++_async_call_num;
DenseAsyncTask *task; DenseAsyncTask *task;
task_queue->Get(task); task_queue->Get(task);
auto *accessor = table_accessor(task->table_id()); auto *accessor = GetTableAccessor(task->table_id());
// 设置请求回调 // 设置请求回调
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
...@@ -1774,7 +1714,7 @@ void BrpcPsClient::push_dense_task_consume() { ...@@ -1774,7 +1714,7 @@ void BrpcPsClient::push_dense_task_consume() {
merge_status[i].wait(); merge_status[i].wait();
} }
VLOG(3) << "BrpcPsClient::push_dense_task_consume before merge " VLOG(3) << "BrpcPsClient::PushDenseTaskConsume before merge "
"total_send_data[0]" "total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]" << total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2] << total_send_data[total_send_data_size - 2]
...@@ -1787,7 +1727,7 @@ void BrpcPsClient::push_dense_task_consume() { ...@@ -1787,7 +1727,7 @@ void BrpcPsClient::push_dense_task_consume() {
mat *= (1.0 / (merge_count + 1)); mat *= (1.0 / (merge_count + 1));
} }
VLOG(3) << "BrpcPsClient::push_dense_task_consume after merge " VLOG(3) << "BrpcPsClient::PushDenseTaskConsume after merge "
"total_send_data[0]" "total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]" << total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2] << total_send_data[total_send_data_size - 2]
...@@ -1796,8 +1736,8 @@ void BrpcPsClient::push_dense_task_consume() { ...@@ -1796,8 +1736,8 @@ void BrpcPsClient::push_dense_task_consume() {
<< merge_count; << merge_count;
} }
std::shared_ptr<DenseAsyncTask> task_ptr(task); std::shared_ptr<DenseAsyncTask> task_ptr(task);
push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size, PushDenseRawGradient(task_ptr, total_send_data, total_send_data_size,
closure); closure);
} }
auto wait_ms = FLAGS_pserver_async_push_dense_interval_ms - auto wait_ms = FLAGS_pserver_async_push_dense_interval_ms -
(butil::gettimeofday_ms() - async_start_time_ms); (butil::gettimeofday_ms() - async_start_time_ms);
...@@ -1807,16 +1747,17 @@ void BrpcPsClient::push_dense_task_consume() { ...@@ -1807,16 +1747,17 @@ void BrpcPsClient::push_dense_task_consume() {
} }
} }
void BrpcPsClient::push_dense_raw_gradient( void BrpcPsClient::PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task,
std::shared_ptr<DenseAsyncTask> &task, float *total_send_data, float *total_send_data,
size_t total_send_data_size, DownpourBrpcClosure *closure) { size_t total_send_data_size,
auto *accessor = table_accessor(task->table_id()); DownpourBrpcClosure *closure) {
auto *accessor = GetTableAccessor(task->table_id());
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
// 将数据拷贝到请求buffer区 // 将数据拷贝到请求buffer区
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc"); auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer); closure->add_timer(timer);
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
auto send_timer = auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send"); std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
...@@ -1832,7 +1773,7 @@ void BrpcPsClient::push_dense_raw_gradient( ...@@ -1832,7 +1773,7 @@ void BrpcPsClient::push_dense_raw_gradient(
total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); total_send_data + i * num_per_shard, num_per_shard * sizeof(float));
closure->cntl(i)->set_request_compress_type( closure->cntl(i)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type); (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i)); PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure); closure->response(i), closure);
} }
......
...@@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService { ...@@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService {
DownpourPsClientService() {} DownpourPsClientService() {}
virtual ~DownpourPsClientService() {} virtual ~DownpourPsClientService() {}
virtual int32_t configure(PSClient *client, size_t rank_id) { virtual int32_t Configure(PSClient *client, size_t rank_id) {
_client = client; _client = client;
_rank = rank_id; _rank = rank_id;
return 0; return 0;
...@@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient { ...@@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient {
BrpcPsClient() {} BrpcPsClient() {}
virtual ~BrpcPsClient() { virtual ~BrpcPsClient() {
if (_running) { if (_running) {
flush(); Flush();
_running = false; _running = false;
} }
if (_async_push_dense_thread.joinable()) { if (_async_push_dense_thread.joinable()) {
...@@ -154,109 +154,98 @@ class BrpcPsClient : public PSClient { ...@@ -154,109 +154,98 @@ class BrpcPsClient : public PSClient {
_server_started = false; _server_started = false;
} }
} }
virtual int32_t create_client2client_connection( virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); int pserver_connect_timeout_ms,
std::future<int32_t> shrink(uint32_t table_id, int max_retry);
std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override; const std::string threshold) override;
std::future<int32_t> load(const std::string &epoch, std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> load(uint32_t table_id, const std::string &epoch, std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Load(const LoadSaveContext &load_context) override; std::future<int32_t> Save(const std::string &epoch,
std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch, std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> Save( std::future<int32_t> Clear() override;
const LoadSaveContext &save_context) override;
std::future<int32_t> clear() override;
std::future<int32_t> clear(uint32_t table_id) override;
std::future<int32_t> stop_server() override; std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> start_profiler() override; std::future<int32_t> StopServer() override;
std::future<int32_t> stop_profiler() override;
void finalize_worker() override; std::future<int32_t> StartProfiler() override;
std::future<int32_t> StopProfiler() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num, void FinalizeWorker() override;
size_t table_id);
virtual std::future<int32_t> push_dense_param(const Region *regions, virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t region_num, size_t table_id);
size_t table_id);
virtual std::future<int32_t> push_dense(const Region *regions, virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num, size_t table_id); size_t region_num,
void push_dense_task_consume(); size_t table_id);
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> Pull(RequestContext &pull_context) override; virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num, size_t table_id);
void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> Push(RequestContext &push_context) override; virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual std::future<int32_t> print_table_stat(uint32_t table_id); virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type); virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> Flush();
virtual std::future<int32_t> pull_geo_param(size_t table_id, std::future<int32_t> SendClient2ClientMsg(int msg_type, int to_client_id,
std::vector<float> *values, const std::string &msg) override;
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> push_global_step(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> flush();
std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
const std::string &msg) override;
// for local save sparse // for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id, virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path); const std::string &path);
void print_queue_size(); void PrintQueueSize();
void print_queue_size_thread(); void PrintQueueSizeThread();
protected: protected:
virtual size_t get_server_nums() { return _server_channels.size(); } virtual size_t GetServerNums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) { inline brpc::Channel *GetSparseChannel(size_t server_id) {
return _server_channels[server_id][0].get(); return _server_channels[server_id][0].get();
} }
inline brpc::Channel *get_dense_channel(size_t server_id) { inline brpc::Channel *GetDenseChannel(size_t server_id) {
return _server_channels[server_id][1].get(); return _server_channels[server_id][1].get();
} }
inline brpc::Channel *get_cmd_channel(size_t server_id) { inline brpc::Channel *GetCmdChannel(size_t server_id) {
return _server_channels[server_id][2].get(); return _server_channels[server_id][2].get();
} }
int32_t initialize() override; int32_t Initialize() override;
private: private:
// virtual int32_t initialize() override; inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1; return dense_dim_total / shard_num + 1;
} }
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id, std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param); const std::vector<std::string> &param);
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id, std::future<int32_t> SendSaveCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param); const std::vector<std::string> &param);
bool _running = false; bool _running = false;
bool _flushing = false; bool _flushing = false;
...@@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient { ...@@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient {
std::thread _print_thread; std::thread _print_thread;
int push_sparse_async_shard_merge( int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
ValueAccessor *accessor); ValueAccessor *accessor);
int push_sparse_async_shard_push( int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor); DownpourBrpcClosure *closure, ValueAccessor *accessor);
...@@ -292,36 +281,36 @@ class BrpcPsClient : public PSClient { ...@@ -292,36 +281,36 @@ class BrpcPsClient : public PSClient {
_client_channels; // client2client _client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>> std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server _server_channels; // client2server
std::future<int32_t> push_dense_raw_gradient(int table_id, std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data, float *total_send_data,
size_t total_send_data_size, size_t total_send_data_size,
void *done) override; void *done) override;
std::future<int32_t> push_sparse_raw_gradient(size_t table_id, std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num, size_t num, void *done) override;
void *done) override;
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
std::future<int32_t> push_sparse_raw_gradient_partial( const uint64_t *keys,
size_t table_id, const uint64_t *keys, const float **update_values, const float **update_values,
uint32_t num, void *done, int pserver_idx) override; uint32_t num, void *done,
int pserver_idx) override;
std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
const float **update_values, std::future<int32_t> PushSparseParam(size_t table_id, const uint64_t *keys,
size_t num, void *done) override; const float **update_values, size_t num,
std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys, void *done) override;
const float **update_values, std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
size_t num) override; const float **update_values,
void push_sparse_task_consume(); size_t num) override;
void PushSparseTaskConsume();
private: private:
int32_t start_client_service(); int32_t StartClientService();
void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data, float *total_send_data, size_t total_send_data_size,
size_t total_send_data_size, DownpourBrpcClosure *closure);
DownpourBrpcClosure *closure);
float _mae = 0; float _mae = 0;
float _mse = 0; float _mse = 0;
uint16_t _push_times = 0; uint16_t _push_times = 0;
......
...@@ -31,7 +31,7 @@ class RpcController; ...@@ -31,7 +31,7 @@ class RpcController;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t BrpcPsServer::initialize() { int32_t BrpcPsServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param(); auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) { if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter"; LOG(ERROR) << "miss service_class in ServerServiceParameter";
...@@ -46,7 +46,7 @@ int32_t BrpcPsServer::initialize() { ...@@ -46,7 +46,7 @@ int32_t BrpcPsServer::initialize() {
} }
_service.reset(service); _service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) { if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:" LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class(); << service_config.service_class();
return -1; return -1;
...@@ -59,7 +59,7 @@ int32_t BrpcPsServer::initialize() { ...@@ -59,7 +59,7 @@ int32_t BrpcPsServer::initialize() {
return 0; return 0;
} }
uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port); std::string ip_port = ip + ":" + std::to_string(port);
...@@ -68,7 +68,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { ...@@ -68,7 +68,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
brpc::ServerOptions options; brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency(); int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers(); auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads; options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) { if (_server.Start(ip_port.c_str(), &options) != 0) {
...@@ -83,7 +83,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { ...@@ -83,7 +83,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
} }
} }
_environment->registe_ps_server(ip, port, _rank); _environment->RegistePsServer(ip, port, _rank);
cv_.wait(lock, [&] { return stoped_; }); cv_.wait(lock, [&] { return stoped_; });
PSHost host; PSHost host;
...@@ -93,31 +93,30 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { ...@@ -93,31 +93,30 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
return host.rank; return host.rank;
} }
int32_t BrpcPsServer::port() { return _server.listen_address().port; } int32_t BrpcPsServer::Port() { return _server.listen_address().port; }
int32_t BrpcPsService::initialize() { int32_t BrpcPsService::Initialize() {
_is_initialize_shard_info = false; _is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &BrpcPsService::stop_server; _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::StopServer;
_service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::pull_dense; _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::PullDense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::push_dense; _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::PushDense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::pull_sparse; _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::PullSparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::push_sparse; _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::PushSparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::save_one_table; _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::SaveOneTable;
_service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::save_all_table; _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::SaveAllTable;
_service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::shrink_table; _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::ShrinkTable;
_service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::load_one_table; _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::load_all_table; _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::LoadAllTable;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::clear_one_table; _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::ClearOneTable;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::clear_all_table; _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::ClearAllTable;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::push_dense_param; _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::PushDenseParam;
_service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::print_table_stat; _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::PrintTableStat;
_service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::pull_geo_param; _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::PullGeoParam;
_service_handler_map[PS_PUSH_SPARSE_PARAM] = _service_handler_map[PS_PUSH_SPARSE_PARAM] = &BrpcPsService::PushSparseParam;
&BrpcPsService::push_sparse_param; _service_handler_map[PS_BARRIER] = &BrpcPsService::Barrier;
_service_handler_map[PS_BARRIER] = &BrpcPsService::barrier; _service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler;
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler; _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler; _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::push_global_step;
auto &profiler = CostProfiler::instance(); auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense"); profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense"); profiler.register_profiler("pserver_server_push_dense");
...@@ -125,7 +124,7 @@ int32_t BrpcPsService::initialize() { ...@@ -125,7 +124,7 @@ int32_t BrpcPsService::initialize() {
profiler.register_profiler("pserver_server_push_sparse"); profiler.register_profiler("pserver_server_push_sparse");
// shard初始化,server启动后才可从env获取到server_list的shard信息 // shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info(); InitializeShardInfo();
return 0; return 0;
} }
...@@ -138,16 +137,16 @@ int32_t BrpcPsService::initialize() { ...@@ -138,16 +137,16 @@ int32_t BrpcPsService::initialize() {
return -1; \ return -1; \
} }
int32_t BrpcPsService::initialize_shard_info() { int32_t BrpcPsService::InitializeShardInfo() {
if (!_is_initialize_shard_info) { if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex); std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) { if (_is_initialize_shard_info) {
return 0; return 0;
} }
size_t shard_num = _server->environment()->get_ps_servers().size(); size_t shard_num = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->table()); auto &table_map = *(_server->GetTable());
for (auto itr : table_map) { for (auto itr : table_map) {
itr.second->set_shard(_rank, shard_num); itr.second->SetShard(_rank, shard_num);
} }
_is_initialize_shard_info = true; _is_initialize_shard_info = true;
} }
...@@ -167,7 +166,7 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, ...@@ -167,7 +166,7 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
response->set_err_code(0); response->set_err_code(0);
response->set_err_msg(""); response->set_err_msg("");
auto *table = _server->table(request->table_id()); auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base); brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id()); auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) { if (itr == _service_handler_map.end()) {
...@@ -185,11 +184,11 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, ...@@ -185,11 +184,11 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
} }
} }
int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
"PsService->pull_dense", platform::TracerEventType::Communication, 1); "PsService->PullDense", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) { if (request.params_size() < 1) {
set_response_code( set_response_code(
...@@ -206,14 +205,15 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, ...@@ -206,14 +205,15 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
} }
auto res_data = butil::get_object<std::vector<float>>(); auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->value_accesor()->GetTableInfo(SELECT_SIZE) / res_data->resize(num * table->ValueAccesor()->GetTableInfo(SELECT_SIZE) /
sizeof(float)); sizeof(float));
TableContext table_context; TableContext table_context;
table_context.value_type = Dense; table_context.value_type = Dense;
table_context.pull_context.values = res_data->data(); table_context.pull_context.values = res_data->data();
table_context.num = num; table_context.num = num;
table->Pull(table_context); table->Pull(table_context);
// table->pull_dense(res_data->data(), num); // table->PullDense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()), cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float)); res_data->size() * sizeof(float));
...@@ -222,13 +222,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, ...@@ -222,13 +222,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t BrpcPsService::push_dense_param(Table *table, int32_t BrpcPsService::PushDenseParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense_param", platform::RecordEvent record_event(
platform::TracerEventType::Communication, "PsService->PushDenseParam", platform::TracerEventType::Communication, 1);
1);
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_buffer; thread_local std::string push_buffer;
auto &req_io_buffer = cntl->request_attachment(); auto &req_io_buffer = cntl->request_attachment();
...@@ -245,17 +244,17 @@ int32_t BrpcPsService::push_dense_param(Table *table, ...@@ -245,17 +244,17 @@ int32_t BrpcPsService::push_dense_param(Table *table,
uint32_t num = *(const uint32_t *)data; uint32_t num = *(const uint32_t *)data;
const float *values = (const float *)(data + sizeof(uint32_t)); const float *values = (const float *)(data + sizeof(uint32_t));
if (table->push_dense_param(values, num) != 0) { if (table->PushDenseParam(values, num) != 0) {
set_response_code(response, -1, "push_dense_param failed"); set_response_code(response, -1, "PushDenseParam failed");
} }
return 0; return 0;
} }
int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
"PsService->push_dense", platform::TracerEventType::Communication, 1); "PsService->PushDense", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
auto req_buffer_size = request.data().size(); auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) { if (req_buffer_size < 1) {
...@@ -278,14 +277,14 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, ...@@ -278,14 +277,14 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
// const float *values = (const float *)(request.data().data() + // const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t)); // sizeof(uint32_t));
if (table->Push(table_context) != 0) { if (table->Push(table_context) != 0) {
// if (table->push_dense(values, num) != 0) { // if (table->PushDense(values, num) != 0) {
set_response_code(response, -1, "push_dense failed"); set_response_code(response, -1, "PushDense failed");
} }
return 0; return 0;
} }
int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
...@@ -299,15 +298,15 @@ int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request, ...@@ -299,15 +298,15 @@ int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request,
auto trainer_id = request.client_id(); auto trainer_id = request.client_id();
auto barrier_type = request.params(0); auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type); table->Barrier(trainer_id, barrier_type);
return 0; return 0;
} }
int32_t BrpcPsService::push_sparse_param(Table *table, int32_t BrpcPsService::PushSparseParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse_param", platform::RecordEvent record_event("PsService->PushSparseParam",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
...@@ -331,16 +330,16 @@ int32_t BrpcPsService::push_sparse_param(Table *table, ...@@ -331,16 +330,16 @@ int32_t BrpcPsService::push_sparse_param(Table *table,
const uint64_t *keys = (const uint64_t *)push_data.data(); const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values = const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num); (const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse_param(keys, values, num) != 0) { if (table->PushSparseParam(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse_param error"); set_response_code(response, -1, "PushSparseParam error");
} }
return 0; return 0;
} }
int32_t BrpcPsService::pull_geo_param(Table *table, int32_t BrpcPsService::PullGeoParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
"PsService->pull_geo_param", platform::TracerEventType::Communication, 1); "PsService->pull_geo_param", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
...@@ -350,7 +349,7 @@ int32_t BrpcPsService::pull_geo_param(Table *table, ...@@ -350,7 +349,7 @@ int32_t BrpcPsService::pull_geo_param(Table *table,
std::vector<float> values; std::vector<float> values;
std::vector<uint64_t> ids; std::vector<uint64_t> ids;
table->pull_geo_param(trainer_id, &values, &ids); table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size(); uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
...@@ -361,12 +360,11 @@ int32_t BrpcPsService::pull_geo_param(Table *table, ...@@ -361,12 +360,11 @@ int32_t BrpcPsService::pull_geo_param(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::pull_sparse(Table *table, int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl) {
brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
"PsService->pull_sparse", platform::TracerEventType::Communication, 1); "PsService->PullSparse", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
auto &req_io_buffer = cntl->request_attachment(); auto &req_io_buffer = cntl->request_attachment();
...@@ -386,7 +384,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, ...@@ -386,7 +384,7 @@ int32_t BrpcPsService::pull_sparse(Table *table,
CostTimer timer("pserver_server_pull_sparse"); CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); uint32_t num = *(uint32_t *)(request.params(0).c_str());
auto dim = table->value_accesor()->GetTableInfo(SELECT_DIM); auto dim = table->ValueAccesor()->GetTableInfo(SELECT_DIM);
thread_local std::string req_buffer; thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size); req_buffer.reserve(req_buffer_size);
...@@ -405,7 +403,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, ...@@ -405,7 +403,7 @@ int32_t BrpcPsService::pull_sparse(Table *table,
table_context.pull_context.pull_value = value; table_context.pull_context.pull_value = value;
table_context.pull_context.values = res_data->data(); table_context.pull_context.values = res_data->data();
table->Pull(table_context); table->Pull(table_context);
// table->pull_sparse(res_data->data(), value); // table->PullSparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()), cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float)); res_data->size() * sizeof(float));
...@@ -413,12 +411,11 @@ int32_t BrpcPsService::pull_sparse(Table *table, ...@@ -413,12 +411,11 @@ int32_t BrpcPsService::pull_sparse(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::push_sparse(Table *table, int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request,
const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl) {
brpc::Controller *cntl) {
platform::RecordEvent record_event( platform::RecordEvent record_event(
"PsService->push_sparse", platform::TracerEventType::Communication, 1); "PsService->PushSparse", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data(); auto &push_data = request.data();
if (push_data.size() < 1) { if (push_data.size() < 1) {
...@@ -448,18 +445,18 @@ int32_t BrpcPsService::push_sparse(Table *table, ...@@ -448,18 +445,18 @@ int32_t BrpcPsService::push_sparse(Table *table,
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) * // const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num); // num);
if (table->Push(table_context) != 0) { if (table->Push(table_context) != 0) {
// if (table->push_sparse(keys, values, num) != 0) { // if (table->PushSparse(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse error"); set_response_code(response, -1, "PushSparse error");
} }
return 0; return 0;
} }
int32_t BrpcPsService::print_table_stat(Table *table, int32_t BrpcPsService::PrintTableStat(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat(); std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar; paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second; ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length()); std::string table_info(ar.Buffer(), ar.Length());
...@@ -468,10 +465,10 @@ int32_t BrpcPsService::print_table_stat(Table *table, ...@@ -468,10 +465,10 @@ int32_t BrpcPsService::print_table_stat(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::load_one_table(Table *table, int32_t BrpcPsService::LoadOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
...@@ -479,20 +476,20 @@ int32_t BrpcPsService::load_one_table(Table *table, ...@@ -479,20 +476,20 @@ int32_t BrpcPsService::load_one_table(Table *table,
"PsRequestMessage.datas is requeired at least 2 for path & load_param"); "PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1; return -1;
} }
if (table->load(request.params(0), request.params(1)) != 0) { if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed"); set_response_code(response, -1, "table load failed");
return -1; return -1;
} }
return 0; return 0;
} }
int32_t BrpcPsService::load_all_table(Table *table, int32_t BrpcPsService::LoadAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
auto &table_map = *(_server->table()); auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) { for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) { if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed"; LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1; return -1;
} }
...@@ -500,10 +497,10 @@ int32_t BrpcPsService::load_all_table(Table *table, ...@@ -500,10 +497,10 @@ int32_t BrpcPsService::load_all_table(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::save_one_table(Table *table, int32_t BrpcPsService::SaveOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
...@@ -511,12 +508,12 @@ int32_t BrpcPsService::save_one_table(Table *table, ...@@ -511,12 +508,12 @@ int32_t BrpcPsService::save_one_table(Table *table,
"PsRequestMessage.datas is requeired at least 2, path&mode"); "PsRequestMessage.datas is requeired at least 2, path&mode");
return -1; return -1;
} }
table->flush(); table->Flush();
int32_t feasign_size = 0; int32_t feasign_size = 0;
VLOG(3) << "save table " << request.params(0) << " " << request.params(1); VLOG(3) << "save table " << request.params(0) << " " << request.params(1);
feasign_size = table->save(request.params(0), request.params(1)); feasign_size = table->Save(request.params(0), request.params(1));
if (feasign_size < 0) { if (feasign_size < 0) {
set_response_code(response, -1, "table save failed"); set_response_code(response, -1, "table save failed");
return -1; return -1;
...@@ -524,16 +521,16 @@ int32_t BrpcPsService::save_one_table(Table *table, ...@@ -524,16 +521,16 @@ int32_t BrpcPsService::save_one_table(Table *table,
return feasign_size; return feasign_size;
} }
int32_t BrpcPsService::save_all_table(Table *table, int32_t BrpcPsService::SaveAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
auto &table_map = *(_server->table()); auto &table_map = *(_server->GetTable());
int32_t all_feasign_size = 0; int32_t all_feasign_size = 0;
int32_t feasign_size = 0; int32_t feasign_size = 0;
for (auto &itr : table_map) { for (auto &itr : table_map) {
feasign_size = save_one_table(itr.second.get(), request, response, cntl); feasign_size = SaveOneTable(itr.second.get(), request, response, cntl);
if (feasign_size < 0) { if (feasign_size < 0) {
LOG(ERROR) << "save table[" << itr.first << "] failed"; LOG(ERROR) << "save table[" << itr.first << "] failed";
return -1; return -1;
...@@ -542,10 +539,10 @@ int32_t BrpcPsService::save_all_table(Table *table, ...@@ -542,10 +539,10 @@ int32_t BrpcPsService::save_all_table(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::shrink_table(Table *table, int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) { if (request.params_size() < 1) {
set_response_code( set_response_code(
...@@ -553,8 +550,8 @@ int32_t BrpcPsService::shrink_table(Table *table, ...@@ -553,8 +550,8 @@ int32_t BrpcPsService::shrink_table(Table *table,
"PsRequestMessage.datas is requeired at least 1, threshold"); "PsRequestMessage.datas is requeired at least 1, threshold");
return -1; return -1;
} }
table->flush(); table->Flush();
if (table->shrink(request.params(0)) != 0) { if (table->Shrink(request.params(0)) != 0) {
set_response_code(response, -1, "table shrink failed"); set_response_code(response, -1, "table shrink failed");
return -1; return -1;
} }
...@@ -562,63 +559,62 @@ int32_t BrpcPsService::shrink_table(Table *table, ...@@ -562,63 +559,62 @@ int32_t BrpcPsService::shrink_table(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::clear_one_table(Table *table, int32_t BrpcPsService::ClearOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
table->flush(); table->Flush();
table->clear(); table->Clear();
return 0; return 0;
} }
int32_t BrpcPsService::clear_all_table(Table *table, int32_t BrpcPsService::ClearAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
auto &table_map = *(_server->table()); auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) { for (auto &itr : table_map) {
if (clear_one_table(itr.second.get(), request, response, cntl) != 0) { if (ClearOneTable(itr.second.get(), request, response, cntl) != 0) {
return -1; return -1;
} }
} }
return 0; return 0;
} }
int32_t BrpcPsService::stop_server(Table *table, int32_t BrpcPsService::StopServer(Table *table, const PsRequestMessage &request,
const PsRequestMessage &request, PsResponseMessage &response,
PsResponseMessage &response, brpc::Controller *cntl) {
brpc::Controller *cntl) {
auto *p_server = _server; auto *p_server = _server;
std::thread t_stop([p_server]() { std::thread t_stop([p_server]() {
p_server->stop(); p_server->Stop();
VLOG(3) << "Server Stoped"; VLOG(3) << "Server Stoped";
}); });
t_stop.detach(); t_stop.detach();
return 0; return 0;
} }
int32_t BrpcPsService::stop_profiler(Table *table, int32_t BrpcPsService::StopProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault, platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank)); string::Sprintf("server_%s_profile", _rank));
return 0; return 0;
} }
int32_t BrpcPsService::start_profiler(Table *table, int32_t BrpcPsService::StartProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU); platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0; return 0;
} }
int32_t BrpcPsService::push_global_step(Table *table, int32_t BrpcPsService::PushGlobalStep(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response); CHECK_TABLE_EXIST(table, request, response);
auto req_buffer_size = request.data().size(); auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) { if (req_buffer_size < 1) {
...@@ -629,7 +625,7 @@ int32_t BrpcPsService::push_global_step(Table *table, ...@@ -629,7 +625,7 @@ int32_t BrpcPsService::push_global_step(Table *table,
const int64_t *values = const int64_t *values =
(const int64_t *)(request.data().data() + sizeof(uint32_t)); (const int64_t *)(request.data().data() + sizeof(uint32_t));
auto trainer_id = request.client_id(); auto trainer_id = request.client_id();
if (table->push_dense(values, trainer_id) != 0) { if (table->PushDense(values, trainer_id) != 0) {
set_response_code(response, -1, "run_program failed"); set_response_code(response, -1, "run_program failed");
} }
......
...@@ -41,8 +41,8 @@ class BrpcPsServer : public PSServer { ...@@ -41,8 +41,8 @@ class BrpcPsServer : public PSServer {
public: public:
BrpcPsServer() {} BrpcPsServer() {}
virtual ~BrpcPsServer() {} virtual ~BrpcPsServer() {}
virtual uint64_t start(const std::string &ip, uint32_t port); virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t stop() { virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true; stoped_ = true;
cv_.notify_all(); cv_.notify_all();
...@@ -51,10 +51,10 @@ class BrpcPsServer : public PSServer { ...@@ -51,10 +51,10 @@ class BrpcPsServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
int32_t port(); int32_t Port();
private: private:
virtual int32_t initialize(); virtual int32_t Initialize();
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
bool stoped_ = false; bool stoped_ = false;
...@@ -71,7 +71,7 @@ typedef int32_t (BrpcPsService::*serviceHandlerFunc)( ...@@ -71,7 +71,7 @@ typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
class BrpcPsService : public PsBaseService { class BrpcPsService : public PsBaseService {
public: public:
virtual int32_t initialize() override; virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, const PsRequestMessage *request,
...@@ -79,50 +79,49 @@ class BrpcPsService : public PsBaseService { ...@@ -79,50 +79,49 @@ class BrpcPsService : public PsBaseService {
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
private: private:
int32_t initialize_shard_info(); int32_t InitializeShardInfo();
int32_t pull_dense(Table *table, const PsRequestMessage &request, int32_t PullDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense(Table *table, const PsRequestMessage &request, int32_t PushDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense_param(Table *table, const PsRequestMessage &request, int32_t PushDenseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request, int32_t PushSparseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullGeoParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse(Table *table, const PsRequestMessage &request, int32_t PushSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request, int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t shrink_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_one_table(Table *table, const PsRequestMessage &request, int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_all_table(Table *table, const PsRequestMessage &request, int32_t SaveOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request, int32_t SaveAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t ShrinkTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request, int32_t ClearOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t ClearAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request, int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_global_step(Table *table, const PsRequestMessage &request, int32_t PushGlobalStep(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
bool _is_initialize_shard_info; bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex; std::mutex _initialize_shard_mutex;
......
...@@ -39,7 +39,7 @@ inline double GetCurrentUS() { ...@@ -39,7 +39,7 @@ inline double GetCurrentUS() {
Communicator::Communicator() {} Communicator::Communicator() {}
void Communicator::init_gflag(const std::string &gflags) { void Communicator::InitGFlag(const std::string &gflags) {
VLOG(3) << "Init With Gflags:" << gflags; VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags); std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) { if (flags.size() < 1) {
...@@ -73,7 +73,7 @@ void Communicator::InitBrpcClient( ...@@ -73,7 +73,7 @@ void Communicator::InitBrpcClient(
} }
std::vector<uint64_t> Communicator::GetClientInfo() { std::vector<uint64_t> Communicator::GetClientInfo() {
std::vector<uint64_t> res = _ps_env.get_client_info(); std::vector<uint64_t> res = _ps_env.GetClientInfo();
for (auto rr : res) { for (auto rr : res) {
VLOG(2) << "Communicator::GetClientInfo " << rr; VLOG(2) << "Communicator::GetClientInfo " << rr;
} }
...@@ -82,7 +82,7 @@ std::vector<uint64_t> Communicator::GetClientInfo() { ...@@ -82,7 +82,7 @@ std::vector<uint64_t> Communicator::GetClientInfo() {
int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) { int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
int node = host_sign_list.size(); int node = host_sign_list.size();
return _ps_env.set_ps_clients(host_sign_list.data(), node); return _ps_env.SetPsClients(host_sign_list.data(), node);
} }
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
...@@ -114,7 +114,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, ...@@ -114,7 +114,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
} }
} }
auto status = auto status =
_worker_ptr->pull_dense(regions.data(), regions.size(), table_id); _worker_ptr->PullDense(regions.data(), regions.size(), table_id);
status.wait(); status.wait();
for (auto &t : varnames) { for (auto &t : varnames) {
...@@ -177,7 +177,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames, ...@@ -177,7 +177,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
} }
} }
auto status = auto status =
_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); _worker_ptr->PushDenseParam(regions.data(), regions.size(), table_id);
status.wait(); status.wait();
VLOG(4) << "RPC Send Dense Param " << table_id << " done!"; VLOG(4) << "RPC Send Dense Param " << table_id << " done!";
return; return;
...@@ -190,9 +190,9 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { ...@@ -190,9 +190,9 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
auto &var_names = ctx.origin_varnames; auto &var_names = ctx.origin_varnames;
auto &table_id = ctx.table_id; auto &table_id = ctx.table_id;
auto dense_data = std::make_shared<std::vector<float>>(); auto dense_data = std::make_shared<std::vector<float>>();
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->GetServerNums();
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(ctx.height_sections[0], request_call_num); DenseDimPerShard(ctx.height_sections[0], request_call_num);
dense_data->resize(num_per_shard * dense_data->resize(num_per_shard *
request_call_num); // accessor->update_dim() = 1 request_call_num); // accessor->update_dim() = 1
float *data = dense_data->data(); float *data = dense_data->data();
...@@ -222,8 +222,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { ...@@ -222,8 +222,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
closure->set_promise_value(ret); closure->set_promise_value(ret);
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->push_dense_raw_gradient( auto status = _worker_ptr->PushDenseRawGradient(table_id, data,
table_id, data, dense_data->size(), closure); dense_data->size(), closure);
status.wait(); status.wait();
return; return;
} }
...@@ -233,7 +233,7 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, ...@@ -233,7 +233,7 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
platform::RecordEvent record_event("Communicator->RpcSendSparseParam", platform::RecordEvent record_event("Communicator->RpcSendSparseParam",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<float *> push_g_vec; std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(varname); auto *send_var = scope.FindVar(varname);
...@@ -260,9 +260,9 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, ...@@ -260,9 +260,9 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto status = _worker_ptr->push_sparse_param( auto status = _worker_ptr->PushSparseParam(table_id, sparse_push_keys.data(),
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure); sparse_push_keys.size(), closure);
status.wait(); status.wait();
return; return;
} }
...@@ -272,7 +272,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -272,7 +272,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
platform::RecordEvent record_event("Communicator->RpcSendSparse", platform::RecordEvent record_event("Communicator->RpcSendSparse",
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<uint64_t> sparse_push_keys; std::vector<uint64_t> sparse_push_keys;
std::vector<float *> push_g_vec; std::vector<float *> push_g_vec;
...@@ -313,7 +313,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -313,7 +313,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
closure->set_promise_value(ret); closure->set_promise_value(ret);
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->push_sparse_raw_gradient( auto status = _worker_ptr->PushSparseRawGradient(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure); sparse_push_keys.size(), closure);
status.wait(); status.wait();
...@@ -340,7 +340,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id, ...@@ -340,7 +340,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
bool training = true; bool training = true;
auto status = _worker_ptr->pull_sparse_param( auto status = _worker_ptr->PullSparseParam(
(float **)push_g_vec.data(), table_id, // NOLINT (float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training); sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait(); status.wait();
...@@ -376,11 +376,11 @@ void Communicator::RpcProfilerControl() { ...@@ -376,11 +376,11 @@ void Communicator::RpcProfilerControl() {
if (!do_server_profiler_ && platform::IsProfileEnabled()) { if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag // send profiler start flag
do_server_profiler_ = true; do_server_profiler_ = true;
auto start_status = _worker_ptr->start_profiler(); auto start_status = _worker_ptr->StartProfiler();
start_status.wait(); start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) { } else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag // send profiler end flag
auto stop_status = _worker_ptr->stop_profiler(); auto stop_status = _worker_ptr->StopProfiler();
stop_status.wait(); stop_status.wait();
do_server_profiler_ = false; do_server_profiler_ = false;
} }
...@@ -396,7 +396,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, ...@@ -396,7 +396,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
platform::TracerEventType::Communication, platform::TracerEventType::Communication,
1); 1);
auto &table_id = ctx.table_id; auto &table_id = ctx.table_id;
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->GetServerNums();
auto &var_name = STEP_COUNTER; auto &var_name = STEP_COUNTER;
auto *out_var = send_scope->Var(var_name); auto *out_var = send_scope->Var(var_name);
...@@ -416,7 +416,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, ...@@ -416,7 +416,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto status = _worker_ptr->push_global_step(table_id, data, closure); auto status = _worker_ptr->PushGlobalStep(table_id, data, closure);
status.wait(); status.wait();
return; return;
} }
...@@ -605,8 +605,8 @@ void AsyncCommunicator::PullSparseToTensorSync( ...@@ -605,8 +605,8 @@ void AsyncCommunicator::PullSparseToTensorSync(
} }
} }
auto status = auto status =
_worker_ptr->pull_sparse(pull_result_ptr.data(), table_id, _worker_ptr->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.data(), fea_keys.size(), is_training); fea_keys.size(), is_training);
status.wait(); status.wait();
auto ret = status.get(); auto ret = status.get();
if (ret != 0) { if (ret != 0) {
...@@ -738,9 +738,9 @@ void AsyncCommunicator::PushSparseFromTensorAsync( ...@@ -738,9 +738,9 @@ void AsyncCommunicator::PushSparseFromTensorAsync(
this->Check(table_id), true, this->Check(table_id), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id)); "can not find table: %s, please check your config", table_id));
auto status = _worker_ptr->push_sparse(table_id, push_keys.data(), auto status = _worker_ptr->PushSparse(table_id, push_keys.data(),
(const float **)push_g_vec.data(), (const float **)push_g_vec.data(),
push_keys.size()); push_keys.size());
} }
void HalfAsyncCommunicator::MainThread() { void HalfAsyncCommunicator::MainThread() {
...@@ -813,7 +813,7 @@ void AsyncCommunicator::Stop() { ...@@ -813,7 +813,7 @@ void AsyncCommunicator::Stop() {
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
} else { } else {
// _worker_ptr->finalize_worker(); // _worker_ptr->FinalizeWorker();
VLOG(1) << "client finalize_worker done"; VLOG(1) << "client finalize_worker done";
if (recv_thread_) { if (recv_thread_) {
VLOG(1) << "stop recv thread"; VLOG(1) << "stop recv thread";
...@@ -1327,7 +1327,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, ...@@ -1327,7 +1327,7 @@ void GeoCommunicator::SendSparse(const std::string &varname,
closure->set_promise_value(ret); closure->set_promise_value(ret);
--_async_call_num; --_async_call_num;
}); });
auto status = _worker_ptr->push_sparse_raw_gradient_partial( auto status = _worker_ptr->PushSparseRawGradientPartial(
table_id, (const uint64_t *)sparse_ids.data(), table_id, (const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx); (const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx);
status.wait(); status.wait();
...@@ -1345,7 +1345,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, ...@@ -1345,7 +1345,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
// 1. recv from pserver // 1. recv from pserver
std::vector<uint64_t> keys; std::vector<uint64_t> keys;
std::vector<float> values; std::vector<float> values;
auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx); auto status = _worker_ptr->PullGeoParam(table_id, &values, &keys, ep_idx);
status.wait(); status.wait();
std::string param = SplitedGradToParam(varname); std::string param = SplitedGradToParam(varname);
......
...@@ -299,7 +299,7 @@ class Communicator { ...@@ -299,7 +299,7 @@ class Communicator {
virtual void Barrier() {} virtual void Barrier() {}
virtual void BarrierWithTable(uint32_t barrier_type) { virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
rets.wait(); rets.wait();
int status = rets.get(); int status = rets.get();
PADDLE_ENFORCE_EQ(status, 0, PADDLE_ENFORCE_EQ(status, 0,
...@@ -310,7 +310,7 @@ class Communicator { ...@@ -310,7 +310,7 @@ class Communicator {
virtual void CreateC2CConnection(int pserver_timeout_ms, virtual void CreateC2CConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) { int max_retry) {
_worker_ptr->create_client2client_connection( _worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
} }
...@@ -379,12 +379,12 @@ class Communicator { ...@@ -379,12 +379,12 @@ class Communicator {
std::unordered_map<std::string, std::string> envs; std::unordered_map<std::string, std::string> envs;
// 计算每个shard 对 dense的存储量 // 计算每个shard 对 dense的存储量
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) { uint32_t shard_num) {
return dense_dim_total / shard_num + 1; return dense_dim_total / shard_num + 1;
} }
void init_gflag(const std::string &gflags); void InitGFlag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param; paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0; int servers_ = 0;
......
...@@ -40,7 +40,7 @@ struct PSHost { ...@@ -40,7 +40,7 @@ struct PSHost {
// |---ip---|---port---|--rank--| // |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-| // |-32bit--|--20bit---|--12bit-|
uint64_t serialize_to_uint64() { uint64_t SerializeToUint64() {
uint64_t host_label = 0; uint64_t host_label = 0;
host_label = inet_addr(ip.c_str()); host_label = inet_addr(ip.c_str());
host_label = host_label << 32; host_label = host_label << 32;
...@@ -49,7 +49,7 @@ struct PSHost { ...@@ -49,7 +49,7 @@ struct PSHost {
return host_label; return host_label;
} }
void parse_from_uint64(uint64_t host_label) { void ParseFromUint64(uint64_t host_label) {
static uint64_t rank_label_mask = (1L << 12) - 1; static uint64_t rank_label_mask = (1L << 12) - 1;
static uint64_t port_label_mask = (1L << 20) - 1; static uint64_t port_label_mask = (1L << 20) - 1;
rank = host_label & rank_label_mask; rank = host_label & rank_label_mask;
...@@ -58,17 +58,17 @@ struct PSHost { ...@@ -58,17 +58,17 @@ struct PSHost {
ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT
} }
std::string to_string() { std::string ToString() {
std::stringstream s; std::stringstream s;
s << "host: " << ip; s << "host: " << ip;
s << " port: " << port; s << " port: " << port;
s << " rank: " << rank; s << " rank: " << rank;
s << " uint: " << serialize_to_uint64(); s << " uint: " << SerializeToUint64();
return s.str(); return s.str();
} }
// for open source parameter server // for open source parameter server
std::string serialize_to_string() { std::string SerializeToString() {
std::stringstream s; std::stringstream s;
s << ip << ":"; s << ip << ":";
s << port << ":"; s << port << ":";
...@@ -76,16 +76,16 @@ struct PSHost { ...@@ -76,16 +76,16 @@ struct PSHost {
return s.str(); return s.str();
} }
void parse_from_string(std::string endpoint) { void ParseFromString(std::string endpoint) {
std::vector<std::string> endpoint_info; std::vector<std::string> endpoint_info;
string_split(endpoint, ':', &endpoint_info); StringSplit(endpoint, ':', &endpoint_info);
ip = endpoint_info[0]; ip = endpoint_info[0];
port = std::stoi(endpoint_info[1]); port = std::stoi(endpoint_info[1]);
rank = std::stoi(endpoint_info[2]); rank = std::stoi(endpoint_info[2]);
} }
void string_split(const std::string &str, char sep, void StringSplit(const std::string &str, char sep,
std::vector<std::string> *pieces, bool ignore_null = true) { std::vector<std::string> *pieces, bool ignore_null = true) {
pieces->clear(); pieces->clear();
if (str.empty()) { if (str.empty()) {
if (!ignore_null) { if (!ignore_null) {
...@@ -111,63 +111,60 @@ class PSEnvironment { ...@@ -111,63 +111,60 @@ class PSEnvironment {
explicit PSEnvironment() {} // NOLINT explicit PSEnvironment() {} // NOLINT
virtual ~PSEnvironment() {} virtual ~PSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
return 0; return 0;
} }
virtual int32_t set_ps_servers( virtual int32_t SetPsServers(
const std::vector<std::string> *host_endpoint_list, int node_num) { const std::vector<std::string> *host_endpoint_list, int node_num) {
return 0; return 0;
} }
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
return 0; return 0;
} }
virtual int32_t set_ps_clients(std::string *host_endpoint_list, virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
int node_num) {
return 0; return 0;
} }
virtual uint64_t get_local_host_sign() { return 0; } virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> get_ps_servers() const { return _ps_server_list; } virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t registe_ps_server(const std::string &ip, uint32_t port, virtual int32_t RegistePsServer(const std::string &ip, uint32_t port,
int32_t rank) { int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_server_list, return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
_ps_server_sign_set);
} }
virtual std::vector<PSHost> get_ps_clients() const { return _ps_client_list; } virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
virtual int32_t registe_ps_client(const std::string &ip, uint32_t port, virtual int32_t RegistePsClient(const std::string &ip, uint32_t port,
int32_t rank) { int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_client_list, return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
_ps_client_sign_set);
} }
virtual std::vector<uint64_t> get_client_info() { virtual std::vector<uint64_t> GetClientInfo() {
std::vector<uint64_t> client_info; std::vector<uint64_t> client_info;
for (auto &i : _ps_client_list) { for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_uint64()); client_info.push_back(i.SerializeToUint64());
} }
return client_info; return client_info;
} }
virtual std::vector<std::string> get_client_info(bool use_string_endpoint) { virtual std::vector<std::string> GetClientInfo(bool use_string_endpoint) {
if (use_string_endpoint) { if (use_string_endpoint) {
std::vector<std::string> client_info; std::vector<std::string> client_info;
for (auto &i : _ps_client_list) { for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_string()); client_info.push_back(i.SerializeToString());
} }
return client_info; return client_info;
} }
return {}; return {};
} }
virtual void set_trainers(int trainers) { trainers_ = trainers; } virtual void SetTrainers(int trainers) { trainers_ = trainers; }
virtual int get_trainers() { return trainers_; } virtual int GetTrainers() { return trainers_; }
protected: protected:
//注册一个host // NOLINT //注册一个host // NOLINT
virtual int32_t registe_ps_host( virtual int32_t RegistePsHost(
const std::string &ip, uint32_t port, int32_t rank, const std::string &ip, uint32_t port, int32_t rank,
std::vector<PSHost> &host_list, // NOLINT std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT std::unordered_set<uint64_t> &sign_set) { // NOLINT
...@@ -198,15 +195,15 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -198,15 +195,15 @@ class PaddlePSEnvironment : public PSEnvironment {
explicit PaddlePSEnvironment() {} // NOLINT explicit PaddlePSEnvironment() {} // NOLINT
virtual ~PaddlePSEnvironment() {} virtual ~PaddlePSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
_ps_server_list.clear(); _ps_server_list.clear();
_ps_server_sign_set.clear(); _ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) { for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) { if (host_sign_list[i] > 0) {
PSHost host; PSHost host;
host.parse_from_uint64(host_sign_list[i]); host.ParseFromUint64(host_sign_list[i]);
_ps_server_list.push_back(host); _ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.serialize_to_uint64()); _ps_server_sign_set.insert(host.SerializeToUint64());
} }
} }
std::sort( std::sort(
...@@ -215,14 +212,14 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -215,14 +212,14 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0; return 0;
} }
virtual int32_t set_ps_servers(const std::vector<std::string> *host_sign_list, virtual int32_t SetPsServers(const std::vector<std::string> *host_sign_list,
int node_num) { int node_num) {
_ps_server_list.clear(); _ps_server_list.clear();
_ps_server_sign_set.clear(); _ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) { for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") { if (host_sign_list->at(i) != "") {
PSHost host; PSHost host;
host.parse_from_string(host_sign_list->at(i)); host.ParseFromString(host_sign_list->at(i));
_ps_server_list.push_back(host); _ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.rank); _ps_server_sign_set.insert(host.rank);
} }
...@@ -233,15 +230,15 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -233,15 +230,15 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0; return 0;
} }
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
_ps_client_list.clear(); _ps_client_list.clear();
_ps_client_sign_set.clear(); _ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) { for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) { if (host_sign_list[i] > 0) {
PSHost host; PSHost host;
host.parse_from_uint64(host_sign_list[i]); host.ParseFromUint64(host_sign_list[i]);
_ps_client_list.push_back(host); _ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.serialize_to_uint64()); _ps_client_sign_set.insert(host.SerializeToUint64());
} }
} }
std::sort( std::sort(
...@@ -250,14 +247,14 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -250,14 +247,14 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0; return 0;
} }
virtual int32_t set_ps_clients(const std::vector<std::string> *host_sign_list, virtual int32_t SetPsClients(const std::vector<std::string> *host_sign_list,
int node_num) { int node_num) {
_ps_client_list.clear(); _ps_client_list.clear();
_ps_client_sign_set.clear(); _ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) { for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") { if (host_sign_list->at(i) != "") {
PSHost host; PSHost host;
host.parse_from_string(host_sign_list->at(i)); host.ParseFromString(host_sign_list->at(i));
_ps_client_list.push_back(host); _ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.rank); _ps_client_sign_set.insert(host.rank);
} }
...@@ -269,9 +266,9 @@ class PaddlePSEnvironment : public PSEnvironment { ...@@ -269,9 +266,9 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0; return 0;
} }
virtual uint64_t get_local_host_sign() { virtual uint64_t GetLocalHostSign() {
if (_ps_client_list.size() > 0) { if (_ps_client_list.size() > 0) {
return _ps_client_list[0].serialize_to_uint64(); return _ps_client_list[0].SerializeToUint64();
} else { } else {
return 0; return 0;
} }
......
...@@ -135,8 +135,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -135,8 +135,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size()); ->add_params(joint_feature_name.c_str(), joint_feature_name.size());
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -169,8 +168,7 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) { ...@@ -169,8 +168,7 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
closure->request(server_index)->set_table_id(table_id); closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id); closure->request(server_index)->set_client_id(_client_id);
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index), rpc_stub.service(closure->cntl(server_index),
closure->request(server_index), closure->request(server_index),
...@@ -238,9 +236,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -238,9 +236,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
->add_params((char *)weighted, ->add_params((char *)weighted,
sizeof(bool) * is_weighted_bucket[request_idx].size()); sizeof(bool) * is_weighted_bucket[request_idx].size());
} }
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -292,9 +289,8 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -292,9 +289,8 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(), ->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num); sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -362,9 +358,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -362,9 +358,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool)); closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
; ;
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), rpc_stub.service(closure->cntl(0), closure->request(0),
closure->response(0), closure); closure->response(0), closure);
...@@ -464,9 +459,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -464,9 +459,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool)); ->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -506,8 +500,8 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes( ...@@ -506,8 +500,8 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure->request(0)->set_client_id(_client_id); closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&sample_size, sizeof(int));
; ;
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure); closure);
...@@ -541,8 +535,7 @@ std::future<int32_t> GraphBrpcClient::load_graph_split_config( ...@@ -541,8 +535,7 @@ std::future<int32_t> GraphBrpcClient::load_graph_split_config(
closure->request(server_index)->set_table_id(table_id); closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id); closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params(path); closure->request(server_index)->add_params(path);
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index), rpc_stub.service(closure->cntl(server_index),
closure->request(server_index), closure->request(server_index),
...@@ -581,8 +574,7 @@ std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache( ...@@ -581,8 +574,7 @@ std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
closure->request(server_index) closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t)); ->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t)); closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index), rpc_stub.service(closure->cntl(server_index),
closure->request(server_index), closure->request(server_index),
...@@ -624,8 +616,8 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list( ...@@ -624,8 +616,8 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure->request(0)->add_params((char *)&start, sizeof(int)); closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int)); closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int)); closure->request(0)->add_params((char *)&step, sizeof(int));
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure); closure);
...@@ -717,8 +709,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -717,8 +709,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size()); ->add_params(set_feature.c_str(), set_feature.size());
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
...@@ -727,10 +718,10 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -727,10 +718,10 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
return fut; return fut;
} }
int32_t GraphBrpcClient::initialize() { int32_t GraphBrpcClient::Initialize() {
// set_shard_num(_config.shard_num()); // set_shard_num(_config.shard_num());
BrpcPsClient::initialize(); BrpcPsClient::Initialize();
server_size = get_server_nums(); server_size = GetServerNums();
graph_service = NULL; graph_service = NULL;
local_channel = NULL; local_channel = NULL;
return 0; return 0;
......
...@@ -97,12 +97,12 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -97,12 +97,12 @@ class GraphBrpcClient : public BrpcPsClient {
std::string path); std::string path);
virtual std::future<int32_t> remove_graph_node( virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<int64_t>& node_id_list); uint32_t table_id, std::vector<int64_t>& node_id_list);
virtual int32_t initialize(); virtual int32_t Initialize();
int get_shard_num() { return shard_num; } int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(int64_t id); int get_server_index_by_id(int64_t id);
void set_local_channel(int index) { void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index); this->local_channel = GetCmdChannel(index);
} }
void set_local_graph_service(GraphBrpcService* graph_service) { void set_local_graph_service(GraphBrpcService* graph_service) {
this->graph_service = graph_service; this->graph_service = graph_service;
......
...@@ -33,7 +33,7 @@ namespace distributed { ...@@ -33,7 +33,7 @@ namespace distributed {
return -1; \ return -1; \
} }
int32_t GraphBrpcServer::initialize() { int32_t GraphBrpcServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param(); auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) { if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter"; LOG(ERROR) << "miss service_class in ServerServiceParameter";
...@@ -48,7 +48,7 @@ int32_t GraphBrpcServer::initialize() { ...@@ -48,7 +48,7 @@ int32_t GraphBrpcServer::initialize() {
} }
_service.reset(service); _service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) { if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:" LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class(); << service_config.service_class();
return -1; return -1;
...@@ -61,11 +61,11 @@ int32_t GraphBrpcServer::initialize() { ...@@ -61,11 +61,11 @@ int32_t GraphBrpcServer::initialize() {
return 0; return 0;
} }
brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) { brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
return _pserver_channels[server_index].get(); return _pserver_channels[server_index].get();
} }
uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port); std::string ip_port = ip + ":" + std::to_string(port);
...@@ -73,20 +73,20 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { ...@@ -73,20 +73,20 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
brpc::ServerOptions options; brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency(); int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers(); auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads; options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) { if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port; LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
return 0; return 0;
} }
_environment->registe_ps_server(ip, port, _rank); _environment->RegistePsServer(ip, port, _rank);
return 0; return 0;
} }
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
this->rank = rank; this->rank = rank;
auto _env = environment(); auto _env = Environment();
brpc::ChannelOptions options; brpc::ChannelOptions options;
options.protocol = "baidu_std"; options.protocol = "baidu_std";
options.timeout_ms = 500000; options.timeout_ms = 500000;
...@@ -94,7 +94,7 @@ int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { ...@@ -94,7 +94,7 @@ int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
options.connect_timeout_ms = 10000; options.connect_timeout_ms = 10000;
options.max_retry = 3; options.max_retry = 3;
std::vector<PSHost> server_list = _env->get_ps_servers(); std::vector<PSHost> server_list = _env->GetPsServers();
_pserver_channels.resize(server_list.size()); _pserver_channels.resize(server_list.size());
std::ostringstream os; std::ostringstream os;
std::string server_ip_port; std::string server_ip_port;
...@@ -172,19 +172,18 @@ int32_t GraphBrpcService::remove_graph_node(Table *table, ...@@ -172,19 +172,18 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
((GraphTable *)table)->remove_graph_node(node_ids); ((GraphTable *)table)->remove_graph_node(node_ids);
return 0; return 0;
} }
int32_t GraphBrpcServer::port() { return _server.listen_address().port; } int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
int32_t GraphBrpcService::initialize() { int32_t GraphBrpcService::Initialize() {
_is_initialize_shard_info = false; _is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server; _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table; _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table; _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
_service_handler_map[PS_PRINT_TABLE_STAT] = _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
&GraphBrpcService::print_table_stat; _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier; _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler; _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list; _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] = _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
...@@ -207,21 +206,21 @@ int32_t GraphBrpcService::initialize() { ...@@ -207,21 +206,21 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] = _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
&GraphBrpcService::load_graph_split_config; &GraphBrpcService::load_graph_split_config;
// shard初始化,server启动后才可从env获取到server_list的shard信息 // shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info(); InitializeShardInfo();
return 0; return 0;
} }
int32_t GraphBrpcService::initialize_shard_info() { int32_t GraphBrpcService::InitializeShardInfo() {
if (!_is_initialize_shard_info) { if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex); std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) { if (_is_initialize_shard_info) {
return 0; return 0;
} }
server_size = _server->environment()->get_ps_servers().size(); server_size = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->table()); auto &table_map = *(_server->GetTable());
for (auto itr : table_map) { for (auto itr : table_map) {
itr.second->set_shard(_rank, server_size); itr.second->SetShard(_rank, server_size);
} }
_is_initialize_shard_info = true; _is_initialize_shard_info = true;
} }
...@@ -241,7 +240,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, ...@@ -241,7 +240,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
response->set_err_code(0); response->set_err_code(0);
response->set_err_msg(""); response->set_err_msg("");
auto *table = _server->table(request->table_id()); auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base); brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id()); auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) { if (itr == _service_handler_map.end()) {
...@@ -261,7 +260,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, ...@@ -261,7 +260,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
} }
} }
int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
...@@ -275,16 +274,16 @@ int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, ...@@ -275,16 +274,16 @@ int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
auto trainer_id = request.client_id(); auto trainer_id = request.client_id();
auto barrier_type = request.params(0); auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type); table->Barrier(trainer_id, barrier_type);
return 0; return 0;
} }
int32_t GraphBrpcService::print_table_stat(Table *table, int32_t GraphBrpcService::PrintTableStat(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat(); std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar; paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second; ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length()); std::string table_info(ar.Buffer(), ar.Length());
...@@ -293,10 +292,10 @@ int32_t GraphBrpcService::print_table_stat(Table *table, ...@@ -293,10 +292,10 @@ int32_t GraphBrpcService::print_table_stat(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::load_one_table(Table *table, int32_t GraphBrpcService::LoadOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
...@@ -304,20 +303,20 @@ int32_t GraphBrpcService::load_one_table(Table *table, ...@@ -304,20 +303,20 @@ int32_t GraphBrpcService::load_one_table(Table *table,
"PsRequestMessage.datas is requeired at least 2 for path & load_param"); "PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1; return -1;
} }
if (table->load(request.params(0), request.params(1)) != 0) { if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed"); set_response_code(response, -1, "table load failed");
return -1; return -1;
} }
return 0; return 0;
} }
int32_t GraphBrpcService::load_all_table(Table *table, int32_t GraphBrpcService::LoadAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
auto &table_map = *(_server->table()); auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) { for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) { if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed"; LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1; return -1;
} }
...@@ -325,13 +324,13 @@ int32_t GraphBrpcService::load_all_table(Table *table, ...@@ -325,13 +324,13 @@ int32_t GraphBrpcService::load_all_table(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::stop_server(Table *table, int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server; GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
std::thread t_stop([p_server]() { std::thread t_stop([p_server]() {
p_server->stop(); p_server->Stop();
LOG(INFO) << "Server Stoped"; LOG(INFO) << "Server Stoped";
}); });
p_server->export_cv()->notify_all(); p_server->export_cv()->notify_all();
...@@ -339,19 +338,19 @@ int32_t GraphBrpcService::stop_server(Table *table, ...@@ -339,19 +338,19 @@ int32_t GraphBrpcService::stop_server(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::stop_profiler(Table *table, int32_t GraphBrpcService::StopProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault, platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank)); string::Sprintf("server_%s_profile", _rank));
return 0; return 0;
} }
int32_t GraphBrpcService::start_profiler(Table *table, int32_t GraphBrpcService::StartProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU); platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0; return 0;
} }
...@@ -475,7 +474,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -475,7 +474,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
std::vector<int64_t> local_id; std::vector<int64_t> local_id;
std::vector<int> local_query_idx; std::vector<int> local_query_idx;
size_t rank = get_rank(); size_t rank = GetRank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
...@@ -589,9 +588,9 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -589,9 +588,9 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool)); ->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub( PsService_Stub rpc_stub(
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index)); ((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
// GraphPsService_Stub rpc_stub = // GraphPsService_Stub rpc_stub =
// getServiceStub(get_cmd_channel(server_index)); // getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure); closure->response(request_idx), closure);
......
...@@ -31,10 +31,10 @@ class GraphBrpcServer : public PSServer { ...@@ -31,10 +31,10 @@ class GraphBrpcServer : public PSServer {
GraphBrpcServer() {} GraphBrpcServer() {}
virtual ~GraphBrpcServer() {} virtual ~GraphBrpcServer() {}
PsBaseService *get_service() { return _service.get(); } PsBaseService *get_service() { return _service.get(); }
virtual uint64_t start(const std::string &ip, uint32_t port); virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t build_peer2peer_connection(int rank); virtual int32_t build_peer2peer_connection(int rank);
virtual brpc::Channel *get_cmd_channel(size_t server_index); virtual brpc::Channel *GetCmdChannel(size_t server_index);
virtual int32_t stop() { virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0; if (stoped_) return 0;
stoped_ = true; stoped_ = true;
...@@ -43,12 +43,12 @@ class GraphBrpcServer : public PSServer { ...@@ -43,12 +43,12 @@ class GraphBrpcServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
int32_t port(); int32_t Port();
std::condition_variable *export_cv() { return &cv_; } std::condition_variable *export_cv() { return &cv_; }
private: private:
virtual int32_t initialize(); virtual int32_t Initialize();
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
bool stoped_ = false; bool stoped_ = false;
...@@ -66,7 +66,7 @@ typedef int32_t (GraphBrpcService::*serviceFunc)( ...@@ -66,7 +66,7 @@ typedef int32_t (GraphBrpcService::*serviceFunc)(
class GraphBrpcService : public PsBaseService { class GraphBrpcService : public PsBaseService {
public: public:
virtual int32_t initialize() override; virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, const PsRequestMessage *request,
...@@ -75,7 +75,7 @@ class GraphBrpcService : public PsBaseService { ...@@ -75,7 +75,7 @@ class GraphBrpcService : public PsBaseService {
protected: protected:
std::unordered_map<int32_t, serviceFunc> _service_handler_map; std::unordered_map<int32_t, serviceFunc> _service_handler_map;
int32_t initialize_shard_info(); int32_t InitializeShardInfo();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request, int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample_neighbors(Table *table, int32_t graph_random_sample_neighbors(Table *table,
...@@ -100,21 +100,21 @@ class GraphBrpcService : public PsBaseService { ...@@ -100,21 +100,21 @@ class GraphBrpcService : public PsBaseService {
int32_t remove_graph_node(Table *table, const PsRequestMessage &request, int32_t remove_graph_node(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request, int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request, int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request, int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request, int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request, int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request, int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighbors_across_multi_servers(Table *table, int32_t sample_neighbors_across_multi_servers(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
......
...@@ -25,7 +25,7 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient); ...@@ -25,7 +25,7 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient); REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient); REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
int32_t PSClient::configure( int32_t PSClient::Configure(
const PSParameter &config, const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions, const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env, size_t client_id) { PSEnvironment &env, size_t client_id) {
...@@ -51,10 +51,10 @@ int32_t PSClient::configure( ...@@ -51,10 +51,10 @@ int32_t PSClient::configure(
_table_accessors[work_param.downpour_table_param(i).table_id()].reset( _table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor); accessor);
} }
return initialize(); return Initialize();
} }
PSClient *PSClientFactory::create(const PSParameter &ps_config) { PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param(); const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) { if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter"; LOG(ERROR) << "miss downpour_server_param in ServerParameter";
...@@ -81,7 +81,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) { ...@@ -81,7 +81,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
return NULL; return NULL;
} }
TableManager::instance().initialize(); TableManager::Instance().Initialize();
VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success"; VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success";
return client; return client;
} }
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
...@@ -60,41 +59,6 @@ class PSClientClosure : public google::protobuf::Closure { ...@@ -60,41 +59,6 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises; std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
}; };
struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};
enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };
enum TrainingPhase { Init = 0, Train = 1, Save = 2 };
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};
struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
uint64_t *keys;
float **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};
class PSClient { class PSClient {
public: public:
PSClient() {} PSClient() {}
...@@ -102,41 +66,37 @@ class PSClient { ...@@ -102,41 +66,37 @@ class PSClient {
PSClient(PSClient &&) = delete; PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete; PSClient(const PSClient &) = delete;
virtual int32_t configure( // NOLINT virtual int32_t Configure( // NOLINT
const PSParameter &config, const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions, &regions,
PSEnvironment &_env, size_t client_id) final; // NOLINT PSEnvironment &_env, size_t client_id) final; // NOLINT
virtual int32_t create_client2client_connection( virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_timeout_ms, int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) = 0; int max_retry) = 0;
// 触发table数据退场 // 触发table数据退场
virtual std::future<int32_t> shrink(uint32_t table_id, virtual std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) = 0; const std::string threshold) = 0;
// 全量table进行数据load // 全量table进行数据load
virtual std::future<int32_t> load(const std::string &epoch, virtual std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// 指定table数据load // 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件 // 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch, virtual std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件 // 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;
// 清空table数据 // 清空table数据
virtual std::future<int32_t> clear() = 0; virtual std::future<int32_t> Clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0; virtual std::future<int32_t> Clear(uint32_t table_id) = 0;
// pull dense的参数部分,并分块填充到本地网络参数中 // pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数 // start和num用于拉取部分参数
...@@ -145,23 +105,19 @@ class PSClient { ...@@ -145,23 +105,19 @@ class PSClient {
// sender聚集同一区块的请求,累计多个填充buffer // sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回 // server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中 // 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num, virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留 size_t table_id) = 0; // 保留
virtual std::future<int32_t> Push(RequestContext &push_context) = 0;
// firstly push dense param for parameter server // firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold // this is neccessary because dense weight initialized in trainer on cold
// start // start
virtual std::future<int32_t> push_dense_param(const Region *regions, virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) = 0; size_t table_id) = 0;
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0; virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values // 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间 // keys和values的个数均为num个,每个value占用select_size空间
...@@ -169,15 +125,14 @@ class PSClient { ...@@ -169,15 +125,14 @@ class PSClient {
// 整合多个线程请求的keys,聚集并分散发送到server // 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值 // 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理. // is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual std::future<int32_t> pull_sparse(float **select_values, virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, size_t table_id, const uint64_t *keys,
const uint64_t *keys, size_t num, size_t num, bool is_training) = 0;
bool is_training) = 0;
virtual std::future<int32_t> PullSparseParam(float **select_values,
virtual std::future<int32_t> pull_sparse_param(float **select_values, size_t table_id,
size_t table_id, const uint64_t *keys, size_t num,
const uint64_t *keys, bool is_training) {
size_t num, bool is_training) {
VLOG(0) << "Did not implement"; VLOG(0) << "Did not implement";
std::promise<int32_t> promise; std::promise<int32_t> promise;
std::future<int> fut = promise.get_future(); std::future<int> fut = promise.get_future();
...@@ -185,10 +140,10 @@ class PSClient { ...@@ -185,10 +140,10 @@ class PSClient {
return fut; return fut;
} }
virtual ::std::future<int32_t> pull_sparse_ptr(char **select_values, virtual ::std::future<int32_t> PullSparsePtr(char **select_values,
size_t table_id, size_t table_id,
const uint64_t *keys, const uint64_t *keys,
size_t num) { size_t num) {
VLOG(0) << "Did not implement"; VLOG(0) << "Did not implement";
std::promise<int32_t> promise; std::promise<int32_t> promise;
std::future<int> fut = promise.get_future(); std::future<int> fut = promise.get_future();
...@@ -196,38 +151,38 @@ class PSClient { ...@@ -196,38 +151,38 @@ class PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> print_table_stat(uint32_t table_id) = 0; virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送 // 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> flush() = 0; virtual std::future<int32_t> Flush() = 0;
// server优雅退出 // server优雅退出
virtual std::future<int32_t> stop_server() = 0; virtual std::future<int32_t> StopServer() = 0;
// server profilera // server profilera
virtual std::future<int32_t> start_profiler() = 0; virtual std::future<int32_t> StartProfiler() = 0;
virtual std::future<int32_t> stop_profiler() = 0; virtual std::future<int32_t> StopProfiler() = 0;
virtual std::future<int32_t> barrier(size_t table_id, virtual std::future<int32_t> Barrier(size_t table_id,
uint32_t barrier_type) = 0; uint32_t barrier_type) = 0;
virtual std::future<int32_t> pull_geo_param(size_t table_id, virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values, std::vector<float> *values,
std::vector<uint64_t> *keys, std::vector<uint64_t> *keys,
int pserver_idx) = 0; int pserver_idx) = 0;
virtual std::future<int32_t> push_global_step(int table_id, virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data, int64_t *total_send_data,
void *done) = 0; void *done) = 0;
// recv table from server and save it in LodTensor // recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id, virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path) = 0; const std::string &path) = 0;
virtual void finalize_worker() = 0; virtual void FinalizeWorker() = 0;
// client to client, 消息发送 // client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type, virtual std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id, int to_client_id,
const std::string &msg) { const std::string &msg) {
VLOG(0) << "Did not implement"; VLOG(0) << "Did not implement";
std::promise<int32_t> promise; std::promise<int32_t> promise;
std::future<int> fut = promise.get_future(); std::future<int> fut = promise.get_future();
...@@ -238,13 +193,13 @@ class PSClient { ...@@ -238,13 +193,13 @@ class PSClient {
// client2client消息处理,std::function<int32_t (int, int, const std::string&) // client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg) // -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc; typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_client2client_msg_handler(int msg_type, virtual int RegisteClient2ClientMsgHandler(int msg_type,
MsgHandlerFunc handler) { MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler; _msg_handler_map[msg_type] = handler;
return 0; return 0;
} }
virtual int handle_client2client_msg(int msg_type, int from_client_id, virtual int HandleClient2ClientMsg(int msg_type, int from_client_id,
const std::string &msg) { const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type); auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) { if (itr == _msg_handler_map.end()) {
LOG(WARNING) << "unknown client2client_msg type:" << msg_type; LOG(WARNING) << "unknown client2client_msg type:" << msg_type;
...@@ -253,7 +208,7 @@ class PSClient { ...@@ -253,7 +208,7 @@ class PSClient {
return itr->second(msg_type, from_client_id, msg); return itr->second(msg_type, from_client_id, msg);
} }
virtual ValueAccessor *table_accessor(size_t table_id) { virtual ValueAccessor *GetTableAccessor(size_t table_id) {
auto itr = _table_accessors.find(table_id); auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) { if (itr == _table_accessors.end()) {
return NULL; return NULL;
...@@ -261,31 +216,31 @@ class PSClient { ...@@ -261,31 +216,31 @@ class PSClient {
return itr->second.get(); return itr->second.get();
} }
virtual size_t get_server_nums() = 0; virtual size_t GetServerNums() = 0;
virtual std::future<int32_t> push_dense_raw_gradient( virtual std::future<int32_t> PushDenseRawGradient(int table_id,
int table_id, float *total_send_data, size_t total_send_data_size, float *total_send_data,
void *done) = 0; size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient( virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) = 0; size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient_partial( virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) = 0; uint32_t num, void *done, int pserver_idx) = 0;
virtual std::future<int32_t> push_sparse_param(size_t table_id, virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num, void *done) = 0; size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse(size_t table_id, virtual std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
const uint64_t *keys, const float **update_values,
const float **update_values, size_t num) = 0;
size_t num) = 0;
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t Initialize() = 0;
size_t _client_id; size_t _client_id;
PSParameter _config; PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>> std::map<uint64_t, std::vector<paddle::distributed::Region>>
...@@ -333,7 +288,7 @@ REGISTER_PSCORE_REGISTERER(PSClient); ...@@ -333,7 +288,7 @@ REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory { class PSClientFactory {
public: public:
static PSClient *create(const PSParameter &config); static PSClient *Create(const PSParameter &config);
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -19,166 +19,91 @@ ...@@ -19,166 +19,91 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t PsLocalClient::initialize() { int32_t PsLocalClient::Initialize() {
const auto& downpour_param = _config.server_param().downpour_server_param(); const auto& downpour_param = _config.server_param().downpour_server_param();
TableManager::instance().initialize(); TableManager::Instance().Initialize();
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto* table = CREATE_PSCORE_CLASS( auto* table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class()); Table, downpour_param.downpour_table_param(i).table_class());
table->set_shard(0, 1); table->SetShard(0, 1);
table->initialize(downpour_param.downpour_table_param(i), table->Initialize(downpour_param.downpour_table_param(i),
_config.fs_client_param()); _config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
} }
return 0; return 0;
} }
::std::future<int32_t> PsLocalClient::shrink(uint32_t table_id, ::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
const std::string threshold) { const std::string threshold) {
// TODO // TODO
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::load(const std::string& epoch, ::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
for (auto& it : _table_map) { for (auto& it : _table_map) {
load(it.first, epoch, mode); Load(it.first, epoch, mode);
} }
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::load(uint32_t table_id, ::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->load(epoch, mode); table_ptr->Load(epoch, mode);
return done(); return done();
} }
std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
for (auto& it : _table_map) { for (auto& it : _table_map) {
save(it.first, epoch, mode); Save(it.first, epoch, mode);
} }
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::save(uint32_t table_id, ::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->flush(); table_ptr->Flush();
table_ptr->save(epoch, mode); table_ptr->Save(epoch, mode);
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Save( ::std::future<int32_t> PsLocalClient::Clear() {
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::clear() {
// TODO // TODO
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::clear(uint32_t table_id) { ::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
// TODO // TODO
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::flush() { ::std::future<int32_t> PsLocalClient::Flush() {
// no need // no need
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::stop_server() { ::std::future<int32_t> PsLocalClient::StopServer() {
// no need // no need
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) { ::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
if (pull_context.value_type == Dense) { // pull dense size_t region_num,
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values); size_t table_id) {
pull_dense(dense_region, pull_context.num, pull_context.table); auto* accessor = GetTableAccessor(table_id);
} else { // pull sparse auto* table_ptr = GetTable(table_id);
// uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
// char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(reinterpret_cast<char**>(pull_context.sparse_values),
table_id, pull_context.keys, num);
}
}
::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) { uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1);
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}
::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(num_per_shard); region_buffer.resize(num_per_shard);
table_ptr->pull_dense(region_buffer.data(), region_buffer.size()); table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t region_idx = 0; size_t region_idx = 0;
size_t region_data_idx = 0; size_t region_data_idx = 0;
...@@ -213,48 +138,49 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -213,48 +138,49 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_dense_param(const Region* regions, ::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto* accessor = table_accessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1), region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0);
0);
for (size_t i = 0, offset = 0; i < region_num; ++i) { for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float); uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size); memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num; offset += data_num;
} }
// table_ptr->push_dense_param(region_buffer.data(), region_buffer.size()); // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_dense_raw_gradient( ::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
int table_id, float* total_send_data, size_t total_send_data_size, int table_id, float* total_send_data, size_t total_send_data_size,
void* callback) { void* callback) {
VLOG(1) << "wxx push_dense_raw_gradient"; VLOG(1) << "wxx push_dense_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback); PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->push_dense(total_send_data, total_send_data_size); table_ptr->PushDense(total_send_data, total_send_data_size);
delete closure; delete closure;
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_dense(const Region* regions, ::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto* accessor = table_accessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1)); region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1));
size_t data_size = region_buffer.size(); size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) { for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float); uint32_t data_num = regions[i].size / sizeof(float);
...@@ -267,12 +193,12 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -267,12 +193,12 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
offset += data_num; offset += data_num;
} }
table_ptr->push_dense(region_buffer.data(), region_buffer.size()); table_ptr->PushDense(region_buffer.data(), region_buffer.size());
return done(); return done();
} }
//::std::future<int32_t> PsLocalClient::pull_sparse(float** select_values, //::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id, // size_t table_id,
// const uint64_t* keys, // const uint64_t* keys,
// size_t num) { // size_t num) {
...@@ -282,14 +208,14 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -282,14 +208,14 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// // auto local_timer = // // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local"); // // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针 // //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = table_accessor(table_id); // auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = table(table_id); // auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size(); // size_t value_size = accessor->select_size();
// //
// // table_ptr->pull_sparse(keys, num); // // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data; // std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float)); // res_data.resize(num * value_size / sizeof(float));
// table_ptr->pull_sparse(res_data.data(), keys, num); // table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() * // // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float)); // // sizeof(float));
// size_t offset = 0; // size_t offset = 0;
...@@ -302,43 +228,43 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -302,43 +228,43 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// return done(); // return done();
//} //}
::std::future<int32_t> PsLocalClient::pull_sparse_ptr(char** select_values, ::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
size_t table_id, size_t table_id,
const uint64_t* keys, const uint64_t* keys,
size_t num) { size_t num) {
// FIXME // FIXME
// auto timer = // auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse"); // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer = // auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local"); // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针 //将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->pull_sparse_ptr(select_values, keys, num); table_ptr->PullSparsePtr(select_values, keys, num);
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_sparse_raw_gradient( ::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
size_t table_id, const uint64_t* keys, const float** update_values, size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) { size_t num, void* callback) {
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback); PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* accessor = table_accessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->push_sparse(keys, update_values, num); table_ptr->PushSparse(keys, update_values, num);
delete closure; delete closure;
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::push_sparse(size_t table_id, ::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
const uint64_t* keys, const uint64_t* keys,
const float** update_values, const float** update_values,
size_t num) { size_t num) {
auto* accessor = table_accessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->push_sparse(keys, update_values, num); table_ptr->PushSparse(keys, update_values, num);
return done(); return done();
} }
} }
......
...@@ -26,54 +26,46 @@ class PsLocalClient : public PSClient { ...@@ -26,54 +26,46 @@ class PsLocalClient : public PSClient {
public: public:
PsLocalClient() {} PsLocalClient() {}
virtual ~PsLocalClient() { _running = false; } virtual ~PsLocalClient() { _running = false; }
virtual int32_t create_client2client_connection(int pslib_timeout_ms, virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
int pslib_connect_timeout_ms, int pslib_connect_timeout_ms,
int max_retry) { int max_retry) {
return 0; return 0;
} }
virtual ::std::future<int32_t> shrink(uint32_t table_id, virtual ::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override; const std::string threshold) override;
virtual ::std::future<int32_t> load(const std::string& epoch, virtual ::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual ::std::future<int32_t> load(uint32_t table_id, virtual ::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;
virtual ::std::future<int32_t> save(const std::string& epoch, virtual ::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id, virtual ::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;
virtual ::std::future<int32_t> clear() override; virtual ::std::future<int32_t> Clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override; virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
virtual ::std::future<int32_t> stop_server() override; virtual ::std::future<int32_t> StopServer() override;
virtual void finalize_worker() override {} virtual void FinalizeWorker() override {}
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num, virtual ::std::future<int32_t> PullDense(Region* regions, size_t region_num,
size_t table_id); size_t table_id);
virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override; virtual ::std::future<int32_t> PushDense(const Region* regions,
size_t region_num, size_t table_id);
virtual ::std::future<int32_t> Push(RequestContext& push_context) override; virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> push_dense(const Region* regions, virtual ::std::future<int32_t> PullSparse(float** select_values,
size_t region_num, size_t table_id); size_t table_id,
const uint64_t* keys, size_t num,
virtual ::std::future<int32_t> push_dense_param(const Region* regions, bool is_training) {
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> pull_sparse(float** select_values,
size_t table_id,
const uint64_t* keys, size_t num,
bool is_training) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -81,26 +73,26 @@ class PsLocalClient : public PSClient { ...@@ -81,26 +73,26 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual ::std::future<int32_t> pull_sparse_ptr(char** select_values, virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
size_t table_id, size_t table_id,
const uint64_t* keys, const uint64_t* keys,
size_t num); size_t num);
virtual ::std::future<int32_t> print_table_stat(uint32_t table_id) { virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
return fut; return fut;
} }
virtual ::std::future<int32_t> push_sparse(size_t table_id, virtual ::std::future<int32_t> PushSparse(size_t table_id,
const uint64_t* keys, const uint64_t* keys,
const float** update_values, const float** update_values,
size_t num); size_t num);
virtual ::std::future<int32_t> flush(); virtual ::std::future<int32_t> Flush();
// server profilera // server profilera
virtual std::future<int32_t> start_profiler() { virtual std::future<int32_t> StartProfiler() {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -108,7 +100,7 @@ class PsLocalClient : public PSClient { ...@@ -108,7 +100,7 @@ class PsLocalClient : public PSClient {
return fut; return fut;
}; };
virtual std::future<int32_t> stop_profiler() { virtual std::future<int32_t> StopProfiler() {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -116,7 +108,7 @@ class PsLocalClient : public PSClient { ...@@ -116,7 +108,7 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type) { virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -124,10 +116,10 @@ class PsLocalClient : public PSClient { ...@@ -124,10 +116,10 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> pull_geo_param(size_t table_id, virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float>* values, std::vector<float>* values,
std::vector<uint64_t>* keys, std::vector<uint64_t>* keys,
int pserver_idx) { int pserver_idx) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -135,9 +127,9 @@ class PsLocalClient : public PSClient { ...@@ -135,9 +127,9 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> push_global_step(int table_id, virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t* total_send_data, int64_t* total_send_data,
void* done) { void* done) {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -146,12 +138,12 @@ class PsLocalClient : public PSClient { ...@@ -146,12 +138,12 @@ class PsLocalClient : public PSClient {
} }
// recv table from server and save it in LodTensor // recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id, virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string& path) { const std::string& path) {
return 0; return 0;
} }
virtual ::std::future<int32_t> send_client2client_msg( virtual ::std::future<int32_t> SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string& msg) override { int msg_type, int to_client_id, const std::string& msg) override {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
...@@ -159,17 +151,18 @@ class PsLocalClient : public PSClient { ...@@ -159,17 +151,18 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual size_t get_server_nums() { return 1; } virtual size_t GetServerNums() { return 1; }
virtual std::future<int32_t> push_dense_raw_gradient( virtual std::future<int32_t> PushDenseRawGradient(int table_id,
int table_id, float* total_send_data, size_t total_send_data_size, float* total_send_data,
void* callback) override; size_t total_send_data_size,
void* callback) override;
virtual std::future<int32_t> push_sparse_raw_gradient( virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id, const uint64_t* keys, const float** update_values, size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) override; size_t num, void* callback) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial( virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id, const uint64_t* keys, const float** update_values, size_t table_id, const uint64_t* keys, const float** update_values,
uint32_t num, void* done, int pserver_idx) override { uint32_t num, void* done, int pserver_idx) override {
std::promise<int32_t> prom; std::promise<int32_t> prom;
...@@ -179,11 +172,11 @@ class PsLocalClient : public PSClient { ...@@ -179,11 +172,11 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> push_sparse_param(size_t table_id, virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys, const uint64_t* keys,
const float** update_values, const float** update_values,
size_t num, size_t num,
void* done) override { void* done) override {
std::promise<int32_t> prom; std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future(); std::future<int32_t> fut = prom.get_future();
prom.set_value(0); prom.set_value(0);
...@@ -192,7 +185,7 @@ class PsLocalClient : public PSClient { ...@@ -192,7 +185,7 @@ class PsLocalClient : public PSClient {
} }
private: private:
virtual int32_t initialize() override; virtual int32_t Initialize() override;
std::future<int32_t> done() { std::future<int32_t> done() {
std::shared_ptr<std::promise<int32_t>> prom = std::shared_ptr<std::promise<int32_t>> prom =
...@@ -202,16 +195,16 @@ class PsLocalClient : public PSClient { ...@@ -202,16 +195,16 @@ class PsLocalClient : public PSClient {
return fut; return fut;
} }
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) { uint32_t shard_num) {
return dense_dim_total / shard_num + 1; return dense_dim_total / shard_num + 1;
} }
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* table() { inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
return &_table_map; return &_table_map;
} }
inline Table* table(size_t table_id) { inline Table* GetTable(size_t table_id) {
auto itr = _table_map.find(table_id); auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) { if (itr != _table_map.end()) {
return itr->second.get(); return itr->second.get();
......
...@@ -25,17 +25,17 @@ class PsLocalServer : public PSServer { ...@@ -25,17 +25,17 @@ class PsLocalServer : public PSServer {
public: public:
PsLocalServer() {} PsLocalServer() {}
virtual ~PsLocalServer() {} virtual ~PsLocalServer() {}
virtual uint64_t start() { return 0; } virtual uint64_t Start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; } virtual int32_t Stop() { return 0; }
virtual int32_t configure( virtual int32_t Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) { const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0; return 0;
} }
private: private:
virtual int32_t initialize() { return 0; } virtual int32_t Initialize() { return 0; }
}; };
} }
} }
...@@ -70,7 +70,7 @@ void GraphPyService::set_up(std::string ips_str, int shard_num, ...@@ -70,7 +70,7 @@ void GraphPyService::set_up(std::string ips_str, int shard_num,
port_list.push_back(ip_and_port[1]); port_list.push_back(ip_and_port[1]);
uint32_t port = stoul(ip_and_port[1]); uint32_t port = stoul(ip_and_port[1]);
auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index); auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index);
host_sign_list.push_back(ph_host.serialize_to_string()); host_sign_list.push_back(ph_host.SerializeToString());
index++; index++;
} }
} }
...@@ -83,11 +83,11 @@ void GraphPyClient::start_client() { ...@@ -83,11 +83,11 @@ void GraphPyClient::start_client() {
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list.size(); auto servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list, servers_); _ps_env.SetPsServers(&host_sign_list, servers_);
worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>( worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*) (paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto)); paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id); worker_ptr->Configure(worker_proto, dense_regions, _ps_env, client_id);
worker_ptr->set_shard_num(get_shard_num()); worker_ptr->set_shard_num(get_shard_num());
} }
void GraphPyServer::start_server(bool block) { void GraphPyServer::start_server(bool block) {
...@@ -96,17 +96,17 @@ void GraphPyServer::start_server(bool block) { ...@@ -96,17 +96,17 @@ void GraphPyServer::start_server(bool block) {
::paddle::distributed::PSParameter server_proto = this->GetServerProto(); ::paddle::distributed::PSParameter server_proto = this->GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment(); auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&this->host_sign_list, _ps_env.SetPsServers(&this->host_sign_list,
this->host_sign_list.size()); // test this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>( pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*) (paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto)); paddle::distributed::PSServerFactory::Create(server_proto));
VLOG(0) << "pserver-ptr created "; VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec; std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog; framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog); empty_vec.push_back(empty_prog);
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec); pserver_ptr->Configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->start(ip, port); pserver_ptr->Start(ip, port);
pserver_ptr->build_peer2peer_connection(rank); pserver_ptr->build_peer2peer_connection(rank);
std::condition_variable* cv_ = pserver_ptr->export_cv(); std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) { if (block) {
...@@ -246,7 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath, ...@@ -246,7 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath,
VLOG(0) << "loadding data with type " << name << " from " << filepath; VLOG(0) << "loadding data with type " << name << " from " << filepath;
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status =
get_ps_client()->load(table_id, std::string(filepath), params); get_ps_client()->Load(table_id, std::string(filepath), params);
status.wait(); status.wait();
} }
} }
...@@ -285,7 +285,7 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { ...@@ -285,7 +285,7 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status =
get_ps_client()->load(table_id, std::string(filepath), params); get_ps_client()->Load(table_id, std::string(filepath), params);
status.wait(); status.wait();
} }
} }
...@@ -396,13 +396,13 @@ std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name, ...@@ -396,13 +396,13 @@ std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
return res; return res;
} }
void GraphPyClient::stop_server() { void GraphPyClient::StopServer() {
VLOG(0) << "going to stop server"; VLOG(0) << "going to stop server";
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return; if (stoped_) return;
auto status = this->worker_ptr->stop_server(); auto status = this->worker_ptr->StopServer();
if (status.get() == 0) stoped_ = true; if (status.get() == 0) stoped_ = true;
} }
void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); } void GraphPyClient::FinalizeWorker() { this->worker_ptr->FinalizeWorker(); }
} }
} }
...@@ -123,7 +123,7 @@ class GraphPyServer : public GraphPyService { ...@@ -123,7 +123,7 @@ class GraphPyServer : public GraphPyService {
set_rank(rank); set_rank(rank);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
} }
int get_rank() { return rank; } int GetRank() { return rank; }
void set_rank(int rank) { this->rank = rank; } void set_rank(int rank) { this->rank = rank; }
void start_server(bool block = true); void start_server(bool block = true);
...@@ -154,8 +154,8 @@ class GraphPyClient : public GraphPyService { ...@@ -154,8 +154,8 @@ class GraphPyClient : public GraphPyService {
(paddle::distributed::GraphBrpcService*)server.get_ps_server() (paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service()); ->get_service());
} }
void stop_server(); void StopServer();
void finalize_worker(); void FinalizeWorker();
void load_edge_file(std::string name, std::string filepath, bool reverse); void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath); void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name); void clear_nodes(std::string name);
......
...@@ -46,7 +46,7 @@ paddle::distributed::PSParameter load_from_prototxt( ...@@ -46,7 +46,7 @@ paddle::distributed::PSParameter load_from_prototxt(
return param; return param;
} }
void PSCore::init_gflag(const std::string& gflags) { void PSCore::InitGFlag(const std::string& gflags) {
VLOG(3) << "Init With Gflags:" << gflags; VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags); std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) { if (flags.size() < 1) {
...@@ -65,67 +65,67 @@ void PSCore::init_gflag(const std::string& gflags) { ...@@ -65,67 +65,67 @@ void PSCore::init_gflag(const std::string& gflags) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true); ::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
} }
int PSCore::init_server( int PSCore::InitServer(
const std::string& dist_desc, const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index, const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers, int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) { const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags()); InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num); _ps_env.SetPsServers(host_sign_list, node_num);
_ps_env.set_trainers(trainers); _ps_env.SetTrainers(trainers);
int ret = 0; int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>( _server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param)); paddle::distributed::PSServerFactory::Create(_ps_param));
ret = _server_ptr->configure(_ps_param, _ps_env, index, server_sub_program); ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program);
CHECK(ret == 0) << "failed to configure server"; CHECK(ret == 0) << "failed to configure server";
return ret; return ret;
} }
int PSCore::init_worker( int PSCore::InitWorker(
const std::string& dist_desc, const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions, const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
const std::vector<std::string>* host_sign_list, int node_num, int index) { const std::vector<std::string>* host_sign_list, int node_num, int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags()); InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num); _ps_env.SetPsServers(host_sign_list, node_num);
int ret = 0; int ret = 0;
VLOG(1) << "PSCore::init_worker"; VLOG(1) << "PSCore::InitWorker";
auto* communicator = Communicator::GetInstance(); auto* communicator = Communicator::GetInstance();
ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env, ret = communicator->GetPsClient()->Configure(_ps_param, regions, _ps_env,
index); index);
communicator->Start(); communicator->Start();
return ret; return ret;
} }
std::vector<uint64_t> PSCore::get_client_info() { std::vector<uint64_t> PSCore::GetClientInfo() {
return _ps_env.get_client_info(); return _ps_env.GetClientInfo();
} }
int PSCore::create_client2client_connection(int pserver_timeout_ms, int PSCore::CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) { int max_retry) {
int ret = _worker_ptr->create_client2client_connection( int ret = _worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
return ret; return ret;
} }
uint64_t PSCore::run_server(const std::string& ip, uint32_t port) { uint64_t PSCore::RunServer(const std::string& ip, uint32_t port) {
return _server_ptr->start(ip, port); return _server_ptr->Start(ip, port);
} }
int PSCore::finalize_worker() { int PSCore::FinalizeWorker() {
_worker_ptr->finalize_worker(); _worker_ptr->FinalizeWorker();
return 0; return 0;
} }
int PSCore::stop_server() { int PSCore::StopServer() {
auto stop_status = _worker_ptr->stop_server(); auto stop_status = _worker_ptr->StopServer();
stop_status.wait(); stop_status.wait();
return 0; return 0;
} }
paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; } paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -42,31 +42,31 @@ class PSCore { ...@@ -42,31 +42,31 @@ class PSCore {
explicit PSCore() {} explicit PSCore() {}
virtual ~PSCore() {} virtual ~PSCore() {}
virtual int init_server( virtual int InitServer(
const std::string& dist_desc, const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index, const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers, int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {}); const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int init_worker( virtual int InitWorker(
const std::string& dist_desc, const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
regions, regions,
const std::vector<std::string>* host_sign_list, int node_num, int index); const std::vector<std::string>* host_sign_list, int node_num, int index);
virtual uint64_t run_server(const std::string& ip, uint32_t port); virtual uint64_t RunServer(const std::string& ip, uint32_t port);
virtual int stop_server(); virtual int StopServer();
virtual int finalize_worker(); virtual int FinalizeWorker();
virtual std::vector<uint64_t> get_client_info(); virtual std::vector<uint64_t> GetClientInfo();
virtual int create_client2client_connection(int pserver_timeout_ms, virtual int CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
int max_retry); int max_retry);
std::shared_ptr<paddle::distributed::PSServer> std::shared_ptr<paddle::distributed::PSServer>
_server_ptr; // pointer to server _server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient> std::shared_ptr<paddle::distributed::PSClient>
_worker_ptr; // pointer to worker _worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* get_param(); virtual paddle::distributed::PSParameter* GetParam();
private: private:
void init_gflag(const std::string& gflags); void InitGFlag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param; paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
}; };
......
...@@ -29,7 +29,7 @@ REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService); ...@@ -29,7 +29,7 @@ REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer); REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService); REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) { PSServer *PSServerFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param(); const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) { if (!config.has_downpour_server_param()) {
...@@ -56,18 +56,18 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) { ...@@ -56,18 +56,18 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
<< service_param.server_class(); << service_param.server_class();
return NULL; return NULL;
} }
TableManager::instance().initialize(); TableManager::Instance().Initialize();
return server; return server;
} }
int32_t PSServer::configure( int32_t PSServer::Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program) { const std::vector<framework::ProgramDesc> &server_sub_program) {
scope_.reset(new framework::Scope()); scope_.reset(new framework::Scope());
_config = config.server_param(); _config = config.server_param();
_rank = server_rank; _rank = server_rank;
_environment = &env; _environment = &env;
size_t shard_num = env.get_ps_servers().size(); size_t shard_num = env.GetPsServers().size();
const auto &downpour_param = _config.downpour_server_param(); const auto &downpour_param = _config.downpour_server_param();
...@@ -87,21 +87,21 @@ int32_t PSServer::configure( ...@@ -87,21 +87,21 @@ int32_t PSServer::configure(
global_step_table = downpour_param.downpour_table_param(i).table_id(); global_step_table = downpour_param.downpour_table_param(i).table_id();
} }
table->set_program_env(scope_.get(), place_, &server_sub_program); table->SetProgramEnv(scope_.get(), place_, &server_sub_program);
table->set_shard(_rank, shard_num); table->SetShard(_rank, shard_num);
table->initialize(downpour_param.downpour_table_param(i), table->Initialize(downpour_param.downpour_table_param(i),
config.fs_client_param()); config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
} }
if (barrier_table != UINT32_MAX) { if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->set_table_map(&_table_map); _table_map[barrier_table]->SetTableMap(&_table_map);
} }
if (global_step_table != UINT32_MAX) { if (global_step_table != UINT32_MAX) {
_table_map[global_step_table]->set_table_map(&_table_map); _table_map[global_step_table]->SetTableMap(&_table_map);
} }
return initialize(); return Initialize();
} }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -65,19 +65,19 @@ class PSServer { ...@@ -65,19 +65,19 @@ class PSServer {
PSServer(PSServer &&) = delete; PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete; PSServer(const PSServer &) = delete;
virtual int32_t configure( virtual int32_t Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}); const std::vector<framework::ProgramDesc> &server_sub_program = {});
virtual uint64_t start(const std::string &ip, uint32_t port) = 0; virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0; virtual int32_t Stop() = 0;
inline size_t rank() const { return _rank; } inline size_t Rank() const { return _rank; }
inline PSEnvironment *environment() { return _environment; } inline PSEnvironment *Environment() { return _environment; }
inline const ServerParameter *config() const { return &_config; } inline const ServerParameter *Config() const { return &_config; }
inline Table *table(size_t table_id) { inline Table *GetTable(size_t table_id) {
auto itr = _table_map.find(table_id); auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) { if (itr != _table_map.end()) {
return itr->second.get(); return itr->second.get();
...@@ -85,12 +85,12 @@ class PSServer { ...@@ -85,12 +85,12 @@ class PSServer {
return NULL; return NULL;
} }
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *table() { inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
return &_table_map; return &_table_map;
} }
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t Initialize() = 0;
protected: protected:
size_t _rank; size_t _rank;
...@@ -129,11 +129,11 @@ class PsBaseService : public PsService { ...@@ -129,11 +129,11 @@ class PsBaseService : public PsService {
public: public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {} PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {} virtual ~PsBaseService() {}
virtual size_t get_rank() { return _rank; } virtual size_t GetRank() { return _rank; }
virtual int32_t configure(PSServer *server) { virtual int32_t Configure(PSServer *server) {
_server = server; _server = server;
_rank = _server->rank(); _rank = _server->Rank();
_config = _server->config(); _config = _server->Config();
return 0; return 0;
} }
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
...@@ -148,8 +148,8 @@ class PsBaseService : public PsService { ...@@ -148,8 +148,8 @@ class PsBaseService : public PsService {
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg; LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
} }
virtual int32_t initialize() = 0; virtual int32_t Initialize() = 0;
PSServer *get_server() { return _server; } PSServer *GetServer() { return _server; }
protected: protected:
size_t _rank; size_t _rank;
...@@ -160,7 +160,7 @@ REGISTER_PSCORE_REGISTERER(PsBaseService); ...@@ -160,7 +160,7 @@ REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory { class PSServerFactory {
public: public:
static PSServer *create(const PSParameter &config); static PSServer *Create(const PSParameter &config);
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t BarrierTable::initialize() { int32_t BarrierTable::Initialize() {
auto trainers = _config.common().trainer_num(); auto trainers = _config.common().trainer_num();
trigger_.store(trainers); trigger_.store(trainers);
...@@ -29,7 +29,7 @@ int32_t BarrierTable::initialize() { ...@@ -29,7 +29,7 @@ int32_t BarrierTable::initialize() {
} }
// 0: send_barrier 1: recv_barrier 2: complete // 0: send_barrier 1: recv_barrier 2: complete
int32_t BarrierTable::barrier(const uint32_t trainer_id, int32_t BarrierTable::Barrier(const uint32_t trainer_id,
const std::string barrier_type) { const std::string barrier_type) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -56,7 +56,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id, ...@@ -56,7 +56,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id,
VLOG(1) << "barrier table optimize begin"; VLOG(1) << "barrier table optimize begin";
for (auto& x : *table_map_) { for (auto& x : *table_map_) {
auto table = x.second; auto table = x.second;
table->pour(); table->Pour();
} }
VLOG(1) << "barrier table optimize done"; VLOG(1) << "barrier table optimize done";
...@@ -66,7 +66,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id, ...@@ -66,7 +66,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id,
return 0; return 0;
} }
int32_t BarrierTable::set_table_map( int32_t BarrierTable::SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) { std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) {
table_map_ = table_map; table_map_ = table_map;
return 0; return 0;
......
...@@ -21,8 +21,8 @@ namespace distributed { ...@@ -21,8 +21,8 @@ namespace distributed {
int FLAGS_pslib_table_save_max_retry_dense = 3; int FLAGS_pslib_table_save_max_retry_dense = 3;
void CommonDenseTable::create_initializer(const std::string& attr, void CommonDenseTable::CreateInitializer(const std::string& attr,
const std::string& name) { const std::string& name) {
auto slices = string::split_string<std::string>(attr, "&"); auto slices = string::split_string<std::string>(attr, "&");
if (slices[0] == "gaussian_random") { if (slices[0] == "gaussian_random") {
...@@ -39,7 +39,7 @@ void CommonDenseTable::create_initializer(const std::string& attr, ...@@ -39,7 +39,7 @@ void CommonDenseTable::create_initializer(const std::string& attr,
} }
} }
int32_t CommonDenseTable::initialize() { int32_t CommonDenseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_); _shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
...@@ -49,12 +49,12 @@ int32_t CommonDenseTable::initialize() { ...@@ -49,12 +49,12 @@ int32_t CommonDenseTable::initialize() {
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync; VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
_global_lr = new float(1.0); _global_lr = new float(1.0);
initialize_value(); InitializeValue();
initialize_optimizer(); InitializeOptimizer();
return 0; return 0;
} }
int32_t CommonDenseTable::initialize_value() { int32_t CommonDenseTable::InitializeValue() {
auto common = _config.common(); auto common = _config.common();
int size = static_cast<int>(common.params().size()); int size = static_cast<int>(common.params().size());
values_.resize(size); values_.resize(size);
...@@ -70,7 +70,7 @@ int32_t CommonDenseTable::initialize_value() { ...@@ -70,7 +70,7 @@ int32_t CommonDenseTable::initialize_value() {
auto& initializer = common.initializers()[x]; auto& initializer = common.initializers()[x];
total_dim_ += dim; total_dim_ += dim;
create_initializer(initializer, varname); CreateInitializer(initializer, varname);
values_[x].resize(dim); values_[x].resize(dim);
names_index_[varname] = x; names_index_[varname] = x;
...@@ -92,27 +92,27 @@ int32_t CommonDenseTable::initialize_value() { ...@@ -92,27 +92,27 @@ int32_t CommonDenseTable::initialize_value() {
param_col_ids_.insert(param_col_ids_.begin() + 1, -1); param_col_ids_.insert(param_col_ids_.begin() + 1, -1);
} }
VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_ VLOG(1) << "CommonDenseTable::InitializeValue total dim: " << total_dim_
<< " fixed_len_params_dim: " << fixed_len_params_dim_; << " fixed_len_params_dim: " << fixed_len_params_dim_;
pull_reservoir_ = ReservoirValue<float>(param_dim_); pull_reservoir_ = ReservoirValue<float>(param_dim_);
return 0; return 0;
} }
int32_t CommonDenseTable::initialize_optimizer() { int32_t CommonDenseTable::InitializeOptimizer() {
auto common = _config.common(); auto common = _config.common();
auto name = common.name(); auto name = common.name();
auto attrs = common.attributes(); auto attrs = common.attributes();
if (name == "sgd") { if (name == "sgd") {
optimizer_ = std::make_shared<DSGD>(common, &values_); optimizer_ = std::make_shared<DSGD>(common, &values_);
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam") { } else if (name == "adam") {
optimizer_ = std::make_shared<DAdam>(common, &values_); optimizer_ = std::make_shared<DAdam>(common, &values_);
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam_d2sum") { } else if (name == "adam_d2sum") {
optimizer_ = std::make_shared<DAdamD2Sum>(common, &values_); optimizer_ = std::make_shared<DAdamD2Sum>(common, &values_);
// optimizer_->set_global_lr(_global_lr); //no use // optimizer_->SetGlobalLR(_global_lr); //no use
} else if (name == "sum") { } else if (name == "sum") {
optimizer_ = std::make_shared<DSUM>(common, &values_); optimizer_ = std::make_shared<DSUM>(common, &values_);
} else if (name == "summary") { } else if (name == "summary") {
...@@ -124,34 +124,34 @@ int32_t CommonDenseTable::initialize_optimizer() { ...@@ -124,34 +124,34 @@ int32_t CommonDenseTable::initialize_optimizer() {
return 0; return 0;
} }
int32_t CommonDenseTable::set_global_lr(float* lr) { int32_t CommonDenseTable::SetGlobalLR(float* lr) {
_global_lr = lr; _global_lr = lr;
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
return 0; return 0;
} }
int32_t CommonDenseTable::Pull(TableContext& context) { int32_t CommonDenseTable::Pull(TableContext& context) {
CHECK(context.value_type == Dense); CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values; float* pull_values = context.pull_context.values;
return pull_dense(pull_values, context.num); return PullDense(pull_values, context.num);
} }
int32_t CommonDenseTable::Push(TableContext& context) { int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense); CHECK(context.value_type == Dense);
if (context.push_context.values != nullptr) { if (context.push_context.values != nullptr) {
const float* values = context.push_context.values; const float* values = context.push_context.values;
return push_dense(values, context.num); return PushDense(values, context.num);
} }
return 0; return 0;
} }
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { int32_t CommonDenseTable::PullDense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values); pull_values);
return 0; return 0;
} }
int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { int32_t CommonDenseTable::PushDenseParam(const float* values, size_t num) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
num, param_dim_, num, param_dim_,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
...@@ -160,14 +160,14 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { ...@@ -160,14 +160,14 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
return 0; return 0;
} }
int32_t CommonDenseTable::pour() { int32_t CommonDenseTable::Pour() {
pull_reservoir_.avg(); pull_reservoir_.avg();
_push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); _PushDense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
pull_reservoir_.reset(); pull_reservoir_.reset();
return 0; return 0;
} }
int32_t CommonDenseTable::push_dense(const float* values, size_t num) { int32_t CommonDenseTable::PushDense(const float* values, size_t num) {
if (sync) { if (sync) {
std::future<int> task = std::future<int> task =
_shards_task_pool[0]->enqueue([this, &values]() -> int { _shards_task_pool[0]->enqueue([this, &values]() -> int {
...@@ -176,12 +176,12 @@ int32_t CommonDenseTable::push_dense(const float* values, size_t num) { ...@@ -176,12 +176,12 @@ int32_t CommonDenseTable::push_dense(const float* values, size_t num) {
}); });
task.wait(); task.wait();
} else { } else {
_push_dense(values, num); _PushDense(values, num);
} }
return 0; return 0;
} }
int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { int32_t CommonDenseTable::_PushDense(const float* values, size_t num) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
num, param_dim_, num, param_dim_,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
...@@ -195,7 +195,7 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { ...@@ -195,7 +195,7 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
[this, shard_id, &buckets, &values]() -> int { [this, shard_id, &buckets, &values]() -> int {
auto begin = buckets[shard_id]; auto begin = buckets[shard_id];
auto end = buckets[shard_id + 1]; auto end = buckets[shard_id + 1];
optimizer_->update(values, param_dim_, begin, end); optimizer_->Update(values, param_dim_, begin, end);
return 0; return 0;
}); });
} }
...@@ -207,12 +207,12 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { ...@@ -207,12 +207,12 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
return 0; return 0;
} }
int32_t CommonDenseTable::load(const std::string& path, int32_t CommonDenseTable::Load(const std::string& path,
const std::string& param) { const std::string& param) {
if (param_dim_ <= 0) { if (param_dim_ <= 0) {
return 0; return 0;
} }
std::string table_path = table_dir(path); std::string table_path = TableDir(path);
auto file_list = _afs_client.list(table_path); auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end()); std::sort(file_list.begin(), file_list.end());
for (auto ff : file_list) { for (auto ff : file_list) {
...@@ -314,7 +314,7 @@ int32_t CommonDenseTable::load(const std::string& path, ...@@ -314,7 +314,7 @@ int32_t CommonDenseTable::load(const std::string& path,
return 0; return 0;
} }
int32_t CommonDenseTable::save(const std::string& path, int32_t CommonDenseTable::Save(const std::string& path,
const std::string& param) { const std::string& param) {
int save_param = atoi(param.c_str()); int save_param = atoi(param.c_str());
uint32_t feasign_size; uint32_t feasign_size;
...@@ -323,10 +323,10 @@ int32_t CommonDenseTable::save(const std::string& path, ...@@ -323,10 +323,10 @@ int32_t CommonDenseTable::save(const std::string& path,
FsChannelConfig channel_config; FsChannelConfig channel_config;
if (_config.compress_in_save()) { if (_config.compress_in_save()) {
channel_config.path = paddle::string::format_string( channel_config.path = paddle::string::format_string(
"%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx); "%s/part-%03d.gz", TableDir(path).c_str(), _shard_idx);
} else { } else {
channel_config.path = paddle::string::format_string( channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_dir(path).c_str(), _shard_idx); "%s/part-%03d", TableDir(path).c_str(), _shard_idx);
} }
_afs_client.remove(channel_config.path); _afs_client.remove(channel_config.path);
channel_config.converter = _value_accesor->Converter(save_param).converter; channel_config.converter = _value_accesor->Converter(save_param).converter;
......
...@@ -34,29 +34,29 @@ class CommonDenseTable : public DenseTable { ...@@ -34,29 +34,29 @@ class CommonDenseTable : public DenseTable {
public: public:
CommonDenseTable() {} CommonDenseTable() {}
virtual ~CommonDenseTable() {} virtual ~CommonDenseTable() {}
int32_t initialize() override; int32_t Initialize() override;
int32_t initialize_shard() override { return 0; } int32_t InitializeShard() override { return 0; }
virtual void create_initializer(const std::string& attr, virtual void CreateInitializer(const std::string& attr,
const std::string& name); const std::string& name);
virtual int32_t initialize_value(); virtual int32_t InitializeValue();
virtual int32_t initialize_optimizer(); virtual int32_t InitializeOptimizer();
virtual int32_t Pull(TableContext& context); virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context); virtual int32_t Push(TableContext& context);
int32_t pull_dense(float* pull_values, size_t num) override; int32_t PullDense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override; int32_t PushDenseParam(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override; int32_t PushDense(const float* values, size_t num) override;
int32_t pour() override; int32_t Pour() override;
int32_t set_global_lr(float* lr) override; int32_t SetGlobalLR(float* lr) override;
int32_t load(const std::string& path, const std::string& param) override; int32_t Load(const std::string& path, const std::string& param) override;
int32_t save(const std::string& path, const std::string& param) override; int32_t Save(const std::string& path, const std::string& param) override;
int32_t flush() override { return 0; } int32_t Flush() override { return 0; }
int32_t shrink(const std::string& param) override { return 0; } int32_t Shrink(const std::string& param) override { return 0; }
void clear() override { return; } void Clear() override { return; }
protected: protected:
int32_t _push_dense(const float* values, size_t num); int32_t _PushDense(const float* values, size_t num);
private: private:
const int task_pool_size_ = 10; const int task_pool_size_ = 10;
......
...@@ -448,7 +448,7 @@ int32_t GraphTable::load_graph_split_config(const std::string &path) { ...@@ -448,7 +448,7 @@ int32_t GraphTable::load_graph_split_config(const std::string &path) {
return 0; return 0;
} }
int32_t GraphTable::load(const std::string &path, const std::string &param) { int32_t GraphTable::Load(const std::string &path, const std::string &param) {
bool load_edge = (param[0] == 'e'); bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n'); bool load_node = (param[0] == 'n');
if (load_edge) { if (load_edge) {
...@@ -1066,11 +1066,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, ...@@ -1066,11 +1066,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int32_t GraphTable::get_server_index_by_id(int64_t id) { int32_t GraphTable::get_server_index_by_id(int64_t id) {
return id % shard_num / shard_num_per_server; return id % shard_num / shard_num_per_server;
} }
int32_t GraphTable::initialize(const TableParameter &config, int32_t GraphTable::Initialize(const TableParameter &config,
const FsClientParameter &fs_config) { const FsClientParameter &fs_config) {
LOG(INFO) << "in graphTable initialize"; LOG(INFO) << "in graphTable initialize";
_config = config; _config = config;
if (initialize_accessor() != 0) { if (InitializeAccessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed"; LOG(WARNING) << "Table accessor initialize failed";
return -1; return -1;
} }
...@@ -1082,9 +1082,9 @@ int32_t GraphTable::initialize(const TableParameter &config, ...@@ -1082,9 +1082,9 @@ int32_t GraphTable::initialize(const TableParameter &config,
auto graph = config.graph_parameter(); auto graph = config.graph_parameter();
shard_num = _config.shard_num(); shard_num = _config.shard_num();
LOG(INFO) << "in graphTable initialize over"; LOG(INFO) << "in graphTable initialize over";
return initialize(graph); return Initialize(graph);
} }
int32_t GraphTable::initialize(const GraphParameter &graph) { int32_t GraphTable::Initialize(const GraphParameter &graph) {
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
if (graph.gpups_mode()) { if (graph.gpups_mode()) {
gpups_mode = true; gpups_mode = true;
......
...@@ -280,7 +280,7 @@ class ScaledLRU { ...@@ -280,7 +280,7 @@ class ScaledLRU {
} }
} }
auto status = auto status =
thread_pool->enqueue([this]() -> int { return shrink(); }); thread_pool->enqueue([this]() -> int { return Shrink(); });
status.wait(); status.wait();
} }
}); });
...@@ -298,7 +298,7 @@ class ScaledLRU { ...@@ -298,7 +298,7 @@ class ScaledLRU {
LRUResponse insert(size_t index, K *keys, V *data, size_t length) { LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
return lru_pool[index].insert(keys, data, length); return lru_pool[index].insert(keys, data, length);
} }
int shrink() { int Shrink() {
int node_size = 0; int node_size = 0;
for (size_t i = 0; i < lru_pool.size(); i++) { for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count; node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
...@@ -329,7 +329,7 @@ class ScaledLRU { ...@@ -329,7 +329,7 @@ class ScaledLRU {
if (diff != 0) { if (diff != 0) {
__sync_fetch_and_add(&global_count, diff); __sync_fetch_and_add(&global_count, diff);
if (global_count > int(1.25 * size_limit)) { if (global_count > int(1.25 * size_limit)) {
thread_pool->enqueue([this]() -> int { return shrink(); }); thread_pool->enqueue([this]() -> int { return Shrink(); });
} }
} }
} }
...@@ -430,11 +430,11 @@ class GraphTable : public SparseTable { ...@@ -430,11 +430,11 @@ class GraphTable : public SparseTable {
virtual int32_t get_nodes_ids_by_ranges( virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res); std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res);
virtual int32_t initialize() { return 0; } virtual int32_t Initialize() { return 0; }
virtual int32_t initialize(const TableParameter &config, virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config); const FsClientParameter &fs_config);
virtual int32_t initialize(const GraphParameter &config); virtual int32_t Initialize(const GraphParameter &config);
int32_t load(const std::string &path, const std::string &param); int32_t Load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path); int32_t load_graph_split_config(const std::string &path);
int32_t load_edges(const std::string &path, bool reverse); int32_t load_edges(const std::string &path, bool reverse);
...@@ -452,26 +452,25 @@ class GraphTable : public SparseTable { ...@@ -452,26 +452,25 @@ class GraphTable : public SparseTable {
virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t pull_sparse(float *values, virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) {
const PullSparseValue &pull_value) {
return 0; return 0;
} }
virtual int32_t push_sparse(const uint64_t *keys, const float *values, virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) { size_t num) {
return 0; return 0;
} }
virtual int32_t clear_nodes(); virtual int32_t clear_nodes();
virtual void clear() {} virtual void Clear() {}
virtual int32_t flush() { return 0; } virtual int32_t Flush() { return 0; }
virtual int32_t shrink(const std::string &param) { return 0; } virtual int32_t Shrink(const std::string &param) { return 0; }
//指定保存路径 //指定保存路径
virtual int32_t save(const std::string &path, const std::string &converter) { virtual int32_t Save(const std::string &path, const std::string &converter) {
return 0; return 0;
} }
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t set_shard(size_t shard_idx, size_t server_num) { virtual int32_t SetShard(size_t shard_idx, size_t server_num) {
_shard_idx = shard_idx; _shard_idx = shard_idx;
/* /*
_shard_num is not used in graph_table, this following operation is for the _shard_num is not used in graph_table, this following operation is for the
......
...@@ -167,7 +167,7 @@ int64_t CommonSparseTable::LoadFromText( ...@@ -167,7 +167,7 @@ int64_t CommonSparseTable::LoadFromText(
return 0; return 0;
} }
int32_t CommonSparseTable::initialize() { int32_t CommonSparseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_); _shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
...@@ -200,15 +200,15 @@ int32_t CommonSparseTable::initialize() { ...@@ -200,15 +200,15 @@ int32_t CommonSparseTable::initialize() {
offset += dim; offset += dim;
} }
initialize_value(); InitializeValue();
initialize_optimizer(); InitializeOptimizer();
initialize_recorder(); InitializeRecorder();
return 0; return 0;
} }
int32_t CommonSparseTable::initialize_recorder() { return 0; } int32_t CommonSparseTable::InitializeRecorder() { return 0; }
int32_t CommonSparseTable::initialize_value() { int32_t CommonSparseTable::InitializeValue() {
auto common = _config.common(); auto common = _config.common();
shard_values_.reserve(task_pool_size_); shard_values_.reserve(task_pool_size_);
...@@ -223,18 +223,18 @@ int32_t CommonSparseTable::initialize_value() { ...@@ -223,18 +223,18 @@ int32_t CommonSparseTable::initialize_value() {
return 0; return 0;
} }
int32_t CommonSparseTable::initialize_optimizer() { int32_t CommonSparseTable::InitializeOptimizer() {
auto common = _config.common(); auto common = _config.common();
auto name = common.name(); auto name = common.name();
if (name == "sgd") { if (name == "sgd") {
optimizer_ = std::make_shared<SSGD>(value_names_, value_dims_, optimizer_ = std::make_shared<SSGD>(value_names_, value_dims_,
value_offsets_, value_idx_); value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam") { } else if (name == "adam") {
optimizer_ = std::make_shared<SAdam>(value_names_, value_dims_, optimizer_ = std::make_shared<SAdam>(value_names_, value_dims_,
value_offsets_, value_idx_); value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
} else if (name == "sum") { } else if (name == "sum") {
optimizer_ = std::make_shared<SSUM>(value_names_, value_dims_, optimizer_ = std::make_shared<SSUM>(value_names_, value_dims_,
value_offsets_, value_idx_); value_offsets_, value_idx_);
...@@ -246,13 +246,13 @@ int32_t CommonSparseTable::initialize_optimizer() { ...@@ -246,13 +246,13 @@ int32_t CommonSparseTable::initialize_optimizer() {
return 0; return 0;
} }
int32_t CommonSparseTable::set_global_lr(float* lr) { int32_t CommonSparseTable::SetGlobalLR(float* lr) {
_global_lr = lr; _global_lr = lr;
optimizer_->set_global_lr(_global_lr); optimizer_->SetGlobalLR(_global_lr);
return 0; return 0;
} }
int32_t CommonSparseTable::load(const std::string& dirname, int32_t CommonSparseTable::Load(const std::string& dirname,
const std::string& param) { const std::string& param) {
auto begin = GetCurrentUS(); auto begin = GetCurrentUS();
rwlock_->WRLock(); rwlock_->WRLock();
...@@ -276,7 +276,7 @@ int32_t CommonSparseTable::load(const std::string& dirname, ...@@ -276,7 +276,7 @@ int32_t CommonSparseTable::load(const std::string& dirname,
return 0; return 0;
} }
int32_t CommonSparseTable::save(const std::string& dirname, int32_t CommonSparseTable::Save(const std::string& dirname,
const std::string& param) { const std::string& param) {
auto begin = GetCurrentUS(); auto begin = GetCurrentUS();
rwlock_->WRLock(); rwlock_->WRLock();
...@@ -322,7 +322,7 @@ int32_t CommonSparseTable::save(const std::string& dirname, ...@@ -322,7 +322,7 @@ int32_t CommonSparseTable::save(const std::string& dirname,
return 0; return 0;
} }
std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() { std::pair<int64_t, int64_t> CommonSparseTable::PrintTableStat() {
int64_t feasign_size = 0; int64_t feasign_size = 0;
int64_t mf_size = 0; int64_t mf_size = 0;
...@@ -335,7 +335,7 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() { ...@@ -335,7 +335,7 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
return {feasign_size, mf_size}; return {feasign_size, mf_size};
} }
int32_t CommonSparseTable::pour() { int32_t CommonSparseTable::Pour() {
std::vector<float> values; std::vector<float> values;
std::vector<uint64_t> keys; std::vector<uint64_t> keys;
...@@ -349,7 +349,7 @@ int32_t CommonSparseTable::pour() { ...@@ -349,7 +349,7 @@ int32_t CommonSparseTable::pour() {
std::copy(reservoir.values.begin(), reservoir.values.end(), std::copy(reservoir.values.begin(), reservoir.values.end(),
std::back_inserter(values)); std::back_inserter(values));
} }
_push_sparse(keys.data(), values.data(), pull_reservoir_.size()); _PushSparse(keys.data(), values.data(), pull_reservoir_.size());
pull_reservoir_.clear(); pull_reservoir_.clear();
return 0; return 0;
...@@ -360,11 +360,11 @@ int32_t CommonSparseTable::Pull(TableContext& context) { ...@@ -360,11 +360,11 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
if (context.use_ptr) { if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values; char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys; const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num); return PullSparsePtr(pull_values, keys, context.num);
} else { } else {
float* pull_values = context.pull_context.values; float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value; const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value); return PullSparse(pull_values, pull_value);
} }
} }
...@@ -373,16 +373,16 @@ int32_t CommonSparseTable::Push(TableContext& context) { ...@@ -373,16 +373,16 @@ int32_t CommonSparseTable::Push(TableContext& context) {
if (context.push_context.values != nullptr) { if (context.push_context.values != nullptr) {
const float* values = context.push_context.values; const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys; const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num); return PushSparse(keys, values, context.num);
} else { } else {
const float** values = context.push_context.ptr_values; const float** values = context.push_context.ptr_values;
const uint64_t* keys = context.push_context.keys; const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num); return PushSparse(keys, values, context.num);
} }
} }
int32_t CommonSparseTable::pull_sparse(float* pull_values, int32_t CommonSparseTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
...@@ -421,8 +421,8 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, ...@@ -421,8 +421,8 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values,
return 0; return 0;
} }
int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, int32_t CommonSparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys, size_t num) { const uint64_t* keys, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -458,8 +458,8 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, ...@@ -458,8 +458,8 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
return 0; return 0;
} }
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -474,7 +474,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -474,7 +474,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &values, num, &offset_bucket]() -> int { [this, shard_id, &keys, &values, num, &offset_bucket]() -> int {
auto& offsets = offset_bucket[shard_id]; auto& offsets = offset_bucket[shard_id];
optimizer_->update(keys, values, num, offsets, optimizer_->Update(keys, values, num, offsets,
shard_values_[shard_id].get()); shard_values_[shard_id].get());
return 0; return 0;
}); });
...@@ -486,8 +486,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -486,8 +486,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
return 0; return 0;
} }
int32_t CommonSparseTable::push_sparse(const uint64_t* keys, int32_t CommonSparseTable::PushSparse(const uint64_t* keys, const float* values,
const float* values, size_t num) { size_t num) {
if (sync) { if (sync) {
std::future<int> task = std::future<int> task =
_shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int { _shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int {
...@@ -506,20 +506,20 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys, ...@@ -506,20 +506,20 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
}); });
task.wait(); task.wait();
} else { } else {
_push_sparse(keys, values, num); _PushSparse(keys, values, num);
} }
return 0; return 0;
} }
int32_t CommonSparseTable::push_sparse(const uint64_t* keys, int32_t CommonSparseTable::PushSparse(const uint64_t* keys,
const float** values, size_t num) { const float** values, size_t num) {
_push_sparse(keys, values, num); _PushSparse(keys, values, num);
return 0; return 0;
} }
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
const float** values, size_t num) { const float** values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -536,7 +536,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -536,7 +536,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
auto& offsets = offset_bucket[shard_id]; auto& offsets = offset_bucket[shard_id];
for (size_t i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
std::vector<uint64_t> tmp_off = {0}; std::vector<uint64_t> tmp_off = {0};
optimizer_->update(keys + offsets[i], values[offsets[i]], num, optimizer_->Update(keys + offsets[i], values[offsets[i]], num,
tmp_off, shard_values_[shard_id].get()); tmp_off, shard_values_[shard_id].get());
} }
return 0; return 0;
...@@ -549,8 +549,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -549,8 +549,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
return 0; return 0;
} }
int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, int32_t CommonSparseTable::PushSparseParam(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -585,21 +585,21 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, ...@@ -585,21 +585,21 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
return 0; return 0;
} }
int32_t CommonSparseTable::flush() { return 0; } int32_t CommonSparseTable::Flush() { return 0; }
int32_t CommonSparseTable::shrink(const std::string& param) { int32_t CommonSparseTable::Shrink(const std::string& param) {
int threshold = std::stoi(param); int threshold = std::stoi(param);
VLOG(3) << "sparse table shrink: " << threshold; VLOG(3) << "sparse table Shrink: " << threshold;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// shrink // Shrink
VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink"; VLOG(4) << shard_id << " " << task_pool_size_ << " begin Shrink";
shard_values_[shard_id]->Shrink(threshold); shard_values_[shard_id]->Shrink(threshold);
} }
return 0; return 0;
} }
void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; } void CommonSparseTable::Clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -114,25 +114,23 @@ class CommonSparseTable : public SparseTable { ...@@ -114,25 +114,23 @@ class CommonSparseTable : public SparseTable {
virtual ~CommonSparseTable() {} virtual ~CommonSparseTable() {}
// unused method begin // unused method begin
virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t push_dense_param(const float* values, size_t num) { virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; }
return 0; virtual int32_t PushDense(const float* values, size_t num) { return 0; }
}
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context); virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context); virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t Initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t InitializeValue();
virtual int32_t initialize_optimizer(); virtual int32_t InitializeOptimizer();
virtual int32_t initialize_recorder(); virtual int32_t InitializeRecorder();
virtual int32_t load(const std::string& path, const std::string& param); virtual int32_t Load(const std::string& path, const std::string& param);
virtual int32_t save(const std::string& path, const std::string& param); virtual int32_t Save(const std::string& path, const std::string& param);
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total); const size_t shard_idx, const int64_t total);
...@@ -150,34 +148,34 @@ class CommonSparseTable : public SparseTable { ...@@ -150,34 +148,34 @@ class CommonSparseTable : public SparseTable {
const int pserver_id, const int pserver_num, const int local_shard_num, const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks); std::vector<std::shared_ptr<ValueBlock>>* blocks);
virtual std::pair<int64_t, int64_t> print_table_stat(); virtual std::pair<int64_t, int64_t> PrintTableStat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num); size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float* values, virtual int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num); size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float** values, virtual int32_t PushSparse(const uint64_t* keys, const float** values,
size_t num); size_t num);
// only for sparse geo table // only for sparse geo table
virtual int32_t push_sparse_param(const uint64_t* keys, const float* values, virtual int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num); size_t num);
virtual int32_t set_global_lr(float* lr) override; virtual int32_t SetGlobalLR(float* lr) override;
virtual int32_t pour(); virtual int32_t Pour();
virtual int32_t flush(); virtual int32_t Flush();
virtual int32_t shrink(const std::string& param); virtual int32_t Shrink(const std::string& param);
virtual void clear(); virtual void Clear();
protected: protected:
virtual int32_t _push_sparse(const uint64_t* keys, const float* values, virtual int32_t _PushSparse(const uint64_t* keys, const float* values,
size_t num); size_t num);
virtual int32_t _push_sparse(const uint64_t* keys, const float** values, virtual int32_t _PushSparse(const uint64_t* keys, const float** values,
size_t num); size_t num);
protected: protected:
const int task_pool_size_ = 11; const int task_pool_size_ = 11;
......
...@@ -71,11 +71,11 @@ class SparseTable : public Table { ...@@ -71,11 +71,11 @@ class SparseTable : public Table {
SparseTable() {} SparseTable() {}
virtual ~SparseTable() {} virtual ~SparseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t PushDense(const float *values, size_t num) override { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num, static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) { uint32_t server_num) {
...@@ -97,19 +97,17 @@ class DenseTable : public Table { ...@@ -97,19 +97,17 @@ class DenseTable : public Table {
DenseTable() {} DenseTable() {}
virtual ~DenseTable() {} virtual ~DenseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t pull_sparse(float *values, int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override { const PullSparseValue &pull_value) override {
return 0; return 0;
} }
int32_t push_sparse(const uint64_t *keys, const float *values, int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override { size_t num) override {
return 0; return 0;
} }
int32_t push_dense_param(const float *values, size_t num) override { int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
return 0; int32_t Shrink(const std::string &param) override { return 0; }
}
int32_t shrink(const std::string &param) override { return 0; }
}; };
class BarrierTable : public Table { class BarrierTable : public Table {
...@@ -117,44 +115,42 @@ class BarrierTable : public Table { ...@@ -117,44 +115,42 @@ class BarrierTable : public Table {
BarrierTable() {} BarrierTable() {}
virtual ~BarrierTable() {} virtual ~BarrierTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values, int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override { const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0; return 0;
} }
int32_t push_dense_param(const float *values, size_t num) override { int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0; return 0;
} }
int32_t shrink(const std::string &param) override { return 0; } int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
virtual void clear() {} int32_t Shrink(const std::string &param) override { return 0; }
virtual int32_t flush() { return 0; } virtual void Clear() {}
virtual int32_t load(const std::string &path, const std::string &param) { virtual int32_t Flush() { return 0; }
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t save(const std::string &path, const std::string &param) { virtual int32_t Save(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t initialize() override; virtual int32_t Initialize() override;
// only for barrier // only for barrier
// 0: send_barrier 1: recv_barrier 2: complete // 0: send_barrier 1: recv_barrier 2: complete
virtual int32_t barrier(const uint32_t trainer_id, virtual int32_t Barrier(const uint32_t trainer_id,
const std::string barrier_type) override; const std::string barrier_type) override;
virtual int32_t set_table_map( virtual int32_t SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override; std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;
private: private:
......
...@@ -34,9 +34,9 @@ class DenseOptimizer { ...@@ -34,9 +34,9 @@ class DenseOptimizer {
DenseOptimizer() {} DenseOptimizer() {}
explicit DenseOptimizer(const CommonAccessorParameter& accessor, explicit DenseOptimizer(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {} std::vector<std::vector<float>>* values) {}
virtual void update(const float* update_values, size_t num, int begin, virtual void Update(const float* update_values, size_t num, int begin,
int end) = 0; int end) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; } virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
protected: protected:
float* global_learning_rate_; float* global_learning_rate_;
...@@ -55,7 +55,7 @@ class DSUM : public DenseOptimizer { ...@@ -55,7 +55,7 @@ class DSUM : public DenseOptimizer {
} }
} }
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
GetBlas<float>().VADD(update_numel, update_values + begin, param + begin, GetBlas<float>().VADD(update_numel, update_values + begin, param + begin,
...@@ -81,7 +81,7 @@ class DSGD : public DenseOptimizer { ...@@ -81,7 +81,7 @@ class DSGD : public DenseOptimizer {
} }
} }
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
std::vector<float> grads; std::vector<float> grads;
...@@ -134,7 +134,7 @@ class DAdam : public DenseOptimizer { ...@@ -134,7 +134,7 @@ class DAdam : public DenseOptimizer {
// make sure common_dense_table.task_pool_size_ == 1; // make sure common_dense_table.task_pool_size_ == 1;
// otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication // otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
std::vector<float> grad, grad2, tmp; std::vector<float> grad, grad2, tmp;
...@@ -214,7 +214,7 @@ class DAdamD2Sum : public DenseOptimizer { ...@@ -214,7 +214,7 @@ class DAdamD2Sum : public DenseOptimizer {
} }
} }
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_ada_g2sum(ada_g2sum + begin, 1, Eigen::Map<Eigen::MatrixXf> mat_ada_g2sum(ada_g2sum + begin, 1,
...@@ -276,7 +276,7 @@ class DSummary : public DenseOptimizer { ...@@ -276,7 +276,7 @@ class DSummary : public DenseOptimizer {
} }
} }
void update(const float* update_values, size_t num, int begin, void Update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel); Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel);
......
...@@ -40,11 +40,11 @@ class SparseOptimizer { ...@@ -40,11 +40,11 @@ class SparseOptimizer {
value_offsets_(value_offsets), value_offsets_(value_offsets),
value_idx_(value_idx) {} value_idx_(value_idx) {}
virtual void update(const uint64_t* keys, const float* update_values, virtual void Update(const uint64_t* keys, const float* update_values,
size_t num, const std::vector<uint64_t>& offsets, size_t num, const std::vector<uint64_t>& offsets,
ValueBlock* block) = 0; ValueBlock* block) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; } virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
const std::vector<std::string>& value_names_; const std::vector<std::string>& value_names_;
const std::vector<int>& value_dims_; const std::vector<int>& value_dims_;
...@@ -70,7 +70,7 @@ class SSUM : public SparseOptimizer { ...@@ -70,7 +70,7 @@ class SSUM : public SparseOptimizer {
update_numel = value_dims.at(idx); update_numel = value_dims.at(idx);
} }
void update(const uint64_t* keys, const float* update_values, size_t num, void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets, const std::vector<uint64_t>& offsets,
ValueBlock* block) override { ValueBlock* block) override {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
...@@ -100,7 +100,7 @@ class SSGD : public SparseOptimizer { ...@@ -100,7 +100,7 @@ class SSGD : public SparseOptimizer {
lr_offset = value_offsets.at(idx); lr_offset = value_offsets.at(idx);
} }
void update(const uint64_t* keys, const float* update_values, size_t num, void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets, const std::vector<uint64_t>& offsets,
ValueBlock* block) override { ValueBlock* block) override {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
...@@ -156,7 +156,7 @@ class SAdam : public SparseOptimizer { ...@@ -156,7 +156,7 @@ class SAdam : public SparseOptimizer {
epsilon = 1.0e-8; epsilon = 1.0e-8;
} }
void update(const uint64_t* keys, const float* update_values, size_t num, void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets, const std::vector<uint64_t>& offsets,
ValueBlock* block) override { ValueBlock* block) override {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
......
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys,
const float* values, const float* values, size_t num) {
size_t num) { VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin "
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param begin " "PushSparseParam "
"push_sparse_param "
<< num; << num;
auto shard_num = _task_pool_size; auto shard_num = _task_pool_size;
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
...@@ -31,8 +30,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, ...@@ -31,8 +30,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
auto y = keys[x] % shard_num; auto y = keys[x] % shard_num;
offset_bucket[y].push_back(x); offset_bucket[y].push_back(x);
if (x < 10) { if (x < 10) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param key: " VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam key: " << keys[x]
<< keys[x] << " shard: " << y; << " shard: " << y;
} }
} }
...@@ -51,8 +50,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, ...@@ -51,8 +50,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
feature_value.resize(_dim); feature_value.resize(_dim);
std::copy_n(values + _dim * offset, _dim, feature_value.data()); std::copy_n(values + _dim * offset, _dim, feature_value.data());
if (i < 10) { if (i < 10) {
VLOG(5) << "MemorySparseGeoTable::push_sparse_param " VLOG(5) << "MemorySparseGeoTable::PushSparseParam "
"push_sparse_param key " "PushSparseParam key "
<< id << " value[0]: " << (values + _dim * offset)[0] << id << " value[0]: " << (values + _dim * offset)[0]
<< " data: " << feature_value.data()[0] << " data: " << feature_value.data()[0]
<< " value[-1]: " << (values + _dim * offset)[_dim - 1] << " value[-1]: " << (values + _dim * offset)[_dim - 1]
...@@ -69,9 +68,9 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, ...@@ -69,9 +68,9 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
return 0; return 0;
} }
int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, int32_t MemorySparseGeoTable::PullGeoParam(const uint32_t trainer_id,
std::vector<float>* values, std::vector<float>* values,
std::vector<uint64_t>* ids) { std::vector<uint64_t>* ids) {
_geo_recorder->GetAndClear(trainer_id, ids); _geo_recorder->GetAndClear(trainer_id, ids);
VLOG(5) VLOG(5)
<< "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id " << "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id "
...@@ -86,23 +85,23 @@ int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, ...@@ -86,23 +85,23 @@ int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id,
pull_value.frequencies_ = frequencies.data(); pull_value.frequencies_ = frequencies.data();
values->resize(ids->size() * _dim); values->resize(ids->size() * _dim);
pull_sparse(values->data(), pull_value); PullSparse(values->data(), pull_value);
return 0; return 0;
} }
int32_t MemorySparseGeoTable::push_sparse(const uint64_t* keys, int32_t MemorySparseGeoTable::PushSparse(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse keys[0]" << keys[0] VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparse keys[0]" << keys[0]
<< " key_num: " << num; << " key_num: " << num;
std::vector<uint64_t> ids; std::vector<uint64_t> ids;
ids.resize(num); ids.resize(num);
std::copy_n(keys, num, ids.begin()); std::copy_n(keys, num, ids.begin());
_geo_recorder->Update(ids); _geo_recorder->Update(ids);
_push_sparse(keys, values, num); _PushSparse(keys, values, num);
return 0; return 0;
} }
int32_t MemorySparseGeoTable::initialize() { int32_t MemorySparseGeoTable::Initialize() {
if (!_geo_recorder) { if (!_geo_recorder) {
auto trainers = _config.common().trainer_num(); auto trainers = _config.common().trainer_num();
_geo_recorder = std::make_shared<GeoRecorder>(trainers); _geo_recorder = std::make_shared<GeoRecorder>(trainers);
...@@ -118,8 +117,8 @@ int32_t MemorySparseGeoTable::initialize() { ...@@ -118,8 +117,8 @@ int32_t MemorySparseGeoTable::initialize() {
return 0; return 0;
} }
int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, int32_t MemorySparseGeoTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = _task_pool_size; auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
...@@ -146,13 +145,13 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, ...@@ -146,13 +145,13 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
auto& feature_value = local_shard[key]; auto& feature_value = local_shard[key];
feature_value.resize(_dim); feature_value.resize(_dim);
memset(feature_value.data(), 0, sizeof(float) * _dim); memset(feature_value.data(), 0, sizeof(float) * _dim);
VLOG(0) << "MemorySparseGeoTable pull_sparse key not found!!! " VLOG(0) << "MemorySparseGeoTable PullSparse key not found!!! "
<< key; << key;
itr = local_shard.find(key); itr = local_shard.find(key);
} }
memcpy(select_data, itr.value().data(), _dim * sizeof(float)); memcpy(select_data, itr.value().data(), _dim * sizeof(float));
VLOG(5) << "DEBUG MemorySparseGeoTable::pull_sparse key: " << key VLOG(5) << "DEBUG MemorySparseGeoTable::PullSparse key: " << key
<< " select_data[0] " << select_data[0] << " select_data[0] " << select_data[0]
<< " value[0]: " << itr.value().data()[0]; << " value[0]: " << itr.value().data()[0];
} }
...@@ -167,8 +166,8 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, ...@@ -167,8 +166,8 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
return 0; return 0;
} }
int32_t MemorySparseGeoTable::_push_sparse(const uint64_t* keys, int32_t MemorySparseGeoTable::_PushSparse(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
auto shard_num = _task_pool_size; auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(shard_num); std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(shard_num);
......
...@@ -40,31 +40,31 @@ class MemorySparseGeoTable : public SparseTable { ...@@ -40,31 +40,31 @@ class MemorySparseGeoTable : public SparseTable {
MemorySparseGeoTable() { _geo_recorder = nullptr; } MemorySparseGeoTable() { _geo_recorder = nullptr; }
virtual ~MemorySparseGeoTable() {} virtual ~MemorySparseGeoTable() {}
virtual int32_t initialize(); virtual int32_t Initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t load(const std::string& path, const std::string& param) { virtual int32_t Load(const std::string& path, const std::string& param) {
return 0; return 0;
} }
virtual int32_t save(const std::string& path, const std::string& param) { virtual int32_t Save(const std::string& path, const std::string& param) {
return 0; return 0;
} }
virtual int32_t Pull(TableContext& context) { return 0; } virtual int32_t Pull(TableContext& context) { return 0; }
virtual int32_t Push(TableContext& context) { return 0; } virtual int32_t Push(TableContext& context) { return 0; }
virtual int32_t flush() { return 0; } virtual int32_t Flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; } virtual int32_t Shrink(const std::string& param) { return 0; }
virtual void clear() { return; } virtual void Clear() { return; }
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
int32_t push_sparse_param(const uint64_t* keys, const float* values, int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num); size_t num);
// TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse // TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values, int32_t PullGeoParam(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys); std::vector<uint64_t>* keys);
int32_t push_sparse(const uint64_t* keys, const float* values, int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num) override; size_t num) override;
int32_t _push_sparse(const uint64_t* keys, const float* values, size_t num); int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num);
// int32_t _pull_sparse(float* pull_values, const PullSparseValue& // int32_t _pull_sparse(float* pull_values, const PullSparseValue&
// pull_value); // pull_value);
......
...@@ -31,7 +31,7 @@ bool FLAGS_pserver_create_value_when_push = true; ...@@ -31,7 +31,7 @@ bool FLAGS_pserver_create_value_when_push = true;
int FLAGS_pserver_table_save_max_retry = 3; int FLAGS_pserver_table_save_max_retry = 3;
bool FLAGS_pserver_enable_create_feasign_randomly = false; bool FLAGS_pserver_enable_create_feasign_randomly = false;
int32_t MemorySparseTable::initialize() { int32_t MemorySparseTable::Initialize() {
_shards_task_pool.resize(_task_pool_size); _shards_task_pool.resize(_task_pool_size);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
...@@ -39,12 +39,12 @@ int32_t MemorySparseTable::initialize() { ...@@ -39,12 +39,12 @@ int32_t MemorySparseTable::initialize() {
auto& profiler = CostProfiler::instance(); auto& profiler = CostProfiler::instance();
profiler.register_profiler("pserver_sparse_update_all"); profiler.register_profiler("pserver_sparse_update_all");
profiler.register_profiler("pserver_sparse_select_all"); profiler.register_profiler("pserver_sparse_select_all");
initialize_value(); InitializeValue();
VLOG(0) << "initalize MemorySparseTable succ"; VLOG(0) << "initalize MemorySparseTable succ";
return 0; return 0;
} }
int32_t MemorySparseTable::initialize_value() { int32_t MemorySparseTable::InitializeValue() {
_sparse_table_shard_num = static_cast<int>(_config.shard_num()); _sparse_table_shard_num = static_cast<int>(_config.shard_num());
_avg_local_shard_num = _avg_local_shard_num =
SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num); SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
...@@ -64,14 +64,14 @@ int32_t MemorySparseTable::initialize_value() { ...@@ -64,14 +64,14 @@ int32_t MemorySparseTable::initialize_value() {
return 0; return 0;
} }
int32_t MemorySparseTable::load(const std::string& path, int32_t MemorySparseTable::Load(const std::string& path,
const std::string& param) { const std::string& param) {
std::string table_path = table_dir(path); std::string table_path = TableDir(path);
auto file_list = _afs_client.list(table_path); auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end()); std::sort(file_list.begin(), file_list.end());
for (auto file : file_list) { for (auto file : file_list) {
VLOG(1) << "MemorySparseTable::load() file list: " << file; VLOG(1) << "MemorySparseTable::Load() file list: " << file;
} }
int load_param = atoi(param.c_str()); int load_param = atoi(param.c_str());
...@@ -154,9 +154,9 @@ int32_t MemorySparseTable::load(const std::string& path, ...@@ -154,9 +154,9 @@ int32_t MemorySparseTable::load(const std::string& path,
return 0; return 0;
} }
int32_t MemorySparseTable::load_local_fs(const std::string& path, int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
const std::string& param) { const std::string& param) {
std::string table_path = table_dir(path); std::string table_path = TableDir(path);
auto file_list = paddle::framework::localfs_list(table_path); auto file_list = paddle::framework::localfs_list(table_path);
int load_param = atoi(param.c_str()); int load_param = atoi(param.c_str());
...@@ -225,12 +225,12 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path, ...@@ -225,12 +225,12 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path,
return 0; return 0;
} }
int32_t MemorySparseTable::save(const std::string& dirname, int32_t MemorySparseTable::Save(const std::string& dirname,
const std::string& param) { const std::string& param) {
VLOG(0) << "MemorySparseTable::save dirname: " << dirname; VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
int save_param = int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
std::string table_path = table_dir(dirname); std::string table_path = TableDir(dirname);
_afs_client.remove(paddle::string::format_string( _afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx)); "%s/part-%03d-*", table_path.c_str(), _shard_idx));
std::atomic<uint32_t> feasign_size_all{0}; std::atomic<uint32_t> feasign_size_all{0};
...@@ -309,12 +309,12 @@ int32_t MemorySparseTable::save(const std::string& dirname, ...@@ -309,12 +309,12 @@ int32_t MemorySparseTable::save(const std::string& dirname,
return 0; return 0;
} }
int32_t MemorySparseTable::save_local_fs(const std::string& dirname, int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
const std::string& param, const std::string& param,
const std::string& prefix) { const std::string& prefix) {
int save_param = int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
std::string table_path = table_dir(dirname); std::string table_path = TableDir(dirname);
int feasign_cnt = 0; int feasign_cnt = 0;
size_t file_start_idx = _avg_local_shard_num * _shard_idx; size_t file_start_idx = _avg_local_shard_num * _shard_idx;
...@@ -349,7 +349,7 @@ int32_t MemorySparseTable::save_local_fs(const std::string& dirname, ...@@ -349,7 +349,7 @@ int32_t MemorySparseTable::save_local_fs(const std::string& dirname,
return 0; return 0;
} }
int64_t MemorySparseTable::local_size() { int64_t MemorySparseTable::LocalSize() {
int64_t local_size = 0; int64_t local_size = 0;
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (size_t i = 0; i < _real_local_shard_num; ++i) {
local_size += _local_shards[i].size(); local_size += _local_shards[i].size();
...@@ -357,7 +357,7 @@ int64_t MemorySparseTable::local_size() { ...@@ -357,7 +357,7 @@ int64_t MemorySparseTable::local_size() {
return local_size; return local_size;
} }
int64_t MemorySparseTable::local_mf_size() { int64_t MemorySparseTable::LocalMFSize() {
std::vector<int64_t> size_arr(_real_local_shard_num, 0); std::vector<int64_t> size_arr(_real_local_shard_num, 0);
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
int64_t ret_size = 0; int64_t ret_size = 0;
...@@ -384,9 +384,9 @@ int64_t MemorySparseTable::local_mf_size() { ...@@ -384,9 +384,9 @@ int64_t MemorySparseTable::local_mf_size() {
return ret_size; return ret_size;
} }
std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() { std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
int64_t feasign_size = local_size(); int64_t feasign_size = LocalSize();
int64_t mf_size = local_mf_size(); int64_t mf_size = LocalMFSize();
return {feasign_size, mf_size}; return {feasign_size, mf_size};
} }
...@@ -395,11 +395,11 @@ int32_t MemorySparseTable::Pull(TableContext& context) { ...@@ -395,11 +395,11 @@ int32_t MemorySparseTable::Pull(TableContext& context) {
if (context.use_ptr) { if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values; char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys; const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num); return PullSparsePtr(pull_values, keys, context.num);
} else { } else {
float* pull_values = context.pull_context.values; float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value; const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value); return PullSparse(pull_values, pull_value);
} }
} }
...@@ -407,11 +407,11 @@ int32_t MemorySparseTable::Push(TableContext& context) { ...@@ -407,11 +407,11 @@ int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse); CHECK(context.value_type == Sparse);
const uint64_t* keys = context.push_context.keys; const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, context.push_context.values, context.num); return PushSparse(keys, context.push_context.values, context.num);
} }
int32_t MemorySparseTable::pull_sparse(float* pull_values, int32_t MemorySparseTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
CostTimer timer("pserver_sparse_select_all"); CostTimer timer("pserver_sparse_select_all");
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
...@@ -479,8 +479,8 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, ...@@ -479,8 +479,8 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
return 0; return 0;
} }
int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys, size_t num) { const uint64_t* keys, size_t num) {
CostTimer timer("pscore_sparse_select_all"); CostTimer timer("pscore_sparse_select_all");
size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float); size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
...@@ -530,8 +530,8 @@ int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, ...@@ -530,8 +530,8 @@ int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values,
return 0; return 0;
} }
int32_t MemorySparseTable::push_sparse(const uint64_t* keys, int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
const float* values, size_t num) { size_t num) {
CostTimer timer("pserver_sparse_update_all"); CostTimer timer("pserver_sparse_update_all");
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys( std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
...@@ -603,14 +603,14 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys, ...@@ -603,14 +603,14 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
return 0; return 0;
} }
int32_t MemorySparseTable::push_sparse(const uint64_t* keys, int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
const float** values, size_t num) { const float** values, size_t num) {
_push_sparse(keys, values, num); _PushSparse(keys, values, num);
return 0; return 0;
} }
int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, int32_t MemorySparseTable::_PushSparse(const uint64_t* keys,
const float** values, size_t num) { const float** values, size_t num) {
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys( std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
_real_local_shard_num); _real_local_shard_num);
...@@ -677,13 +677,13 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, ...@@ -677,13 +677,13 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
return 0; return 0;
} }
int32_t MemorySparseTable::flush() { return 0; } int32_t MemorySparseTable::Flush() { return 0; }
int32_t MemorySparseTable::shrink(const std::string& param) { int32_t MemorySparseTable::Shrink(const std::string& param) {
VLOG(0) << "MemorySparseTable::shrink"; VLOG(0) << "MemorySparseTable::Shrink";
// TODO(zhaocaibei123): implement with multi-thread // TODO(zhaocaibei123): implement with multi-thread
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
// shrink // Shrink
auto& shard = _local_shards[shard_id]; auto& shard = _local_shards[shard_id];
for (auto it = shard.begin(); it != shard.end();) { for (auto it = shard.begin(); it != shard.end();) {
if (_value_accesor->Shrink(it.value().data())) { if (_value_accesor->Shrink(it.value().data())) {
...@@ -696,7 +696,7 @@ int32_t MemorySparseTable::shrink(const std::string& param) { ...@@ -696,7 +696,7 @@ int32_t MemorySparseTable::shrink(const std::string& param) {
return 0; return 0;
} }
void MemorySparseTable::clear() { VLOG(0) << "clear coming soon"; } void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -41,50 +41,48 @@ class MemorySparseTable : public SparseTable { ...@@ -41,50 +41,48 @@ class MemorySparseTable : public SparseTable {
virtual ~MemorySparseTable() {} virtual ~MemorySparseTable() {}
// unused method begin // unused method begin
virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t push_dense_param(const float* values, size_t num) { virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; }
return 0; virtual int32_t PushDense(const float* values, size_t num) { return 0; }
}
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context); virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context); virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t Initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t InitializeValue();
virtual int32_t load(const std::string& path, const std::string& param); virtual int32_t Load(const std::string& path, const std::string& param);
virtual int32_t save(const std::string& path, const std::string& param); virtual int32_t Save(const std::string& path, const std::string& param);
int32_t load_local_fs(const std::string& path, const std::string& param); int32_t LoadLocalFS(const std::string& path, const std::string& param);
int32_t save_local_fs(const std::string& path, const std::string& param, int32_t SaveLocalFS(const std::string& path, const std::string& param,
const std::string& prefix); const std::string& prefix);
int64_t local_size(); int64_t LocalSize();
int64_t local_mf_size(); int64_t LocalMFSize();
virtual std::pair<int64_t, int64_t> print_table_stat(); virtual std::pair<int64_t, int64_t> PrintTableStat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num); size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float* values, virtual int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num); size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float** values, virtual int32_t PushSparse(const uint64_t* keys, const float** values,
size_t num); size_t num);
virtual int32_t flush(); virtual int32_t Flush();
virtual int32_t shrink(const std::string& param); virtual int32_t Shrink(const std::string& param);
virtual void clear(); virtual void Clear();
protected: protected:
virtual int32_t _push_sparse(const uint64_t* keys, const float** values, virtual int32_t _PushSparse(const uint64_t* keys, const float** values,
size_t num); size_t num);
protected: protected:
const int _task_pool_size = 24; const int _task_pool_size = 24;
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id, int32_t SparseGeoTable::PullGeoParam(const uint32_t trainer_id,
std::vector<float>* values, std::vector<float>* values,
std::vector<uint64_t>* ids) { std::vector<uint64_t>* ids) {
geo_recorder->GetAndClear(trainer_id, ids); geo_recorder->GetAndClear(trainer_id, ids);
auto dim = _config.common().dims()[0]; auto dim = _config.common().dims()[0];
...@@ -32,21 +32,21 @@ int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id, ...@@ -32,21 +32,21 @@ int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id,
pull_value.frequencies_ = frequencies.data(); pull_value.frequencies_ = frequencies.data();
values->resize(ids->size() * dim); values->resize(ids->size() * dim);
CommonSparseTable::pull_sparse(values->data(), pull_value); CommonSparseTable::PullSparse(values->data(), pull_value);
return 0; return 0;
} }
int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values, int32_t SparseGeoTable::PushSparse(const uint64_t* keys, const float* values,
size_t num) { size_t num) {
std::vector<uint64_t> ids; std::vector<uint64_t> ids;
ids.resize(num); ids.resize(num);
std::copy_n(keys, num, ids.begin()); std::copy_n(keys, num, ids.begin());
geo_recorder->Update(ids); geo_recorder->Update(ids);
CommonSparseTable::push_sparse(keys, values, num); CommonSparseTable::PushSparse(keys, values, num);
return 0; return 0;
} }
int32_t SparseGeoTable::initialize_value() { int32_t SparseGeoTable::InitializeValue() {
auto common = _config.common(); auto common = _config.common();
shard_values_.reserve(task_pool_size_); shard_values_.reserve(task_pool_size_);
...@@ -82,7 +82,7 @@ int32_t SparseGeoTable::initialize_value() { ...@@ -82,7 +82,7 @@ int32_t SparseGeoTable::initialize_value() {
auto pull_value = PullSparseValue(ids, fres, param_dim_); auto pull_value = PullSparseValue(ids, fres, param_dim_);
std::vector<float> pulls; std::vector<float> pulls;
pulls.resize(bucket_feasigns * param_dim_); pulls.resize(bucket_feasigns * param_dim_);
pull_sparse(pulls.data(), pull_value); PullSparse(pulls.data(), pull_value);
} }
return 0; return 0;
} }
......
...@@ -44,15 +44,15 @@ class SparseGeoTable : public CommonSparseTable { ...@@ -44,15 +44,15 @@ class SparseGeoTable : public CommonSparseTable {
explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; } explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; }
virtual ~SparseGeoTable() {} virtual ~SparseGeoTable() {}
virtual int32_t initialize_value(); virtual int32_t InitializeValue();
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values, int32_t PullGeoParam(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys); std::vector<uint64_t>* keys);
int32_t push_sparse(const uint64_t* keys, const float* values, int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num) override; size_t num) override;
virtual int32_t initialize_recorder() { virtual int32_t InitializeRecorder() {
if (!geo_recorder) { if (!geo_recorder) {
auto trainers = _config.common().trainer_num(); auto trainers = _config.common().trainer_num();
geo_recorder = std::make_shared<GeoRecorder>(trainers); geo_recorder = std::make_shared<GeoRecorder>(trainers);
......
...@@ -20,7 +20,7 @@ DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file"); ...@@ -20,7 +20,7 @@ DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file");
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t SSDSparseTable::initialize() { int32_t SSDSparseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_); _shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
...@@ -53,9 +53,9 @@ int32_t SSDSparseTable::initialize() { ...@@ -53,9 +53,9 @@ int32_t SSDSparseTable::initialize() {
offset += dim; offset += dim;
} }
initialize_value(); InitializeValue();
initialize_optimizer(); InitializeOptimizer();
initialize_recorder(); InitializeRecorder();
_db = paddle::distributed::RocksDBHandler::GetInstance(); _db = paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize(FLAGS_rocksdb_path, task_pool_size_); _db->initialize(FLAGS_rocksdb_path, task_pool_size_);
return 0; return 0;
...@@ -66,18 +66,18 @@ int32_t SSDSparseTable::Pull(TableContext& context) { ...@@ -66,18 +66,18 @@ int32_t SSDSparseTable::Pull(TableContext& context) {
if (context.use_ptr) { if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values; char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys; const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num); return PullSparsePtr(pull_values, keys, context.num);
} else { } else {
float* pull_values = context.pull_context.values; float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value; const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value); return PullSparse(pull_values, pull_value);
} }
} }
int32_t SSDSparseTable::Push(TableContext& context) { return 0; } int32_t SSDSparseTable::Push(TableContext& context) { return 0; }
int32_t SSDSparseTable::pull_sparse(float* pull_values, int32_t SSDSparseTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
...@@ -140,8 +140,8 @@ int32_t SSDSparseTable::pull_sparse(float* pull_values, ...@@ -140,8 +140,8 @@ int32_t SSDSparseTable::pull_sparse(float* pull_values,
return 0; return 0;
} }
int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values, int32_t SSDSparseTable::PullSparsePtr(char** pull_values, const uint64_t* keys,
const uint64_t* keys, size_t num) { size_t num) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
...@@ -201,9 +201,9 @@ int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values, ...@@ -201,9 +201,9 @@ int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values,
return 0; return 0;
} }
int32_t SSDSparseTable::shrink(const std::string& param) { return 0; } int32_t SSDSparseTable::Shrink(const std::string& param) { return 0; }
int32_t SSDSparseTable::update_table() { int32_t SSDSparseTable::UpdateTable() {
int count = 0; int count = 0;
int value_size = shard_values_[0]->value_length_; int value_size = shard_values_[0]->value_length_;
int db_size = 3 + value_size; int db_size = 3 + value_size;
...@@ -299,7 +299,7 @@ int64_t SSDSparseTable::SaveValueToText(std::ostream* os, ...@@ -299,7 +299,7 @@ int64_t SSDSparseTable::SaveValueToText(std::ostream* os,
return save_num; return save_num;
} }
int32_t SSDSparseTable::load(const std::string& path, int32_t SSDSparseTable::Load(const std::string& path,
const std::string& param) { const std::string& param) {
rwlock_->WRLock(); rwlock_->WRLock();
VLOG(3) << "ssd sparse table load with " << path << " with meta " << param; VLOG(3) << "ssd sparse table load with " << path << " with meta " << param;
......
...@@ -23,7 +23,7 @@ class SSDSparseTable : public CommonSparseTable { ...@@ -23,7 +23,7 @@ class SSDSparseTable : public CommonSparseTable {
SSDSparseTable() {} SSDSparseTable() {}
virtual ~SSDSparseTable() {} virtual ~SSDSparseTable() {}
virtual int32_t initialize() override; virtual int32_t Initialize() override;
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total); const size_t shard_idx, const int64_t total);
...@@ -37,22 +37,22 @@ class SSDSparseTable : public CommonSparseTable { ...@@ -37,22 +37,22 @@ class SSDSparseTable : public CommonSparseTable {
const int pserver_id, const int pserver_num, const int local_shard_num, const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks); std::vector<std::shared_ptr<ValueBlock>>* blocks);
virtual int32_t load(const std::string& path, const std::string& param); virtual int32_t Load(const std::string& path, const std::string& param);
// exchange data // exchange data
virtual int32_t update_table(); virtual int32_t UpdateTable();
virtual int32_t Pull(TableContext& context); virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context); virtual int32_t Push(TableContext& context);
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num); size_t num);
virtual int32_t flush() override { return 0; } virtual int32_t Flush() override { return 0; }
virtual int32_t shrink(const std::string& param) override; virtual int32_t Shrink(const std::string& param) override;
virtual void clear() override {} virtual void Clear() override {}
private: private:
RocksDBHandler* _db; RocksDBHandler* _db;
......
...@@ -56,7 +56,7 @@ REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule); ...@@ -56,7 +56,7 @@ REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule);
int32_t TableManager::initialize() { int32_t TableManager::Initialize() {
static bool initialized = false; static bool initialized = false;
if (initialized) { if (initialized) {
return 0; return 0;
...@@ -65,10 +65,10 @@ int32_t TableManager::initialize() { ...@@ -65,10 +65,10 @@ int32_t TableManager::initialize() {
return 0; return 0;
} }
int32_t Table::initialize(const TableParameter &config, int32_t Table::Initialize(const TableParameter &config,
const FsClientParameter &fs_config) { const FsClientParameter &fs_config) {
_config = config; _config = config;
if (initialize_accessor() != 0) { if (InitializeAccessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed"; LOG(WARNING) << "Table accessor initialize failed";
return -1; return -1;
} }
...@@ -77,10 +77,10 @@ int32_t Table::initialize(const TableParameter &config, ...@@ -77,10 +77,10 @@ int32_t Table::initialize(const TableParameter &config,
LOG(WARNING) << "Table fs_client initialize failed"; LOG(WARNING) << "Table fs_client initialize failed";
// return -1; // return -1;
} }
return initialize(); return Initialize();
} }
int32_t Table::initialize_accessor() { int32_t Table::InitializeAccessor() {
if (!_config.has_accessor() || !_config.accessor().has_accessor_class()) { if (!_config.has_accessor() || !_config.accessor().has_accessor_class()) {
LOG(ERROR) << "missing accessor config in table, table_id:" LOG(ERROR) << "missing accessor config in table, table_id:"
<< _config.table_id(); << _config.table_id();
......
...@@ -60,101 +60,99 @@ class Table { ...@@ -60,101 +60,99 @@ class Table {
public: public:
Table() {} Table() {}
virtual ~Table() {} virtual ~Table() {}
virtual int32_t initialize(const TableParameter &config, virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config); const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0; virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0; virtual int32_t Push(TableContext &context) = 0;
virtual int32_t pull_dense(float *values, size_t num) = 0; virtual int32_t PullDense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0; virtual int32_t PushDense(const float *values, size_t num) = 0;
// for push global_step // for push global_step
virtual int32_t push_dense(const int64_t *values, const int32_t trainer_id) { virtual int32_t PushDense(const int64_t *values, const int32_t trainer_id) {
return 0;
}
virtual int32_t push_dense_param(const float *values, size_t num) {
return 0; return 0;
} }
virtual int32_t PushDenseParam(const float *values, size_t num) { return 0; }
virtual int32_t pull_sparse_ptr(char **pull_values, const uint64_t *keys, virtual int32_t PullSparsePtr(char **pull_values, const uint64_t *keys,
size_t num) { size_t num) {
VLOG(0) << "NOT IMPLEMENT"; VLOG(0) << "NOT IMPLEMENT";
return 0; return 0;
} }
virtual int32_t pull_sparse(float *values, virtual int32_t PullSparse(float *values,
const PullSparseValue &pull_value) = 0; const PullSparseValue &pull_value) = 0;
virtual int32_t push_sparse(const uint64_t *keys, const float *values, virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) = 0; size_t num) = 0;
virtual int32_t push_sparse(const uint64_t *keys, const float **values, virtual int32_t PushSparse(const uint64_t *keys, const float **values,
size_t num) { size_t num) {
return 0; return 0;
} }
virtual int32_t push_sparse_param(const uint64_t *keys, const float *values, virtual int32_t PushSparseParam(const uint64_t *keys, const float *values,
size_t num) { size_t num) {
return 0; return 0;
} }
// only for sparse geo table // only for sparse geo table
virtual int32_t pull_geo_param(const uint32_t trainer_id, virtual int32_t PullGeoParam(const uint32_t trainer_id,
std::vector<float> *values, std::vector<float> *values,
std::vector<uint64_t> *keys) { std::vector<uint64_t> *keys) {
return 0; return 0;
} }
// only for barrier // only for barrier
virtual int32_t barrier(const uint32_t trainer_id, virtual int32_t Barrier(const uint32_t trainer_id,
const std::string barrier_type) { const std::string barrier_type) {
return 0; return 0;
} }
// only for barrier table // only for barrier table
virtual int32_t set_table_map( virtual int32_t SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) { std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) {
return 0; return 0;
} }
// only for tensor table // only for tensor table
virtual int32_t set_program_env( virtual int32_t SetProgramEnv(
framework::Scope *scope, platform::Place place, framework::Scope *scope, platform::Place place,
const std::vector<framework::ProgramDesc> *sub_program) { const std::vector<framework::ProgramDesc> *sub_program) {
return 0; return 0;
} }
virtual int32_t set_global_lr(float *lr) { virtual int32_t SetGlobalLR(float *lr) {
_global_lr = lr; _global_lr = lr;
return 0; return 0;
} }
virtual int32_t pour() { return 0; } virtual int32_t Pour() { return 0; }
virtual void clear() = 0; virtual void Clear() = 0;
virtual int32_t flush() = 0; virtual int32_t Flush() = 0;
virtual int32_t shrink(const std::string &param) = 0; virtual int32_t Shrink(const std::string &param) = 0;
// 指定加载路径 // 指定加载路径
virtual int32_t load(const std::string &path, virtual int32_t Load(const std::string &path,
const std::string &converter) = 0; const std::string &converter) = 0;
// 指定保存路径 // 指定保存路径
virtual int32_t save(const std::string &path, virtual int32_t Save(const std::string &path,
const std::string &converter) = 0; const std::string &converter) = 0;
virtual int32_t set_shard(size_t shard_idx, size_t shard_num) { virtual int32_t SetShard(size_t shard_idx, size_t shard_num) {
_shard_idx = shard_idx; _shard_idx = shard_idx;
_shard_num = shard_num; _shard_num = shard_num;
return initialize_shard(); return InitializeShard();
} }
inline std::shared_ptr<ValueAccessor> value_accesor() { inline std::shared_ptr<ValueAccessor> ValueAccesor() {
return _value_accesor; return _value_accesor;
} }
virtual void *get_shard(size_t shard_idx) = 0; virtual void *GetShard(size_t shard_idx) = 0;
virtual std::pair<int64_t, int64_t> print_table_stat() { return {0, 0}; } virtual std::pair<int64_t, int64_t> PrintTableStat() { return {0, 0}; }
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t Initialize() = 0;
virtual int32_t initialize_accessor(); virtual int32_t InitializeAccessor();
virtual int32_t initialize_shard() = 0; virtual int32_t InitializeShard() = 0;
virtual std::string table_dir(const std::string &model_dir) { virtual std::string TableDir(const std::string &model_dir) {
return paddle::string::format_string("%s/%03d/", model_dir.c_str(), return paddle::string::format_string("%s/%03d/", model_dir.c_str(),
_config.table_id()); _config.table_id());
} }
...@@ -171,11 +169,11 @@ REGISTER_PSCORE_REGISTERER(Table); ...@@ -171,11 +169,11 @@ REGISTER_PSCORE_REGISTERER(Table);
class TableManager { class TableManager {
public: public:
static TableManager &instance() { static TableManager &Instance() {
static TableManager manager; static TableManager manager;
return manager; return manager;
} }
int32_t initialize(); int32_t Initialize();
private: private:
TableManager() {} TableManager() {}
......
...@@ -52,42 +52,42 @@ class TensorTable : public Table { ...@@ -52,42 +52,42 @@ class TensorTable : public Table {
virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values, int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override { const PullSparseValue &pull_value) override {
return 0; return 0;
} }
int32_t push_sparse(const uint64_t *keys, const float *values, int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override { size_t num) override {
return 0; return 0;
} }
int32_t shrink(const std::string &param) override { return 0; } int32_t Shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t flush() { return 0; } virtual int32_t Flush() { return 0; }
virtual int32_t load(const std::string &path, const std::string &param) { virtual int32_t Load(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t save(const std::string &path, const std::string &param) { virtual int32_t Save(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual void clear() {} virtual void Clear() {}
int32_t initialize() override { return 0; } int32_t Initialize() override { return 0; }
int32_t push_dense(const int64_t *values, const int32_t trainer_id) override { int32_t PushDense(const int64_t *values, const int32_t trainer_id) override {
return 0; return 0;
} }
int32_t set_program_env( int32_t SetProgramEnv(
framework::Scope *scope, platform::Place place, framework::Scope *scope, platform::Place place,
const std::vector<framework::ProgramDesc> *sub_program) override { const std::vector<framework::ProgramDesc> *sub_program) override {
scope_ = scope; scope_ = scope;
...@@ -111,48 +111,48 @@ class DenseTensorTable : public TensorTable { ...@@ -111,48 +111,48 @@ class DenseTensorTable : public TensorTable {
DenseTensorTable() {} DenseTensorTable() {}
virtual ~DenseTensorTable() {} virtual ~DenseTensorTable() {}
int32_t pull_sparse(float *values, int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override { const PullSparseValue &pull_value) override {
return 0; return 0;
} }
int32_t push_sparse(const uint64_t *keys, const float *values, int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override { size_t num) override {
return 0; return 0;
} }
int32_t shrink(const std::string &param) override { return 0; } int32_t Shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t flush() { return 0; } virtual int32_t Flush() { return 0; }
virtual void clear() {} virtual void Clear() {}
// Todo: Support program Load & Save // Todo: Support program Load & Save
virtual int32_t load(const std::string &path, const std::string &param) { virtual int32_t Load(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t save(const std::string &path, const std::string &param) { virtual int32_t Save(const std::string &path, const std::string &param) {
return 0; return 0;
} }
// Todo: Support pull dense // Todo: Support pull dense
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t PullDense(float *values, size_t num) override { return 0; }
/*----------------------------------------------------------------------*/ /*----------------------------------------------------------------------*/
int32_t initialize() override { return 0; } int32_t Initialize() override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t push_dense(const int64_t *values, const int32_t trainer_id) { int32_t PushDense(const int64_t *values, const int32_t trainer_id) {
return 0; return 0;
} }
protected: protected:
virtual int32_t _run_program(const float *values, size_t num, virtual int32_t _RunProgram(const float *values, size_t num,
const uint32_t trainer_id) { const uint32_t trainer_id) {
return 0; return 0;
} }
...@@ -167,36 +167,36 @@ class GlobalStepTable : public DenseTensorTable { ...@@ -167,36 +167,36 @@ class GlobalStepTable : public DenseTensorTable {
GlobalStepTable() {} GlobalStepTable() {}
virtual ~GlobalStepTable() {} virtual ~GlobalStepTable() {}
int32_t pull_sparse(float *values, int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override { const PullSparseValue &pull_value) override {
return 0; return 0;
} }
int32_t push_sparse(const uint64_t *keys, const float *values, int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override { size_t num) override {
return 0; return 0;
} }
int32_t shrink(const std::string &param) override { return 0; } int32_t Shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t initialize_shard() { return 0; } virtual int32_t InitializeShard() { return 0; }
virtual int32_t flush() { return 0; } virtual int32_t Flush() { return 0; }
virtual void clear() {} virtual void Clear() {}
virtual int32_t load(const std::string &path, const std::string &param) { virtual int32_t Load(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t save(const std::string &path, const std::string &param) { virtual int32_t Save(const std::string &path, const std::string &param) {
return 0; return 0;
} }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t PullDense(float *values, size_t num) override { return 0; }
/*----------------------------------------------------------------------*/ /*----------------------------------------------------------------------*/
int32_t initialize() override { int32_t Initialize() override {
auto _program_config = _config.tensor(); auto _program_config = _config.tensor();
auto trainers_ = _config.common().trainer_num(); auto trainers_ = _config.common().trainer_num();
FLAGS_eager_delete_tensor_gb = -1; FLAGS_eager_delete_tensor_gb = -1;
...@@ -237,14 +237,14 @@ class GlobalStepTable : public DenseTensorTable { ...@@ -237,14 +237,14 @@ class GlobalStepTable : public DenseTensorTable {
} }
} }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t push_dense(const int64_t *values, const int32_t trainer_id) { int32_t PushDense(const int64_t *values, const int32_t trainer_id) {
return _run_program(values, trainer_id); return _RunProgram(values, trainer_id);
} }
int32_t set_table_map(std::unordered_map<uint32_t, std::shared_ptr<Table>> int32_t SetTableMap(std::unordered_map<uint32_t, std::shared_ptr<Table>>
*table_map) override { *table_map) override {
auto *lr_var = scope_->FindVar(fetch_var_name_); auto *lr_var = scope_->FindVar(fetch_var_name_);
auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>(); auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace()); auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
...@@ -255,14 +255,14 @@ class GlobalStepTable : public DenseTensorTable { ...@@ -255,14 +255,14 @@ class GlobalStepTable : public DenseTensorTable {
if (table_id == _config.table_id()) { if (table_id == _config.table_id()) {
continue; continue;
} }
iter->second->set_global_lr(lr_value); iter->second->SetGlobalLR(lr_value);
} }
return 0; return 0;
} }
private: private:
virtual int32_t _run_program(const int64_t *values, virtual int32_t _RunProgram(const int64_t *values,
const uint32_t trainer_id) { const uint32_t trainer_id) {
FLAGS_eager_delete_tensor_gb = -1; FLAGS_eager_delete_tensor_gb = -1;
auto counter = decay_counters_.at(trainer_id); auto counter = decay_counters_.at(trainer_id);
counter += int(values[0]); counter += int(values[0]);
......
...@@ -51,32 +51,6 @@ int32_t FleetWrapper::CopyTableByFeasign( ...@@ -51,32 +51,6 @@ int32_t FleetWrapper::CopyTableByFeasign(
return 0; return 0;
} }
void FleetWrapper::Stop() { StopServer(); }
void FleetWrapper::Load(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id >= 0 && context.meta != "") {
LoadSparseOnServer(context.path, context.meta, context.table_id);
return;
}
if (table_id < 0) { // laod all
LoadModel(context.path, context.mode);
} else { // load one table
LoadModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::Save(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id < 0) {
SaveModel(context.path, context.mode);
} else {
SaveModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms, int connect_timeout_ms,
int max_retry) { int max_retry) {
...@@ -90,7 +64,7 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path, ...@@ -90,7 +64,7 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path,
uint32_t table_id) { uint32_t table_id) {
VLOG(3) << "load sparse table " << table_id << " with " << path << " meta " VLOG(3) << "load sparse table " << table_id << " with " << path << " meta "
<< meta; << meta;
pserver_ptr_->_server_ptr->table(table_id)->load(path, meta); pserver_ptr_->_server_ptr->GetTable(table_id)->Load(path, meta);
} }
void FleetWrapper::InitServer( void FleetWrapper::InitServer(
...@@ -101,8 +75,8 @@ void FleetWrapper::InitServer( ...@@ -101,8 +75,8 @@ void FleetWrapper::InitServer(
VLOG(3) << "Going to init server"; VLOG(3) << "Going to init server";
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>( pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
new paddle::distributed::PSCore()); new paddle::distributed::PSCore());
pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), pserver_ptr_->InitServer(dist_desc, &host_sign_list, host_sign_list.size(),
index, trainers, server_sub_program); index, trainers, server_sub_program);
is_initialized_ = true; is_initialized_ = true;
} else { } else {
VLOG(3) << "Server can be initialized only once"; VLOG(3) << "Server can be initialized only once";
...@@ -143,10 +117,10 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, ...@@ -143,10 +117,10 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param);
InitGFlag(ps_param.init_gflags()); InitGFlag(ps_param.init_gflags());
int servers = host_sign_list.size(); int servers = host_sign_list.size();
ps_env_.set_ps_servers(&host_sign_list, servers); ps_env_.SetPsServers(&host_sign_list, servers);
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>( worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(ps_param)); paddle::distributed::PSClientFactory::Create(ps_param));
worker_ptr_->configure(ps_param, dense_pull_regions, ps_env_, index); worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index);
} }
} else { } else {
VLOG(3) << "Client can be initialized only once"; VLOG(3) << "Client can be initialized only once";
...@@ -155,13 +129,13 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, ...@@ -155,13 +129,13 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
void FleetWrapper::StopServer() { void FleetWrapper::StopServer() {
VLOG(3) << "Going to stop server"; VLOG(3) << "Going to stop server";
auto status = worker_ptr_->stop_server(); auto status = worker_ptr_->StopServer();
status.wait(); status.wait();
} }
void FleetWrapper::FinalizeWorker() { void FleetWrapper::FinalizeWorker() {
VLOG(3) << "Going to finalize worker"; VLOG(3) << "Going to finalize worker";
worker_ptr_->finalize_worker(); worker_ptr_->FinalizeWorker();
} }
void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { void FleetWrapper::BarrierWithTable(uint32_t barrier_type) {
...@@ -172,13 +146,13 @@ void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { ...@@ -172,13 +146,13 @@ void FleetWrapper::BarrierWithTable(uint32_t barrier_type) {
uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
VLOG(3) << "Going to run server with ip " << ip << " port " << port; VLOG(3) << "Going to run server with ip " << ip << " port " << port;
auto ret = pserver_ptr_->run_server(ip, port); auto ret = pserver_ptr_->RunServer(ip, port);
return ret; return ret;
} }
std::vector<uint64_t> FleetWrapper::GetClientsInfo() { std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
VLOG(3) << "Going to get client info"; VLOG(3) << "Going to get client info";
std::vector<uint64_t> res = ps_env_.get_client_info(); std::vector<uint64_t> res = ps_env_.GetClientInfo();
for (auto rr : res) { for (auto rr : res) {
VLOG(2) << "FleetWrapper::GetClientInfo " << rr; VLOG(2) << "FleetWrapper::GetClientInfo " << rr;
} }
...@@ -187,14 +161,14 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() { ...@@ -187,14 +161,14 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
int FleetWrapper::SetClients(std::vector<uint64_t>& host_sign_list) { int FleetWrapper::SetClients(std::vector<uint64_t>& host_sign_list) {
int node = host_sign_list.size(); int node = host_sign_list.size();
return ps_env_.set_ps_clients(host_sign_list.data(), node); return ps_env_.SetPsClients(host_sign_list.data(), node);
} }
void FleetWrapper::CreateClient2ClientConnection() { void FleetWrapper::CreateClient2ClientConnection() {
VLOG(1) << "Going to create client2client connection"; VLOG(1) << "Going to create client2client connection";
worker_ptr_->create_client2client_connection( worker_ptr_->CreateClient2ClientConnection(client2client_request_timeout_ms_,
client2client_request_timeout_ms_, client2client_connect_timeout_ms_, client2client_connect_timeout_ms_,
client2client_max_retry_); client2client_max_retry_);
} }
std::future<int32_t> FleetWrapper::PullSparseVarsAsync( std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
...@@ -230,9 +204,9 @@ std::future<int32_t> FleetWrapper::PullSparseVarsAsync( ...@@ -230,9 +204,9 @@ std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
} }
bool training = true; bool training = true;
return pserver_ptr_->_worker_ptr->pull_sparse(pull_result_ptr.data(), return pserver_ptr_->_worker_ptr->PullSparse(pull_result_ptr.data(), table_id,
table_id, fea_keys->data(), fea_keys->data(),
fea_keys->size(), training); fea_keys->size(), training);
} }
void FleetWrapper::PullSparseVarsSync( void FleetWrapper::PullSparseVarsSync(
...@@ -279,7 +253,7 @@ void FleetWrapper::PullSparseVarsSync( ...@@ -279,7 +253,7 @@ void FleetWrapper::PullSparseVarsSync(
pull_result_ptr.push_back(t.data()); pull_result_ptr.push_back(t.data());
} }
bool training = true; bool training = true;
auto status = pserver_ptr_->_worker_ptr->pull_sparse( auto status = pserver_ptr_->_worker_ptr->PullSparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size(), pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size(),
training); training);
pull_sparse_status.push_back(std::move(status)); pull_sparse_status.push_back(std::move(status));
...@@ -337,21 +311,10 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, ...@@ -337,21 +311,10 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
pull_result_ptr.push_back(output_data + output_len); pull_result_ptr.push_back(output_data + output_len);
} }
} }
// ps client pull sparse
// construct client request context auto status =
RequestContext req_context; worker_ptr_->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(),
req_context.value_type = Sparse; fea_keys.size(), is_training);
req_context.training_mode = Async;
req_context.table = table_id;
req_context.sparse_values = pull_result_ptr.data();
req_context.keys = fea_keys.data();
req_context.num = fea_keys.size();
req_context.is_training = is_training;
auto status = worker_ptr_->Pull(req_context);
// auto status =
// worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id,
// fea_keys.data(), fea_keys.size(),
// is_training);
status.wait(); status.wait();
auto ret = status.get(); auto ret = status.get();
if (ret != 0) { if (ret != 0) {
...@@ -364,7 +327,7 @@ void FleetWrapper::PullDenseVarsAsync( ...@@ -364,7 +327,7 @@ void FleetWrapper::PullDenseVarsAsync(
const Scope& scope, const uint64_t tid, const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* pull_dense_status, bool in_cpu) { std::vector<std::future<int32_t>>* pull_dense_status, bool in_cpu) {
auto& regions = _regions[tid]; auto& regions = regions_[tid];
regions.clear(); regions.clear();
regions.resize(var_names.size()); regions.resize(var_names.size());
for (auto i = 0u; i < var_names.size(); ++i) { for (auto i = 0u; i < var_names.size(); ++i) {
...@@ -378,21 +341,15 @@ void FleetWrapper::PullDenseVarsAsync( ...@@ -378,21 +341,15 @@ void FleetWrapper::PullDenseVarsAsync(
paddle::distributed::Region reg(w, tensor->numel()); paddle::distributed::Region reg(w, tensor->numel());
regions[i] = std::move(reg); regions[i] = std::move(reg);
} }
RequestContext req_context;
req_context.value_type = Dense; auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid);
req_context.training_mode = Async;
req_context.table = tid;
req_context.dense_values = regions.data();
req_context.num = regions.size();
auto status = worker_ptr_->Pull(req_context);
// auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status->push_back(std::move(status)); pull_dense_status->push_back(std::move(status));
} }
void FleetWrapper::PullDenseVarsSync( void FleetWrapper::PullDenseVarsSync(
const Scope& scope, const uint64_t tid, const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names) { const std::vector<std::string>& var_names) {
auto& regions = _regions[tid]; auto& regions = regions_[tid];
regions.clear(); regions.clear();
regions.reserve(var_names.size()); regions.reserve(var_names.size());
for (auto& t : var_names) { for (auto& t : var_names) {
...@@ -404,7 +361,7 @@ void FleetWrapper::PullDenseVarsSync( ...@@ -404,7 +361,7 @@ void FleetWrapper::PullDenseVarsSync(
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} }
} }
auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid); auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid);
status.wait(); status.wait();
} }
...@@ -424,7 +381,7 @@ void FleetWrapper::PushDenseParamSync( ...@@ -424,7 +381,7 @@ void FleetWrapper::PushDenseParamSync(
} }
} }
auto push_status = auto push_status =
worker_ptr_->push_dense_param(regions.data(), regions.size(), table_id); worker_ptr_->PushDenseParam(regions.data(), regions.size(), table_id);
push_status.wait(); push_status.wait();
auto status = push_status.get(); auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status[" << status << "]"; CHECK(status == 0) << "push dense param failed, status[" << status << "]";
...@@ -470,15 +427,8 @@ void FleetWrapper::PushDenseVarsAsync( ...@@ -470,15 +427,8 @@ void FleetWrapper::PushDenseVarsAsync(
<< g[tensor->numel() - 1]; << g[tensor->numel() - 1];
} }
RequestContext req_context; auto push_status =
req_context.value_type = Dense; worker_ptr_->PushDense(regions.data(), regions.size(), table_id);
req_context.training_mode = Async;
req_context.table = table_id;
req_context.push_context.push_dense_values = regions.data();
req_context.num = regions.size();
// auto push_status =
// worker_ptr_->push_dense(regions.data(), regions.size(), table_id);
auto push_status = worker_ptr_->Push(req_context);
} }
void FleetWrapper::PushSparseVarsAsync( void FleetWrapper::PushSparseVarsAsync(
...@@ -650,23 +600,13 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -650,23 +600,13 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_g_vec[i] = push_values.at(i).data(); push_g_vec[i] = push_values.at(i).data();
} }
// ps client push sparse auto status = worker_ptr_->PushSparse(table_id, push_keys.data(),
// construct request context (const float**)push_g_vec.data(),
RequestContext req_context; push_keys.size());
req_context.value_type = Sparse;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.push_context.push_values = (const float**)push_g_vec.data();
req_context.push_context.keys = push_keys.data();
req_context.num = push_keys.size();
auto status = worker_ptr_->Push(req_context);
// auto status = worker_ptr_->push_sparse(table_id, push_keys.data(),
// (const float**)push_g_vec.data(),
// push_keys.size());
} }
void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModel(const std::string& path, const int mode) {
auto ret = worker_ptr_->load(path, std::to_string(mode)); auto ret = worker_ptr_->Load(path, std::to_string(mode));
ret.wait(); ret.wait();
if (ret.get() != 0) { if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed"; LOG(ERROR) << "load model from path:" << path << " failed";
...@@ -675,7 +615,7 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { ...@@ -675,7 +615,7 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
void FleetWrapper::LoadModelOneTable(const uint64_t table_id, void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) { const std::string& path, const int mode) {
auto ret = worker_ptr_->load(table_id, path, std::to_string(mode)); auto ret = worker_ptr_->Load(table_id, path, std::to_string(mode));
ret.wait(); ret.wait();
if (ret.get() != 0) { if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id LOG(ERROR) << "load model of table id: " << table_id
...@@ -684,7 +624,7 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id, ...@@ -684,7 +624,7 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
} }
void FleetWrapper::SaveModel(const std::string& path, const int mode) { void FleetWrapper::SaveModel(const std::string& path, const int mode) {
auto ret = worker_ptr_->save(path, std::to_string(mode)); auto ret = worker_ptr_->Save(path, std::to_string(mode));
ret.wait(); ret.wait();
int32_t feasign_cnt = ret.get(); int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { if (feasign_cnt == -1) {
...@@ -694,7 +634,7 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { ...@@ -694,7 +634,7 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
void FleetWrapper::SaveModelOneTable(const uint64_t table_id, void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) { const std::string& path, const int mode) {
auto ret = worker_ptr_->save(table_id, path, std::to_string(mode)); auto ret = worker_ptr_->Save(table_id, path, std::to_string(mode));
ret.wait(); ret.wait();
if (ret.get() != 0) { if (ret.get() != 0) {
LOG(ERROR) << "save model of table id: " << table_id LOG(ERROR) << "save model of table id: " << table_id
...@@ -704,7 +644,7 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id, ...@@ -704,7 +644,7 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
void FleetWrapper::RecvAndSaveTable(const uint64_t table_id, void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
const std::string& path) { const std::string& path) {
auto ret = worker_ptr_->recv_and_save_table(table_id, path); auto ret = worker_ptr_->RecvAndSaveTable(table_id, path);
if (ret != 0) { if (ret != 0) {
LOG(ERROR) << "save model of table id: " << table_id LOG(ERROR) << "save model of table id: " << table_id
<< ", to path: " << path << " failed"; << ", to path: " << path << " failed";
...@@ -712,7 +652,7 @@ void FleetWrapper::RecvAndSaveTable(const uint64_t table_id, ...@@ -712,7 +652,7 @@ void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
} }
void FleetWrapper::PrintTableStat(const uint64_t table_id) { void FleetWrapper::PrintTableStat(const uint64_t table_id) {
auto ret = worker_ptr_->print_table_stat(table_id); auto ret = worker_ptr_->PrintTableStat(table_id);
ret.wait(); ret.wait();
int32_t err_code = ret.get(); int32_t err_code = ret.get();
if (err_code == -1) { if (err_code == -1) {
...@@ -721,7 +661,7 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) { ...@@ -721,7 +661,7 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) {
} }
void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) { void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) {
auto ret = worker_ptr_->shrink(table_id, std::to_string(threshold)); auto ret = worker_ptr_->Shrink(table_id, std::to_string(threshold));
ret.wait(); ret.wait();
int32_t err_code = ret.get(); int32_t err_code = ret.get();
if (err_code == -1) { if (err_code == -1) {
...@@ -730,12 +670,12 @@ void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) { ...@@ -730,12 +670,12 @@ void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) {
} }
void FleetWrapper::ClearModel() { void FleetWrapper::ClearModel() {
auto ret = pserver_ptr_->_worker_ptr->clear(); auto ret = pserver_ptr_->_worker_ptr->Clear();
ret.wait(); ret.wait();
} }
void FleetWrapper::ClearOneTable(const uint64_t table_id) { void FleetWrapper::ClearOneTable(const uint64_t table_id) {
auto ret = pserver_ptr_->_worker_ptr->clear(table_id); auto ret = pserver_ptr_->_worker_ptr->Clear(table_id);
ret.wait(); ret.wait();
} }
...@@ -774,7 +714,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, ...@@ -774,7 +714,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} }
} }
auto push_status = pserver_ptr_->_worker_ptr->push_dense_param( auto push_status = pserver_ptr_->_worker_ptr->PushDenseParam(
regions.data(), regions.size(), table_id); regions.data(), regions.size(), table_id);
push_status.wait(); push_status.wait();
auto status = push_status.get(); auto status = push_status.get();
...@@ -791,7 +731,7 @@ void FleetWrapper::ClientFlush() { ...@@ -791,7 +731,7 @@ void FleetWrapper::ClientFlush() {
VLOG(0) << "worker_ptr null, do nothing"; VLOG(0) << "worker_ptr null, do nothing";
return; return;
} }
auto ret = worker_ptr_->flush(); auto ret = worker_ptr_->Flush();
ret.wait(); ret.wait();
int32_t err_code = ret.get(); int32_t err_code = ret.get();
if (err_code == -1) { if (err_code == -1) {
...@@ -805,13 +745,13 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, ...@@ -805,13 +745,13 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
VLOG(0) << "FleetWrapper::Client is null"; VLOG(0) << "FleetWrapper::Client is null";
return -1; return -1;
} else { } else {
return worker_ptr_->registe_client2client_msg_handler(msg_type, handler); return worker_ptr_->RegisteClient2ClientMsgHandler(msg_type, handler);
} }
} }
std::future<int32_t> FleetWrapper::SendClientToClientMsg( std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) { int msg_type, int to_client_id, const std::string& msg) {
return worker_ptr_->send_client2client_msg(msg_type, to_client_id, msg); return worker_ptr_->SendClient2ClientMsg(msg_type, to_client_id, msg);
} }
std::default_random_engine& FleetWrapper::LocalRandomEngine() { std::default_random_engine& FleetWrapper::LocalRandomEngine() {
......
...@@ -25,7 +25,6 @@ limitations under the License. */ ...@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h" #include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h" #include "paddle/fluid/framework/io/shell.h"
...@@ -55,7 +54,7 @@ using framework::Variable; ...@@ -55,7 +54,7 @@ using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>; using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper : public PSWrapper { class FleetWrapper {
public: public:
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
FleetWrapper() { FleetWrapper() {
...@@ -69,7 +68,6 @@ class FleetWrapper : public PSWrapper { ...@@ -69,7 +68,6 @@ class FleetWrapper : public PSWrapper {
// pserver request max retry // pserver request max retry
client2client_max_retry_ = 3; client2client_max_retry_ = 3;
} }
virtual int32_t Initialize(InitContext& context) { return 0; }
// TODO(zhaocaibei123: later) // TODO(zhaocaibei123: later)
int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id); int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);
...@@ -81,12 +79,6 @@ class FleetWrapper : public PSWrapper { ...@@ -81,12 +79,6 @@ class FleetWrapper : public PSWrapper {
typedef std::function<void(int, int)> HeterCallBackFunc; typedef std::function<void(int, int)> HeterCallBackFunc;
int RegisterHeterCallback(HeterCallBackFunc handler); int RegisterHeterCallback(HeterCallBackFunc handler);
virtual void Stop() override;
virtual void Load(WrapperContext& context) override;
virtual void Save(WrapperContext& context) override;
// set client to client communication config // set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry); int max_retry);
...@@ -278,7 +270,7 @@ class FleetWrapper : public PSWrapper { ...@@ -278,7 +270,7 @@ class FleetWrapper : public PSWrapper {
protected: protected:
static bool is_initialized_; static bool is_initialized_;
std::map<uint64_t, std::vector<paddle::distributed::Region>> _regions; std::map<uint64_t, std::vector<paddle::distributed::Region>> regions_;
bool scale_sparse_gradient_with_batch_size_; bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_; int32_t sleep_seconds_before_fail_exit_;
int client2client_request_timeout_ms_; int client2client_request_timeout_ms_;
......
...@@ -39,19 +39,19 @@ TEST(BarrierTable, Barrier) { ...@@ -39,19 +39,19 @@ TEST(BarrierTable, Barrier) {
common_config->set_trainer_num(trainers); common_config->set_trainer_num(trainers);
common_config->set_sync(sync); common_config->set_sync(sync);
auto ret = table->initialize(table_config, fs_config); auto ret = table->Initialize(table_config, fs_config);
std::unordered_map<uint32_t, std::shared_ptr<Table>> maps = std::unordered_map<uint32_t, std::shared_ptr<Table>> maps =
std::unordered_map<uint32_t, std::shared_ptr<Table>>(); std::unordered_map<uint32_t, std::shared_ptr<Table>>();
table->set_table_map(&maps); table->SetTableMap(&maps);
std::shared_ptr<::ThreadPool> pool_ = std::shared_ptr<::ThreadPool> pool_ =
std::make_shared<::ThreadPool>(trainers); std::make_shared<::ThreadPool>(trainers);
std::vector<std::future<void>> task_status; std::vector<std::future<void>> task_status;
for (auto x = 0; x < trainers; x++) { for (auto x = 0; x < trainers; x++) {
auto task = [table, x] { table->barrier(x, 0); }; auto task = [table, x] { table->Barrier(x, 0); };
task_status.push_back(pool_->enqueue(std::move(task))); task_status.push_back(pool_->enqueue(std::move(task)));
} }
......
...@@ -155,16 +155,16 @@ void RunServer() { ...@@ -155,16 +155,16 @@ void RunServer() {
auto _ps_env = paddle::distributed::PaddlePSEnvironment(); auto _ps_env = paddle::distributed::PaddlePSEnvironment();
LOG(INFO) << "RUN set_ps_servers"; LOG(INFO) << "RUN set_ps_servers";
_ps_env.set_ps_servers(&host_sign_list_, 1); _ps_env.SetPsServers(&host_sign_list_, 1);
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSServer>( pserver_ptr_ = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(server_proto)); paddle::distributed::PSServerFactory::Create(server_proto));
LOG(INFO) << "RUN configure"; LOG(INFO) << "RUN configure";
std::vector<framework::ProgramDesc> empty_vec; std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog; framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog); empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec);
LOG(INFO) << "RUN start"; LOG(INFO) << "RUN start";
pserver_ptr_->start(ip_, port_); pserver_ptr_->Start(ip_, port_);
LOG(INFO) << "End start"; LOG(INFO) << "End start";
} }
...@@ -175,19 +175,19 @@ void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>& ...@@ -175,19 +175,19 @@ void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
auto servers_ = host_sign_list_.size(); auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
LOG(INFO) << "Run set_ps_servers"; LOG(INFO) << "Run set_ps_servers";
_ps_env.set_ps_servers(&host_sign_list_, servers_); _ps_env.SetPsServers(&host_sign_list_, servers_);
LOG(INFO) << "Run Create PSClient"; LOG(INFO) << "Run Create PSClient";
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>( worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(worker_proto)); paddle::distributed::PSClientFactory::Create(worker_proto));
LOG(INFO) << "Run configure"; LOG(INFO) << "Run configure";
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0);
} }
void RunBrpcPushDense() { void RunBrpcPushDense() {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string()); host_sign_list_.push_back(ph_host.SerializeToString());
// Srart Server // Srart Server
std::thread server_thread(RunServer); std::thread server_thread(RunServer);
...@@ -218,7 +218,7 @@ void RunBrpcPushDense() { ...@@ -218,7 +218,7 @@ void RunBrpcPushDense() {
paddle::distributed::Region temp_reg(temp, tensor->numel()); paddle::distributed::Region temp_reg(temp, tensor->numel());
temp_region.emplace_back(std::move(temp_reg)); temp_region.emplace_back(std::move(temp_reg));
auto pull_status = auto pull_status =
worker_ptr_->pull_dense(temp_region.data(), temp_region.size(), 0); worker_ptr_->PullDense(temp_region.data(), temp_region.size(), 0);
pull_status.wait(); pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) { for (size_t idx = 0; idx < tensor->numel(); ++idx) {
...@@ -229,10 +229,10 @@ void RunBrpcPushDense() { ...@@ -229,10 +229,10 @@ void RunBrpcPushDense() {
LOG(INFO) << "Run push_dense_param"; LOG(INFO) << "Run push_dense_param";
auto push_status = auto push_status =
worker_ptr_->push_dense_param(regions.data(), regions.size(), 0); worker_ptr_->PushDenseParam(regions.data(), regions.size(), 0);
push_status.wait(); push_status.wait();
pull_status = worker_ptr_->pull_dense(regions.data(), regions.size(), 0); pull_status = worker_ptr_->PullDense(regions.data(), regions.size(), 0);
pull_status.wait(); pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) { for (size_t idx = 0; idx < tensor->numel(); ++idx) {
...@@ -257,11 +257,11 @@ void RunBrpcPushDense() { ...@@ -257,11 +257,11 @@ void RunBrpcPushDense() {
LOG(INFO) << "Run pull_dense_grad"; LOG(INFO) << "Run pull_dense_grad";
auto push_grad_status = auto push_grad_status =
worker_ptr_->push_dense_raw_gradient(0, temp, tensor->numel(), closure); worker_ptr_->PushDenseRawGradient(0, temp, tensor->numel(), closure);
push_grad_status.wait(); push_grad_status.wait();
auto pull_update_status = auto pull_update_status =
worker_ptr_->pull_dense(regions.data(), regions.size(), 0); worker_ptr_->PullDense(regions.data(), regions.size(), 0);
pull_update_status.wait(); pull_update_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) { for (size_t idx = 0; idx < tensor->numel(); ++idx) {
...@@ -269,9 +269,9 @@ void RunBrpcPushDense() { ...@@ -269,9 +269,9 @@ void RunBrpcPushDense() {
} }
LOG(INFO) << "Run stop_server"; LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server(); worker_ptr_->StopServer();
LOG(INFO) << "Run finalize_worker"; LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker(); worker_ptr_->FinalizeWorker();
server_thread.join(); server_thread.join();
} }
......
...@@ -156,14 +156,14 @@ void RunServer() { ...@@ -156,14 +156,14 @@ void RunServer() {
::paddle::distributed::PSParameter server_proto = GetServerProto(); ::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment(); auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, 1); _ps_env.SetPsServers(&host_sign_list_, 1);
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSServer>( pserver_ptr_ = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(server_proto)); paddle::distributed::PSServerFactory::Create(server_proto));
std::vector<framework::ProgramDesc> empty_vec; std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog; framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog); empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec);
pserver_ptr_->start(ip_, port_); pserver_ptr_->Start(ip_, port_);
} }
void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>& void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
...@@ -172,17 +172,17 @@ void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>& ...@@ -172,17 +172,17 @@ void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size(); auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, servers_); _ps_env.SetPsServers(&host_sign_list_, servers_);
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>( worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(worker_proto)); paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0);
} }
void RunBrpcPushSparse() { void RunBrpcPushSparse() {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string()); host_sign_list_.push_back(ph_host.SerializeToString());
// Srart Server // Srart Server
std::thread server_thread(RunServer); std::thread server_thread(RunServer);
...@@ -214,7 +214,7 @@ void RunBrpcPushSparse() { ...@@ -214,7 +214,7 @@ void RunBrpcPushSparse() {
/*-----------------------Test Server Init----------------------------------*/ /*-----------------------Test Server Init----------------------------------*/
LOG(INFO) << "Run pull_sparse_param"; LOG(INFO) << "Run pull_sparse_param";
auto pull_status = worker_ptr_->pull_sparse( auto pull_status = worker_ptr_->PullSparse(
fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_status.wait(); pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) { for (size_t idx = 0; idx < tensor->numel(); ++idx) {
...@@ -237,12 +237,12 @@ void RunBrpcPushSparse() { ...@@ -237,12 +237,12 @@ void RunBrpcPushSparse() {
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto push_status = worker_ptr_->push_sparse_param( auto push_status = worker_ptr_->PushSparseParam(
0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(), 0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(),
closure_push_param); closure_push_param);
push_status.wait(); push_status.wait();
auto pull_param_status = worker_ptr_->pull_sparse( auto pull_param_status = worker_ptr_->PullSparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_param_status.wait(); pull_param_status.wait();
...@@ -271,12 +271,12 @@ void RunBrpcPushSparse() { ...@@ -271,12 +271,12 @@ void RunBrpcPushSparse() {
for (auto i = 0; i < static_cast<int>(fea_keys.size()); ++i) { for (auto i = 0; i < static_cast<int>(fea_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * 10); push_g_vec.push_back(tensor->data<float>() + i * 10);
} }
auto push_grad_status = worker_ptr_->push_sparse_raw_gradient( auto push_grad_status = worker_ptr_->PushSparseRawGradient(
0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(),
closure_push_grad); closure_push_grad);
push_grad_status.wait(); push_grad_status.wait();
auto pull_update_status = worker_ptr_->pull_sparse( auto pull_update_status = worker_ptr_->PullSparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_update_status.wait(); pull_update_status.wait();
...@@ -285,9 +285,9 @@ void RunBrpcPushSparse() { ...@@ -285,9 +285,9 @@ void RunBrpcPushSparse() {
} }
LOG(INFO) << "Run stop_server"; LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server(); worker_ptr_->StopServer();
LOG(INFO) << "Run finalize_worker"; LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker(); worker_ptr_->FinalizeWorker();
server_thread.join(); server_thread.join();
} }
......
...@@ -63,13 +63,13 @@ TEST(CommonDenseTable, Adam) { ...@@ -63,13 +63,13 @@ TEST(CommonDenseTable, Adam) {
common_config->add_params("LearningRate"); common_config->add_params("LearningRate");
common_config->add_dims(1); common_config->add_dims(1);
common_config->add_initializers("fill_constant&5e-6"); common_config->add_initializers("fill_constant&5e-6");
auto ret = table->initialize(table_config, fs_config); auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
// pull parameters for create and check // pull parameters for create and check
std::vector<float> init_values; std::vector<float> init_values;
init_values.resize(fea_dim); init_values.resize(fea_dim);
table->pull_dense(init_values.data(), fea_dim); table->PullDense(init_values.data(), fea_dim);
// push gradient // push gradient
std::vector<std::vector<float>> trainer_gradient_values; std::vector<std::vector<float>> trainer_gradient_values;
...@@ -85,12 +85,12 @@ TEST(CommonDenseTable, Adam) { ...@@ -85,12 +85,12 @@ TEST(CommonDenseTable, Adam) {
// for adam // for adam
for (int i = 0; i < trainers; i++) { for (int i = 0; i < trainers; i++) {
auto &push_values = trainer_gradient_values[i]; auto &push_values = trainer_gradient_values[i];
table->push_dense(push_values.data(), push_values.size()); table->PushDense(push_values.data(), push_values.size());
} }
std::vector<float> pull_values; std::vector<float> pull_values;
pull_values.resize(fea_dim); pull_values.resize(fea_dim);
table->pull_dense(pull_values.data(), fea_dim); table->PullDense(pull_values.data(), fea_dim);
float mom_rate = 0.99; float mom_rate = 0.99;
float decay_rate = 0.9999; float decay_rate = 0.9999;
...@@ -118,6 +118,7 @@ TEST(CommonDenseTable, Adam) { ...@@ -118,6 +118,7 @@ TEST(CommonDenseTable, Adam) {
} }
} }
for (int j = 0; j < fea_dim; j++) { for (int j = 0; j < fea_dim; j++) {
VLOG(0) << param[j] << " " << pull_values[j];
ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5); ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5);
} }
} }
...@@ -143,13 +144,13 @@ TEST(CommonDenseTable, SGD) { ...@@ -143,13 +144,13 @@ TEST(CommonDenseTable, SGD) {
common_config->add_params("LearningRate"); common_config->add_params("LearningRate");
common_config->add_dims(1); common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0"); common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config); auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
// pull parameters for create and check // pull parameters for create and check
std::vector<float> init_values; std::vector<float> init_values;
init_values.resize(fea_dim); init_values.resize(fea_dim);
table->pull_dense(init_values.data(), fea_dim); table->PullDense(init_values.data(), fea_dim);
std::vector<float> total_gradients; std::vector<float> total_gradients;
total_gradients.resize(fea_dim); total_gradients.resize(fea_dim);
...@@ -172,7 +173,7 @@ TEST(CommonDenseTable, SGD) { ...@@ -172,7 +173,7 @@ TEST(CommonDenseTable, SGD) {
for (int i = 0; i < trainers; i++) { for (int i = 0; i < trainers; i++) {
auto &push_values = trainer_gradient_values[i]; auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_values] { auto task = [table, &push_values] {
table->push_dense(push_values.data(), push_values.size()); table->PushDense(push_values.data(), push_values.size());
}; };
task_status.push_back(pool_->enqueue(std::move(task))); task_status.push_back(pool_->enqueue(std::move(task)));
} }
...@@ -182,7 +183,7 @@ TEST(CommonDenseTable, SGD) { ...@@ -182,7 +183,7 @@ TEST(CommonDenseTable, SGD) {
std::vector<float> pull_values; std::vector<float> pull_values;
pull_values.resize(fea_dim); pull_values.resize(fea_dim);
table->pull_dense(pull_values.data(), fea_dim); table->PullDense(pull_values.data(), fea_dim);
for (int j = 0; j < fea_dim; j++) { for (int j = 0; j < fea_dim; j++) {
auto update_val = init_values[j] - 1.0 * total_gradients[j]; auto update_val = init_values[j] - 1.0 * total_gradients[j];
ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5); ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5);
......
...@@ -166,16 +166,16 @@ void RunServer() { ...@@ -166,16 +166,16 @@ void RunServer() {
::paddle::distributed::PSParameter server_proto = GetServerProto(); ::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment(); auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, 2); // test _ps_env.SetPsServers(&host_sign_list_, 2); // test
pserver_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcServer>( pserver_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*) (paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto)); paddle::distributed::PSServerFactory::Create(server_proto));
std::vector<framework::ProgramDesc> empty_vec; std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog; framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog); empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec);
LOG(INFO) << "first server, run start(ip,port)"; LOG(INFO) << "first server, run start(ip,port)";
pserver_ptr_->start(ip_, port_); pserver_ptr_->Start(ip_, port_);
pserver_ptr_->build_peer2peer_connection(0); pserver_ptr_->build_peer2peer_connection(0);
LOG(INFO) << "init first server Done"; LOG(INFO) << "init first server Done";
} }
...@@ -185,15 +185,15 @@ void RunServer2() { ...@@ -185,15 +185,15 @@ void RunServer2() {
::paddle::distributed::PSParameter server_proto2 = GetServerProto(); ::paddle::distributed::PSParameter server_proto2 = GetServerProto();
auto _ps_env2 = paddle::distributed::PaddlePSEnvironment(); auto _ps_env2 = paddle::distributed::PaddlePSEnvironment();
_ps_env2.set_ps_servers(&host_sign_list_, 2); // test _ps_env2.SetPsServers(&host_sign_list_, 2); // test
pserver_ptr2 = std::shared_ptr<paddle::distributed::GraphBrpcServer>( pserver_ptr2 = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*) (paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto2)); paddle::distributed::PSServerFactory::Create(server_proto2));
std::vector<framework::ProgramDesc> empty_vec2; std::vector<framework::ProgramDesc> empty_vec2;
framework::ProgramDesc empty_prog2; framework::ProgramDesc empty_prog2;
empty_vec2.push_back(empty_prog2); empty_vec2.push_back(empty_prog2);
pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); pserver_ptr2->Configure(server_proto2, _ps_env2, 1, empty_vec2);
pserver_ptr2->start(ip2, port2); pserver_ptr2->Start(ip2, port2);
pserver_ptr2->build_peer2peer_connection(1); pserver_ptr2->build_peer2peer_connection(1);
} }
...@@ -204,11 +204,11 @@ void RunClient( ...@@ -204,11 +204,11 @@ void RunClient(
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size(); auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, servers_); _ps_env.SetPsServers(&host_sign_list_, servers_);
worker_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcClient>( worker_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*) (paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto)); paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0);
worker_ptr_->set_shard_num(127); worker_ptr_->set_shard_num(127);
worker_ptr_->set_local_channel(index); worker_ptr_->set_local_channel(index);
worker_ptr_->set_local_graph_service( worker_ptr_->set_local_graph_service(
...@@ -222,11 +222,11 @@ void RunGraphSplit() { ...@@ -222,11 +222,11 @@ void RunGraphSplit() {
prepare_file(node_file_name, nodes); prepare_file(node_file_name, nodes);
prepare_file(graph_split_file_name, graph_split); prepare_file(graph_split_file_name, graph_split);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string()); host_sign_list_.push_back(ph_host.SerializeToString());
// test-start // test-start
auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1); auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1);
host_sign_list_.push_back(ph_host2.serialize_to_string()); host_sign_list_.push_back(ph_host2.SerializeToString());
// test-end // test-end
// Srart Server // Srart Server
std::thread* server_thread = new std::thread(RunServer); std::thread* server_thread = new std::thread(RunServer);
...@@ -247,7 +247,7 @@ void RunGraphSplit() { ...@@ -247,7 +247,7 @@ void RunGraphSplit() {
0, std::string(graph_split_file_name)); 0, std::string(graph_split_file_name));
pull_status.wait(); pull_status.wait();
pull_status = pull_status =
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->Load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<int64_t>> _vs; std::vector<std::vector<int64_t>> _vs;
...@@ -266,9 +266,9 @@ void RunGraphSplit() { ...@@ -266,9 +266,9 @@ void RunGraphSplit() {
std::remove(node_file_name); std::remove(node_file_name);
std::remove(graph_split_file_name); std::remove(graph_split_file_name);
LOG(INFO) << "Run stop_server"; LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server(); worker_ptr_->StopServer();
LOG(INFO) << "Run finalize_worker"; LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker(); worker_ptr_->FinalizeWorker();
} }
TEST(RunGraphSplit, Run) { RunGraphSplit(); } TEST(RunGraphSplit, Run) { RunGraphSplit(); }
...@@ -348,16 +348,16 @@ void RunServer() { ...@@ -348,16 +348,16 @@ void RunServer() {
::paddle::distributed::PSParameter server_proto = GetServerProto(); ::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment(); auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, 2); // test _ps_env.SetPsServers(&host_sign_list_, 2); // test
pserver_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcServer>( pserver_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*) (paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto)); paddle::distributed::PSServerFactory::Create(server_proto));
std::vector<framework::ProgramDesc> empty_vec; std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog; framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog); empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec);
LOG(INFO) << "first server, run start(ip,port)"; LOG(INFO) << "first server, run start(ip,port)";
pserver_ptr_->start(ip_, port_); pserver_ptr_->Start(ip_, port_);
pserver_ptr_->build_peer2peer_connection(0); pserver_ptr_->build_peer2peer_connection(0);
LOG(INFO) << "init first server Done"; LOG(INFO) << "init first server Done";
} }
...@@ -367,15 +367,15 @@ void RunServer2() { ...@@ -367,15 +367,15 @@ void RunServer2() {
::paddle::distributed::PSParameter server_proto2 = GetServerProto(); ::paddle::distributed::PSParameter server_proto2 = GetServerProto();
auto _ps_env2 = paddle::distributed::PaddlePSEnvironment(); auto _ps_env2 = paddle::distributed::PaddlePSEnvironment();
_ps_env2.set_ps_servers(&host_sign_list_, 2); // test _ps_env2.SetPsServers(&host_sign_list_, 2); // test
pserver_ptr2 = std::shared_ptr<paddle::distributed::GraphBrpcServer>( pserver_ptr2 = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*) (paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto2)); paddle::distributed::PSServerFactory::Create(server_proto2));
std::vector<framework::ProgramDesc> empty_vec2; std::vector<framework::ProgramDesc> empty_vec2;
framework::ProgramDesc empty_prog2; framework::ProgramDesc empty_prog2;
empty_vec2.push_back(empty_prog2); empty_vec2.push_back(empty_prog2);
pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); pserver_ptr2->Configure(server_proto2, _ps_env2, 1, empty_vec2);
pserver_ptr2->start(ip2, port2); pserver_ptr2->Start(ip2, port2);
pserver_ptr2->build_peer2peer_connection(1); pserver_ptr2->build_peer2peer_connection(1);
} }
...@@ -386,11 +386,11 @@ void RunClient( ...@@ -386,11 +386,11 @@ void RunClient(
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size(); auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, servers_); _ps_env.SetPsServers(&host_sign_list_, servers_);
worker_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcClient>( worker_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*) (paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto)); paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0);
worker_ptr_->set_shard_num(127); worker_ptr_->set_shard_num(127);
worker_ptr_->set_local_channel(index); worker_ptr_->set_local_channel(index);
worker_ptr_->set_local_graph_service( worker_ptr_->set_local_graph_service(
...@@ -404,11 +404,11 @@ void RunBrpcPushSparse() { ...@@ -404,11 +404,11 @@ void RunBrpcPushSparse() {
prepare_file(edge_file_name, 1); prepare_file(edge_file_name, 1);
prepare_file(node_file_name, 0); prepare_file(node_file_name, 0);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string()); host_sign_list_.push_back(ph_host.SerializeToString());
// test-start // test-start
auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1); auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1);
host_sign_list_.push_back(ph_host2.serialize_to_string()); host_sign_list_.push_back(ph_host2.SerializeToString());
// test-end // test-end
// Srart Server // Srart Server
std::thread* server_thread = new std::thread(RunServer); std::thread* server_thread = new std::thread(RunServer);
...@@ -424,7 +424,7 @@ void RunBrpcPushSparse() { ...@@ -424,7 +424,7 @@ void RunBrpcPushSparse() {
/*-----------------------Test Server Init----------------------------------*/ /*-----------------------Test Server Init----------------------------------*/
auto pull_status = auto pull_status =
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->Load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<int64_t>> _vs; std::vector<std::vector<int64_t>> _vs;
...@@ -438,7 +438,7 @@ void RunBrpcPushSparse() { ...@@ -438,7 +438,7 @@ void RunBrpcPushSparse() {
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, _vs[0].size()); ASSERT_EQ(0, _vs[0].size());
paddle::distributed::GraphTable* g = paddle::distributed::GraphTable* g =
(paddle::distributed::GraphTable*)pserver_ptr_->table(0); (paddle::distributed::GraphTable*)pserver_ptr_->GetTable(0);
size_t ttl = 6; size_t ttl = 6;
g->make_neighbor_sample_cache(4, ttl); g->make_neighbor_sample_cache(4, ttl);
int round = 5; int round = 5;
...@@ -622,15 +622,15 @@ void RunBrpcPushSparse() { ...@@ -622,15 +622,15 @@ void RunBrpcPushSparse() {
std::remove(node_file_name); std::remove(node_file_name);
testAddNode(worker_ptr_); testAddNode(worker_ptr_);
LOG(INFO) << "Run stop_server"; LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server(); worker_ptr_->StopServer();
LOG(INFO) << "Run finalize_worker"; LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker(); worker_ptr_->FinalizeWorker();
testFeatureNodeSerializeInt(); testFeatureNodeSerializeInt();
testFeatureNodeSerializeInt64(); testFeatureNodeSerializeInt64();
testFeatureNodeSerializeFloat32(); testFeatureNodeSerializeFloat32();
testFeatureNodeSerializeFloat64(); testFeatureNodeSerializeFloat64();
testGraphToBuffer(); testGraphToBuffer();
client1.stop_server(); client1.StopServer();
} }
void testCache() { void testCache() {
...@@ -700,4 +700,4 @@ void testGraphToBuffer() { ...@@ -700,4 +700,4 @@ void testGraphToBuffer() {
VLOG(0) << s1.get_feature(0); VLOG(0) << s1.get_feature(0);
} }
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
\ No newline at end of file
...@@ -48,7 +48,7 @@ TEST(MemorySparseGeoTable, SSUM) { ...@@ -48,7 +48,7 @@ TEST(MemorySparseGeoTable, SSUM) {
common_config->add_dims(emb_dim); common_config->add_dims(emb_dim);
common_config->add_initializers("fill_constant&1.0"); common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config); auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
// test push_sparse_param, and create params // test push_sparse_param, and create params
...@@ -58,12 +58,12 @@ TEST(MemorySparseGeoTable, SSUM) { ...@@ -58,12 +58,12 @@ TEST(MemorySparseGeoTable, SSUM) {
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
init_values.push_back(0.0); init_values.push_back(0.0);
} }
table->push_sparse_param(init_keys.data(), init_values.data(), table->PushSparseParam(init_keys.data(), init_values.data(),
init_keys.size()); init_keys.size());
std::vector<float> pull_values(init_values.size()); std::vector<float> pull_values(init_values.size());
auto value = PullSparseValue(init_keys, init_fres, emb_dim); auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(pull_values.data(), value); table->PullSparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5); ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5);
...@@ -93,8 +93,7 @@ TEST(MemorySparseGeoTable, SSUM) { ...@@ -93,8 +93,7 @@ TEST(MemorySparseGeoTable, SSUM) {
auto &push_keys = trainer_keys[i]; auto &push_keys = trainer_keys[i];
auto &push_values = trainer_values[i]; auto &push_values = trainer_values[i];
auto task = [table, &push_keys, &push_values] { auto task = [table, &push_keys, &push_values] {
table->push_sparse(push_keys.data(), push_values.data(), table->PushSparse(push_keys.data(), push_values.data(), push_keys.size());
push_keys.size());
}; };
task_status.push_back(pool_->enqueue(std::move(task))); task_status.push_back(pool_->enqueue(std::move(task)));
} }
...@@ -107,7 +106,7 @@ TEST(MemorySparseGeoTable, SSUM) { ...@@ -107,7 +106,7 @@ TEST(MemorySparseGeoTable, SSUM) {
geo_pull_ids.resize(trainers); geo_pull_ids.resize(trainers);
geo_pull_values.resize(trainers); geo_pull_values.resize(trainers);
for (int i = 0; i < trainers; i++) { for (int i = 0; i < trainers; i++) {
table->pull_geo_param(i, &geo_pull_values[i], &geo_pull_ids[i]); table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]);
ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim); ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim);
for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) { for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) {
auto id = geo_pull_ids[i][j]; auto id = geo_pull_ids[i][j];
......
...@@ -36,7 +36,7 @@ TEST(MemorySparseTable, SGD) { ...@@ -36,7 +36,7 @@ TEST(MemorySparseTable, SGD) {
table_config.set_shard_num(10); table_config.set_shard_num(10);
FsClientParameter fs_config; FsClientParameter fs_config;
Table *table = new MemorySparseTable(); Table *table = new MemorySparseTable();
table->set_shard(0, 1); table->SetShard(0, 1);
TableAccessorParameter *accessor_config = table_config.mutable_accessor(); TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CtrCommonAccessor"); accessor_config->set_accessor_class("CtrCommonAccessor");
...@@ -66,7 +66,7 @@ TEST(MemorySparseTable, SGD) { ...@@ -66,7 +66,7 @@ TEST(MemorySparseTable, SGD) {
naive_param->add_weight_bounds(-10.0); naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0); naive_param->add_weight_bounds(10.0);
auto ret = table->initialize(table_config, fs_config); auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
// pull parameters for create and check // pull parameters for create and check
...@@ -76,7 +76,7 @@ TEST(MemorySparseTable, SGD) { ...@@ -76,7 +76,7 @@ TEST(MemorySparseTable, SGD) {
std::vector<float> init_values; std::vector<float> init_values;
init_values.resize(init_keys.size() * (emb_dim + 3)); init_values.resize(init_keys.size() * (emb_dim + 3));
auto value = PullSparseValue(init_keys, init_fres, emb_dim); auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(init_values.data(), value); table->PullSparse(init_values.data(), value);
// for check // for check
std::vector<float> total_gradients; std::vector<float> total_gradients;
...@@ -109,8 +109,7 @@ TEST(MemorySparseTable, SGD) { ...@@ -109,8 +109,7 @@ TEST(MemorySparseTable, SGD) {
auto &push_keys = trainer_keys[i]; auto &push_keys = trainer_keys[i];
auto &push_values = trainer_gradient_values[i]; auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_keys, &push_values] { auto task = [table, &push_keys, &push_values] {
table->push_sparse(push_keys.data(), push_values.data(), table->PushSparse(push_keys.data(), push_values.data(), push_keys.size());
push_keys.size());
}; };
task_status.push_back(pool_->enqueue(std::move(task))); task_status.push_back(pool_->enqueue(std::move(task)));
} }
...@@ -120,7 +119,7 @@ TEST(MemorySparseTable, SGD) { ...@@ -120,7 +119,7 @@ TEST(MemorySparseTable, SGD) {
std::vector<float> pull_values; std::vector<float> pull_values;
pull_values.resize(init_keys.size() * (emb_dim + 3)); pull_values.resize(init_keys.size() * (emb_dim + 3));
table->pull_sparse(pull_values.data(), value); table->PullSparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size(); ++i) { for (size_t i = 0; i < init_keys.size(); ++i) {
for (size_t j = 2; j < emb_dim + 3; ++j) { for (size_t j = 2; j < emb_dim + 3; ++j) {
...@@ -133,7 +132,7 @@ TEST(MemorySparseTable, SGD) { ...@@ -133,7 +132,7 @@ TEST(MemorySparseTable, SGD) {
} }
MemorySparseTable *ctr_table = dynamic_cast<MemorySparseTable *>(table); MemorySparseTable *ctr_table = dynamic_cast<MemorySparseTable *>(table);
ctr_table->save_local_fs("./work/table.save", "0", "test"); ctr_table->SaveLocalFS("./work/table.save", "0", "test");
} }
} // namespace distributed } // namespace distributed
......
...@@ -26,7 +26,7 @@ TEST(Table, Initialize) { ...@@ -26,7 +26,7 @@ TEST(Table, Initialize) {
FsClientParameter fs_config; FsClientParameter fs_config;
// case 1. no accessor // case 1. no accessor
Table *table = new SparseGeoTable(); Table *table = new SparseGeoTable();
auto ret = table->initialize(table_config, fs_config); auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, -1); ASSERT_EQ(ret, -1);
} }
} // namespace distributed } // namespace distributed
......
...@@ -343,7 +343,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -343,7 +343,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
#ifdef PADDLE_WITH_PSCORE #ifdef PADDLE_WITH_PSCORE
int32_t cnt = 0; int32_t cnt = 0;
while (true) { while (true) {
auto tt = fleet_ptr->worker_ptr_->pull_sparse_ptr( auto tt = fleet_ptr->worker_ptr_->PullSparsePtr(
reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_, reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size); local_keys[i].data(), key_size);
bool flag = true; bool flag = true;
......
...@@ -276,7 +276,7 @@ void MultiTrainer::Finalize() { ...@@ -276,7 +276,7 @@ void MultiTrainer::Finalize() {
if (communicator == nullptr) { if (communicator == nullptr) {
VLOG(0) << "MultiTrainer::Finalize communicator is null!"; VLOG(0) << "MultiTrainer::Finalize communicator is null!";
} else { } else {
communicator->_worker_ptr->flush(); communicator->_worker_ptr->Flush();
VLOG(1) << "MultiTrainer::Finalize ps client flush done"; VLOG(1) << "MultiTrainer::Finalize ps client flush done";
} }
#endif #endif
......
...@@ -86,11 +86,11 @@ void BindDistFleetWrapper(py::module* m) { ...@@ -86,11 +86,11 @@ void BindDistFleetWrapper(py::module* m) {
void BindPSHost(py::module* m) { void BindPSHost(py::module* m) {
py::class_<distributed::PSHost>(*m, "PSHost") py::class_<distributed::PSHost>(*m, "PSHost")
.def(py::init<const std::string&, uint32_t, uint32_t>()) .def(py::init<const std::string&, uint32_t, uint32_t>())
.def("serialize_to_string", &distributed::PSHost::serialize_to_string) .def("serialize_to_string", &distributed::PSHost::SerializeToString)
.def("parse_from_string", &distributed::PSHost::parse_from_string) .def("parse_from_string", &distributed::PSHost::ParseFromString)
.def("to_uint64", &distributed::PSHost::serialize_to_uint64) .def("to_uint64", &distributed::PSHost::SerializeToUint64)
.def("from_uint64", &distributed::PSHost::parse_from_uint64) .def("from_uint64", &distributed::PSHost::ParseFromUint64)
.def("to_string", &distributed::PSHost::to_string); .def("to_string", &distributed::PSHost::ToString);
} }
void BindSparseShardingTools(py::module* m) { void BindSparseShardingTools(py::module* m) {
...@@ -224,7 +224,7 @@ void BindGraphPyClient(py::module* m) { ...@@ -224,7 +224,7 @@ void BindGraphPyClient(py::module* m) {
&GraphPyClient::use_neighbors_sample_cache) &GraphPyClient::use_neighbors_sample_cache)
.def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server) .def("stop_server", &GraphPyClient::StopServer)
.def("get_node_feat", .def("get_node_feat",
[](GraphPyClient& self, std::string node_type, [](GraphPyClient& self, std::string node_type,
std::vector<int64_t> node_ids, std::vector<int64_t> node_ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册