未验证 提交 b3270adf 编写于 作者: Z zhaocaibei123 提交者: GitHub

统一ps refine (#41234)

* update name

* update name

* fix test

* fix fleet bind

* update name

* update name

* fix test

* fix gpups wrapper

* remove Push/Pull/Load/Save with context in client and wrapper base class

* fix

* fix
Co-authored-by: Nesythan <esythan@126.com>
上级 cb124156
......@@ -80,7 +80,7 @@ void DownpourPsClientService::service(
const PsRequestMessage *request, PsResponseMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
int ret = _client->handle_client2client_msg(
int ret = _client->HandleClient2ClientMsg(
request->cmd_id(), request->client_id(), request->data());
response->set_err_code(0);
response->set_err_msg("");
......@@ -91,8 +91,8 @@ void DownpourPsClientService::service(
}
// 启动client端RpcService 用于数据互发等操作
int32_t BrpcPsClient::start_client_service() {
if (_service.configure(this, _client_id) != 0) {
int32_t BrpcPsClient::StartClientService() {
if (_service.Configure(this, _client_id) != 0) {
LOG(ERROR)
<< "service initialize failed, service_name:DownpourPsClientService";
return -1;
......@@ -108,12 +108,12 @@ int32_t BrpcPsClient::start_client_service() {
return -1;
}
_server_started = true;
_env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port,
_env->RegistePsClient(butil::my_ip_cstr(), _server.listen_address().port,
_client_id);
return 0;
}
int32_t BrpcPsClient::create_client2client_connection(
int32_t BrpcPsClient::CreateClient2ClientConnection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
......@@ -122,12 +122,12 @@ int32_t BrpcPsClient::create_client2client_connection(
options.connect_timeout_ms = pserver_connect_timeout_ms;
options.max_retry = max_retry;
std::vector<PSHost> client_list = _env->get_ps_clients();
std::vector<PSHost> client_list = _env->GetPsClients();
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: "
<< client_list.size();
for (auto cc : client_list) {
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: "
<< cc.to_string();
<< cc.ToString();
}
_client_channels.resize(client_list.size());
std::ostringstream os;
......@@ -154,7 +154,7 @@ int32_t BrpcPsClient::create_client2client_connection(
return 0;
}
int32_t BrpcPsClient::initialize() {
int32_t BrpcPsClient::Initialize() {
_async_call_num = 0;
brpc::ChannelOptions options;
......@@ -169,7 +169,7 @@ int32_t BrpcPsClient::initialize() {
std::string client_ip(butil::my_ip_cstr());
// 获取server列表,并连接
std::vector<PSHost> server_list = _env->get_ps_servers();
std::vector<PSHost> server_list = _env->GetPsServers();
_server_channels.resize(server_list.size());
for (size_t i = 0; i < server_list.size(); ++i) {
server_ip_port.assign(server_list[i].ip.c_str());
......@@ -194,7 +194,7 @@ int32_t BrpcPsClient::initialize() {
os << server_ip_port << ",";
}
// 启动client探听接口, 并相互建立连接
start_client_service();
StartClientService();
// 异步push 请求队列初始化
const auto &worker_param = _config.worker_param().downpour_worker_param();
......@@ -234,13 +234,13 @@ int32_t BrpcPsClient::initialize() {
_flushing = false;
// 启动异步push线程
_async_push_sparse_thread =
std::thread(std::bind(&BrpcPsClient::push_sparse_task_consume, this));
std::thread(std::bind(&BrpcPsClient::PushSparseTaskConsume, this));
// _async_push_sparse_thread.detach();
_async_push_dense_thread =
std::thread(std::bind(&BrpcPsClient::push_dense_task_consume, this));
std::thread(std::bind(&BrpcPsClient::PushDenseTaskConsume, this));
// for debug
// _print_thread =
// std::thread(std::bind(&BrpcPsClient::print_queue_size_thread, this));
// std::thread(std::bind(&BrpcPsClient::PrintQueueSizeThread, this));
return 0;
}
......@@ -286,7 +286,7 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) {
return data;
}
std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) {
std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, table_id](void *done) {
......@@ -319,7 +319,7 @@ std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) {
closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(get_cmd_channel(i));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
......@@ -327,7 +327,7 @@ std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) {
}
return fut;
}
std::future<int32_t> BrpcPsClient::send_cmd(
std::future<int32_t> BrpcPsClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
......@@ -352,7 +352,7 @@ std::future<int32_t> BrpcPsClient::send_cmd(
for (const auto &param : params) {
closure->request(i)->add_params(param);
}
PsService_Stub rpc_stub(get_cmd_channel(i));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000 * 2); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
......@@ -361,7 +361,7 @@ std::future<int32_t> BrpcPsClient::send_cmd(
return fut;
}
std::future<int32_t> BrpcPsClient::send_save_cmd(
std::future<int32_t> BrpcPsClient::SendSaveCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
......@@ -392,7 +392,7 @@ std::future<int32_t> BrpcPsClient::send_save_cmd(
for (const auto &param : params) {
closure->request(i)->add_params(param);
}
PsService_Stub rpc_stub(get_cmd_channel(i));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
......@@ -401,65 +401,42 @@ std::future<int32_t> BrpcPsClient::send_save_cmd(
return fut;
}
std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id,
std::future<int32_t> BrpcPsClient::Shrink(uint32_t table_id,
const std::string threshold) {
return send_cmd(table_id, PS_SHRINK_TABLE, {threshold});
return SendCmd(table_id, PS_SHRINK_TABLE, {threshold});
}
std::future<int32_t> BrpcPsClient::load(const std::string &epoch,
std::future<int32_t> BrpcPsClient::Load(const std::string &epoch,
const std::string &mode) {
return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode});
return SendCmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
std::future<int32_t> BrpcPsClient::Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
return SendCmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Load(const LoadSaveContext &load_context) {
if (load_context.table_id < 0) {
return send_cmd(-1, PS_LOAD_ALL_TABLE,
{load_context.epoch, load_context.mode});
} else {
return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE,
{load_context.epoch, load_context.mode});
}
}
std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
std::future<int32_t> BrpcPsClient::Save(const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode});
return SendSaveCmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
std::future<int32_t> BrpcPsClient::Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id "
<< table_id;
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Save(const LoadSaveContext &save_context) {
if (save_context.table_id < 0) {
VLOG(1) << "BrpcPsClient::save path " << save_context.epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE,
{save_context.epoch, save_context.mode});
} else {
VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch
<< " table_id " << save_context.table_id;
return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE,
{save_context.epoch, save_context.mode});
}
return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
std::future<int32_t> BrpcPsClient::Clear() {
return SendCmd(-1, PS_CLEAR_ALL_TABLE, {});
}
std::future<int32_t> BrpcPsClient::clear(uint32_t table_id) {
return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {});
std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) {
return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {});
}
std::future<int32_t> BrpcPsClient::flush() {
std::future<int32_t> BrpcPsClient::Flush() {
VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true;
std::promise<int> promise;
......@@ -472,106 +449,69 @@ std::future<int32_t> BrpcPsClient::flush() {
promise.set_value(0);
_flushing = false;
VLOG(0) << "BrpcPsClient::flush done";
print_queue_size();
PrintQueueSize();
return fut;
}
void BrpcPsClient::print_queue_size() {
void BrpcPsClient::PrintQueueSize() {
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first;
auto queue_size = push_sparse_task_itr.second->Size();
VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id
VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id
<< " size: " << queue_size;
}
for (auto &task_queue_itr : _push_dense_task_queue_map) {
auto table_id = task_queue_itr.first;
auto queue_size = task_queue_itr.second->Size();
VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id
VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id
<< " size: " << queue_size;
}
}
void BrpcPsClient::print_queue_size_thread() {
void BrpcPsClient::PrintQueueSizeThread() {
while (_running) {
usleep(1000000 * 60 * 2);
print_queue_size();
PrintQueueSize();
}
}
void BrpcPsClient::finalize_worker() {
flush();
VLOG(0) << "BrpcPsClient::finalize_worker begin join thread";
void BrpcPsClient::FinalizeWorker() {
Flush();
VLOG(0) << "BrpcPsClient::FinalizeWorker begin join thread";
_running = false;
_async_push_dense_thread.join();
_async_push_sparse_thread.join();
// _print_thread.join();
VLOG(0) << "BrpcPsClient::finalize_worker begin join server";
VLOG(0) << "BrpcPsClient::FinalizeWorker begin join server";
_server.Stop(1000);
_server.Join();
_server_started = false;
VLOG(0) << "BrpcPsClient::finalize_worker done";
VLOG(0) << "BrpcPsClient::FinalizeWorker done";
}
std::future<int32_t> BrpcPsClient::stop_server() {
return send_cmd(-1, PS_STOP_SERVER, {});
std::future<int32_t> BrpcPsClient::StopServer() {
return SendCmd(-1, PS_STOP_SERVER, {});
}
std::future<int32_t> BrpcPsClient::start_profiler() {
return send_cmd(-1, PS_START_PROFILER, {});
std::future<int32_t> BrpcPsClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> BrpcPsClient::stop_profiler() {
return send_cmd(-1, PS_STOP_PROFILER, {});
std::future<int32_t> BrpcPsClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
std::future<int32_t> BrpcPsClient::Barrier(size_t table_id,
uint32_t barrier_type) {
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}
std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
return pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
return pull_sparse_param(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
return pull_sparse(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
}
}
}
std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
return push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
bool is_training = push_context.is_training;
if (push_context.training_mode == Geo) { // for geo
// TODO(zhaocaibei)
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
return push_sparse(table_id, keys, update_values, num);
}
}
return SendCmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}
std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::future<int32_t> BrpcPsClient::PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) {
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(1, [keys, values, accessor](void *done) {
int ret = 0;
......@@ -600,7 +540,7 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
PsService_Stub rpc_stub(get_cmd_channel(pserver_idx));
PsService_Stub rpc_stub(GetCmdChannel(pserver_idx));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
......@@ -608,10 +548,11 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
}
// for GEO
std::future<int32_t> BrpcPsClient::push_sparse_param(
size_t table_id, const uint64_t *keys, const float **update_values,
std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) {
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
// 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
......@@ -649,7 +590,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
......@@ -658,16 +599,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
return fut;
}
std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t region_num,
std::future<int32_t> BrpcPsClient::PullDense(Region *regions, size_t region_num,
size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
auto fea_dim = accessor->GetTableInfo(FEA_DIM);
auto select_size = accessor->GetTableInfo(SELECT_SIZE);
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, num_per_shard, regions, region_num,
......@@ -730,22 +670,22 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&num_per_shard, // NOLINT
sizeof(num_per_shard));
PsService_Stub rpc_stub(get_dense_channel(i));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
std::future<int32_t> BrpcPsClient::PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
size_t request_call_num = _server_channels.size();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE);
size_t current_region_idx = 0;
size_t current_region_data_idx = 0;
......@@ -809,17 +749,17 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
fill_num);
fill_remain_size -= fill_num;
}
PsService_Stub rpc_stub(get_dense_channel(i));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) {
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
// 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
......@@ -872,7 +812,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
......@@ -881,7 +821,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
return fut;
}
std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
std::future<int32_t> BrpcPsClient::PushDenseRawGradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) {
size_t request_call_num = _server_channels.size();
......@@ -889,9 +829,9 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
......@@ -905,14 +845,14 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
total_send_data + i * num_per_shard, num_per_shard * sizeof(float));
// closure->cntl(i)->set_request_compress_type(
// (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::push_global_step(int table_id,
std::future<int32_t> BrpcPsClient::PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done) {
size_t request_call_num = _server_channels.size();
......@@ -933,14 +873,14 @@ std::future<int32_t> BrpcPsClient::push_global_step(int table_id,
memcpy(push_data_ptr + sizeof(uint32_t), total_send_data,
num_per_shard * sizeof(int64_t));
PsService_Stub rpc_stub(get_dense_channel(i));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
std::future<int32_t> BrpcPsClient::PullSparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training) {
......@@ -968,7 +908,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
......@@ -1055,7 +995,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
......@@ -1065,7 +1005,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
}
// for GEO
std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
......@@ -1082,7 +1022,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
......@@ -1169,7 +1109,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
......@@ -1178,7 +1118,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
return fut;
}
std::future<int32_t> BrpcPsClient::send_client2client_msg(
std::future<int32_t> BrpcPsClient::SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future();
......@@ -1203,10 +1143,10 @@ std::future<int32_t> BrpcPsClient::send_client2client_msg(
return fut;
}
std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) {
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
......@@ -1228,7 +1168,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
memcpy(push_data_ptr, update_values[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(get_sparse_channel(pserver_idx));
PsService_Stub rpc_stub(GetSparseChannel(pserver_idx));
closure->cntl(0)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
......@@ -1236,7 +1176,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
return fut;
}
int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
int32_t BrpcPsClient::RecvAndSaveTable(const uint64_t table_id,
const std::string &path) {
// get var information
std::string var_name = "";
......@@ -1271,16 +1211,16 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}
VLOG(2) << "recv_and_save_table: table_class: " << table_class;
VLOG(2) << "RecvAndSaveTable: table_class: " << table_class;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// recv_and_save_table
// RecvAndSaveTable
if (table_class == "MemorySparseGeoTable") {
auto status =
pull_sparse_param(reinterpret_cast<float **>(save_vec.data()), table_id,
PullSparseParam(reinterpret_cast<float **>(save_vec.data()), table_id,
save_key.data(), save_key.size(), true);
status.wait();
} else {
auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()),
auto status = PullSparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true);
status.wait();
}
......@@ -1315,7 +1255,7 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
return 0;
}
std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
std::future<int32_t> BrpcPsClient::PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) {
......@@ -1323,7 +1263,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
CostTimer parse_timer("pserver_client_push_sparse_parse");
int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "push_sparse Waiting for async_call_num comsume,
// LOG(INFO) << "PushSparse Waiting for async_call_num comsume,
// task_num:"
// << push_sparse_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
......@@ -1333,7 +1273,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
auto put_timer = std::make_shared<CostTimer>("client_push_sparse_put");
thread_local std::vector<std::vector<std::pair<uint64_t, const float *>>>
shard_sorted_kv_list;
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
size_t request_call_num = _server_channels.size();
shard_sorted_kv_list.resize(request_call_num);
for (auto &x : shard_sorted_kv_list) {
......@@ -1381,7 +1321,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
return fut;
}
void BrpcPsClient::push_sparse_task_consume() {
void BrpcPsClient::PushSparseTaskConsume() {
uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit;
std::vector<std::shared_ptr<SparseAsyncTask>> task_list;
size_t request_call_num = _server_channels.size();
......@@ -1392,7 +1332,7 @@ void BrpcPsClient::push_sparse_task_consume() {
// 所有sparseTable的pushTask 进行处理
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first;
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
auto &task_queue = push_sparse_task_itr.second;
auto queue_size = task_queue->Size();
if (queue_size == 0) {
......@@ -1471,7 +1411,7 @@ void BrpcPsClient::push_sparse_task_consume() {
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] =
async_push_sparse_shard_threads.enqueue(std::bind(
&BrpcPsClient::push_sparse_async_shard_push, this, task_list,
&BrpcPsClient::PushSparseAsyncShardPush, this, task_list,
request_kv_num, table_id, shard_idx, closure, accessor));
}
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
......@@ -1487,7 +1427,7 @@ void BrpcPsClient::push_sparse_task_consume() {
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] =
async_push_sparse_shard_threads.enqueue(std::bind(
&BrpcPsClient::push_sparse_async_shard_merge, this, task_list,
&BrpcPsClient::PushSparseAsyncShardMerge, this, task_list,
request_kv_num, table_id, shard_idx, accessor));
}
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
......@@ -1523,7 +1463,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
accessor->Merge(merge_data_shell, another_data_shell, 1);
}
int BrpcPsClient::push_sparse_async_shard_merge(
int BrpcPsClient::PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num, int table_id, int shard_idx,
ValueAccessor *accessor) {
......@@ -1615,11 +1555,11 @@ int BrpcPsClient::push_sparse_async_shard_merge(
return 0;
}
int BrpcPsClient::push_sparse_async_shard_push(
int BrpcPsClient::PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num, int table_id, int shard_idx,
DownpourBrpcClosure *closure, ValueAccessor *accessor) {
push_sparse_async_shard_merge(task_list, request_kv_num, table_id, shard_idx,
PushSparseAsyncShardMerge(task_list, request_kv_num, table_id, shard_idx,
accessor);
size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num;
......@@ -1649,7 +1589,7 @@ int BrpcPsClient::push_sparse_async_shard_push(
accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
......@@ -1658,10 +1598,10 @@ int BrpcPsClient::push_sparse_async_shard_push(
return 0;
}
std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
std::future<int32_t> BrpcPsClient::PushDense(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = table_accessor(table_id);
auto *accessor = GetTableAccessor(table_id);
int fea_dim = accessor->GetTableInfo(FEA_DIM);
int update_dim = accessor->GetTableInfo(UPDATE_DIM);
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
......@@ -1669,7 +1609,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
std::make_shared<CostTimer>("pserver_client_push_dense_parse");
int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
while (push_dense_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "push_dense Waiting for async_call_num comsume,
// LOG(INFO) << "PushDense Waiting for async_call_num comsume,
// task_num:"
// << push_dense_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
......@@ -1683,7 +1623,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num *
......@@ -1705,7 +1645,7 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
return fut;
}
void BrpcPsClient::push_dense_task_consume() {
void BrpcPsClient::PushDenseTaskConsume() {
uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit;
static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge;
::ThreadPool async_merge_dense_threads(10);
......@@ -1723,7 +1663,7 @@ void BrpcPsClient::push_dense_task_consume() {
++_async_call_num;
DenseAsyncTask *task;
task_queue->Get(task);
auto *accessor = table_accessor(task->table_id());
auto *accessor = GetTableAccessor(task->table_id());
// 设置请求回调
size_t request_call_num = _server_channels.size();
......@@ -1774,7 +1714,7 @@ void BrpcPsClient::push_dense_task_consume() {
merge_status[i].wait();
}
VLOG(3) << "BrpcPsClient::push_dense_task_consume before merge "
VLOG(3) << "BrpcPsClient::PushDenseTaskConsume before merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
......@@ -1787,7 +1727,7 @@ void BrpcPsClient::push_dense_task_consume() {
mat *= (1.0 / (merge_count + 1));
}
VLOG(3) << "BrpcPsClient::push_dense_task_consume after merge "
VLOG(3) << "BrpcPsClient::PushDenseTaskConsume after merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
......@@ -1796,7 +1736,7 @@ void BrpcPsClient::push_dense_task_consume() {
<< merge_count;
}
std::shared_ptr<DenseAsyncTask> task_ptr(task);
push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size,
PushDenseRawGradient(task_ptr, total_send_data, total_send_data_size,
closure);
}
auto wait_ms = FLAGS_pserver_async_push_dense_interval_ms -
......@@ -1807,16 +1747,17 @@ void BrpcPsClient::push_dense_task_consume() {
}
}
void BrpcPsClient::push_dense_raw_gradient(
std::shared_ptr<DenseAsyncTask> &task, float *total_send_data,
size_t total_send_data_size, DownpourBrpcClosure *closure) {
auto *accessor = table_accessor(task->table_id());
void BrpcPsClient::PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task,
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure) {
auto *accessor = GetTableAccessor(task->table_id());
size_t request_call_num = _server_channels.size();
// 将数据拷贝到请求buffer区
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num);
auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) {
......@@ -1832,7 +1773,7 @@ void BrpcPsClient::push_dense_raw_gradient(
total_send_data + i * num_per_shard, num_per_shard * sizeof(float));
closure->cntl(i)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
......
......@@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService {
DownpourPsClientService() {}
virtual ~DownpourPsClientService() {}
virtual int32_t configure(PSClient *client, size_t rank_id) {
virtual int32_t Configure(PSClient *client, size_t rank_id) {
_client = client;
_rank = rank_id;
return 0;
......@@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient {
BrpcPsClient() {}
virtual ~BrpcPsClient() {
if (_running) {
flush();
Flush();
_running = false;
}
if (_async_push_dense_thread.joinable()) {
......@@ -154,108 +154,97 @@ class BrpcPsClient : public PSClient {
_server_started = false;
}
}
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
std::future<int32_t> shrink(uint32_t table_id,
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
std::future<int32_t> load(const std::string &epoch,
std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Load(const LoadSaveContext &load_context) override;
std::future<int32_t> save(const std::string &epoch,
std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;
std::future<int32_t> clear() override;
std::future<int32_t> Clear() override;
std::future<int32_t> clear(uint32_t table_id) override;
std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> stop_server() override;
std::future<int32_t> StopServer() override;
std::future<int32_t> start_profiler() override;
std::future<int32_t> stop_profiler() override;
std::future<int32_t> StartProfiler() override;
std::future<int32_t> StopProfiler() override;
void finalize_worker() override;
void FinalizeWorker() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t table_id);
virtual std::future<int32_t> push_dense_param(const Region *regions,
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> push_dense(const Region *regions,
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num, size_t table_id);
void push_dense_task_consume();
virtual std::future<int32_t> pull_sparse(float **select_values,
void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> Pull(RequestContext &pull_context) override;
virtual std::future<int32_t> Push(RequestContext &push_context) override;
virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> pull_geo_param(size_t table_id,
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> push_global_step(int table_id,
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> flush();
virtual std::future<int32_t> Flush();
std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
std::future<int32_t> SendClient2ClientMsg(int msg_type, int to_client_id,
const std::string &msg) override;
// for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id,
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);
void print_queue_size();
void print_queue_size_thread();
void PrintQueueSize();
void PrintQueueSizeThread();
protected:
virtual size_t get_server_nums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) {
virtual size_t GetServerNums() { return _server_channels.size(); }
inline brpc::Channel *GetSparseChannel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *get_dense_channel(size_t server_id) {
inline brpc::Channel *GetDenseChannel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *get_cmd_channel(size_t server_id) {
inline brpc::Channel *GetCmdChannel(size_t server_id) {
return _server_channels[server_id][2].get();
}
int32_t initialize() override;
int32_t Initialize() override;
private:
// virtual int32_t initialize() override;
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
std::future<int32_t> SendSaveCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
bool _running = false;
......@@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient {
std::thread _print_thread;
int push_sparse_async_shard_merge(
int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
ValueAccessor *accessor);
int push_sparse_async_shard_push(
int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor);
......@@ -292,35 +281,35 @@ class BrpcPsClient : public PSClient {
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
std::future<int32_t> push_dense_raw_gradient(int table_id,
std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;
std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override;
size_t num, void *done) override;
std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) override;
std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
uint32_t num, void *done,
int pserver_idx) override;
std::future<int32_t> PushSparseParam(size_t table_id, const uint64_t *keys,
const float **update_values, size_t num,
void *done) override;
std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) override;
void push_sparse_task_consume();
void PushSparseTaskConsume();
private:
int32_t start_client_service();
int32_t StartClientService();
void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data,
size_t total_send_data_size,
void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data, size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0;
float _mse = 0;
......
......@@ -31,7 +31,7 @@ class RpcController;
namespace paddle {
namespace distributed {
int32_t BrpcPsServer::initialize() {
int32_t BrpcPsServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
......@@ -46,7 +46,7 @@ int32_t BrpcPsServer::initialize() {
}
_service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) {
if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
......@@ -59,7 +59,7 @@ int32_t BrpcPsServer::initialize() {
return 0;
}
uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
......@@ -68,7 +68,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers();
auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
......@@ -83,7 +83,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
}
}
_environment->registe_ps_server(ip, port, _rank);
_environment->RegistePsServer(ip, port, _rank);
cv_.wait(lock, [&] { return stoped_; });
PSHost host;
......@@ -93,31 +93,30 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
return host.rank;
}
int32_t BrpcPsServer::port() { return _server.listen_address().port; }
int32_t BrpcPsServer::Port() { return _server.listen_address().port; }
int32_t BrpcPsService::initialize() {
int32_t BrpcPsService::Initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &BrpcPsService::stop_server;
_service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::pull_dense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::push_dense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::pull_sparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::push_sparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::save_one_table;
_service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::save_all_table;
_service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::shrink_table;
_service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::load_all_table;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::clear_one_table;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::clear_all_table;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::push_dense_param;
_service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::print_table_stat;
_service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::pull_geo_param;
_service_handler_map[PS_PUSH_SPARSE_PARAM] =
&BrpcPsService::push_sparse_param;
_service_handler_map[PS_BARRIER] = &BrpcPsService::barrier;
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::push_global_step;
_service_handler_map[PS_STOP_SERVER] = &BrpcPsService::StopServer;
_service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::PullDense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::PushDense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::PullSparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::PushSparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::SaveOneTable;
_service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::SaveAllTable;
_service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::ShrinkTable;
_service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::LoadAllTable;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::ClearOneTable;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::ClearAllTable;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::PushDenseParam;
_service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::PrintTableStat;
_service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::PullGeoParam;
_service_handler_map[PS_PUSH_SPARSE_PARAM] = &BrpcPsService::PushSparseParam;
_service_handler_map[PS_BARRIER] = &BrpcPsService::Barrier;
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense");
......@@ -125,7 +124,7 @@ int32_t BrpcPsService::initialize() {
profiler.register_profiler("pserver_server_push_sparse");
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
InitializeShardInfo();
return 0;
}
......@@ -138,16 +137,16 @@ int32_t BrpcPsService::initialize() {
return -1; \
}
int32_t BrpcPsService::initialize_shard_info() {
int32_t BrpcPsService::InitializeShardInfo() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
size_t shard_num = _server->environment()->get_ps_servers().size();
auto &table_map = *(_server->table());
size_t shard_num = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->GetTable());
for (auto itr : table_map) {
itr.second->set_shard(_rank, shard_num);
itr.second->SetShard(_rank, shard_num);
}
_is_initialize_shard_info = true;
}
......@@ -167,7 +166,7 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->table(request->table_id());
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
......@@ -185,11 +184,11 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
}
}
int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->pull_dense", platform::TracerEventType::Communication, 1);
"PsService->PullDense", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
......@@ -206,14 +205,15 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
}
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->value_accesor()->GetTableInfo(SELECT_SIZE) /
res_data->resize(num * table->ValueAccesor()->GetTableInfo(SELECT_SIZE) /
sizeof(float));
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = res_data->data();
table_context.num = num;
table->Pull(table_context);
// table->pull_dense(res_data->data(), num);
// table->PullDense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
......@@ -222,13 +222,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t BrpcPsService::push_dense_param(Table *table,
int32_t BrpcPsService::PushDenseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense_param",
platform::TracerEventType::Communication,
1);
platform::RecordEvent record_event(
"PsService->PushDenseParam", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_buffer;
auto &req_io_buffer = cntl->request_attachment();
......@@ -245,17 +244,17 @@ int32_t BrpcPsService::push_dense_param(Table *table,
uint32_t num = *(const uint32_t *)data;
const float *values = (const float *)(data + sizeof(uint32_t));
if (table->push_dense_param(values, num) != 0) {
set_response_code(response, -1, "push_dense_param failed");
if (table->PushDenseParam(values, num) != 0) {
set_response_code(response, -1, "PushDenseParam failed");
}
return 0;
}
int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->push_dense", platform::TracerEventType::Communication, 1);
"PsService->PushDense", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
......@@ -278,14 +277,14 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if (table->Push(table_context) != 0) {
// if (table->push_dense(values, num) != 0) {
set_response_code(response, -1, "push_dense failed");
// if (table->PushDense(values, num) != 0) {
set_response_code(response, -1, "PushDense failed");
}
return 0;
}
int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request,
int32_t BrpcPsService::Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
......@@ -299,15 +298,15 @@ int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request,
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type);
table->Barrier(trainer_id, barrier_type);
return 0;
}
int32_t BrpcPsService::push_sparse_param(Table *table,
int32_t BrpcPsService::PushSparseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse_param",
platform::RecordEvent record_event("PsService->PushSparseParam",
platform::TracerEventType::Communication,
1);
CHECK_TABLE_EXIST(table, request, response)
......@@ -331,13 +330,13 @@ int32_t BrpcPsService::push_sparse_param(Table *table,
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse_param(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse_param error");
if (table->PushSparseParam(keys, values, num) != 0) {
set_response_code(response, -1, "PushSparseParam error");
}
return 0;
}
int32_t BrpcPsService::pull_geo_param(Table *table,
int32_t BrpcPsService::PullGeoParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -350,7 +349,7 @@ int32_t BrpcPsService::pull_geo_param(Table *table,
std::vector<float> values;
std::vector<uint64_t> ids;
table->pull_geo_param(trainer_id, &values, &ids);
table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
......@@ -361,12 +360,11 @@ int32_t BrpcPsService::pull_geo_param(Table *table,
return 0;
}
int32_t BrpcPsService::pull_sparse(Table *table,
const PsRequestMessage &request,
int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->pull_sparse", platform::TracerEventType::Communication, 1);
"PsService->PullSparse", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto &req_io_buffer = cntl->request_attachment();
......@@ -386,7 +384,7 @@ int32_t BrpcPsService::pull_sparse(Table *table,
CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
auto dim = table->value_accesor()->GetTableInfo(SELECT_DIM);
auto dim = table->ValueAccesor()->GetTableInfo(SELECT_DIM);
thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size);
......@@ -405,7 +403,7 @@ int32_t BrpcPsService::pull_sparse(Table *table,
table_context.pull_context.pull_value = value;
table_context.pull_context.values = res_data->data();
table->Pull(table_context);
// table->pull_sparse(res_data->data(), value);
// table->PullSparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
......@@ -413,12 +411,11 @@ int32_t BrpcPsService::pull_sparse(Table *table,
return 0;
}
int32_t BrpcPsService::push_sparse(Table *table,
const PsRequestMessage &request,
int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->push_sparse", platform::TracerEventType::Communication, 1);
"PsService->PushSparse", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
......@@ -448,18 +445,18 @@ int32_t BrpcPsService::push_sparse(Table *table,
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if (table->Push(table_context) != 0) {
// if (table->push_sparse(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse error");
// if (table->PushSparse(keys, values, num) != 0) {
set_response_code(response, -1, "PushSparse error");
}
return 0;
}
int32_t BrpcPsService::print_table_stat(Table *table,
int32_t BrpcPsService::PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat();
std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
......@@ -468,7 +465,7 @@ int32_t BrpcPsService::print_table_stat(Table *table,
return 0;
}
int32_t BrpcPsService::load_one_table(Table *table,
int32_t BrpcPsService::LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -479,20 +476,20 @@ int32_t BrpcPsService::load_one_table(Table *table,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->load(request.params(0), request.params(1)) != 0) {
if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t BrpcPsService::load_all_table(Table *table,
int32_t BrpcPsService::LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
......@@ -500,7 +497,7 @@ int32_t BrpcPsService::load_all_table(Table *table,
return 0;
}
int32_t BrpcPsService::save_one_table(Table *table,
int32_t BrpcPsService::SaveOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -511,12 +508,12 @@ int32_t BrpcPsService::save_one_table(Table *table,
"PsRequestMessage.datas is requeired at least 2, path&mode");
return -1;
}
table->flush();
table->Flush();
int32_t feasign_size = 0;
VLOG(3) << "save table " << request.params(0) << " " << request.params(1);
feasign_size = table->save(request.params(0), request.params(1));
feasign_size = table->Save(request.params(0), request.params(1));
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
......@@ -524,16 +521,16 @@ int32_t BrpcPsService::save_one_table(Table *table,
return feasign_size;
}
int32_t BrpcPsService::save_all_table(Table *table,
int32_t BrpcPsService::SaveAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
auto &table_map = *(_server->GetTable());
int32_t all_feasign_size = 0;
int32_t feasign_size = 0;
for (auto &itr : table_map) {
feasign_size = save_one_table(itr.second.get(), request, response, cntl);
feasign_size = SaveOneTable(itr.second.get(), request, response, cntl);
if (feasign_size < 0) {
LOG(ERROR) << "save table[" << itr.first << "] failed";
return -1;
......@@ -542,7 +539,7 @@ int32_t BrpcPsService::save_all_table(Table *table,
return 0;
}
int32_t BrpcPsService::shrink_table(Table *table,
int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -553,8 +550,8 @@ int32_t BrpcPsService::shrink_table(Table *table,
"PsRequestMessage.datas is requeired at least 1, threshold");
return -1;
}
table->flush();
if (table->shrink(request.params(0)) != 0) {
table->Flush();
if (table->Shrink(request.params(0)) != 0) {
set_response_code(response, -1, "table shrink failed");
return -1;
}
......@@ -562,43 +559,42 @@ int32_t BrpcPsService::shrink_table(Table *table,
return 0;
}
int32_t BrpcPsService::clear_one_table(Table *table,
int32_t BrpcPsService::ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->flush();
table->clear();
table->Flush();
table->Clear();
return 0;
}
int32_t BrpcPsService::clear_all_table(Table *table,
int32_t BrpcPsService::ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (clear_one_table(itr.second.get(), request, response, cntl) != 0) {
if (ClearOneTable(itr.second.get(), request, response, cntl) != 0) {
return -1;
}
}
return 0;
}
int32_t BrpcPsService::stop_server(Table *table,
const PsRequestMessage &request,
int32_t BrpcPsService::StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server;
std::thread t_stop([p_server]() {
p_server->stop();
p_server->Stop();
VLOG(3) << "Server Stoped";
});
t_stop.detach();
return 0;
}
int32_t BrpcPsService::stop_profiler(Table *table,
int32_t BrpcPsService::StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -607,7 +603,7 @@ int32_t BrpcPsService::stop_profiler(Table *table,
return 0;
}
int32_t BrpcPsService::start_profiler(Table *table,
int32_t BrpcPsService::StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -615,7 +611,7 @@ int32_t BrpcPsService::start_profiler(Table *table,
return 0;
}
int32_t BrpcPsService::push_global_step(Table *table,
int32_t BrpcPsService::PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -629,7 +625,7 @@ int32_t BrpcPsService::push_global_step(Table *table,
const int64_t *values =
(const int64_t *)(request.data().data() + sizeof(uint32_t));
auto trainer_id = request.client_id();
if (table->push_dense(values, trainer_id) != 0) {
if (table->PushDense(values, trainer_id) != 0) {
set_response_code(response, -1, "run_program failed");
}
......
......@@ -41,8 +41,8 @@ class BrpcPsServer : public PSServer {
public:
BrpcPsServer() {}
virtual ~BrpcPsServer() {}
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual int32_t stop() {
virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true;
cv_.notify_all();
......@@ -51,10 +51,10 @@ class BrpcPsServer : public PSServer {
_server.Join();
return 0;
}
int32_t port();
int32_t Port();
private:
virtual int32_t initialize();
virtual int32_t Initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
......@@ -71,7 +71,7 @@ typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
class BrpcPsService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
......@@ -79,49 +79,48 @@ class BrpcPsService : public PsBaseService {
::google::protobuf::Closure *done) override;
private:
int32_t initialize_shard_info();
int32_t pull_dense(Table *table, const PsRequestMessage &request,
int32_t InitializeShardInfo();
int32_t PullDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense(Table *table, const PsRequestMessage &request,
int32_t PushDense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense_param(Table *table, const PsRequestMessage &request,
int32_t PushDenseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t pull_sparse(Table *table, const PsRequestMessage &request,
int32_t PushSparseParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t PullSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
int32_t PullGeoParam(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse(Table *table, const PsRequestMessage &request,
int32_t PushSparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_one_table(Table *table, const PsRequestMessage &request,
int32_t SaveOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_all_table(Table *table, const PsRequestMessage &request,
int32_t SaveAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t shrink_table(Table *table, const PsRequestMessage &request,
int32_t ShrinkTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_one_table(Table *table, const PsRequestMessage &request,
int32_t ClearOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_all_table(Table *table, const PsRequestMessage &request,
int32_t ClearAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request,
int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request,
int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_global_step(Table *table, const PsRequestMessage &request,
int32_t PushGlobalStep(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
bool _is_initialize_shard_info;
......
......@@ -39,7 +39,7 @@ inline double GetCurrentUS() {
Communicator::Communicator() {}
void Communicator::init_gflag(const std::string &gflags) {
void Communicator::InitGFlag(const std::string &gflags) {
VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
......@@ -73,7 +73,7 @@ void Communicator::InitBrpcClient(
}
std::vector<uint64_t> Communicator::GetClientInfo() {
std::vector<uint64_t> res = _ps_env.get_client_info();
std::vector<uint64_t> res = _ps_env.GetClientInfo();
for (auto rr : res) {
VLOG(2) << "Communicator::GetClientInfo " << rr;
}
......@@ -82,7 +82,7 @@ std::vector<uint64_t> Communicator::GetClientInfo() {
int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
int node = host_sign_list.size();
return _ps_env.set_ps_clients(host_sign_list.data(), node);
return _ps_env.SetPsClients(host_sign_list.data(), node);
}
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
......@@ -114,7 +114,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
}
}
auto status =
_worker_ptr->pull_dense(regions.data(), regions.size(), table_id);
_worker_ptr->PullDense(regions.data(), regions.size(), table_id);
status.wait();
for (auto &t : varnames) {
......@@ -177,7 +177,7 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
}
}
auto status =
_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id);
_worker_ptr->PushDenseParam(regions.data(), regions.size(), table_id);
status.wait();
VLOG(4) << "RPC Send Dense Param " << table_id << " done!";
return;
......@@ -190,9 +190,9 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
auto &var_names = ctx.origin_varnames;
auto &table_id = ctx.table_id;
auto dense_data = std::make_shared<std::vector<float>>();
size_t request_call_num = _worker_ptr->get_server_nums();
size_t request_call_num = _worker_ptr->GetServerNums();
uint32_t num_per_shard =
dense_dim_per_shard(ctx.height_sections[0], request_call_num);
DenseDimPerShard(ctx.height_sections[0], request_call_num);
dense_data->resize(num_per_shard *
request_call_num); // accessor->update_dim() = 1
float *data = dense_data->data();
......@@ -222,8 +222,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_dense_raw_gradient(
table_id, data, dense_data->size(), closure);
auto status = _worker_ptr->PushDenseRawGradient(table_id, data,
dense_data->size(), closure);
status.wait();
return;
}
......@@ -233,7 +233,7 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
platform::RecordEvent record_event("Communicator->RpcSendSparseParam",
platform::TracerEventType::Communication,
1);
size_t request_call_num = _worker_ptr->get_server_nums();
size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(varname);
......@@ -260,8 +260,8 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->push_sparse_param(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
auto status = _worker_ptr->PushSparseParam(table_id, sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
status.wait();
return;
......@@ -272,7 +272,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
platform::RecordEvent record_event("Communicator->RpcSendSparse",
platform::TracerEventType::Communication,
1);
size_t request_call_num = _worker_ptr->get_server_nums();
size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<uint64_t> sparse_push_keys;
std::vector<float *> push_g_vec;
......@@ -313,7 +313,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_sparse_raw_gradient(
auto status = _worker_ptr->PushSparseRawGradient(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
status.wait();
......@@ -340,7 +340,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
bool training = true;
auto status = _worker_ptr->pull_sparse_param(
auto status = _worker_ptr->PullSparseParam(
(float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait();
......@@ -376,11 +376,11 @@ void Communicator::RpcProfilerControl() {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = _worker_ptr->start_profiler();
auto start_status = _worker_ptr->StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = _worker_ptr->stop_profiler();
auto stop_status = _worker_ptr->StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
......@@ -396,7 +396,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
platform::TracerEventType::Communication,
1);
auto &table_id = ctx.table_id;
size_t request_call_num = _worker_ptr->get_server_nums();
size_t request_call_num = _worker_ptr->GetServerNums();
auto &var_name = STEP_COUNTER;
auto *out_var = send_scope->Var(var_name);
......@@ -416,7 +416,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->push_global_step(table_id, data, closure);
auto status = _worker_ptr->PushGlobalStep(table_id, data, closure);
status.wait();
return;
}
......@@ -605,8 +605,8 @@ void AsyncCommunicator::PullSparseToTensorSync(
}
}
auto status =
_worker_ptr->pull_sparse(pull_result_ptr.data(), table_id,
fea_keys.data(), fea_keys.size(), is_training);
_worker_ptr->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.size(), is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
......@@ -738,7 +738,7 @@ void AsyncCommunicator::PushSparseFromTensorAsync(
this->Check(table_id), true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
auto status = _worker_ptr->push_sparse(table_id, push_keys.data(),
auto status = _worker_ptr->PushSparse(table_id, push_keys.data(),
(const float **)push_g_vec.data(),
push_keys.size());
}
......@@ -813,7 +813,7 @@ void AsyncCommunicator::Stop() {
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
// _worker_ptr->finalize_worker();
// _worker_ptr->FinalizeWorker();
VLOG(1) << "client finalize_worker done";
if (recv_thread_) {
VLOG(1) << "stop recv thread";
......@@ -1327,7 +1327,7 @@ void GeoCommunicator::SendSparse(const std::string &varname,
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_sparse_raw_gradient_partial(
auto status = _worker_ptr->PushSparseRawGradientPartial(
table_id, (const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx);
status.wait();
......@@ -1345,7 +1345,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
// 1. recv from pserver
std::vector<uint64_t> keys;
std::vector<float> values;
auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx);
auto status = _worker_ptr->PullGeoParam(table_id, &values, &keys, ep_idx);
status.wait();
std::string param = SplitedGradToParam(varname);
......
......@@ -299,7 +299,7 @@ class Communicator {
virtual void Barrier() {}
virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type);
auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
rets.wait();
int status = rets.get();
PADDLE_ENFORCE_EQ(status, 0,
......@@ -310,7 +310,7 @@ class Communicator {
virtual void CreateC2CConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
_worker_ptr->create_client2client_connection(
_worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
}
......@@ -379,12 +379,12 @@ class Communicator {
std::unordered_map<std::string, std::string> envs;
// 计算每个shard 对 dense的存储量
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
void init_gflag(const std::string &gflags);
void InitGFlag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0;
......
......@@ -40,7 +40,7 @@ struct PSHost {
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
uint64_t serialize_to_uint64() {
uint64_t SerializeToUint64() {
uint64_t host_label = 0;
host_label = inet_addr(ip.c_str());
host_label = host_label << 32;
......@@ -49,7 +49,7 @@ struct PSHost {
return host_label;
}
void parse_from_uint64(uint64_t host_label) {
void ParseFromUint64(uint64_t host_label) {
static uint64_t rank_label_mask = (1L << 12) - 1;
static uint64_t port_label_mask = (1L << 20) - 1;
rank = host_label & rank_label_mask;
......@@ -58,17 +58,17 @@ struct PSHost {
ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT
}
std::string to_string() {
std::string ToString() {
std::stringstream s;
s << "host: " << ip;
s << " port: " << port;
s << " rank: " << rank;
s << " uint: " << serialize_to_uint64();
s << " uint: " << SerializeToUint64();
return s.str();
}
// for open source parameter server
std::string serialize_to_string() {
std::string SerializeToString() {
std::stringstream s;
s << ip << ":";
s << port << ":";
......@@ -76,15 +76,15 @@ struct PSHost {
return s.str();
}
void parse_from_string(std::string endpoint) {
void ParseFromString(std::string endpoint) {
std::vector<std::string> endpoint_info;
string_split(endpoint, ':', &endpoint_info);
StringSplit(endpoint, ':', &endpoint_info);
ip = endpoint_info[0];
port = std::stoi(endpoint_info[1]);
rank = std::stoi(endpoint_info[2]);
}
void string_split(const std::string &str, char sep,
void StringSplit(const std::string &str, char sep,
std::vector<std::string> *pieces, bool ignore_null = true) {
pieces->clear();
if (str.empty()) {
......@@ -111,63 +111,60 @@ class PSEnvironment {
explicit PSEnvironment() {} // NOLINT
virtual ~PSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t set_ps_servers(
virtual int32_t SetPsServers(
const std::vector<std::string> *host_endpoint_list, int node_num) {
return 0;
}
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) {
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t set_ps_clients(std::string *host_endpoint_list,
int node_num) {
virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
return 0;
}
virtual uint64_t get_local_host_sign() { return 0; }
virtual std::vector<PSHost> get_ps_servers() const { return _ps_server_list; }
virtual int32_t registe_ps_server(const std::string &ip, uint32_t port,
virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t RegistePsServer(const std::string &ip, uint32_t port,
int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_server_list,
_ps_server_sign_set);
return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
}
virtual std::vector<PSHost> get_ps_clients() const { return _ps_client_list; }
virtual int32_t registe_ps_client(const std::string &ip, uint32_t port,
virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
virtual int32_t RegistePsClient(const std::string &ip, uint32_t port,
int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_client_list,
_ps_client_sign_set);
return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
}
virtual std::vector<uint64_t> get_client_info() {
virtual std::vector<uint64_t> GetClientInfo() {
std::vector<uint64_t> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_uint64());
client_info.push_back(i.SerializeToUint64());
}
return client_info;
}
virtual std::vector<std::string> get_client_info(bool use_string_endpoint) {
virtual std::vector<std::string> GetClientInfo(bool use_string_endpoint) {
if (use_string_endpoint) {
std::vector<std::string> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_string());
client_info.push_back(i.SerializeToString());
}
return client_info;
}
return {};
}
virtual void set_trainers(int trainers) { trainers_ = trainers; }
virtual void SetTrainers(int trainers) { trainers_ = trainers; }
virtual int get_trainers() { return trainers_; }
virtual int GetTrainers() { return trainers_; }
protected:
//注册一个host // NOLINT
virtual int32_t registe_ps_host(
virtual int32_t RegistePsHost(
const std::string &ip, uint32_t port, int32_t rank,
std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT
......@@ -198,15 +195,15 @@ class PaddlePSEnvironment : public PSEnvironment {
explicit PaddlePSEnvironment() {} // NOLINT
virtual ~PaddlePSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.parse_from_uint64(host_sign_list[i]);
host.ParseFromUint64(host_sign_list[i]);
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.serialize_to_uint64());
_ps_server_sign_set.insert(host.SerializeToUint64());
}
}
std::sort(
......@@ -215,14 +212,14 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual int32_t set_ps_servers(const std::vector<std::string> *host_sign_list,
virtual int32_t SetPsServers(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.parse_from_string(host_sign_list->at(i));
host.ParseFromString(host_sign_list->at(i));
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.rank);
}
......@@ -233,15 +230,15 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) {
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.parse_from_uint64(host_sign_list[i]);
host.ParseFromUint64(host_sign_list[i]);
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.serialize_to_uint64());
_ps_client_sign_set.insert(host.SerializeToUint64());
}
}
std::sort(
......@@ -250,14 +247,14 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual int32_t set_ps_clients(const std::vector<std::string> *host_sign_list,
virtual int32_t SetPsClients(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.parse_from_string(host_sign_list->at(i));
host.ParseFromString(host_sign_list->at(i));
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.rank);
}
......@@ -269,9 +266,9 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual uint64_t get_local_host_sign() {
virtual uint64_t GetLocalHostSign() {
if (_ps_client_list.size() > 0) {
return _ps_client_list[0].serialize_to_uint64();
return _ps_client_list[0].SerializeToUint64();
} else {
return 0;
}
......
......@@ -135,8 +135,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -169,8 +168,7 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
......@@ -238,9 +236,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
->add_params((char *)weighted,
sizeof(bool) * is_weighted_bucket[request_idx].size());
}
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -292,9 +289,8 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -362,9 +358,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0),
closure->response(0), closure);
......@@ -464,9 +459,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -506,8 +500,8 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
......@@ -541,8 +535,7 @@ std::future<int32_t> GraphBrpcClient::load_graph_split_config(
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params(path);
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
......@@ -581,8 +574,7 @@ std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
......@@ -624,8 +616,8 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
......@@ -717,8 +709,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size());
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......@@ -727,10 +718,10 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
return fut;
}
int32_t GraphBrpcClient::initialize() {
int32_t GraphBrpcClient::Initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::initialize();
server_size = get_server_nums();
BrpcPsClient::Initialize();
server_size = GetServerNums();
graph_service = NULL;
local_channel = NULL;
return 0;
......
......@@ -97,12 +97,12 @@ class GraphBrpcClient : public BrpcPsClient {
std::string path);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<int64_t>& node_id_list);
virtual int32_t initialize();
virtual int32_t Initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(int64_t id);
void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index);
this->local_channel = GetCmdChannel(index);
}
void set_local_graph_service(GraphBrpcService* graph_service) {
this->graph_service = graph_service;
......
......@@ -33,7 +33,7 @@ namespace distributed {
return -1; \
}
int32_t GraphBrpcServer::initialize() {
int32_t GraphBrpcServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
......@@ -48,7 +48,7 @@ int32_t GraphBrpcServer::initialize() {
}
_service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) {
if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
......@@ -61,11 +61,11 @@ int32_t GraphBrpcServer::initialize() {
return 0;
}
brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) {
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
return _pserver_channels[server_index].get();
}
uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
......@@ -73,20 +73,20 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers();
auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
return 0;
}
_environment->registe_ps_server(ip, port, _rank);
_environment->RegistePsServer(ip, port, _rank);
return 0;
}
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
this->rank = rank;
auto _env = environment();
auto _env = Environment();
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = 500000;
......@@ -94,7 +94,7 @@ int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
options.connect_timeout_ms = 10000;
options.max_retry = 3;
std::vector<PSHost> server_list = _env->get_ps_servers();
std::vector<PSHost> server_list = _env->GetPsServers();
_pserver_channels.resize(server_list.size());
std::ostringstream os;
std::string server_ip_port;
......@@ -172,19 +172,18 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
((GraphTable *)table)->remove_graph_node(node_ids);
return 0;
}
int32_t GraphBrpcServer::port() { return _server.listen_address().port; }
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
int32_t GraphBrpcService::initialize() {
int32_t GraphBrpcService::Initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
_service_handler_map[PS_PRINT_TABLE_STAT] =
&GraphBrpcService::print_table_stat;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;
_service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
......@@ -207,21 +206,21 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
&GraphBrpcService::load_graph_split_config;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
InitializeShardInfo();
return 0;
}
int32_t GraphBrpcService::initialize_shard_info() {
int32_t GraphBrpcService::InitializeShardInfo() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
server_size = _server->environment()->get_ps_servers().size();
auto &table_map = *(_server->table());
server_size = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->GetTable());
for (auto itr : table_map) {
itr.second->set_shard(_rank, server_size);
itr.second->SetShard(_rank, server_size);
}
_is_initialize_shard_info = true;
}
......@@ -241,7 +240,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->table(request->table_id());
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
......@@ -261,7 +260,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
}
}
int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
......@@ -275,16 +274,16 @@ int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type);
table->Barrier(trainer_id, barrier_type);
return 0;
}
int32_t GraphBrpcService::print_table_stat(Table *table,
int32_t GraphBrpcService::PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat();
std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
......@@ -293,7 +292,7 @@ int32_t GraphBrpcService::print_table_stat(Table *table,
return 0;
}
int32_t GraphBrpcService::load_one_table(Table *table,
int32_t GraphBrpcService::LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -304,20 +303,20 @@ int32_t GraphBrpcService::load_one_table(Table *table,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->load(request.params(0), request.params(1)) != 0) {
if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t GraphBrpcService::load_all_table(Table *table,
int32_t GraphBrpcService::LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
......@@ -325,13 +324,13 @@ int32_t GraphBrpcService::load_all_table(Table *table,
return 0;
}
int32_t GraphBrpcService::stop_server(Table *table,
int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
std::thread t_stop([p_server]() {
p_server->stop();
p_server->Stop();
LOG(INFO) << "Server Stoped";
});
p_server->export_cv()->notify_all();
......@@ -339,7 +338,7 @@ int32_t GraphBrpcService::stop_server(Table *table,
return 0;
}
int32_t GraphBrpcService::stop_profiler(Table *table,
int32_t GraphBrpcService::StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -348,7 +347,7 @@ int32_t GraphBrpcService::stop_profiler(Table *table,
return 0;
}
int32_t GraphBrpcService::start_profiler(Table *table,
int32_t GraphBrpcService::StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
......@@ -475,7 +474,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<int> server2request(server_size, -1);
std::vector<int64_t> local_id;
std::vector<int> local_query_idx;
size_t rank = get_rank();
size_t rank = GetRank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
......@@ -589,9 +588,9 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub(
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index));
((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
// GraphPsService_Stub rpc_stub =
// getServiceStub(get_cmd_channel(server_index));
// getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
......
......@@ -31,10 +31,10 @@ class GraphBrpcServer : public PSServer {
GraphBrpcServer() {}
virtual ~GraphBrpcServer() {}
PsBaseService *get_service() { return _service.get(); }
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t build_peer2peer_connection(int rank);
virtual brpc::Channel *get_cmd_channel(size_t server_index);
virtual int32_t stop() {
virtual brpc::Channel *GetCmdChannel(size_t server_index);
virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0;
stoped_ = true;
......@@ -43,12 +43,12 @@ class GraphBrpcServer : public PSServer {
_server.Join();
return 0;
}
int32_t port();
int32_t Port();
std::condition_variable *export_cv() { return &cv_; }
private:
virtual int32_t initialize();
virtual int32_t Initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
......@@ -66,7 +66,7 @@ typedef int32_t (GraphBrpcService::*serviceFunc)(
class GraphBrpcService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
......@@ -75,7 +75,7 @@ class GraphBrpcService : public PsBaseService {
protected:
std::unordered_map<int32_t, serviceFunc> _service_handler_map;
int32_t initialize_shard_info();
int32_t InitializeShardInfo();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample_neighbors(Table *table,
......@@ -100,20 +100,20 @@ class GraphBrpcService : public PsBaseService {
int32_t remove_graph_node(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
int32_t Barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
int32_t LoadOneTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
int32_t LoadAllTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request,
int32_t StopServer(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request,
int32_t StartProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
int32_t StopProfiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
int32_t PrintTableStat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighbors_across_multi_servers(Table *table,
......
......@@ -25,7 +25,7 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
int32_t PSClient::configure(
int32_t PSClient::Configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env, size_t client_id) {
......@@ -51,10 +51,10 @@ int32_t PSClient::configure(
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor);
}
return initialize();
return Initialize();
}
PSClient *PSClientFactory::create(const PSParameter &ps_config) {
PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
......@@ -81,7 +81,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
return NULL;
}
TableManager::instance().initialize();
TableManager::Instance().Initialize();
VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success";
return client;
}
......
......@@ -26,7 +26,6 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
......@@ -60,41 +59,6 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};
enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };
enum TrainingPhase { Init = 0, Train = 1, Save = 2 };
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};
struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
uint64_t *keys;
float **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};
class PSClient {
public:
PSClient() {}
......@@ -102,41 +66,37 @@ class PSClient {
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t configure( // NOLINT
virtual int32_t Configure( // NOLINT
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, size_t client_id) final; // NOLINT
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms,
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) = 0;
// 触发table数据退场
virtual std::future<int32_t> shrink(uint32_t table_id,
virtual std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) = 0;
// 全量table进行数据load
virtual std::future<int32_t> load(const std::string &epoch,
virtual std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
virtual std::future<int32_t> Load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch,
virtual std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
virtual std::future<int32_t> Save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;
// 清空table数据
virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0;
virtual std::future<int32_t> Clear() = 0;
virtual std::future<int32_t> Clear(uint32_t table_id) = 0;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
......@@ -145,39 +105,34 @@ class PSClient {
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
virtual std::future<int32_t> PullDense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留
virtual std::future<int32_t> Push(RequestContext &push_context) = 0;
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// start
virtual std::future<int32_t> push_dense_param(const Region *regions,
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> push_dense(const Region *regions,
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training) = 0;
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id, const uint64_t *keys,
size_t num, bool is_training) = 0;
virtual std::future<int32_t> pull_sparse_param(float **select_values,
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training) {
const uint64_t *keys, size_t num,
bool is_training) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
......@@ -185,7 +140,7 @@ class PSClient {
return fut;
}
virtual ::std::future<int32_t> pull_sparse_ptr(char **select_values,
virtual ::std::future<int32_t> PullSparsePtr(char **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) {
......@@ -196,36 +151,36 @@ class PSClient {
return fut;
}
virtual std::future<int32_t> print_table_stat(uint32_t table_id) = 0;
virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> flush() = 0;
virtual std::future<int32_t> Flush() = 0;
// server优雅退出
virtual std::future<int32_t> stop_server() = 0;
virtual std::future<int32_t> StopServer() = 0;
// server profilera
virtual std::future<int32_t> start_profiler() = 0;
virtual std::future<int32_t> stop_profiler() = 0;
virtual std::future<int32_t> StartProfiler() = 0;
virtual std::future<int32_t> StopProfiler() = 0;
virtual std::future<int32_t> barrier(size_t table_id,
virtual std::future<int32_t> Barrier(size_t table_id,
uint32_t barrier_type) = 0;
virtual std::future<int32_t> pull_geo_param(size_t table_id,
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual std::future<int32_t> push_global_step(int table_id,
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done) = 0;
// recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id,
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path) = 0;
virtual void finalize_worker() = 0;
virtual void FinalizeWorker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type,
virtual std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) {
VLOG(0) << "Did not implement";
......@@ -238,12 +193,12 @@ class PSClient {
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_client2client_msg_handler(int msg_type,
virtual int RegisteClient2ClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int handle_client2client_msg(int msg_type, int from_client_id,
virtual int HandleClient2ClientMsg(int msg_type, int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
......@@ -253,7 +208,7 @@ class PSClient {
return itr->second(msg_type, from_client_id, msg);
}
virtual ValueAccessor *table_accessor(size_t table_id) {
virtual ValueAccessor *GetTableAccessor(size_t table_id) {
auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) {
return NULL;
......@@ -261,31 +216,31 @@ class PSClient {
return itr->second.get();
}
virtual size_t get_server_nums() = 0;
virtual size_t GetServerNums() = 0;
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient(
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) = 0;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse(size_t table_id,
const uint64_t *keys,
virtual std::future<int32_t> PushSparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) = 0;
protected:
virtual int32_t initialize() = 0;
virtual int32_t Initialize() = 0;
size_t _client_id;
PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>>
......@@ -333,7 +288,7 @@ REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory {
public:
static PSClient *create(const PSParameter &config);
static PSClient *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
......@@ -19,166 +19,91 @@
namespace paddle {
namespace distributed {
int32_t PsLocalClient::initialize() {
int32_t PsLocalClient::Initialize() {
const auto& downpour_param = _config.server_param().downpour_server_param();
TableManager::instance().initialize();
TableManager::Instance().Initialize();
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto* table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
table->set_shard(0, 1);
table->initialize(downpour_param.downpour_table_param(i),
table->SetShard(0, 1);
table->Initialize(downpour_param.downpour_table_param(i),
_config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
return 0;
}
::std::future<int32_t> PsLocalClient::shrink(uint32_t table_id,
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
const std::string threshold) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::load(const std::string& epoch,
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
load(it.first, epoch, mode);
Load(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::load(uint32_t table_id,
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = table(table_id);
table_ptr->load(epoch, mode);
auto* table_ptr = GetTable(table_id);
table_ptr->Load(epoch, mode);
return done();
}
std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
save(it.first, epoch, mode);
Save(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::save(uint32_t table_id,
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = table(table_id);
table_ptr->flush();
table_ptr->save(epoch, mode);
auto* table_ptr = GetTable(table_id);
table_ptr->Flush();
table_ptr->Save(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Save(
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::clear() {
::std::future<int32_t> PsLocalClient::Clear() {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::clear(uint32_t table_id) {
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::flush() {
::std::future<int32_t> PsLocalClient::Flush() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::stop_server() {
::std::future<int32_t> PsLocalClient::StopServer() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
// uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
// char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(reinterpret_cast<char**>(pull_context.sparse_values),
table_id, pull_context.keys, num);
}
}
::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) {
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}
::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1);
std::vector<float> region_buffer;
region_buffer.resize(num_per_shard);
table_ptr->pull_dense(region_buffer.data(), region_buffer.size());
table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t region_idx = 0;
size_t region_data_idx = 0;
......@@ -213,48 +138,49 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
return done();
}
::std::future<int32_t> PsLocalClient::push_dense_param(const Region* regions,
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1),
0);
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0);
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num;
}
// table_ptr->push_dense_param(region_buffer.data(), region_buffer.size());
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return done();
}
::std::future<int32_t> PsLocalClient::push_dense_raw_gradient(
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
int table_id, float* total_send_data, size_t total_send_data_size,
void* callback) {
VLOG(1) << "wxx push_dense_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = table(table_id);
auto* table_ptr = GetTable(table_id);
table_ptr->push_dense(total_send_data, total_send_data_size);
table_ptr->PushDense(total_send_data, total_send_data_size);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::push_dense(const Region* regions,
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1));
region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1));
size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
......@@ -267,12 +193,12 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
offset += data_num;
}
table_ptr->push_dense(region_buffer.data(), region_buffer.size());
table_ptr->PushDense(region_buffer.data(), region_buffer.size());
return done();
}
//::std::future<int32_t> PsLocalClient::pull_sparse(float** select_values,
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
......@@ -282,14 +208,14 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = table_accessor(table_id);
// auto* table_ptr = table(table_id);
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->pull_sparse(keys, num);
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->pull_sparse(res_data.data(), keys, num);
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
......@@ -302,7 +228,7 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// return done();
//}
::std::future<int32_t> PsLocalClient::pull_sparse_ptr(char** select_values,
::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num) {
......@@ -312,33 +238,33 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
// auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = table(table_id);
auto* table_ptr = GetTable(table_id);
table_ptr->pull_sparse_ptr(select_values, keys, num);
table_ptr->PullSparsePtr(select_values, keys, num);
return done();
}
::std::future<int32_t> PsLocalClient::push_sparse_raw_gradient(
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) {
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
table_ptr->push_sparse(keys, update_values, num);
table_ptr->PushSparse(keys, update_values, num);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::push_sparse(size_t table_id,
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num) {
auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id);
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
table_ptr->push_sparse(keys, update_values, num);
table_ptr->PushSparse(keys, update_values, num);
return done();
}
}
......
......@@ -26,51 +26,43 @@ class PsLocalClient : public PSClient {
public:
PsLocalClient() {}
virtual ~PsLocalClient() { _running = false; }
virtual int32_t create_client2client_connection(int pslib_timeout_ms,
virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
int pslib_connect_timeout_ms,
int max_retry) {
return 0;
}
virtual ::std::future<int32_t> shrink(uint32_t table_id,
virtual ::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
virtual ::std::future<int32_t> load(const std::string& epoch,
virtual ::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> load(uint32_t table_id,
virtual ::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;
virtual ::std::future<int32_t> save(const std::string& epoch,
virtual ::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id,
virtual ::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;
virtual ::std::future<int32_t> clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override;
virtual ::std::future<int32_t> Clear() override;
virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
virtual ::std::future<int32_t> stop_server() override;
virtual ::std::future<int32_t> StopServer() override;
virtual void finalize_worker() override {}
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num,
virtual void FinalizeWorker() override {}
virtual ::std::future<int32_t> PullDense(Region* regions, size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override;
virtual ::std::future<int32_t> Push(RequestContext& push_context) override;
virtual ::std::future<int32_t> push_dense(const Region* regions,
virtual ::std::future<int32_t> PushDense(const Region* regions,
size_t region_num, size_t table_id);
virtual ::std::future<int32_t> push_dense_param(const Region* regions,
virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> pull_sparse(float** select_values,
virtual ::std::future<int32_t> PullSparse(float** select_values,
size_t table_id,
const uint64_t* keys, size_t num,
bool is_training) {
......@@ -81,26 +73,26 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual ::std::future<int32_t> pull_sparse_ptr(char** select_values,
virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num);
virtual ::std::future<int32_t> print_table_stat(uint32_t table_id) {
virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual ::std::future<int32_t> push_sparse(size_t table_id,
virtual ::std::future<int32_t> PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num);
virtual ::std::future<int32_t> flush();
virtual ::std::future<int32_t> Flush();
// server profilera
virtual std::future<int32_t> start_profiler() {
virtual std::future<int32_t> StartProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -108,7 +100,7 @@ class PsLocalClient : public PSClient {
return fut;
};
virtual std::future<int32_t> stop_profiler() {
virtual std::future<int32_t> StopProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -116,7 +108,7 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type) {
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
......@@ -124,7 +116,7 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual std::future<int32_t> pull_geo_param(size_t table_id,
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float>* values,
std::vector<uint64_t>* keys,
int pserver_idx) {
......@@ -135,7 +127,7 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual std::future<int32_t> push_global_step(int table_id,
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t* total_send_data,
void* done) {
std::promise<int32_t> prom;
......@@ -146,12 +138,12 @@ class PsLocalClient : public PSClient {
}
// recv table from server and save it in LodTensor
virtual int32_t recv_and_save_table(const uint64_t table_id,
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
return 0;
}
virtual ::std::future<int32_t> send_client2client_msg(
virtual ::std::future<int32_t> SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string& msg) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
......@@ -159,17 +151,18 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual size_t get_server_nums() { return 1; }
virtual size_t GetServerNums() { return 1; }
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float* total_send_data, size_t total_send_data_size,
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) override;
virtual std::future<int32_t> push_sparse_raw_gradient(
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id, const uint64_t* keys, const float** update_values,
uint32_t num, void* done, int pserver_idx) override {
std::promise<int32_t> prom;
......@@ -179,7 +172,7 @@ class PsLocalClient : public PSClient {
return fut;
}
virtual std::future<int32_t> push_sparse_param(size_t table_id,
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
......@@ -192,7 +185,7 @@ class PsLocalClient : public PSClient {
}
private:
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
std::future<int32_t> done() {
std::shared_ptr<std::promise<int32_t>> prom =
......@@ -202,16 +195,16 @@ class PsLocalClient : public PSClient {
return fut;
}
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* table() {
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
return &_table_map;
}
inline Table* table(size_t table_id) {
inline Table* GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
......
......@@ -25,17 +25,17 @@ class PsLocalServer : public PSServer {
public:
PsLocalServer() {}
virtual ~PsLocalServer() {}
virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; }
virtual int32_t configure(
virtual uint64_t Start() { return 0; }
virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t Stop() { return 0; }
virtual int32_t Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0;
}
private:
virtual int32_t initialize() { return 0; }
virtual int32_t Initialize() { return 0; }
};
}
}
......@@ -70,7 +70,7 @@ void GraphPyService::set_up(std::string ips_str, int shard_num,
port_list.push_back(ip_and_port[1]);
uint32_t port = stoul(ip_and_port[1]);
auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index);
host_sign_list.push_back(ph_host.serialize_to_string());
host_sign_list.push_back(ph_host.SerializeToString());
index++;
}
}
......@@ -83,11 +83,11 @@ void GraphPyClient::start_client() {
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list, servers_);
_ps_env.SetPsServers(&host_sign_list, servers_);
worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto));
worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id);
paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr->Configure(worker_proto, dense_regions, _ps_env, client_id);
worker_ptr->set_shard_num(get_shard_num());
}
void GraphPyServer::start_server(bool block) {
......@@ -96,17 +96,17 @@ void GraphPyServer::start_server(bool block) {
::paddle::distributed::PSParameter server_proto = this->GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&this->host_sign_list,
_ps_env.SetPsServers(&this->host_sign_list,
this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto));
paddle::distributed::PSServerFactory::Create(server_proto));
VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->start(ip, port);
pserver_ptr->Configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->Start(ip, port);
pserver_ptr->build_peer2peer_connection(rank);
std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) {
......@@ -246,7 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath,
VLOG(0) << "loadding data with type " << name << " from " << filepath;
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->load(table_id, std::string(filepath), params);
get_ps_client()->Load(table_id, std::string(filepath), params);
status.wait();
}
}
......@@ -285,7 +285,7 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->load(table_id, std::string(filepath), params);
get_ps_client()->Load(table_id, std::string(filepath), params);
status.wait();
}
}
......@@ -396,13 +396,13 @@ std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
return res;
}
void GraphPyClient::stop_server() {
void GraphPyClient::StopServer() {
VLOG(0) << "going to stop server";
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return;
auto status = this->worker_ptr->stop_server();
auto status = this->worker_ptr->StopServer();
if (status.get() == 0) stoped_ = true;
}
void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); }
void GraphPyClient::FinalizeWorker() { this->worker_ptr->FinalizeWorker(); }
}
}
......@@ -123,7 +123,7 @@ class GraphPyServer : public GraphPyService {
set_rank(rank);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
int get_rank() { return rank; }
int GetRank() { return rank; }
void set_rank(int rank) { this->rank = rank; }
void start_server(bool block = true);
......@@ -154,8 +154,8 @@ class GraphPyClient : public GraphPyService {
(paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service());
}
void stop_server();
void finalize_worker();
void StopServer();
void FinalizeWorker();
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name);
......
......@@ -46,7 +46,7 @@ paddle::distributed::PSParameter load_from_prototxt(
return param;
}
void PSCore::init_gflag(const std::string& gflags) {
void PSCore::InitGFlag(const std::string& gflags) {
VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
......@@ -65,67 +65,67 @@ void PSCore::init_gflag(const std::string& gflags) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
int PSCore::init_server(
int PSCore::InitServer(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
_ps_env.set_trainers(trainers);
_ps_env.SetPsServers(host_sign_list, node_num);
_ps_env.SetTrainers(trainers);
int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param));
ret = _server_ptr->configure(_ps_param, _ps_env, index, server_sub_program);
paddle::distributed::PSServerFactory::Create(_ps_param));
ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program);
CHECK(ret == 0) << "failed to configure server";
return ret;
}
int PSCore::init_worker(
int PSCore::InitWorker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
const std::vector<std::string>* host_sign_list, int node_num, int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
_ps_env.SetPsServers(host_sign_list, node_num);
int ret = 0;
VLOG(1) << "PSCore::init_worker";
VLOG(1) << "PSCore::InitWorker";
auto* communicator = Communicator::GetInstance();
ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env,
ret = communicator->GetPsClient()->Configure(_ps_param, regions, _ps_env,
index);
communicator->Start();
return ret;
}
std::vector<uint64_t> PSCore::get_client_info() {
return _ps_env.get_client_info();
std::vector<uint64_t> PSCore::GetClientInfo() {
return _ps_env.GetClientInfo();
}
int PSCore::create_client2client_connection(int pserver_timeout_ms,
int PSCore::CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
int ret = _worker_ptr->create_client2client_connection(
int ret = _worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
return ret;
}
uint64_t PSCore::run_server(const std::string& ip, uint32_t port) {
return _server_ptr->start(ip, port);
uint64_t PSCore::RunServer(const std::string& ip, uint32_t port) {
return _server_ptr->Start(ip, port);
}
int PSCore::finalize_worker() {
_worker_ptr->finalize_worker();
int PSCore::FinalizeWorker() {
_worker_ptr->FinalizeWorker();
return 0;
}
int PSCore::stop_server() {
auto stop_status = _worker_ptr->stop_server();
int PSCore::StopServer() {
auto stop_status = _worker_ptr->StopServer();
stop_status.wait();
return 0;
}
paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; }
paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; }
} // namespace distributed
} // namespace paddle
......@@ -42,31 +42,31 @@ class PSCore {
explicit PSCore() {}
virtual ~PSCore() {}
virtual int init_server(
virtual int InitServer(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int init_worker(
virtual int InitWorker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
regions,
const std::vector<std::string>* host_sign_list, int node_num, int index);
virtual uint64_t run_server(const std::string& ip, uint32_t port);
virtual int stop_server();
virtual int finalize_worker();
virtual std::vector<uint64_t> get_client_info();
virtual int create_client2client_connection(int pserver_timeout_ms,
virtual uint64_t RunServer(const std::string& ip, uint32_t port);
virtual int StopServer();
virtual int FinalizeWorker();
virtual std::vector<uint64_t> GetClientInfo();
virtual int CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::shared_ptr<paddle::distributed::PSServer>
_server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient>
_worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* get_param();
virtual paddle::distributed::PSParameter* GetParam();
private:
void init_gflag(const std::string& gflags);
void InitGFlag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
};
......
......@@ -29,7 +29,7 @@ REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
PSServer *PSServerFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
......@@ -56,18 +56,18 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
<< service_param.server_class();
return NULL;
}
TableManager::instance().initialize();
TableManager::Instance().Initialize();
return server;
}
int32_t PSServer::configure(
int32_t PSServer::Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program) {
scope_.reset(new framework::Scope());
_config = config.server_param();
_rank = server_rank;
_environment = &env;
size_t shard_num = env.get_ps_servers().size();
size_t shard_num = env.GetPsServers().size();
const auto &downpour_param = _config.downpour_server_param();
......@@ -87,21 +87,21 @@ int32_t PSServer::configure(
global_step_table = downpour_param.downpour_table_param(i).table_id();
}
table->set_program_env(scope_.get(), place_, &server_sub_program);
table->set_shard(_rank, shard_num);
table->initialize(downpour_param.downpour_table_param(i),
table->SetProgramEnv(scope_.get(), place_, &server_sub_program);
table->SetShard(_rank, shard_num);
table->Initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->set_table_map(&_table_map);
_table_map[barrier_table]->SetTableMap(&_table_map);
}
if (global_step_table != UINT32_MAX) {
_table_map[global_step_table]->set_table_map(&_table_map);
_table_map[global_step_table]->SetTableMap(&_table_map);
}
return initialize();
return Initialize();
}
} // namespace distributed
} // namespace paddle
......@@ -65,19 +65,19 @@ class PSServer {
PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete;
virtual int32_t configure(
virtual int32_t Configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {});
virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0;
virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
virtual int32_t Stop() = 0;
inline size_t rank() const { return _rank; }
inline size_t Rank() const { return _rank; }
inline PSEnvironment *environment() { return _environment; }
inline PSEnvironment *Environment() { return _environment; }
inline const ServerParameter *config() const { return &_config; }
inline Table *table(size_t table_id) {
inline const ServerParameter *Config() const { return &_config; }
inline Table *GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
......@@ -85,12 +85,12 @@ class PSServer {
return NULL;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *table() {
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
return &_table_map;
}
protected:
virtual int32_t initialize() = 0;
virtual int32_t Initialize() = 0;
protected:
size_t _rank;
......@@ -129,11 +129,11 @@ class PsBaseService : public PsService {
public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {}
virtual size_t get_rank() { return _rank; }
virtual int32_t configure(PSServer *server) {
virtual size_t GetRank() { return _rank; }
virtual int32_t Configure(PSServer *server) {
_server = server;
_rank = _server->rank();
_config = _server->config();
_rank = _server->Rank();
_config = _server->Config();
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
......@@ -148,8 +148,8 @@ class PsBaseService : public PsService {
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
}
virtual int32_t initialize() = 0;
PSServer *get_server() { return _server; }
virtual int32_t Initialize() = 0;
PSServer *GetServer() { return _server; }
protected:
size_t _rank;
......@@ -160,7 +160,7 @@ REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory {
public:
static PSServer *create(const PSParameter &config);
static PSServer *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
......@@ -17,7 +17,7 @@
namespace paddle {
namespace distributed {
int32_t BarrierTable::initialize() {
int32_t BarrierTable::Initialize() {
auto trainers = _config.common().trainer_num();
trigger_.store(trainers);
......@@ -29,7 +29,7 @@ int32_t BarrierTable::initialize() {
}
// 0: send_barrier 1: recv_barrier 2: complete
int32_t BarrierTable::barrier(const uint32_t trainer_id,
int32_t BarrierTable::Barrier(const uint32_t trainer_id,
const std::string barrier_type) {
std::unique_lock<std::mutex> lock(mutex_);
......@@ -56,7 +56,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id,
VLOG(1) << "barrier table optimize begin";
for (auto& x : *table_map_) {
auto table = x.second;
table->pour();
table->Pour();
}
VLOG(1) << "barrier table optimize done";
......@@ -66,7 +66,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id,
return 0;
}
int32_t BarrierTable::set_table_map(
int32_t BarrierTable::SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) {
table_map_ = table_map;
return 0;
......
......@@ -21,7 +21,7 @@ namespace distributed {
int FLAGS_pslib_table_save_max_retry_dense = 3;
void CommonDenseTable::create_initializer(const std::string& attr,
void CommonDenseTable::CreateInitializer(const std::string& attr,
const std::string& name) {
auto slices = string::split_string<std::string>(attr, "&");
......@@ -39,7 +39,7 @@ void CommonDenseTable::create_initializer(const std::string& attr,
}
}
int32_t CommonDenseTable::initialize() {
int32_t CommonDenseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
......@@ -49,12 +49,12 @@ int32_t CommonDenseTable::initialize() {
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
_global_lr = new float(1.0);
initialize_value();
initialize_optimizer();
InitializeValue();
InitializeOptimizer();
return 0;
}
int32_t CommonDenseTable::initialize_value() {
int32_t CommonDenseTable::InitializeValue() {
auto common = _config.common();
int size = static_cast<int>(common.params().size());
values_.resize(size);
......@@ -70,7 +70,7 @@ int32_t CommonDenseTable::initialize_value() {
auto& initializer = common.initializers()[x];
total_dim_ += dim;
create_initializer(initializer, varname);
CreateInitializer(initializer, varname);
values_[x].resize(dim);
names_index_[varname] = x;
......@@ -92,27 +92,27 @@ int32_t CommonDenseTable::initialize_value() {
param_col_ids_.insert(param_col_ids_.begin() + 1, -1);
}
VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_
VLOG(1) << "CommonDenseTable::InitializeValue total dim: " << total_dim_
<< " fixed_len_params_dim: " << fixed_len_params_dim_;
pull_reservoir_ = ReservoirValue<float>(param_dim_);
return 0;
}
int32_t CommonDenseTable::initialize_optimizer() {
int32_t CommonDenseTable::InitializeOptimizer() {
auto common = _config.common();
auto name = common.name();
auto attrs = common.attributes();
if (name == "sgd") {
optimizer_ = std::make_shared<DSGD>(common, &values_);
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam") {
optimizer_ = std::make_shared<DAdam>(common, &values_);
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam_d2sum") {
optimizer_ = std::make_shared<DAdamD2Sum>(common, &values_);
// optimizer_->set_global_lr(_global_lr); //no use
// optimizer_->SetGlobalLR(_global_lr); //no use
} else if (name == "sum") {
optimizer_ = std::make_shared<DSUM>(common, &values_);
} else if (name == "summary") {
......@@ -124,34 +124,34 @@ int32_t CommonDenseTable::initialize_optimizer() {
return 0;
}
int32_t CommonDenseTable::set_global_lr(float* lr) {
int32_t CommonDenseTable::SetGlobalLR(float* lr) {
_global_lr = lr;
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
return 0;
}
int32_t CommonDenseTable::Pull(TableContext& context) {
CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values;
return pull_dense(pull_values, context.num);
return PullDense(pull_values, context.num);
}
int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense);
if (context.push_context.values != nullptr) {
const float* values = context.push_context.values;
return push_dense(values, context.num);
return PushDense(values, context.num);
}
return 0;
}
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) {
int32_t CommonDenseTable::PullDense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values);
return 0;
}
int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
int32_t CommonDenseTable::PushDenseParam(const float* values, size_t num) {
PADDLE_ENFORCE_GE(
num, param_dim_,
paddle::platform::errors::InvalidArgument(
......@@ -160,14 +160,14 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
return 0;
}
int32_t CommonDenseTable::pour() {
int32_t CommonDenseTable::Pour() {
pull_reservoir_.avg();
_push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
_PushDense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
pull_reservoir_.reset();
return 0;
}
int32_t CommonDenseTable::push_dense(const float* values, size_t num) {
int32_t CommonDenseTable::PushDense(const float* values, size_t num) {
if (sync) {
std::future<int> task =
_shards_task_pool[0]->enqueue([this, &values]() -> int {
......@@ -176,12 +176,12 @@ int32_t CommonDenseTable::push_dense(const float* values, size_t num) {
});
task.wait();
} else {
_push_dense(values, num);
_PushDense(values, num);
}
return 0;
}
int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
int32_t CommonDenseTable::_PushDense(const float* values, size_t num) {
PADDLE_ENFORCE_GE(
num, param_dim_,
paddle::platform::errors::InvalidArgument(
......@@ -195,7 +195,7 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
[this, shard_id, &buckets, &values]() -> int {
auto begin = buckets[shard_id];
auto end = buckets[shard_id + 1];
optimizer_->update(values, param_dim_, begin, end);
optimizer_->Update(values, param_dim_, begin, end);
return 0;
});
}
......@@ -207,12 +207,12 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
return 0;
}
int32_t CommonDenseTable::load(const std::string& path,
int32_t CommonDenseTable::Load(const std::string& path,
const std::string& param) {
if (param_dim_ <= 0) {
return 0;
}
std::string table_path = table_dir(path);
std::string table_path = TableDir(path);
auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end());
for (auto ff : file_list) {
......@@ -314,7 +314,7 @@ int32_t CommonDenseTable::load(const std::string& path,
return 0;
}
int32_t CommonDenseTable::save(const std::string& path,
int32_t CommonDenseTable::Save(const std::string& path,
const std::string& param) {
int save_param = atoi(param.c_str());
uint32_t feasign_size;
......@@ -323,10 +323,10 @@ int32_t CommonDenseTable::save(const std::string& path,
FsChannelConfig channel_config;
if (_config.compress_in_save()) {
channel_config.path = paddle::string::format_string(
"%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx);
"%s/part-%03d.gz", TableDir(path).c_str(), _shard_idx);
} else {
channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_dir(path).c_str(), _shard_idx);
"%s/part-%03d", TableDir(path).c_str(), _shard_idx);
}
_afs_client.remove(channel_config.path);
channel_config.converter = _value_accesor->Converter(save_param).converter;
......
......@@ -34,29 +34,29 @@ class CommonDenseTable : public DenseTable {
public:
CommonDenseTable() {}
virtual ~CommonDenseTable() {}
int32_t initialize() override;
int32_t initialize_shard() override { return 0; }
virtual void create_initializer(const std::string& attr,
int32_t Initialize() override;
int32_t InitializeShard() override { return 0; }
virtual void CreateInitializer(const std::string& attr,
const std::string& name);
virtual int32_t initialize_value();
virtual int32_t initialize_optimizer();
virtual int32_t InitializeValue();
virtual int32_t InitializeOptimizer();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
int32_t pull_dense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override;
int32_t pour() override;
int32_t set_global_lr(float* lr) override;
int32_t PullDense(float* pull_values, size_t num) override;
int32_t PushDenseParam(const float* values, size_t num) override;
int32_t PushDense(const float* values, size_t num) override;
int32_t Pour() override;
int32_t SetGlobalLR(float* lr) override;
int32_t load(const std::string& path, const std::string& param) override;
int32_t save(const std::string& path, const std::string& param) override;
int32_t Load(const std::string& path, const std::string& param) override;
int32_t Save(const std::string& path, const std::string& param) override;
int32_t flush() override { return 0; }
int32_t shrink(const std::string& param) override { return 0; }
void clear() override { return; }
int32_t Flush() override { return 0; }
int32_t Shrink(const std::string& param) override { return 0; }
void Clear() override { return; }
protected:
int32_t _push_dense(const float* values, size_t num);
int32_t _PushDense(const float* values, size_t num);
private:
const int task_pool_size_ = 10;
......
......@@ -448,7 +448,7 @@ int32_t GraphTable::load_graph_split_config(const std::string &path) {
return 0;
}
int32_t GraphTable::load(const std::string &path, const std::string &param) {
int32_t GraphTable::Load(const std::string &path, const std::string &param) {
bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
......@@ -1066,11 +1066,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int32_t GraphTable::get_server_index_by_id(int64_t id) {
return id % shard_num / shard_num_per_server;
}
int32_t GraphTable::initialize(const TableParameter &config,
int32_t GraphTable::Initialize(const TableParameter &config,
const FsClientParameter &fs_config) {
LOG(INFO) << "in graphTable initialize";
_config = config;
if (initialize_accessor() != 0) {
if (InitializeAccessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed";
return -1;
}
......@@ -1082,9 +1082,9 @@ int32_t GraphTable::initialize(const TableParameter &config,
auto graph = config.graph_parameter();
shard_num = _config.shard_num();
LOG(INFO) << "in graphTable initialize over";
return initialize(graph);
return Initialize(graph);
}
int32_t GraphTable::initialize(const GraphParameter &graph) {
int32_t GraphTable::Initialize(const GraphParameter &graph) {
#ifdef PADDLE_WITH_HETERPS
if (graph.gpups_mode()) {
gpups_mode = true;
......
......@@ -280,7 +280,7 @@ class ScaledLRU {
}
}
auto status =
thread_pool->enqueue([this]() -> int { return shrink(); });
thread_pool->enqueue([this]() -> int { return Shrink(); });
status.wait();
}
});
......@@ -298,7 +298,7 @@ class ScaledLRU {
LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
return lru_pool[index].insert(keys, data, length);
}
int shrink() {
int Shrink() {
int node_size = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
......@@ -329,7 +329,7 @@ class ScaledLRU {
if (diff != 0) {
__sync_fetch_and_add(&global_count, diff);
if (global_count > int(1.25 * size_limit)) {
thread_pool->enqueue([this]() -> int { return shrink(); });
thread_pool->enqueue([this]() -> int { return Shrink(); });
}
}
}
......@@ -430,11 +430,11 @@ class GraphTable : public SparseTable {
virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res);
virtual int32_t initialize() { return 0; }
virtual int32_t initialize(const TableParameter &config,
virtual int32_t Initialize() { return 0; }
virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t initialize(const GraphParameter &config);
int32_t load(const std::string &path, const std::string &param);
virtual int32_t Initialize(const GraphParameter &config);
int32_t Load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path);
int32_t load_edges(const std::string &path, bool reverse);
......@@ -452,26 +452,25 @@ class GraphTable : public SparseTable {
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) {
virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) {
return 0;
}
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}
virtual int32_t clear_nodes();
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string &param) { return 0; }
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Shrink(const std::string &param) { return 0; }
//指定保存路径
virtual int32_t save(const std::string &path, const std::string &converter) {
virtual int32_t Save(const std::string &path, const std::string &converter) {
return 0;
}
virtual int32_t initialize_shard() { return 0; }
virtual int32_t set_shard(size_t shard_idx, size_t server_num) {
virtual int32_t InitializeShard() { return 0; }
virtual int32_t SetShard(size_t shard_idx, size_t server_num) {
_shard_idx = shard_idx;
/*
_shard_num is not used in graph_table, this following operation is for the
......
......@@ -167,7 +167,7 @@ int64_t CommonSparseTable::LoadFromText(
return 0;
}
int32_t CommonSparseTable::initialize() {
int32_t CommonSparseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
......@@ -200,15 +200,15 @@ int32_t CommonSparseTable::initialize() {
offset += dim;
}
initialize_value();
initialize_optimizer();
initialize_recorder();
InitializeValue();
InitializeOptimizer();
InitializeRecorder();
return 0;
}
int32_t CommonSparseTable::initialize_recorder() { return 0; }
int32_t CommonSparseTable::InitializeRecorder() { return 0; }
int32_t CommonSparseTable::initialize_value() {
int32_t CommonSparseTable::InitializeValue() {
auto common = _config.common();
shard_values_.reserve(task_pool_size_);
......@@ -223,18 +223,18 @@ int32_t CommonSparseTable::initialize_value() {
return 0;
}
int32_t CommonSparseTable::initialize_optimizer() {
int32_t CommonSparseTable::InitializeOptimizer() {
auto common = _config.common();
auto name = common.name();
if (name == "sgd") {
optimizer_ = std::make_shared<SSGD>(value_names_, value_dims_,
value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
} else if (name == "adam") {
optimizer_ = std::make_shared<SAdam>(value_names_, value_dims_,
value_offsets_, value_idx_);
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
} else if (name == "sum") {
optimizer_ = std::make_shared<SSUM>(value_names_, value_dims_,
value_offsets_, value_idx_);
......@@ -246,13 +246,13 @@ int32_t CommonSparseTable::initialize_optimizer() {
return 0;
}
int32_t CommonSparseTable::set_global_lr(float* lr) {
int32_t CommonSparseTable::SetGlobalLR(float* lr) {
_global_lr = lr;
optimizer_->set_global_lr(_global_lr);
optimizer_->SetGlobalLR(_global_lr);
return 0;
}
int32_t CommonSparseTable::load(const std::string& dirname,
int32_t CommonSparseTable::Load(const std::string& dirname,
const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock();
......@@ -276,7 +276,7 @@ int32_t CommonSparseTable::load(const std::string& dirname,
return 0;
}
int32_t CommonSparseTable::save(const std::string& dirname,
int32_t CommonSparseTable::Save(const std::string& dirname,
const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock();
......@@ -322,7 +322,7 @@ int32_t CommonSparseTable::save(const std::string& dirname,
return 0;
}
std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
std::pair<int64_t, int64_t> CommonSparseTable::PrintTableStat() {
int64_t feasign_size = 0;
int64_t mf_size = 0;
......@@ -335,7 +335,7 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
return {feasign_size, mf_size};
}
int32_t CommonSparseTable::pour() {
int32_t CommonSparseTable::Pour() {
std::vector<float> values;
std::vector<uint64_t> keys;
......@@ -349,7 +349,7 @@ int32_t CommonSparseTable::pour() {
std::copy(reservoir.values.begin(), reservoir.values.end(),
std::back_inserter(values));
}
_push_sparse(keys.data(), values.data(), pull_reservoir_.size());
_PushSparse(keys.data(), values.data(), pull_reservoir_.size());
pull_reservoir_.clear();
return 0;
......@@ -360,11 +360,11 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
return PullSparsePtr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
return PullSparse(pull_values, pull_value);
}
}
......@@ -373,15 +373,15 @@ int32_t CommonSparseTable::Push(TableContext& context) {
if (context.push_context.values != nullptr) {
const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
return PushSparse(keys, values, context.num);
} else {
const float** values = context.push_context.ptr_values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
return PushSparse(keys, values, context.num);
}
}
int32_t CommonSparseTable::pull_sparse(float* pull_values,
int32_t CommonSparseTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num);
......@@ -421,7 +421,7 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values,
return 0;
}
int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
int32_t CommonSparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -458,7 +458,7 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
return 0;
}
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
const float* values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -474,7 +474,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &values, num, &offset_bucket]() -> int {
auto& offsets = offset_bucket[shard_id];
optimizer_->update(keys, values, num, offsets,
optimizer_->Update(keys, values, num, offsets,
shard_values_[shard_id].get());
return 0;
});
......@@ -486,8 +486,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
return 0;
}
int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
const float* values, size_t num) {
int32_t CommonSparseTable::PushSparse(const uint64_t* keys, const float* values,
size_t num) {
if (sync) {
std::future<int> task =
_shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int {
......@@ -506,19 +506,19 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
});
task.wait();
} else {
_push_sparse(keys, values, num);
_PushSparse(keys, values, num);
}
return 0;
}
int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
int32_t CommonSparseTable::PushSparse(const uint64_t* keys,
const float** values, size_t num) {
_push_sparse(keys, values, num);
_PushSparse(keys, values, num);
return 0;
}
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
const float** values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -536,7 +536,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
auto& offsets = offset_bucket[shard_id];
for (size_t i = 0; i < offsets.size(); ++i) {
std::vector<uint64_t> tmp_off = {0};
optimizer_->update(keys + offsets[i], values[offsets[i]], num,
optimizer_->Update(keys + offsets[i], values[offsets[i]], num,
tmp_off, shard_values_[shard_id].get());
}
return 0;
......@@ -549,7 +549,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
return 0;
}
int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
int32_t CommonSparseTable::PushSparseParam(const uint64_t* keys,
const float* values, size_t num) {
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
......@@ -585,21 +585,21 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
return 0;
}
int32_t CommonSparseTable::flush() { return 0; }
int32_t CommonSparseTable::Flush() { return 0; }
int32_t CommonSparseTable::shrink(const std::string& param) {
int32_t CommonSparseTable::Shrink(const std::string& param) {
int threshold = std::stoi(param);
VLOG(3) << "sparse table shrink: " << threshold;
VLOG(3) << "sparse table Shrink: " << threshold;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// shrink
VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink";
// Shrink
VLOG(4) << shard_id << " " << task_pool_size_ << " begin Shrink";
shard_values_[shard_id]->Shrink(threshold);
}
return 0;
}
void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; }
void CommonSparseTable::Clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed
} // namespace paddle
......@@ -114,25 +114,23 @@ class CommonSparseTable : public SparseTable {
virtual ~CommonSparseTable() {}
// unused method begin
virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; }
virtual int32_t push_dense_param(const float* values, size_t num) {
return 0;
}
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; }
virtual int32_t PushDense(const float* values, size_t num) { return 0; }
// unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value();
virtual int32_t initialize_optimizer();
virtual int32_t initialize_recorder();
virtual int32_t Initialize();
virtual int32_t InitializeShard() { return 0; }
virtual int32_t InitializeValue();
virtual int32_t InitializeOptimizer();
virtual int32_t InitializeRecorder();
virtual int32_t load(const std::string& path, const std::string& param);
virtual int32_t Load(const std::string& path, const std::string& param);
virtual int32_t save(const std::string& path, const std::string& param);
virtual int32_t Save(const std::string& path, const std::string& param);
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total);
......@@ -150,33 +148,33 @@ class CommonSparseTable : public SparseTable {
const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks);
virtual std::pair<int64_t, int64_t> print_table_stat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual std::pair<int64_t, int64_t> PrintTableStat();
virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float* values,
virtual int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float** values,
virtual int32_t PushSparse(const uint64_t* keys, const float** values,
size_t num);
// only for sparse geo table
virtual int32_t push_sparse_param(const uint64_t* keys, const float* values,
virtual int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t set_global_lr(float* lr) override;
virtual int32_t SetGlobalLR(float* lr) override;
virtual int32_t pour();
virtual int32_t flush();
virtual int32_t shrink(const std::string& param);
virtual void clear();
virtual int32_t Pour();
virtual int32_t Flush();
virtual int32_t Shrink(const std::string& param);
virtual void Clear();
protected:
virtual int32_t _push_sparse(const uint64_t* keys, const float* values,
virtual int32_t _PushSparse(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t _push_sparse(const uint64_t* keys, const float** values,
virtual int32_t _PushSparse(const uint64_t* keys, const float** values,
size_t num);
protected:
......
......@@ -71,11 +71,11 @@ class SparseTable : public Table {
SparseTable() {}
virtual ~SparseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
......@@ -97,19 +97,17 @@ class DenseTable : public Table {
DenseTable() {}
virtual ~DenseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
int32_t pull_sparse(float *values,
virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t push_dense_param(const float *values, size_t num) override {
return 0;
}
int32_t shrink(const std::string &param) override { return 0; }
int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
};
class BarrierTable : public Table {
......@@ -117,44 +115,42 @@ class BarrierTable : public Table {
BarrierTable() {}
virtual ~BarrierTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values,
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t push_dense_param(const float *values, size_t num) override {
return 0;
}
int32_t shrink(const std::string &param) override { return 0; }
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t load(const std::string &path, const std::string &param) {
int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t save(const std::string &path, const std::string &param) {
virtual int32_t Save(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t initialize_shard() { return 0; }
virtual int32_t InitializeShard() { return 0; }
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
// only for barrier
// 0: send_barrier 1: recv_barrier 2: complete
virtual int32_t barrier(const uint32_t trainer_id,
virtual int32_t Barrier(const uint32_t trainer_id,
const std::string barrier_type) override;
virtual int32_t set_table_map(
virtual int32_t SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;
private:
......
......@@ -34,9 +34,9 @@ class DenseOptimizer {
DenseOptimizer() {}
explicit DenseOptimizer(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {}
virtual void update(const float* update_values, size_t num, int begin,
virtual void Update(const float* update_values, size_t num, int begin,
int end) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; }
virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
protected:
float* global_learning_rate_;
......@@ -55,7 +55,7 @@ class DSUM : public DenseOptimizer {
}
}
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
GetBlas<float>().VADD(update_numel, update_values + begin, param + begin,
......@@ -81,7 +81,7 @@ class DSGD : public DenseOptimizer {
}
}
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
std::vector<float> grads;
......@@ -134,7 +134,7 @@ class DAdam : public DenseOptimizer {
// make sure common_dense_table.task_pool_size_ == 1;
// otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
std::vector<float> grad, grad2, tmp;
......@@ -214,7 +214,7 @@ class DAdamD2Sum : public DenseOptimizer {
}
}
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_ada_g2sum(ada_g2sum + begin, 1,
......@@ -276,7 +276,7 @@ class DSummary : public DenseOptimizer {
}
}
void update(const float* update_values, size_t num, int begin,
void Update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel);
......
......@@ -40,11 +40,11 @@ class SparseOptimizer {
value_offsets_(value_offsets),
value_idx_(value_idx) {}
virtual void update(const uint64_t* keys, const float* update_values,
virtual void Update(const uint64_t* keys, const float* update_values,
size_t num, const std::vector<uint64_t>& offsets,
ValueBlock* block) = 0;
virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; }
virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
const std::vector<std::string>& value_names_;
const std::vector<int>& value_dims_;
......@@ -70,7 +70,7 @@ class SSUM : public SparseOptimizer {
update_numel = value_dims.at(idx);
}
void update(const uint64_t* keys, const float* update_values, size_t num,
void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
......@@ -100,7 +100,7 @@ class SSGD : public SparseOptimizer {
lr_offset = value_offsets.at(idx);
}
void update(const uint64_t* keys, const float* update_values, size_t num,
void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
......@@ -156,7 +156,7 @@ class SAdam : public SparseOptimizer {
epsilon = 1.0e-8;
}
void update(const uint64_t* keys, const float* update_values, size_t num,
void Update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
......
......@@ -17,11 +17,10 @@
namespace paddle {
namespace distributed {
int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
const float* values,
size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param begin "
"push_sparse_param "
int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys,
const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin "
"PushSparseParam "
<< num;
auto shard_num = _task_pool_size;
std::vector<std::vector<uint64_t>> offset_bucket;
......@@ -31,8 +30,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
auto y = keys[x] % shard_num;
offset_bucket[y].push_back(x);
if (x < 10) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param key: "
<< keys[x] << " shard: " << y;
VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam key: " << keys[x]
<< " shard: " << y;
}
}
......@@ -51,8 +50,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
feature_value.resize(_dim);
std::copy_n(values + _dim * offset, _dim, feature_value.data());
if (i < 10) {
VLOG(5) << "MemorySparseGeoTable::push_sparse_param "
"push_sparse_param key "
VLOG(5) << "MemorySparseGeoTable::PushSparseParam "
"PushSparseParam key "
<< id << " value[0]: " << (values + _dim * offset)[0]
<< " data: " << feature_value.data()[0]
<< " value[-1]: " << (values + _dim * offset)[_dim - 1]
......@@ -69,7 +68,7 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
return 0;
}
int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id,
int32_t MemorySparseGeoTable::PullGeoParam(const uint32_t trainer_id,
std::vector<float>* values,
std::vector<uint64_t>* ids) {
_geo_recorder->GetAndClear(trainer_id, ids);
......@@ -86,23 +85,23 @@ int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id,
pull_value.frequencies_ = frequencies.data();
values->resize(ids->size() * _dim);
pull_sparse(values->data(), pull_value);
PullSparse(values->data(), pull_value);
return 0;
}
int32_t MemorySparseGeoTable::push_sparse(const uint64_t* keys,
int32_t MemorySparseGeoTable::PushSparse(const uint64_t* keys,
const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse keys[0]" << keys[0]
VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparse keys[0]" << keys[0]
<< " key_num: " << num;
std::vector<uint64_t> ids;
ids.resize(num);
std::copy_n(keys, num, ids.begin());
_geo_recorder->Update(ids);
_push_sparse(keys, values, num);
_PushSparse(keys, values, num);
return 0;
}
int32_t MemorySparseGeoTable::initialize() {
int32_t MemorySparseGeoTable::Initialize() {
if (!_geo_recorder) {
auto trainers = _config.common().trainer_num();
_geo_recorder = std::make_shared<GeoRecorder>(trainers);
......@@ -118,7 +117,7 @@ int32_t MemorySparseGeoTable::initialize() {
return 0;
}
int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
int32_t MemorySparseGeoTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num);
......@@ -146,13 +145,13 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
auto& feature_value = local_shard[key];
feature_value.resize(_dim);
memset(feature_value.data(), 0, sizeof(float) * _dim);
VLOG(0) << "MemorySparseGeoTable pull_sparse key not found!!! "
VLOG(0) << "MemorySparseGeoTable PullSparse key not found!!! "
<< key;
itr = local_shard.find(key);
}
memcpy(select_data, itr.value().data(), _dim * sizeof(float));
VLOG(5) << "DEBUG MemorySparseGeoTable::pull_sparse key: " << key
VLOG(5) << "DEBUG MemorySparseGeoTable::PullSparse key: " << key
<< " select_data[0] " << select_data[0]
<< " value[0]: " << itr.value().data()[0];
}
......@@ -167,7 +166,7 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
return 0;
}
int32_t MemorySparseGeoTable::_push_sparse(const uint64_t* keys,
int32_t MemorySparseGeoTable::_PushSparse(const uint64_t* keys,
const float* values, size_t num) {
auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num);
......
......@@ -40,31 +40,31 @@ class MemorySparseGeoTable : public SparseTable {
MemorySparseGeoTable() { _geo_recorder = nullptr; }
virtual ~MemorySparseGeoTable() {}
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t load(const std::string& path, const std::string& param) {
virtual int32_t Initialize();
virtual int32_t InitializeShard() { return 0; }
virtual int32_t Load(const std::string& path, const std::string& param) {
return 0;
}
virtual int32_t save(const std::string& path, const std::string& param) {
virtual int32_t Save(const std::string& path, const std::string& param) {
return 0;
}
virtual int32_t Pull(TableContext& context) { return 0; }
virtual int32_t Push(TableContext& context) { return 0; }
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; }
virtual void clear() { return; }
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t Flush() { return 0; }
virtual int32_t Shrink(const std::string& param) { return 0; }
virtual void Clear() { return; }
virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
int32_t push_sparse_param(const uint64_t* keys, const float* values,
int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num);
// TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values,
int32_t PullGeoParam(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys);
int32_t push_sparse(const uint64_t* keys, const float* values,
int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num) override;
int32_t _push_sparse(const uint64_t* keys, const float* values, size_t num);
int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num);
// int32_t _pull_sparse(float* pull_values, const PullSparseValue&
// pull_value);
......
......@@ -31,7 +31,7 @@ bool FLAGS_pserver_create_value_when_push = true;
int FLAGS_pserver_table_save_max_retry = 3;
bool FLAGS_pserver_enable_create_feasign_randomly = false;
int32_t MemorySparseTable::initialize() {
int32_t MemorySparseTable::Initialize() {
_shards_task_pool.resize(_task_pool_size);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
......@@ -39,12 +39,12 @@ int32_t MemorySparseTable::initialize() {
auto& profiler = CostProfiler::instance();
profiler.register_profiler("pserver_sparse_update_all");
profiler.register_profiler("pserver_sparse_select_all");
initialize_value();
InitializeValue();
VLOG(0) << "initalize MemorySparseTable succ";
return 0;
}
int32_t MemorySparseTable::initialize_value() {
int32_t MemorySparseTable::InitializeValue() {
_sparse_table_shard_num = static_cast<int>(_config.shard_num());
_avg_local_shard_num =
SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
......@@ -64,14 +64,14 @@ int32_t MemorySparseTable::initialize_value() {
return 0;
}
int32_t MemorySparseTable::load(const std::string& path,
int32_t MemorySparseTable::Load(const std::string& path,
const std::string& param) {
std::string table_path = table_dir(path);
std::string table_path = TableDir(path);
auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end());
for (auto file : file_list) {
VLOG(1) << "MemorySparseTable::load() file list: " << file;
VLOG(1) << "MemorySparseTable::Load() file list: " << file;
}
int load_param = atoi(param.c_str());
......@@ -154,9 +154,9 @@ int32_t MemorySparseTable::load(const std::string& path,
return 0;
}
int32_t MemorySparseTable::load_local_fs(const std::string& path,
int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
const std::string& param) {
std::string table_path = table_dir(path);
std::string table_path = TableDir(path);
auto file_list = paddle::framework::localfs_list(table_path);
int load_param = atoi(param.c_str());
......@@ -225,12 +225,12 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path,
return 0;
}
int32_t MemorySparseTable::save(const std::string& dirname,
int32_t MemorySparseTable::Save(const std::string& dirname,
const std::string& param) {
VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
std::string table_path = table_dir(dirname);
std::string table_path = TableDir(dirname);
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx));
std::atomic<uint32_t> feasign_size_all{0};
......@@ -309,12 +309,12 @@ int32_t MemorySparseTable::save(const std::string& dirname,
return 0;
}
int32_t MemorySparseTable::save_local_fs(const std::string& dirname,
int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
const std::string& param,
const std::string& prefix) {
int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
std::string table_path = table_dir(dirname);
std::string table_path = TableDir(dirname);
int feasign_cnt = 0;
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
......@@ -349,7 +349,7 @@ int32_t MemorySparseTable::save_local_fs(const std::string& dirname,
return 0;
}
int64_t MemorySparseTable::local_size() {
int64_t MemorySparseTable::LocalSize() {
int64_t local_size = 0;
for (size_t i = 0; i < _real_local_shard_num; ++i) {
local_size += _local_shards[i].size();
......@@ -357,7 +357,7 @@ int64_t MemorySparseTable::local_size() {
return local_size;
}
int64_t MemorySparseTable::local_mf_size() {
int64_t MemorySparseTable::LocalMFSize() {
std::vector<int64_t> size_arr(_real_local_shard_num, 0);
std::vector<std::future<int>> tasks(_real_local_shard_num);
int64_t ret_size = 0;
......@@ -384,9 +384,9 @@ int64_t MemorySparseTable::local_mf_size() {
return ret_size;
}
std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
int64_t feasign_size = local_size();
int64_t mf_size = local_mf_size();
std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
int64_t feasign_size = LocalSize();
int64_t mf_size = LocalMFSize();
return {feasign_size, mf_size};
}
......@@ -395,11 +395,11 @@ int32_t MemorySparseTable::Pull(TableContext& context) {
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
return PullSparsePtr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
return PullSparse(pull_values, pull_value);
}
}
......@@ -407,10 +407,10 @@ int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, context.push_context.values, context.num);
return PushSparse(keys, context.push_context.values, context.num);
}
int32_t MemorySparseTable::pull_sparse(float* pull_values,
int32_t MemorySparseTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) {
CostTimer timer("pserver_sparse_select_all");
std::vector<std::future<int>> tasks(_real_local_shard_num);
......@@ -479,7 +479,7 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
return 0;
}
int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values,
int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys, size_t num) {
CostTimer timer("pscore_sparse_select_all");
size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
......@@ -530,8 +530,8 @@ int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values,
return 0;
}
int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
const float* values, size_t num) {
int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
size_t num) {
CostTimer timer("pserver_sparse_update_all");
std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
......@@ -603,13 +603,13 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
return 0;
}
int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
const float** values, size_t num) {
_push_sparse(keys, values, num);
_PushSparse(keys, values, num);
return 0;
}
int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
int32_t MemorySparseTable::_PushSparse(const uint64_t* keys,
const float** values, size_t num) {
std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
......@@ -677,13 +677,13 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
return 0;
}
int32_t MemorySparseTable::flush() { return 0; }
int32_t MemorySparseTable::Flush() { return 0; }
int32_t MemorySparseTable::shrink(const std::string& param) {
VLOG(0) << "MemorySparseTable::shrink";
int32_t MemorySparseTable::Shrink(const std::string& param) {
VLOG(0) << "MemorySparseTable::Shrink";
// TODO(zhaocaibei123): implement with multi-thread
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
// shrink
// Shrink
auto& shard = _local_shards[shard_id];
for (auto it = shard.begin(); it != shard.end();) {
if (_value_accesor->Shrink(it.value().data())) {
......@@ -696,7 +696,7 @@ int32_t MemorySparseTable::shrink(const std::string& param) {
return 0;
}
void MemorySparseTable::clear() { VLOG(0) << "clear coming soon"; }
void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed
} // namespace paddle
......@@ -41,49 +41,47 @@ class MemorySparseTable : public SparseTable {
virtual ~MemorySparseTable() {}
// unused method begin
virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; }
virtual int32_t push_dense_param(const float* values, size_t num) {
return 0;
}
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; }
virtual int32_t PushDense(const float* values, size_t num) { return 0; }
// unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value();
virtual int32_t Initialize();
virtual int32_t InitializeShard() { return 0; }
virtual int32_t InitializeValue();
virtual int32_t load(const std::string& path, const std::string& param);
virtual int32_t Load(const std::string& path, const std::string& param);
virtual int32_t save(const std::string& path, const std::string& param);
virtual int32_t Save(const std::string& path, const std::string& param);
int32_t load_local_fs(const std::string& path, const std::string& param);
int32_t save_local_fs(const std::string& path, const std::string& param,
int32_t LoadLocalFS(const std::string& path, const std::string& param);
int32_t SaveLocalFS(const std::string& path, const std::string& param,
const std::string& prefix);
int64_t local_size();
int64_t local_mf_size();
int64_t LocalSize();
int64_t LocalMFSize();
virtual std::pair<int64_t, int64_t> print_table_stat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual std::pair<int64_t, int64_t> PrintTableStat();
virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float* values,
virtual int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float** values,
virtual int32_t PushSparse(const uint64_t* keys, const float** values,
size_t num);
virtual int32_t flush();
virtual int32_t shrink(const std::string& param);
virtual void clear();
virtual int32_t Flush();
virtual int32_t Shrink(const std::string& param);
virtual void Clear();
protected:
virtual int32_t _push_sparse(const uint64_t* keys, const float** values,
virtual int32_t _PushSparse(const uint64_t* keys, const float** values,
size_t num);
protected:
......
......@@ -17,7 +17,7 @@
namespace paddle {
namespace distributed {
int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id,
int32_t SparseGeoTable::PullGeoParam(const uint32_t trainer_id,
std::vector<float>* values,
std::vector<uint64_t>* ids) {
geo_recorder->GetAndClear(trainer_id, ids);
......@@ -32,21 +32,21 @@ int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id,
pull_value.frequencies_ = frequencies.data();
values->resize(ids->size() * dim);
CommonSparseTable::pull_sparse(values->data(), pull_value);
CommonSparseTable::PullSparse(values->data(), pull_value);
return 0;
}
int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values,
int32_t SparseGeoTable::PushSparse(const uint64_t* keys, const float* values,
size_t num) {
std::vector<uint64_t> ids;
ids.resize(num);
std::copy_n(keys, num, ids.begin());
geo_recorder->Update(ids);
CommonSparseTable::push_sparse(keys, values, num);
CommonSparseTable::PushSparse(keys, values, num);
return 0;
}
int32_t SparseGeoTable::initialize_value() {
int32_t SparseGeoTable::InitializeValue() {
auto common = _config.common();
shard_values_.reserve(task_pool_size_);
......@@ -82,7 +82,7 @@ int32_t SparseGeoTable::initialize_value() {
auto pull_value = PullSparseValue(ids, fres, param_dim_);
std::vector<float> pulls;
pulls.resize(bucket_feasigns * param_dim_);
pull_sparse(pulls.data(), pull_value);
PullSparse(pulls.data(), pull_value);
}
return 0;
}
......
......@@ -44,15 +44,15 @@ class SparseGeoTable : public CommonSparseTable {
explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; }
virtual ~SparseGeoTable() {}
virtual int32_t initialize_value();
virtual int32_t InitializeValue();
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values,
int32_t PullGeoParam(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys);
int32_t push_sparse(const uint64_t* keys, const float* values,
int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num) override;
virtual int32_t initialize_recorder() {
virtual int32_t InitializeRecorder() {
if (!geo_recorder) {
auto trainers = _config.common().trainer_num();
geo_recorder = std::make_shared<GeoRecorder>(trainers);
......
......@@ -20,7 +20,7 @@ DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file");
namespace paddle {
namespace distributed {
int32_t SSDSparseTable::initialize() {
int32_t SSDSparseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
......@@ -53,9 +53,9 @@ int32_t SSDSparseTable::initialize() {
offset += dim;
}
initialize_value();
initialize_optimizer();
initialize_recorder();
InitializeValue();
InitializeOptimizer();
InitializeRecorder();
_db = paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize(FLAGS_rocksdb_path, task_pool_size_);
return 0;
......@@ -66,17 +66,17 @@ int32_t SSDSparseTable::Pull(TableContext& context) {
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
return PullSparsePtr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
return PullSparse(pull_values, pull_value);
}
}
int32_t SSDSparseTable::Push(TableContext& context) { return 0; }
int32_t SSDSparseTable::pull_sparse(float* pull_values,
int32_t SSDSparseTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num);
......@@ -140,8 +140,8 @@ int32_t SSDSparseTable::pull_sparse(float* pull_values,
return 0;
}
int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values,
const uint64_t* keys, size_t num) {
int32_t SSDSparseTable::PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num) {
auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num);
......@@ -201,9 +201,9 @@ int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values,
return 0;
}
int32_t SSDSparseTable::shrink(const std::string& param) { return 0; }
int32_t SSDSparseTable::Shrink(const std::string& param) { return 0; }
int32_t SSDSparseTable::update_table() {
int32_t SSDSparseTable::UpdateTable() {
int count = 0;
int value_size = shard_values_[0]->value_length_;
int db_size = 3 + value_size;
......@@ -299,7 +299,7 @@ int64_t SSDSparseTable::SaveValueToText(std::ostream* os,
return save_num;
}
int32_t SSDSparseTable::load(const std::string& path,
int32_t SSDSparseTable::Load(const std::string& path,
const std::string& param) {
rwlock_->WRLock();
VLOG(3) << "ssd sparse table load with " << path << " with meta " << param;
......
......@@ -23,7 +23,7 @@ class SSDSparseTable : public CommonSparseTable {
SSDSparseTable() {}
virtual ~SSDSparseTable() {}
virtual int32_t initialize() override;
virtual int32_t Initialize() override;
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total);
......@@ -37,22 +37,22 @@ class SSDSparseTable : public CommonSparseTable {
const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks);
virtual int32_t load(const std::string& path, const std::string& param);
virtual int32_t Load(const std::string& path, const std::string& param);
// exchange data
virtual int32_t update_table();
virtual int32_t UpdateTable();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys,
size_t num);
virtual int32_t flush() override { return 0; }
virtual int32_t shrink(const std::string& param) override;
virtual void clear() override {}
virtual int32_t Flush() override { return 0; }
virtual int32_t Shrink(const std::string& param) override;
virtual void Clear() override {}
private:
RocksDBHandler* _db;
......
......@@ -56,7 +56,7 @@ REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule);
int32_t TableManager::initialize() {
int32_t TableManager::Initialize() {
static bool initialized = false;
if (initialized) {
return 0;
......@@ -65,10 +65,10 @@ int32_t TableManager::initialize() {
return 0;
}
int32_t Table::initialize(const TableParameter &config,
int32_t Table::Initialize(const TableParameter &config,
const FsClientParameter &fs_config) {
_config = config;
if (initialize_accessor() != 0) {
if (InitializeAccessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed";
return -1;
}
......@@ -77,10 +77,10 @@ int32_t Table::initialize(const TableParameter &config,
LOG(WARNING) << "Table fs_client initialize failed";
// return -1;
}
return initialize();
return Initialize();
}
int32_t Table::initialize_accessor() {
int32_t Table::InitializeAccessor() {
if (!_config.has_accessor() || !_config.accessor().has_accessor_class()) {
LOG(ERROR) << "missing accessor config in table, table_id:"
<< _config.table_id();
......
......@@ -60,101 +60,99 @@ class Table {
public:
Table() {}
virtual ~Table() {}
virtual int32_t initialize(const TableParameter &config,
virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0;
virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0;
virtual int32_t PullDense(float *values, size_t num) = 0;
virtual int32_t PushDense(const float *values, size_t num) = 0;
// for push global_step
virtual int32_t push_dense(const int64_t *values, const int32_t trainer_id) {
return 0;
}
virtual int32_t push_dense_param(const float *values, size_t num) {
virtual int32_t PushDense(const int64_t *values, const int32_t trainer_id) {
return 0;
}
virtual int32_t PushDenseParam(const float *values, size_t num) { return 0; }
virtual int32_t pull_sparse_ptr(char **pull_values, const uint64_t *keys,
virtual int32_t PullSparsePtr(char **pull_values, const uint64_t *keys,
size_t num) {
VLOG(0) << "NOT IMPLEMENT";
return 0;
}
virtual int32_t pull_sparse(float *values,
virtual int32_t PullSparse(float *values,
const PullSparseValue &pull_value) = 0;
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) = 0;
virtual int32_t push_sparse(const uint64_t *keys, const float **values,
virtual int32_t PushSparse(const uint64_t *keys, const float **values,
size_t num) {
return 0;
}
virtual int32_t push_sparse_param(const uint64_t *keys, const float *values,
virtual int32_t PushSparseParam(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}
// only for sparse geo table
virtual int32_t pull_geo_param(const uint32_t trainer_id,
virtual int32_t PullGeoParam(const uint32_t trainer_id,
std::vector<float> *values,
std::vector<uint64_t> *keys) {
return 0;
}
// only for barrier
virtual int32_t barrier(const uint32_t trainer_id,
virtual int32_t Barrier(const uint32_t trainer_id,
const std::string barrier_type) {
return 0;
}
// only for barrier table
virtual int32_t set_table_map(
virtual int32_t SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) {
return 0;
}
// only for tensor table
virtual int32_t set_program_env(
virtual int32_t SetProgramEnv(
framework::Scope *scope, platform::Place place,
const std::vector<framework::ProgramDesc> *sub_program) {
return 0;
}
virtual int32_t set_global_lr(float *lr) {
virtual int32_t SetGlobalLR(float *lr) {
_global_lr = lr;
return 0;
}
virtual int32_t pour() { return 0; }
virtual int32_t Pour() { return 0; }
virtual void clear() = 0;
virtual int32_t flush() = 0;
virtual int32_t shrink(const std::string &param) = 0;
virtual void Clear() = 0;
virtual int32_t Flush() = 0;
virtual int32_t Shrink(const std::string &param) = 0;
// 指定加载路径
virtual int32_t load(const std::string &path,
virtual int32_t Load(const std::string &path,
const std::string &converter) = 0;
// 指定保存路径
virtual int32_t save(const std::string &path,
virtual int32_t Save(const std::string &path,
const std::string &converter) = 0;
virtual int32_t set_shard(size_t shard_idx, size_t shard_num) {
virtual int32_t SetShard(size_t shard_idx, size_t shard_num) {
_shard_idx = shard_idx;
_shard_num = shard_num;
return initialize_shard();
return InitializeShard();
}
inline std::shared_ptr<ValueAccessor> value_accesor() {
inline std::shared_ptr<ValueAccessor> ValueAccesor() {
return _value_accesor;
}
virtual void *get_shard(size_t shard_idx) = 0;
virtual std::pair<int64_t, int64_t> print_table_stat() { return {0, 0}; }
virtual void *GetShard(size_t shard_idx) = 0;
virtual std::pair<int64_t, int64_t> PrintTableStat() { return {0, 0}; }
protected:
virtual int32_t initialize() = 0;
virtual int32_t initialize_accessor();
virtual int32_t initialize_shard() = 0;
virtual std::string table_dir(const std::string &model_dir) {
virtual int32_t Initialize() = 0;
virtual int32_t InitializeAccessor();
virtual int32_t InitializeShard() = 0;
virtual std::string TableDir(const std::string &model_dir) {
return paddle::string::format_string("%s/%03d/", model_dir.c_str(),
_config.table_id());
}
......@@ -171,11 +169,11 @@ REGISTER_PSCORE_REGISTERER(Table);
class TableManager {
public:
static TableManager &instance() {
static TableManager &Instance() {
static TableManager manager;
return manager;
}
int32_t initialize();
int32_t Initialize();
private:
TableManager() {}
......
......@@ -52,42 +52,42 @@ class TensorTable : public Table {
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values,
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t shrink(const std::string &param) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t initialize_shard() { return 0; }
virtual int32_t InitializeShard() { return 0; }
virtual int32_t flush() { return 0; }
virtual int32_t Flush() { return 0; }
virtual int32_t load(const std::string &path, const std::string &param) {
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t save(const std::string &path, const std::string &param) {
virtual int32_t Save(const std::string &path, const std::string &param) {
return 0;
}
virtual void clear() {}
virtual void Clear() {}
int32_t initialize() override { return 0; }
int32_t Initialize() override { return 0; }
int32_t push_dense(const int64_t *values, const int32_t trainer_id) override {
int32_t PushDense(const int64_t *values, const int32_t trainer_id) override {
return 0;
}
int32_t set_program_env(
int32_t SetProgramEnv(
framework::Scope *scope, platform::Place place,
const std::vector<framework::ProgramDesc> *sub_program) override {
scope_ = scope;
......@@ -111,47 +111,47 @@ class DenseTensorTable : public TensorTable {
DenseTensorTable() {}
virtual ~DenseTensorTable() {}
int32_t pull_sparse(float *values,
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t shrink(const std::string &param) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t initialize_shard() { return 0; }
virtual int32_t InitializeShard() { return 0; }
virtual int32_t flush() { return 0; }
virtual int32_t Flush() { return 0; }
virtual void clear() {}
virtual void Clear() {}
// Todo: Support program Load & Save
virtual int32_t load(const std::string &path, const std::string &param) {
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t save(const std::string &path, const std::string &param) {
virtual int32_t Save(const std::string &path, const std::string &param) {
return 0;
}
// Todo: Support pull dense
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
/*----------------------------------------------------------------------*/
int32_t initialize() override { return 0; }
int32_t Initialize() override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t push_dense(const int64_t *values, const int32_t trainer_id) {
int32_t PushDense(const int64_t *values, const int32_t trainer_id) {
return 0;
}
protected:
virtual int32_t _run_program(const float *values, size_t num,
virtual int32_t _RunProgram(const float *values, size_t num,
const uint32_t trainer_id) {
return 0;
}
......@@ -167,36 +167,36 @@ class GlobalStepTable : public DenseTensorTable {
GlobalStepTable() {}
virtual ~GlobalStepTable() {}
int32_t pull_sparse(float *values,
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t shrink(const std::string &param) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t initialize_shard() { return 0; }
virtual int32_t InitializeShard() { return 0; }
virtual int32_t flush() { return 0; }
virtual int32_t Flush() { return 0; }
virtual void clear() {}
virtual void Clear() {}
virtual int32_t load(const std::string &path, const std::string &param) {
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t save(const std::string &path, const std::string &param) {
virtual int32_t Save(const std::string &path, const std::string &param) {
return 0;
}
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
/*----------------------------------------------------------------------*/
int32_t initialize() override {
int32_t Initialize() override {
auto _program_config = _config.tensor();
auto trainers_ = _config.common().trainer_num();
FLAGS_eager_delete_tensor_gb = -1;
......@@ -237,13 +237,13 @@ class GlobalStepTable : public DenseTensorTable {
}
}
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t push_dense(const int64_t *values, const int32_t trainer_id) {
return _run_program(values, trainer_id);
int32_t PushDense(const int64_t *values, const int32_t trainer_id) {
return _RunProgram(values, trainer_id);
}
int32_t set_table_map(std::unordered_map<uint32_t, std::shared_ptr<Table>>
int32_t SetTableMap(std::unordered_map<uint32_t, std::shared_ptr<Table>>
*table_map) override {
auto *lr_var = scope_->FindVar(fetch_var_name_);
auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
......@@ -255,13 +255,13 @@ class GlobalStepTable : public DenseTensorTable {
if (table_id == _config.table_id()) {
continue;
}
iter->second->set_global_lr(lr_value);
iter->second->SetGlobalLR(lr_value);
}
return 0;
}
private:
virtual int32_t _run_program(const int64_t *values,
virtual int32_t _RunProgram(const int64_t *values,
const uint32_t trainer_id) {
FLAGS_eager_delete_tensor_gb = -1;
auto counter = decay_counters_.at(trainer_id);
......
......@@ -51,32 +51,6 @@ int32_t FleetWrapper::CopyTableByFeasign(
return 0;
}
void FleetWrapper::Stop() { StopServer(); }
void FleetWrapper::Load(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id >= 0 && context.meta != "") {
LoadSparseOnServer(context.path, context.meta, context.table_id);
return;
}
if (table_id < 0) { // laod all
LoadModel(context.path, context.mode);
} else { // load one table
LoadModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::Save(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id < 0) {
SaveModel(context.path, context.mode);
} else {
SaveModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms,
int max_retry) {
......@@ -90,7 +64,7 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path,
uint32_t table_id) {
VLOG(3) << "load sparse table " << table_id << " with " << path << " meta "
<< meta;
pserver_ptr_->_server_ptr->table(table_id)->load(path, meta);
pserver_ptr_->_server_ptr->GetTable(table_id)->Load(path, meta);
}
void FleetWrapper::InitServer(
......@@ -101,7 +75,7 @@ void FleetWrapper::InitServer(
VLOG(3) << "Going to init server";
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
new paddle::distributed::PSCore());
pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(),
pserver_ptr_->InitServer(dist_desc, &host_sign_list, host_sign_list.size(),
index, trainers, server_sub_program);
is_initialized_ = true;
} else {
......@@ -143,10 +117,10 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param);
InitGFlag(ps_param.init_gflags());
int servers = host_sign_list.size();
ps_env_.set_ps_servers(&host_sign_list, servers);
ps_env_.SetPsServers(&host_sign_list, servers);
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(ps_param));
worker_ptr_->configure(ps_param, dense_pull_regions, ps_env_, index);
paddle::distributed::PSClientFactory::Create(ps_param));
worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index);
}
} else {
VLOG(3) << "Client can be initialized only once";
......@@ -155,13 +129,13 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
void FleetWrapper::StopServer() {
VLOG(3) << "Going to stop server";
auto status = worker_ptr_->stop_server();
auto status = worker_ptr_->StopServer();
status.wait();
}
void FleetWrapper::FinalizeWorker() {
VLOG(3) << "Going to finalize worker";
worker_ptr_->finalize_worker();
worker_ptr_->FinalizeWorker();
}
void FleetWrapper::BarrierWithTable(uint32_t barrier_type) {
......@@ -172,13 +146,13 @@ void FleetWrapper::BarrierWithTable(uint32_t barrier_type) {
uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
VLOG(3) << "Going to run server with ip " << ip << " port " << port;
auto ret = pserver_ptr_->run_server(ip, port);
auto ret = pserver_ptr_->RunServer(ip, port);
return ret;
}
std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
VLOG(3) << "Going to get client info";
std::vector<uint64_t> res = ps_env_.get_client_info();
std::vector<uint64_t> res = ps_env_.GetClientInfo();
for (auto rr : res) {
VLOG(2) << "FleetWrapper::GetClientInfo " << rr;
}
......@@ -187,13 +161,13 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
int FleetWrapper::SetClients(std::vector<uint64_t>& host_sign_list) {
int node = host_sign_list.size();
return ps_env_.set_ps_clients(host_sign_list.data(), node);
return ps_env_.SetPsClients(host_sign_list.data(), node);
}
void FleetWrapper::CreateClient2ClientConnection() {
VLOG(1) << "Going to create client2client connection";
worker_ptr_->create_client2client_connection(
client2client_request_timeout_ms_, client2client_connect_timeout_ms_,
worker_ptr_->CreateClient2ClientConnection(client2client_request_timeout_ms_,
client2client_connect_timeout_ms_,
client2client_max_retry_);
}
......@@ -230,8 +204,8 @@ std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
}
bool training = true;
return pserver_ptr_->_worker_ptr->pull_sparse(pull_result_ptr.data(),
table_id, fea_keys->data(),
return pserver_ptr_->_worker_ptr->PullSparse(pull_result_ptr.data(), table_id,
fea_keys->data(),
fea_keys->size(), training);
}
......@@ -279,7 +253,7 @@ void FleetWrapper::PullSparseVarsSync(
pull_result_ptr.push_back(t.data());
}
bool training = true;
auto status = pserver_ptr_->_worker_ptr->pull_sparse(
auto status = pserver_ptr_->_worker_ptr->PullSparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size(),
training);
pull_sparse_status.push_back(std::move(status));
......@@ -337,21 +311,10 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
pull_result_ptr.push_back(output_data + output_len);
}
}
// ps client pull sparse
// construct client request context
RequestContext req_context;
req_context.value_type = Sparse;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.sparse_values = pull_result_ptr.data();
req_context.keys = fea_keys.data();
req_context.num = fea_keys.size();
req_context.is_training = is_training;
auto status = worker_ptr_->Pull(req_context);
// auto status =
// worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id,
// fea_keys.data(), fea_keys.size(),
// is_training);
auto status =
worker_ptr_->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.size(), is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
......@@ -364,7 +327,7 @@ void FleetWrapper::PullDenseVarsAsync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* pull_dense_status, bool in_cpu) {
auto& regions = _regions[tid];
auto& regions = regions_[tid];
regions.clear();
regions.resize(var_names.size());
for (auto i = 0u; i < var_names.size(); ++i) {
......@@ -378,21 +341,15 @@ void FleetWrapper::PullDenseVarsAsync(
paddle::distributed::Region reg(w, tensor->numel());
regions[i] = std::move(reg);
}
RequestContext req_context;
req_context.value_type = Dense;
req_context.training_mode = Async;
req_context.table = tid;
req_context.dense_values = regions.data();
req_context.num = regions.size();
auto status = worker_ptr_->Pull(req_context);
// auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid);
pull_dense_status->push_back(std::move(status));
}
void FleetWrapper::PullDenseVarsSync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names) {
auto& regions = _regions[tid];
auto& regions = regions_[tid];
regions.clear();
regions.reserve(var_names.size());
for (auto& t : var_names) {
......@@ -404,7 +361,7 @@ void FleetWrapper::PullDenseVarsSync(
regions.emplace_back(std::move(reg));
}
}
auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid);
status.wait();
}
......@@ -424,7 +381,7 @@ void FleetWrapper::PushDenseParamSync(
}
}
auto push_status =
worker_ptr_->push_dense_param(regions.data(), regions.size(), table_id);
worker_ptr_->PushDenseParam(regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status[" << status << "]";
......@@ -470,15 +427,8 @@ void FleetWrapper::PushDenseVarsAsync(
<< g[tensor->numel() - 1];
}
RequestContext req_context;
req_context.value_type = Dense;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.push_context.push_dense_values = regions.data();
req_context.num = regions.size();
// auto push_status =
// worker_ptr_->push_dense(regions.data(), regions.size(), table_id);
auto push_status = worker_ptr_->Push(req_context);
auto push_status =
worker_ptr_->PushDense(regions.data(), regions.size(), table_id);
}
void FleetWrapper::PushSparseVarsAsync(
......@@ -650,23 +600,13 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_g_vec[i] = push_values.at(i).data();
}
// ps client push sparse
// construct request context
RequestContext req_context;
req_context.value_type = Sparse;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.push_context.push_values = (const float**)push_g_vec.data();
req_context.push_context.keys = push_keys.data();
req_context.num = push_keys.size();
auto status = worker_ptr_->Push(req_context);
// auto status = worker_ptr_->push_sparse(table_id, push_keys.data(),
// (const float**)push_g_vec.data(),
// push_keys.size());
auto status = worker_ptr_->PushSparse(table_id, push_keys.data(),
(const float**)push_g_vec.data(),
push_keys.size());
}
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
auto ret = worker_ptr_->load(path, std::to_string(mode));
auto ret = worker_ptr_->Load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
......@@ -675,7 +615,7 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
auto ret = worker_ptr_->load(table_id, path, std::to_string(mode));
auto ret = worker_ptr_->Load(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
......@@ -684,7 +624,7 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) {
auto ret = worker_ptr_->save(path, std::to_string(mode));
auto ret = worker_ptr_->Save(path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
......@@ -694,7 +634,7 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
auto ret = worker_ptr_->save(table_id, path, std::to_string(mode));
auto ret = worker_ptr_->Save(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "save model of table id: " << table_id
......@@ -704,7 +644,7 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
auto ret = worker_ptr_->recv_and_save_table(table_id, path);
auto ret = worker_ptr_->RecvAndSaveTable(table_id, path);
if (ret != 0) {
LOG(ERROR) << "save model of table id: " << table_id
<< ", to path: " << path << " failed";
......@@ -712,7 +652,7 @@ void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
}
void FleetWrapper::PrintTableStat(const uint64_t table_id) {
auto ret = worker_ptr_->print_table_stat(table_id);
auto ret = worker_ptr_->PrintTableStat(table_id);
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
......@@ -721,7 +661,7 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) {
}
void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) {
auto ret = worker_ptr_->shrink(table_id, std::to_string(threshold));
auto ret = worker_ptr_->Shrink(table_id, std::to_string(threshold));
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
......@@ -730,12 +670,12 @@ void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) {
}
void FleetWrapper::ClearModel() {
auto ret = pserver_ptr_->_worker_ptr->clear();
auto ret = pserver_ptr_->_worker_ptr->Clear();
ret.wait();
}
void FleetWrapper::ClearOneTable(const uint64_t table_id) {
auto ret = pserver_ptr_->_worker_ptr->clear(table_id);
auto ret = pserver_ptr_->_worker_ptr->Clear(table_id);
ret.wait();
}
......@@ -774,7 +714,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
regions.emplace_back(std::move(reg));
}
}
auto push_status = pserver_ptr_->_worker_ptr->push_dense_param(
auto push_status = pserver_ptr_->_worker_ptr->PushDenseParam(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
......@@ -791,7 +731,7 @@ void FleetWrapper::ClientFlush() {
VLOG(0) << "worker_ptr null, do nothing";
return;
}
auto ret = worker_ptr_->flush();
auto ret = worker_ptr_->Flush();
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
......@@ -805,13 +745,13 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
VLOG(0) << "FleetWrapper::Client is null";
return -1;
} else {
return worker_ptr_->registe_client2client_msg_handler(msg_type, handler);
return worker_ptr_->RegisteClient2ClientMsgHandler(msg_type, handler);
}
}
std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
return worker_ptr_->send_client2client_msg(msg_type, to_client_id, msg);
return worker_ptr_->SendClient2ClientMsg(msg_type, to_client_id, msg);
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
......
......@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
......@@ -55,7 +54,7 @@ using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper : public PSWrapper {
class FleetWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() {
......@@ -69,7 +68,6 @@ class FleetWrapper : public PSWrapper {
// pserver request max retry
client2client_max_retry_ = 3;
}
virtual int32_t Initialize(InitContext& context) { return 0; }
// TODO(zhaocaibei123: later)
int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);
......@@ -81,12 +79,6 @@ class FleetWrapper : public PSWrapper {
typedef std::function<void(int, int)> HeterCallBackFunc;
int RegisterHeterCallback(HeterCallBackFunc handler);
virtual void Stop() override;
virtual void Load(WrapperContext& context) override;
virtual void Save(WrapperContext& context) override;
// set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
......@@ -278,7 +270,7 @@ class FleetWrapper : public PSWrapper {
protected:
static bool is_initialized_;
std::map<uint64_t, std::vector<paddle::distributed::Region>> _regions;
std::map<uint64_t, std::vector<paddle::distributed::Region>> regions_;
bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_;
int client2client_request_timeout_ms_;
......
......@@ -39,19 +39,19 @@ TEST(BarrierTable, Barrier) {
common_config->set_trainer_num(trainers);
common_config->set_sync(sync);
auto ret = table->initialize(table_config, fs_config);
auto ret = table->Initialize(table_config, fs_config);
std::unordered_map<uint32_t, std::shared_ptr<Table>> maps =
std::unordered_map<uint32_t, std::shared_ptr<Table>>();
table->set_table_map(&maps);
table->SetTableMap(&maps);
std::shared_ptr<::ThreadPool> pool_ =
std::make_shared<::ThreadPool>(trainers);
std::vector<std::future<void>> task_status;
for (auto x = 0; x < trainers; x++) {
auto task = [table, x] { table->barrier(x, 0); };
auto task = [table, x] { table->Barrier(x, 0); };
task_status.push_back(pool_->enqueue(std::move(task)));
}
......
......@@ -155,16 +155,16 @@ void RunServer() {
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
LOG(INFO) << "RUN set_ps_servers";
_ps_env.set_ps_servers(&host_sign_list_, 1);
_ps_env.SetPsServers(&host_sign_list_, 1);
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(server_proto));
paddle::distributed::PSServerFactory::Create(server_proto));
LOG(INFO) << "RUN configure";
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec);
pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec);
LOG(INFO) << "RUN start";
pserver_ptr_->start(ip_, port_);
pserver_ptr_->Start(ip_, port_);
LOG(INFO) << "End start";
}
......@@ -175,19 +175,19 @@ void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
LOG(INFO) << "Run set_ps_servers";
_ps_env.set_ps_servers(&host_sign_list_, servers_);
_ps_env.SetPsServers(&host_sign_list_, servers_);
LOG(INFO) << "Run Create PSClient";
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(worker_proto));
paddle::distributed::PSClientFactory::Create(worker_proto));
LOG(INFO) << "Run configure";
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0);
worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0);
}
void RunBrpcPushDense() {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string());
host_sign_list_.push_back(ph_host.SerializeToString());
// Srart Server
std::thread server_thread(RunServer);
......@@ -218,7 +218,7 @@ void RunBrpcPushDense() {
paddle::distributed::Region temp_reg(temp, tensor->numel());
temp_region.emplace_back(std::move(temp_reg));
auto pull_status =
worker_ptr_->pull_dense(temp_region.data(), temp_region.size(), 0);
worker_ptr_->PullDense(temp_region.data(), temp_region.size(), 0);
pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
......@@ -229,10 +229,10 @@ void RunBrpcPushDense() {
LOG(INFO) << "Run push_dense_param";
auto push_status =
worker_ptr_->push_dense_param(regions.data(), regions.size(), 0);
worker_ptr_->PushDenseParam(regions.data(), regions.size(), 0);
push_status.wait();
pull_status = worker_ptr_->pull_dense(regions.data(), regions.size(), 0);
pull_status = worker_ptr_->PullDense(regions.data(), regions.size(), 0);
pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
......@@ -257,11 +257,11 @@ void RunBrpcPushDense() {
LOG(INFO) << "Run pull_dense_grad";
auto push_grad_status =
worker_ptr_->push_dense_raw_gradient(0, temp, tensor->numel(), closure);
worker_ptr_->PushDenseRawGradient(0, temp, tensor->numel(), closure);
push_grad_status.wait();
auto pull_update_status =
worker_ptr_->pull_dense(regions.data(), regions.size(), 0);
worker_ptr_->PullDense(regions.data(), regions.size(), 0);
pull_update_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
......@@ -269,9 +269,9 @@ void RunBrpcPushDense() {
}
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
worker_ptr_->StopServer();
LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker();
worker_ptr_->FinalizeWorker();
server_thread.join();
}
......
......@@ -156,14 +156,14 @@ void RunServer() {
::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, 1);
_ps_env.SetPsServers(&host_sign_list_, 1);
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(server_proto));
paddle::distributed::PSServerFactory::Create(server_proto));
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec);
pserver_ptr_->start(ip_, port_);
pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec);
pserver_ptr_->Start(ip_, port_);
}
void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
......@@ -172,17 +172,17 @@ void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, servers_);
_ps_env.SetPsServers(&host_sign_list_, servers_);
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(worker_proto));
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0);
paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0);
}
void RunBrpcPushSparse() {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string());
host_sign_list_.push_back(ph_host.SerializeToString());
// Srart Server
std::thread server_thread(RunServer);
......@@ -214,7 +214,7 @@ void RunBrpcPushSparse() {
/*-----------------------Test Server Init----------------------------------*/
LOG(INFO) << "Run pull_sparse_param";
auto pull_status = worker_ptr_->pull_sparse(
auto pull_status = worker_ptr_->PullSparse(
fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
......@@ -237,12 +237,12 @@ void RunBrpcPushSparse() {
}
closure->set_promise_value(ret);
});
auto push_status = worker_ptr_->push_sparse_param(
auto push_status = worker_ptr_->PushSparseParam(
0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(),
closure_push_param);
push_status.wait();
auto pull_param_status = worker_ptr_->pull_sparse(
auto pull_param_status = worker_ptr_->PullSparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_param_status.wait();
......@@ -271,12 +271,12 @@ void RunBrpcPushSparse() {
for (auto i = 0; i < static_cast<int>(fea_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * 10);
}
auto push_grad_status = worker_ptr_->push_sparse_raw_gradient(
auto push_grad_status = worker_ptr_->PushSparseRawGradient(
0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(),
closure_push_grad);
push_grad_status.wait();
auto pull_update_status = worker_ptr_->pull_sparse(
auto pull_update_status = worker_ptr_->PullSparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_update_status.wait();
......@@ -285,9 +285,9 @@ void RunBrpcPushSparse() {
}
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
worker_ptr_->StopServer();
LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker();
worker_ptr_->FinalizeWorker();
server_thread.join();
}
......
......@@ -63,13 +63,13 @@ TEST(CommonDenseTable, Adam) {
common_config->add_params("LearningRate");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&5e-6");
auto ret = table->initialize(table_config, fs_config);
auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// pull parameters for create and check
std::vector<float> init_values;
init_values.resize(fea_dim);
table->pull_dense(init_values.data(), fea_dim);
table->PullDense(init_values.data(), fea_dim);
// push gradient
std::vector<std::vector<float>> trainer_gradient_values;
......@@ -85,12 +85,12 @@ TEST(CommonDenseTable, Adam) {
// for adam
for (int i = 0; i < trainers; i++) {
auto &push_values = trainer_gradient_values[i];
table->push_dense(push_values.data(), push_values.size());
table->PushDense(push_values.data(), push_values.size());
}
std::vector<float> pull_values;
pull_values.resize(fea_dim);
table->pull_dense(pull_values.data(), fea_dim);
table->PullDense(pull_values.data(), fea_dim);
float mom_rate = 0.99;
float decay_rate = 0.9999;
......@@ -118,6 +118,7 @@ TEST(CommonDenseTable, Adam) {
}
}
for (int j = 0; j < fea_dim; j++) {
VLOG(0) << param[j] << " " << pull_values[j];
ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5);
}
}
......@@ -143,13 +144,13 @@ TEST(CommonDenseTable, SGD) {
common_config->add_params("LearningRate");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config);
auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// pull parameters for create and check
std::vector<float> init_values;
init_values.resize(fea_dim);
table->pull_dense(init_values.data(), fea_dim);
table->PullDense(init_values.data(), fea_dim);
std::vector<float> total_gradients;
total_gradients.resize(fea_dim);
......@@ -172,7 +173,7 @@ TEST(CommonDenseTable, SGD) {
for (int i = 0; i < trainers; i++) {
auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_values] {
table->push_dense(push_values.data(), push_values.size());
table->PushDense(push_values.data(), push_values.size());
};
task_status.push_back(pool_->enqueue(std::move(task)));
}
......@@ -182,7 +183,7 @@ TEST(CommonDenseTable, SGD) {
std::vector<float> pull_values;
pull_values.resize(fea_dim);
table->pull_dense(pull_values.data(), fea_dim);
table->PullDense(pull_values.data(), fea_dim);
for (int j = 0; j < fea_dim; j++) {
auto update_val = init_values[j] - 1.0 * total_gradients[j];
ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5);
......
......@@ -166,16 +166,16 @@ void RunServer() {
::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, 2); // test
_ps_env.SetPsServers(&host_sign_list_, 2); // test
pserver_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto));
paddle::distributed::PSServerFactory::Create(server_proto));
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec);
pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec);
LOG(INFO) << "first server, run start(ip,port)";
pserver_ptr_->start(ip_, port_);
pserver_ptr_->Start(ip_, port_);
pserver_ptr_->build_peer2peer_connection(0);
LOG(INFO) << "init first server Done";
}
......@@ -185,15 +185,15 @@ void RunServer2() {
::paddle::distributed::PSParameter server_proto2 = GetServerProto();
auto _ps_env2 = paddle::distributed::PaddlePSEnvironment();
_ps_env2.set_ps_servers(&host_sign_list_, 2); // test
_ps_env2.SetPsServers(&host_sign_list_, 2); // test
pserver_ptr2 = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto2));
paddle::distributed::PSServerFactory::Create(server_proto2));
std::vector<framework::ProgramDesc> empty_vec2;
framework::ProgramDesc empty_prog2;
empty_vec2.push_back(empty_prog2);
pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2);
pserver_ptr2->start(ip2, port2);
pserver_ptr2->Configure(server_proto2, _ps_env2, 1, empty_vec2);
pserver_ptr2->Start(ip2, port2);
pserver_ptr2->build_peer2peer_connection(1);
}
......@@ -204,11 +204,11 @@ void RunClient(
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, servers_);
_ps_env.SetPsServers(&host_sign_list_, servers_);
worker_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto));
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0);
paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0);
worker_ptr_->set_shard_num(127);
worker_ptr_->set_local_channel(index);
worker_ptr_->set_local_graph_service(
......@@ -222,11 +222,11 @@ void RunGraphSplit() {
prepare_file(node_file_name, nodes);
prepare_file(graph_split_file_name, graph_split);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string());
host_sign_list_.push_back(ph_host.SerializeToString());
// test-start
auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1);
host_sign_list_.push_back(ph_host2.serialize_to_string());
host_sign_list_.push_back(ph_host2.SerializeToString());
// test-end
// Srart Server
std::thread* server_thread = new std::thread(RunServer);
......@@ -247,7 +247,7 @@ void RunGraphSplit() {
0, std::string(graph_split_file_name));
pull_status.wait();
pull_status =
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
worker_ptr_->Load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0));
pull_status.wait();
std::vector<std::vector<int64_t>> _vs;
......@@ -266,9 +266,9 @@ void RunGraphSplit() {
std::remove(node_file_name);
std::remove(graph_split_file_name);
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
worker_ptr_->StopServer();
LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker();
worker_ptr_->FinalizeWorker();
}
TEST(RunGraphSplit, Run) { RunGraphSplit(); }
......@@ -348,16 +348,16 @@ void RunServer() {
::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, 2); // test
_ps_env.SetPsServers(&host_sign_list_, 2); // test
pserver_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto));
paddle::distributed::PSServerFactory::Create(server_proto));
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec);
pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec);
LOG(INFO) << "first server, run start(ip,port)";
pserver_ptr_->start(ip_, port_);
pserver_ptr_->Start(ip_, port_);
pserver_ptr_->build_peer2peer_connection(0);
LOG(INFO) << "init first server Done";
}
......@@ -367,15 +367,15 @@ void RunServer2() {
::paddle::distributed::PSParameter server_proto2 = GetServerProto();
auto _ps_env2 = paddle::distributed::PaddlePSEnvironment();
_ps_env2.set_ps_servers(&host_sign_list_, 2); // test
_ps_env2.SetPsServers(&host_sign_list_, 2); // test
pserver_ptr2 = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto2));
paddle::distributed::PSServerFactory::Create(server_proto2));
std::vector<framework::ProgramDesc> empty_vec2;
framework::ProgramDesc empty_prog2;
empty_vec2.push_back(empty_prog2);
pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2);
pserver_ptr2->start(ip2, port2);
pserver_ptr2->Configure(server_proto2, _ps_env2, 1, empty_vec2);
pserver_ptr2->Start(ip2, port2);
pserver_ptr2->build_peer2peer_connection(1);
}
......@@ -386,11 +386,11 @@ void RunClient(
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, servers_);
_ps_env.SetPsServers(&host_sign_list_, servers_);
worker_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto));
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0);
paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0);
worker_ptr_->set_shard_num(127);
worker_ptr_->set_local_channel(index);
worker_ptr_->set_local_graph_service(
......@@ -404,11 +404,11 @@ void RunBrpcPushSparse() {
prepare_file(edge_file_name, 1);
prepare_file(node_file_name, 0);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string());
host_sign_list_.push_back(ph_host.SerializeToString());
// test-start
auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1);
host_sign_list_.push_back(ph_host2.serialize_to_string());
host_sign_list_.push_back(ph_host2.SerializeToString());
// test-end
// Srart Server
std::thread* server_thread = new std::thread(RunServer);
......@@ -424,7 +424,7 @@ void RunBrpcPushSparse() {
/*-----------------------Test Server Init----------------------------------*/
auto pull_status =
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
worker_ptr_->Load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0));
pull_status.wait();
std::vector<std::vector<int64_t>> _vs;
......@@ -438,7 +438,7 @@ void RunBrpcPushSparse() {
pull_status.wait();
ASSERT_EQ(0, _vs[0].size());
paddle::distributed::GraphTable* g =
(paddle::distributed::GraphTable*)pserver_ptr_->table(0);
(paddle::distributed::GraphTable*)pserver_ptr_->GetTable(0);
size_t ttl = 6;
g->make_neighbor_sample_cache(4, ttl);
int round = 5;
......@@ -622,15 +622,15 @@ void RunBrpcPushSparse() {
std::remove(node_file_name);
testAddNode(worker_ptr_);
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
worker_ptr_->StopServer();
LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker();
worker_ptr_->FinalizeWorker();
testFeatureNodeSerializeInt();
testFeatureNodeSerializeInt64();
testFeatureNodeSerializeFloat32();
testFeatureNodeSerializeFloat64();
testGraphToBuffer();
client1.stop_server();
client1.StopServer();
}
void testCache() {
......
......@@ -48,7 +48,7 @@ TEST(MemorySparseGeoTable, SSUM) {
common_config->add_dims(emb_dim);
common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config);
auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// test push_sparse_param, and create params
......@@ -58,12 +58,12 @@ TEST(MemorySparseGeoTable, SSUM) {
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
init_values.push_back(0.0);
}
table->push_sparse_param(init_keys.data(), init_values.data(),
table->PushSparseParam(init_keys.data(), init_values.data(),
init_keys.size());
std::vector<float> pull_values(init_values.size());
auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(pull_values.data(), value);
table->PullSparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5);
......@@ -93,8 +93,7 @@ TEST(MemorySparseGeoTable, SSUM) {
auto &push_keys = trainer_keys[i];
auto &push_values = trainer_values[i];
auto task = [table, &push_keys, &push_values] {
table->push_sparse(push_keys.data(), push_values.data(),
push_keys.size());
table->PushSparse(push_keys.data(), push_values.data(), push_keys.size());
};
task_status.push_back(pool_->enqueue(std::move(task)));
}
......@@ -107,7 +106,7 @@ TEST(MemorySparseGeoTable, SSUM) {
geo_pull_ids.resize(trainers);
geo_pull_values.resize(trainers);
for (int i = 0; i < trainers; i++) {
table->pull_geo_param(i, &geo_pull_values[i], &geo_pull_ids[i]);
table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]);
ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim);
for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) {
auto id = geo_pull_ids[i][j];
......
......@@ -36,7 +36,7 @@ TEST(MemorySparseTable, SGD) {
table_config.set_shard_num(10);
FsClientParameter fs_config;
Table *table = new MemorySparseTable();
table->set_shard(0, 1);
table->SetShard(0, 1);
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CtrCommonAccessor");
......@@ -66,7 +66,7 @@ TEST(MemorySparseTable, SGD) {
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
auto ret = table->initialize(table_config, fs_config);
auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// pull parameters for create and check
......@@ -76,7 +76,7 @@ TEST(MemorySparseTable, SGD) {
std::vector<float> init_values;
init_values.resize(init_keys.size() * (emb_dim + 3));
auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(init_values.data(), value);
table->PullSparse(init_values.data(), value);
// for check
std::vector<float> total_gradients;
......@@ -109,8 +109,7 @@ TEST(MemorySparseTable, SGD) {
auto &push_keys = trainer_keys[i];
auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_keys, &push_values] {
table->push_sparse(push_keys.data(), push_values.data(),
push_keys.size());
table->PushSparse(push_keys.data(), push_values.data(), push_keys.size());
};
task_status.push_back(pool_->enqueue(std::move(task)));
}
......@@ -120,7 +119,7 @@ TEST(MemorySparseTable, SGD) {
std::vector<float> pull_values;
pull_values.resize(init_keys.size() * (emb_dim + 3));
table->pull_sparse(pull_values.data(), value);
table->PullSparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size(); ++i) {
for (size_t j = 2; j < emb_dim + 3; ++j) {
......@@ -133,7 +132,7 @@ TEST(MemorySparseTable, SGD) {
}
MemorySparseTable *ctr_table = dynamic_cast<MemorySparseTable *>(table);
ctr_table->save_local_fs("./work/table.save", "0", "test");
ctr_table->SaveLocalFS("./work/table.save", "0", "test");
}
} // namespace distributed
......
......@@ -26,7 +26,7 @@ TEST(Table, Initialize) {
FsClientParameter fs_config;
// case 1. no accessor
Table *table = new SparseGeoTable();
auto ret = table->initialize(table_config, fs_config);
auto ret = table->Initialize(table_config, fs_config);
ASSERT_EQ(ret, -1);
}
} // namespace distributed
......
......@@ -343,7 +343,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
#ifdef PADDLE_WITH_PSCORE
int32_t cnt = 0;
while (true) {
auto tt = fleet_ptr->worker_ptr_->pull_sparse_ptr(
auto tt = fleet_ptr->worker_ptr_->PullSparsePtr(
reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size);
bool flag = true;
......
......@@ -276,7 +276,7 @@ void MultiTrainer::Finalize() {
if (communicator == nullptr) {
VLOG(0) << "MultiTrainer::Finalize communicator is null!";
} else {
communicator->_worker_ptr->flush();
communicator->_worker_ptr->Flush();
VLOG(1) << "MultiTrainer::Finalize ps client flush done";
}
#endif
......
......@@ -86,11 +86,11 @@ void BindDistFleetWrapper(py::module* m) {
void BindPSHost(py::module* m) {
py::class_<distributed::PSHost>(*m, "PSHost")
.def(py::init<const std::string&, uint32_t, uint32_t>())
.def("serialize_to_string", &distributed::PSHost::serialize_to_string)
.def("parse_from_string", &distributed::PSHost::parse_from_string)
.def("to_uint64", &distributed::PSHost::serialize_to_uint64)
.def("from_uint64", &distributed::PSHost::parse_from_uint64)
.def("to_string", &distributed::PSHost::to_string);
.def("serialize_to_string", &distributed::PSHost::SerializeToString)
.def("parse_from_string", &distributed::PSHost::ParseFromString)
.def("to_uint64", &distributed::PSHost::SerializeToUint64)
.def("from_uint64", &distributed::PSHost::ParseFromUint64)
.def("to_string", &distributed::PSHost::ToString);
}
void BindSparseShardingTools(py::module* m) {
......@@ -224,7 +224,7 @@ void BindGraphPyClient(py::module* m) {
&GraphPyClient::use_neighbors_sample_cache)
.def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server)
.def("stop_server", &GraphPyClient::StopServer)
.def("get_node_feat",
[](GraphPyClient& self, std::string node_type,
std::vector<int64_t> node_ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册