diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index 5a92afb297c7e2d5cb24f4603adbd30449a9e769..893e0f9a975968b8052320de5d25d1978833a291 100755 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -80,7 +80,7 @@ void DownpourPsClientService::service( const PsRequestMessage *request, PsResponseMessage *response, ::google::protobuf::Closure *done) { brpc::ClosureGuard done_guard(done); - int ret = _client->handle_client2client_msg( + int ret = _client->HandleClient2ClientMsg( request->cmd_id(), request->client_id(), request->data()); response->set_err_code(0); response->set_err_msg(""); @@ -91,8 +91,8 @@ void DownpourPsClientService::service( } // 启动client端RpcService 用于数据互发等操作 -int32_t BrpcPsClient::start_client_service() { - if (_service.configure(this, _client_id) != 0) { +int32_t BrpcPsClient::StartClientService() { + if (_service.Configure(this, _client_id) != 0) { LOG(ERROR) << "service initialize failed, service_name:DownpourPsClientService"; return -1; @@ -108,12 +108,12 @@ int32_t BrpcPsClient::start_client_service() { return -1; } _server_started = true; - _env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, - _client_id); + _env->RegistePsClient(butil::my_ip_cstr(), _server.listen_address().port, + _client_id); return 0; } -int32_t BrpcPsClient::create_client2client_connection( +int32_t BrpcPsClient::CreateClient2ClientConnection( int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { brpc::ChannelOptions options; options.protocol = "baidu_std"; @@ -122,12 +122,12 @@ int32_t BrpcPsClient::create_client2client_connection( options.connect_timeout_ms = pserver_connect_timeout_ms; options.max_retry = max_retry; - std::vector client_list = _env->get_ps_clients(); + std::vector client_list = _env->GetPsClients(); VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: " << client_list.size(); for (auto cc : client_list) { VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: " - << cc.to_string(); + << cc.ToString(); } _client_channels.resize(client_list.size()); std::ostringstream os; @@ -154,7 +154,7 @@ int32_t BrpcPsClient::create_client2client_connection( return 0; } -int32_t BrpcPsClient::initialize() { +int32_t BrpcPsClient::Initialize() { _async_call_num = 0; brpc::ChannelOptions options; @@ -169,7 +169,7 @@ int32_t BrpcPsClient::initialize() { std::string client_ip(butil::my_ip_cstr()); // 获取server列表,并连接 - std::vector server_list = _env->get_ps_servers(); + std::vector server_list = _env->GetPsServers(); _server_channels.resize(server_list.size()); for (size_t i = 0; i < server_list.size(); ++i) { server_ip_port.assign(server_list[i].ip.c_str()); @@ -194,7 +194,7 @@ int32_t BrpcPsClient::initialize() { os << server_ip_port << ","; } // 启动client探听接口, 并相互建立连接 - start_client_service(); + StartClientService(); // 异步push 请求队列初始化 const auto &worker_param = _config.worker_param().downpour_worker_param(); @@ -234,13 +234,13 @@ int32_t BrpcPsClient::initialize() { _flushing = false; // 启动异步push线程 _async_push_sparse_thread = - std::thread(std::bind(&BrpcPsClient::push_sparse_task_consume, this)); + std::thread(std::bind(&BrpcPsClient::PushSparseTaskConsume, this)); // _async_push_sparse_thread.detach(); _async_push_dense_thread = - std::thread(std::bind(&BrpcPsClient::push_dense_task_consume, this)); + std::thread(std::bind(&BrpcPsClient::PushDenseTaskConsume, this)); // for debug // _print_thread = - // std::thread(std::bind(&BrpcPsClient::print_queue_size_thread, this)); + // std::thread(std::bind(&BrpcPsClient::PrintQueueSizeThread, this)); return 0; } @@ -286,7 +286,7 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { return data; } -std::future BrpcPsClient::print_table_stat(uint32_t table_id) { +std::future BrpcPsClient::PrintTableStat(uint32_t table_id) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, table_id](void *done) { @@ -319,7 +319,7 @@ std::future BrpcPsClient::print_table_stat(uint32_t table_id) { closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_timeout_ms( 10800000); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), @@ -327,7 +327,7 @@ std::future BrpcPsClient::print_table_stat(uint32_t table_id) { } return fut; } -std::future BrpcPsClient::send_cmd( +std::future BrpcPsClient::SendCmd( uint32_t table_id, int cmd_id, const std::vector ¶ms) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( @@ -352,7 +352,7 @@ std::future BrpcPsClient::send_cmd( for (const auto ¶m : params) { closure->request(i)->add_params(param); } - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_timeout_ms( 10800000 * 2); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), @@ -361,7 +361,7 @@ std::future BrpcPsClient::send_cmd( return fut; } -std::future BrpcPsClient::send_save_cmd( +std::future BrpcPsClient::SendSaveCmd( uint32_t table_id, int cmd_id, const std::vector ¶ms) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( @@ -392,7 +392,7 @@ std::future BrpcPsClient::send_save_cmd( for (const auto ¶m : params) { closure->request(i)->add_params(param); } - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_timeout_ms( 10800000); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), @@ -401,65 +401,42 @@ std::future BrpcPsClient::send_save_cmd( return fut; } -std::future BrpcPsClient::shrink(uint32_t table_id, +std::future BrpcPsClient::Shrink(uint32_t table_id, const std::string threshold) { - return send_cmd(table_id, PS_SHRINK_TABLE, {threshold}); + return SendCmd(table_id, PS_SHRINK_TABLE, {threshold}); } -std::future BrpcPsClient::load(const std::string &epoch, +std::future BrpcPsClient::Load(const std::string &epoch, const std::string &mode) { - return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); + return SendCmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); } -std::future BrpcPsClient::load(uint32_t table_id, +std::future BrpcPsClient::Load(uint32_t table_id, const std::string &epoch, const std::string &mode) { - return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); + return SendCmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); } -std::future BrpcPsClient::Load(const LoadSaveContext &load_context) { - if (load_context.table_id < 0) { - return send_cmd(-1, PS_LOAD_ALL_TABLE, - {load_context.epoch, load_context.mode}); - } else { - return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE, - {load_context.epoch, load_context.mode}); - } -} - -std::future BrpcPsClient::save(const std::string &epoch, +std::future BrpcPsClient::Save(const std::string &epoch, const std::string &mode) { VLOG(1) << "BrpcPsClient::save path " << epoch; - return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); + return SendSaveCmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); } -std::future BrpcPsClient::save(uint32_t table_id, +std::future BrpcPsClient::Save(uint32_t table_id, const std::string &epoch, const std::string &mode) { VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id " << table_id; - return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); + return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); } -std::future BrpcPsClient::Save(const LoadSaveContext &save_context) { - if (save_context.table_id < 0) { - VLOG(1) << "BrpcPsClient::save path " << save_context.epoch; - return send_save_cmd(-1, PS_SAVE_ALL_TABLE, - {save_context.epoch, save_context.mode}); - } else { - VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch - << " table_id " << save_context.table_id; - return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE, - {save_context.epoch, save_context.mode}); - } -} - -std::future BrpcPsClient::clear() { - return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); +std::future BrpcPsClient::Clear() { + return SendCmd(-1, PS_CLEAR_ALL_TABLE, {}); } -std::future BrpcPsClient::clear(uint32_t table_id) { - return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {}); +std::future BrpcPsClient::Clear(uint32_t table_id) { + return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {}); } -std::future BrpcPsClient::flush() { +std::future BrpcPsClient::Flush() { VLOG(0) << "BrpcPsClient::flush begin"; _flushing = true; std::promise promise; @@ -472,106 +449,69 @@ std::future BrpcPsClient::flush() { promise.set_value(0); _flushing = false; VLOG(0) << "BrpcPsClient::flush done"; - print_queue_size(); + PrintQueueSize(); return fut; } -void BrpcPsClient::print_queue_size() { +void BrpcPsClient::PrintQueueSize() { for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { auto table_id = push_sparse_task_itr.first; auto queue_size = push_sparse_task_itr.second->Size(); - VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id + VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id << " size: " << queue_size; } for (auto &task_queue_itr : _push_dense_task_queue_map) { auto table_id = task_queue_itr.first; auto queue_size = task_queue_itr.second->Size(); - VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id + VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id << " size: " << queue_size; } } -void BrpcPsClient::print_queue_size_thread() { +void BrpcPsClient::PrintQueueSizeThread() { while (_running) { usleep(1000000 * 60 * 2); - print_queue_size(); + PrintQueueSize(); } } -void BrpcPsClient::finalize_worker() { - flush(); - VLOG(0) << "BrpcPsClient::finalize_worker begin join thread"; +void BrpcPsClient::FinalizeWorker() { + Flush(); + VLOG(0) << "BrpcPsClient::FinalizeWorker begin join thread"; _running = false; _async_push_dense_thread.join(); _async_push_sparse_thread.join(); // _print_thread.join(); - VLOG(0) << "BrpcPsClient::finalize_worker begin join server"; + VLOG(0) << "BrpcPsClient::FinalizeWorker begin join server"; _server.Stop(1000); _server.Join(); _server_started = false; - VLOG(0) << "BrpcPsClient::finalize_worker done"; + VLOG(0) << "BrpcPsClient::FinalizeWorker done"; } -std::future BrpcPsClient::stop_server() { - return send_cmd(-1, PS_STOP_SERVER, {}); +std::future BrpcPsClient::StopServer() { + return SendCmd(-1, PS_STOP_SERVER, {}); } -std::future BrpcPsClient::start_profiler() { - return send_cmd(-1, PS_START_PROFILER, {}); +std::future BrpcPsClient::StartProfiler() { + return SendCmd(-1, PS_START_PROFILER, {}); } -std::future BrpcPsClient::stop_profiler() { - return send_cmd(-1, PS_STOP_PROFILER, {}); +std::future BrpcPsClient::StopProfiler() { + return SendCmd(-1, PS_STOP_PROFILER, {}); } -std::future BrpcPsClient::barrier(size_t table_id, +std::future BrpcPsClient::Barrier(size_t table_id, uint32_t barrier_type) { - return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); -} - -std::future BrpcPsClient::Pull(RequestContext &pull_context) { - if (pull_context.value_type == Dense) { // pull dense - Region *dense_region = - reinterpret_cast(pull_context.dense_values); - return pull_dense(dense_region, pull_context.num, pull_context.table); - } else { // pull sparse - size_t table_id = pull_context.table; - size_t num = pull_context.num; - bool is_training = pull_context.is_training; - if (pull_context.training_mode == Geo) { // for geo - return pull_sparse_param(pull_context.sparse_values, table_id, - pull_context.keys, num, is_training); - } else if (pull_context.training_mode == Async) { // for async - return pull_sparse(pull_context.sparse_values, table_id, - pull_context.keys, num, is_training); - } - } + return SendCmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); } -std::future BrpcPsClient::Push(RequestContext &push_context) { - if (push_context.value_type == Dense) { // push dense - const Region *dense_region = push_context.push_context.push_dense_values; - return push_dense(dense_region, push_context.num, push_context.table); - } else { // push sparse - size_t table_id = push_context.table; - size_t num = push_context.num; - bool is_training = push_context.is_training; - if (push_context.training_mode == Geo) { // for geo - // TODO(zhaocaibei) - } else if (push_context.training_mode == Async) { // for async - const uint64_t *keys = push_context.push_context.keys; - const float **update_values = push_context.push_context.push_values; - return push_sparse(table_id, keys, update_values, num); - } - } -} - -std::future BrpcPsClient::pull_geo_param(size_t table_id, - std::vector *values, - std::vector *keys, - int pserver_idx) { - auto *accessor = table_accessor(table_id); +std::future BrpcPsClient::PullGeoParam(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) { + auto *accessor = GetTableAccessor(table_id); DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { int ret = 0; @@ -600,7 +540,7 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM); closure->request(0)->set_table_id(table_id); closure->request(0)->set_client_id(_client_id); - PsService_Stub rpc_stub(get_cmd_channel(pserver_idx)); + PsService_Stub rpc_stub(GetCmdChannel(pserver_idx)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); @@ -608,10 +548,11 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, } // for GEO -std::future BrpcPsClient::push_sparse_param( - size_t table_id, const uint64_t *keys, const float **update_values, - size_t num, void *done) { - auto *accessor = table_accessor(table_id); +std::future BrpcPsClient::PushSparseParam(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) { + auto *accessor = GetTableAccessor(table_id); // 发送RPC请求 DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); @@ -649,7 +590,7 @@ std::future BrpcPsClient::push_sparse_param( memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } - PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), @@ -658,16 +599,15 @@ std::future BrpcPsClient::push_sparse_param( return fut; } -std::future BrpcPsClient::pull_dense(Region *regions, - size_t region_num, - size_t table_id) { +std::future BrpcPsClient::PullDense(Region *regions, size_t region_num, + size_t table_id) { auto timer = std::make_shared("pserver_client_pull_dense"); - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); auto fea_dim = accessor->GetTableInfo(FEA_DIM); auto select_size = accessor->GetTableInfo(SELECT_SIZE); size_t request_call_num = _server_channels.size(); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); // callback 将各shard结果,顺序填入region DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, num_per_shard, regions, region_num, @@ -730,22 +670,22 @@ std::future BrpcPsClient::pull_dense(Region *regions, closure->request(i)->set_client_id(_client_id); closure->request(i)->add_params((char *)&num_per_shard, // NOLINT sizeof(num_per_shard)); - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } -std::future BrpcPsClient::push_dense_param(const Region *regions, - size_t region_num, - size_t table_id) { - auto *accessor = table_accessor(table_id); +std::future BrpcPsClient::PushDenseParam(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = GetTableAccessor(table_id); size_t request_call_num = _server_channels.size(); // 1.拆分Region数据到shard中,后续多shard并行拷贝数据 std::vector> regions_partition(request_call_num); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE); size_t current_region_idx = 0; size_t current_region_data_idx = 0; @@ -809,17 +749,17 @@ std::future BrpcPsClient::push_dense_param(const Region *regions, fill_num); fill_remain_size -= fill_num; } - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } -std::future BrpcPsClient::push_sparse_raw_gradient( +std::future BrpcPsClient::PushSparseRawGradient( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) { - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); // 发送RPC请求 DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); @@ -872,7 +812,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } - PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), @@ -881,7 +821,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( return fut; } -std::future BrpcPsClient::push_dense_raw_gradient( +std::future BrpcPsClient::PushDenseRawGradient( int table_id, float *total_send_data, size_t total_send_data_size, void *done) { size_t request_call_num = _server_channels.size(); @@ -889,9 +829,9 @@ std::future BrpcPsClient::push_dense_raw_gradient( auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_table_id(table_id); @@ -905,16 +845,16 @@ std::future BrpcPsClient::push_dense_raw_gradient( total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); // closure->cntl(i)->set_request_compress_type( // (brpc::CompressType)FLAGS_pserver_communicate_compress_type); - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } -std::future BrpcPsClient::push_global_step(int table_id, - int64_t *total_send_data, - void *done) { +std::future BrpcPsClient::PushGlobalStep(int table_id, + int64_t *total_send_data, + void *done) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); @@ -933,17 +873,17 @@ std::future BrpcPsClient::push_global_step(int table_id, memcpy(push_data_ptr + sizeof(uint32_t), total_send_data, num_per_shard * sizeof(int64_t)); - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } -std::future BrpcPsClient::pull_sparse(float **select_values, - size_t table_id, - const uint64_t *keys, size_t num, - bool is_training) { +std::future BrpcPsClient::PullSparse(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num, + bool is_training) { auto timer = std::make_shared("pserver_client_pull_sparse"); auto local_timer = std::make_shared("pserver_client_pull_sparse_local"); @@ -968,7 +908,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); } - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); size_t value_size = accessor->GetTableInfo(SELECT_SIZE); @@ -1055,7 +995,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, closure->request(i)->set_client_id(_client_id); closure->request(i)->add_params((char *)&kv_request_count, // NOLINT sizeof(uint32_t)); - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); @@ -1065,11 +1005,11 @@ std::future BrpcPsClient::pull_sparse(float **select_values, } // for GEO -std::future BrpcPsClient::pull_sparse_param(float **select_values, - size_t table_id, - const uint64_t *keys, - size_t num, - bool is_training) { +std::future BrpcPsClient::PullSparseParam(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num, + bool is_training) { auto timer = std::make_shared("pserver_client_pull_sparse_param"); size_t request_call_num = _server_channels.size(); @@ -1082,7 +1022,7 @@ std::future BrpcPsClient::pull_sparse_param(float **select_values, shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); } - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); size_t value_size = accessor->GetTableInfo(SELECT_SIZE); DownpourBrpcClosure *closure = new DownpourBrpcClosure( @@ -1169,7 +1109,7 @@ std::future BrpcPsClient::pull_sparse_param(float **select_values, closure->request(i)->set_client_id(_client_id); closure->request(i)->add_params((char *)&kv_request_count, // NOLINT sizeof(uint32_t)); - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); @@ -1178,7 +1118,7 @@ std::future BrpcPsClient::pull_sparse_param(float **select_values, return fut; } -std::future BrpcPsClient::send_client2client_msg( +std::future BrpcPsClient::SendClient2ClientMsg( int msg_type, int to_client_id, const std::string &msg) { auto promise = std::make_shared>(); std::future fut = promise->get_future(); @@ -1203,10 +1143,10 @@ std::future BrpcPsClient::send_client2client_msg( return fut; } -std::future BrpcPsClient::push_sparse_raw_gradient_partial( +std::future BrpcPsClient::PushSparseRawGradientPartial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) { - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); size_t value_size = accessor->GetTableInfo(UPDATE_SIZE); DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); @@ -1228,7 +1168,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient_partial( memcpy(push_data_ptr, update_values[i], value_size); push_data_ptr += value_size; } - PsService_Stub rpc_stub(get_sparse_channel(pserver_idx)); + PsService_Stub rpc_stub(GetSparseChannel(pserver_idx)); closure->cntl(0)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), @@ -1236,8 +1176,8 @@ std::future BrpcPsClient::push_sparse_raw_gradient_partial( return fut; } -int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, - const std::string &path) { +int32_t BrpcPsClient::RecvAndSaveTable(const uint64_t table_id, + const std::string &path) { // get var information std::string var_name = ""; int64_t var_num = 0; @@ -1271,17 +1211,17 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, save_vec.push_back(save_huge_vec.data() + i * var_shape); } - VLOG(2) << "recv_and_save_table: table_class: " << table_class; + VLOG(2) << "RecvAndSaveTable: table_class: " << table_class; // TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its - // recv_and_save_table + // RecvAndSaveTable if (table_class == "MemorySparseGeoTable") { auto status = - pull_sparse_param(reinterpret_cast(save_vec.data()), table_id, - save_key.data(), save_key.size(), true); + PullSparseParam(reinterpret_cast(save_vec.data()), table_id, + save_key.data(), save_key.size(), true); status.wait(); } else { - auto status = pull_sparse(reinterpret_cast(save_vec.data()), - table_id, save_key.data(), save_key.size(), true); + auto status = PullSparse(reinterpret_cast(save_vec.data()), + table_id, save_key.data(), save_key.size(), true); status.wait(); } @@ -1315,15 +1255,15 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, return 0; } -std::future BrpcPsClient::push_sparse(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num) { +std::future BrpcPsClient::PushSparse(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num) { auto push_timer = std::make_shared("pserver_client_push_sparse"); CostTimer parse_timer("pserver_client_push_sparse_parse"); int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) { - // LOG(INFO) << "push_sparse Waiting for async_call_num comsume, + // LOG(INFO) << "PushSparse Waiting for async_call_num comsume, // task_num:" // << push_sparse_async_num // << ", max_task_limit:" << FLAGS_pserver_max_async_call_num; @@ -1333,7 +1273,7 @@ std::future BrpcPsClient::push_sparse(size_t table_id, auto put_timer = std::make_shared("client_push_sparse_put"); thread_local std::vector>> shard_sorted_kv_list; - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); size_t request_call_num = _server_channels.size(); shard_sorted_kv_list.resize(request_call_num); for (auto &x : shard_sorted_kv_list) { @@ -1381,7 +1321,7 @@ std::future BrpcPsClient::push_sparse(size_t table_id, return fut; } -void BrpcPsClient::push_sparse_task_consume() { +void BrpcPsClient::PushSparseTaskConsume() { uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit; std::vector> task_list; size_t request_call_num = _server_channels.size(); @@ -1392,7 +1332,7 @@ void BrpcPsClient::push_sparse_task_consume() { // 所有sparseTable的pushTask 进行处理 for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { auto table_id = push_sparse_task_itr.first; - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); auto &task_queue = push_sparse_task_itr.second; auto queue_size = task_queue->Size(); if (queue_size == 0) { @@ -1471,7 +1411,7 @@ void BrpcPsClient::push_sparse_task_consume() { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { merge_status[shard_idx] = async_push_sparse_shard_threads.enqueue(std::bind( - &BrpcPsClient::push_sparse_async_shard_push, this, task_list, + &BrpcPsClient::PushSparseAsyncShardPush, this, task_list, request_kv_num, table_id, shard_idx, closure, accessor)); } for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { @@ -1487,7 +1427,7 @@ void BrpcPsClient::push_sparse_task_consume() { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { merge_status[shard_idx] = async_push_sparse_shard_threads.enqueue(std::bind( - &BrpcPsClient::push_sparse_async_shard_merge, this, task_list, + &BrpcPsClient::PushSparseAsyncShardMerge, this, task_list, request_kv_num, table_id, shard_idx, accessor)); } for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { @@ -1523,7 +1463,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data, accessor->Merge(merge_data_shell, another_data_shell, 1); } -int BrpcPsClient::push_sparse_async_shard_merge( +int BrpcPsClient::PushSparseAsyncShardMerge( std::vector> &task_list, std::vector &request_kv_num, int table_id, int shard_idx, ValueAccessor *accessor) { @@ -1615,12 +1555,12 @@ int BrpcPsClient::push_sparse_async_shard_merge( return 0; } -int BrpcPsClient::push_sparse_async_shard_push( +int BrpcPsClient::PushSparseAsyncShardPush( std::vector> &task_list, std::vector &request_kv_num, int table_id, int shard_idx, DownpourBrpcClosure *closure, ValueAccessor *accessor) { - push_sparse_async_shard_merge(task_list, request_kv_num, table_id, shard_idx, - accessor); + PushSparseAsyncShardMerge(task_list, request_kv_num, table_id, shard_idx, + accessor); size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num; auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list; @@ -1649,7 +1589,7 @@ int BrpcPsClient::push_sparse_async_shard_push( accessor->GetTableInfo(UPDATE_SIZE)); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } - PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), @@ -1658,10 +1598,10 @@ int BrpcPsClient::push_sparse_async_shard_push( return 0; } -std::future BrpcPsClient::push_dense(const Region *regions, - size_t region_num, - size_t table_id) { - auto *accessor = table_accessor(table_id); +std::future BrpcPsClient::PushDense(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = GetTableAccessor(table_id); int fea_dim = accessor->GetTableInfo(FEA_DIM); int update_dim = accessor->GetTableInfo(UPDATE_DIM); auto push_timer = std::make_shared("pserver_client_push_dense"); @@ -1669,7 +1609,7 @@ std::future BrpcPsClient::push_dense(const Region *regions, std::make_shared("pserver_client_push_dense_parse"); int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size(); while (push_dense_async_num > FLAGS_pserver_max_async_call_num) { - // LOG(INFO) << "push_dense Waiting for async_call_num comsume, + // LOG(INFO) << "PushDense Waiting for async_call_num comsume, // task_num:" // << push_dense_async_num // << ", max_task_limit:" << FLAGS_pserver_max_async_call_num; @@ -1683,7 +1623,7 @@ std::future BrpcPsClient::push_dense(const Region *regions, size_t request_call_num = _server_channels.size(); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); // 将region数据拷贝到转置矩阵中 async_task->data()->resize(num_per_shard * request_call_num * @@ -1705,7 +1645,7 @@ std::future BrpcPsClient::push_dense(const Region *regions, return fut; } -void BrpcPsClient::push_dense_task_consume() { +void BrpcPsClient::PushDenseTaskConsume() { uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit; static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge; ::ThreadPool async_merge_dense_threads(10); @@ -1723,7 +1663,7 @@ void BrpcPsClient::push_dense_task_consume() { ++_async_call_num; DenseAsyncTask *task; task_queue->Get(task); - auto *accessor = table_accessor(task->table_id()); + auto *accessor = GetTableAccessor(task->table_id()); // 设置请求回调 size_t request_call_num = _server_channels.size(); @@ -1774,7 +1714,7 @@ void BrpcPsClient::push_dense_task_consume() { merge_status[i].wait(); } - VLOG(3) << "BrpcPsClient::push_dense_task_consume before merge " + VLOG(3) << "BrpcPsClient::PushDenseTaskConsume before merge " "total_send_data[0]" << total_send_data[0] << " total_send_data[-2]" << total_send_data[total_send_data_size - 2] @@ -1787,7 +1727,7 @@ void BrpcPsClient::push_dense_task_consume() { mat *= (1.0 / (merge_count + 1)); } - VLOG(3) << "BrpcPsClient::push_dense_task_consume after merge " + VLOG(3) << "BrpcPsClient::PushDenseTaskConsume after merge " "total_send_data[0]" << total_send_data[0] << " total_send_data[-2]" << total_send_data[total_send_data_size - 2] @@ -1796,8 +1736,8 @@ void BrpcPsClient::push_dense_task_consume() { << merge_count; } std::shared_ptr task_ptr(task); - push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size, - closure); + PushDenseRawGradient(task_ptr, total_send_data, total_send_data_size, + closure); } auto wait_ms = FLAGS_pserver_async_push_dense_interval_ms - (butil::gettimeofday_ms() - async_start_time_ms); @@ -1807,16 +1747,17 @@ void BrpcPsClient::push_dense_task_consume() { } } -void BrpcPsClient::push_dense_raw_gradient( - std::shared_ptr &task, float *total_send_data, - size_t total_send_data_size, DownpourBrpcClosure *closure) { - auto *accessor = table_accessor(task->table_id()); +void BrpcPsClient::PushDenseRawGradient(std::shared_ptr &task, + float *total_send_data, + size_t total_send_data_size, + DownpourBrpcClosure *closure) { + auto *accessor = GetTableAccessor(task->table_id()); size_t request_call_num = _server_channels.size(); // 将数据拷贝到请求buffer区 auto timer = std::make_shared("pserver_client_push_dense_rpc"); closure->add_timer(timer); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); auto send_timer = std::make_shared("pserver_client_push_dense_send"); for (size_t i = 0; i < request_call_num; ++i) { @@ -1832,7 +1773,7 @@ void BrpcPsClient::push_dense_raw_gradient( total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); closure->cntl(i)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.h b/paddle/fluid/distributed/ps/service/brpc_ps_client.h index 8b0cb0741b4004fbad444a9919ec540289067f55..f109b473ca1f455140559037f05b10d2f18d8027 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.h @@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService { DownpourPsClientService() {} virtual ~DownpourPsClientService() {} - virtual int32_t configure(PSClient *client, size_t rank_id) { + virtual int32_t Configure(PSClient *client, size_t rank_id) { _client = client; _rank = rank_id; return 0; @@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient { BrpcPsClient() {} virtual ~BrpcPsClient() { if (_running) { - flush(); + Flush(); _running = false; } if (_async_push_dense_thread.joinable()) { @@ -154,109 +154,98 @@ class BrpcPsClient : public PSClient { _server_started = false; } } - virtual int32_t create_client2client_connection( - int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); - std::future shrink(uint32_t table_id, + virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry); + std::future Shrink(uint32_t table_id, const std::string threshold) override; - std::future load(const std::string &epoch, + std::future Load(const std::string &epoch, const std::string &mode) override; - std::future load(uint32_t table_id, const std::string &epoch, + std::future Load(uint32_t table_id, const std::string &epoch, const std::string &mode) override; - std::future Load(const LoadSaveContext &load_context) override; - - std::future save(const std::string &epoch, + std::future Save(const std::string &epoch, const std::string &mode) override; - std::future save(uint32_t table_id, const std::string &epoch, + std::future Save(uint32_t table_id, const std::string &epoch, const std::string &mode) override; - virtual std::future Save( - const LoadSaveContext &save_context) override; - - std::future clear() override; - - std::future clear(uint32_t table_id) override; + std::future Clear() override; - std::future stop_server() override; + std::future Clear(uint32_t table_id) override; - std::future start_profiler() override; - std::future stop_profiler() override; + std::future StopServer() override; - void finalize_worker() override; + std::future StartProfiler() override; + std::future StopProfiler() override; - virtual std::future pull_dense(Region *regions, size_t region_num, - size_t table_id); + void FinalizeWorker() override; - virtual std::future push_dense_param(const Region *regions, - size_t region_num, - size_t table_id); + virtual std::future PullDense(Region *regions, size_t region_num, + size_t table_id); - virtual std::future push_dense(const Region *regions, - size_t region_num, size_t table_id); - void push_dense_task_consume(); - virtual std::future pull_sparse(float **select_values, - size_t table_id, - const uint64_t *keys, size_t num, - bool is_training); - virtual std::future pull_sparse_param(float **select_values, - size_t table_id, - const uint64_t *keys, - size_t num, bool is_training); + virtual std::future PushDenseParam(const Region *regions, + size_t region_num, + size_t table_id); - virtual std::future Pull(RequestContext &pull_context) override; + virtual std::future PushDense(const Region *regions, + size_t region_num, size_t table_id); + void PushDenseTaskConsume(); + virtual std::future PullSparse(float **select_values, + size_t table_id, const uint64_t *keys, + size_t num, bool is_training); + virtual std::future PullSparseParam(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num, + bool is_training); - virtual std::future Push(RequestContext &push_context) override; + virtual std::future PrintTableStat(uint32_t table_id); - virtual std::future print_table_stat(uint32_t table_id); + virtual std::future Barrier(size_t table_id, uint32_t barrier_type); - virtual std::future barrier(size_t table_id, uint32_t barrier_type); + virtual std::future PullGeoParam(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx); + virtual std::future PushGlobalStep(int table_id, + int64_t *total_send_data, + void *done); + virtual std::future Flush(); - virtual std::future pull_geo_param(size_t table_id, - std::vector *values, - std::vector *keys, - int pserver_idx); - virtual std::future push_global_step(int table_id, - int64_t *total_send_data, - void *done); - virtual std::future flush(); - - std::future send_client2client_msg(int msg_type, int to_client_id, - const std::string &msg) override; + std::future SendClient2ClientMsg(int msg_type, int to_client_id, + const std::string &msg) override; // for local save sparse - virtual int32_t recv_and_save_table(const uint64_t table_id, - const std::string &path); + virtual int32_t RecvAndSaveTable(const uint64_t table_id, + const std::string &path); - void print_queue_size(); - void print_queue_size_thread(); + void PrintQueueSize(); + void PrintQueueSizeThread(); protected: - virtual size_t get_server_nums() { return _server_channels.size(); } - inline brpc::Channel *get_sparse_channel(size_t server_id) { + virtual size_t GetServerNums() { return _server_channels.size(); } + inline brpc::Channel *GetSparseChannel(size_t server_id) { return _server_channels[server_id][0].get(); } - inline brpc::Channel *get_dense_channel(size_t server_id) { + inline brpc::Channel *GetDenseChannel(size_t server_id) { return _server_channels[server_id][1].get(); } - inline brpc::Channel *get_cmd_channel(size_t server_id) { + inline brpc::Channel *GetCmdChannel(size_t server_id) { return _server_channels[server_id][2].get(); } - int32_t initialize() override; + int32_t Initialize() override; private: - // virtual int32_t initialize() override; - - inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, - uint32_t shard_num) { + inline uint32_t DenseDimPerShard(uint32_t dense_dim_total, + uint32_t shard_num) { return dense_dim_total / shard_num + 1; } - std::future send_cmd(uint32_t table_id, int cmd_id, - const std::vector ¶m); + std::future SendCmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); - std::future send_save_cmd(uint32_t table_id, int cmd_id, - const std::vector ¶m); + std::future SendSaveCmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); bool _running = false; bool _flushing = false; @@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient { std::thread _print_thread; - int push_sparse_async_shard_merge( + int PushSparseAsyncShardMerge( std::vector> &task_list, // NOLINT std::vector &request_kv_num, int table_id, int shard_idx, // NOLINT ValueAccessor *accessor); - int push_sparse_async_shard_push( + int PushSparseAsyncShardPush( std::vector> &task_list, // NOLINT std::vector &request_kv_num, int table_id, int shard_idx, // NOLINT DownpourBrpcClosure *closure, ValueAccessor *accessor); @@ -292,36 +281,36 @@ class BrpcPsClient : public PSClient { _client_channels; // client2client std::vector, 3>> _server_channels; // client2server - std::future push_dense_raw_gradient(int table_id, - float *total_send_data, - size_t total_send_data_size, - void *done) override; - - std::future push_sparse_raw_gradient(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num, - void *done) override; - - std::future push_sparse_raw_gradient_partial( - size_t table_id, const uint64_t *keys, const float **update_values, - uint32_t num, void *done, int pserver_idx) override; - - std::future push_sparse_param(size_t table_id, const uint64_t *keys, - const float **update_values, - size_t num, void *done) override; - std::future push_sparse(size_t table_id, const uint64_t *keys, - const float **update_values, - size_t num) override; - void push_sparse_task_consume(); + std::future PushDenseRawGradient(int table_id, + float *total_send_data, + size_t total_send_data_size, + void *done) override; + + std::future PushSparseRawGradient(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) override; + + std::future PushSparseRawGradientPartial(size_t table_id, + const uint64_t *keys, + const float **update_values, + uint32_t num, void *done, + int pserver_idx) override; + + std::future PushSparseParam(size_t table_id, const uint64_t *keys, + const float **update_values, size_t num, + void *done) override; + std::future PushSparse(size_t table_id, const uint64_t *keys, + const float **update_values, + size_t num) override; + void PushSparseTaskConsume(); private: - int32_t start_client_service(); + int32_t StartClientService(); - void push_dense_raw_gradient(std::shared_ptr &task, // NOLINT - float *total_send_data, - size_t total_send_data_size, - DownpourBrpcClosure *closure); + void PushDenseRawGradient(std::shared_ptr &task, // NOLINT + float *total_send_data, size_t total_send_data_size, + DownpourBrpcClosure *closure); float _mae = 0; float _mse = 0; uint16_t _push_times = 0; diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index 2e77020c3075179330c762a9a74d40c13d190116..1d88d88ebcf140bd1e4081e82a734623574b7e27 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -31,7 +31,7 @@ class RpcController; namespace paddle { namespace distributed { -int32_t BrpcPsServer::initialize() { +int32_t BrpcPsServer::Initialize() { auto &service_config = _config.downpour_server_param().service_param(); if (!service_config.has_service_class()) { LOG(ERROR) << "miss service_class in ServerServiceParameter"; @@ -46,7 +46,7 @@ int32_t BrpcPsServer::initialize() { } _service.reset(service); - if (service->configure(this) != 0 || service->initialize() != 0) { + if (service->Configure(this) != 0 || service->Initialize() != 0) { LOG(ERROR) << "service initialize failed, service_name:" << service_config.service_class(); return -1; @@ -59,7 +59,7 @@ int32_t BrpcPsServer::initialize() { return 0; } -uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { +uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) { std::unique_lock lock(mutex_); std::string ip_port = ip + ":" + std::to_string(port); @@ -68,7 +68,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { brpc::ServerOptions options; int num_threads = std::thread::hardware_concurrency(); - auto trainers = _environment->get_trainers(); + auto trainers = _environment->GetTrainers(); options.num_threads = trainers > num_threads ? trainers : num_threads; if (_server.Start(ip_port.c_str(), &options) != 0) { @@ -83,7 +83,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { } } - _environment->registe_ps_server(ip, port, _rank); + _environment->RegistePsServer(ip, port, _rank); cv_.wait(lock, [&] { return stoped_; }); PSHost host; @@ -93,31 +93,30 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { return host.rank; } -int32_t BrpcPsServer::port() { return _server.listen_address().port; } +int32_t BrpcPsServer::Port() { return _server.listen_address().port; } -int32_t BrpcPsService::initialize() { +int32_t BrpcPsService::Initialize() { _is_initialize_shard_info = false; - _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::stop_server; - _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::pull_dense; - _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::push_dense; - _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::pull_sparse; - _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::push_sparse; - _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::save_one_table; - _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::save_all_table; - _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::shrink_table; - _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::load_one_table; - _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::load_all_table; - _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::clear_one_table; - _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::clear_all_table; - _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::push_dense_param; - _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::print_table_stat; - _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::pull_geo_param; - _service_handler_map[PS_PUSH_SPARSE_PARAM] = - &BrpcPsService::push_sparse_param; - _service_handler_map[PS_BARRIER] = &BrpcPsService::barrier; - _service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler; - _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler; - _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::push_global_step; + _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::StopServer; + _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::PullDense; + _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::PushDense; + _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::PullSparse; + _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::PushSparse; + _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::SaveOneTable; + _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::SaveAllTable; + _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::ShrinkTable; + _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::LoadOneTable; + _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::LoadAllTable; + _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::ClearOneTable; + _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::ClearAllTable; + _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::PushDenseParam; + _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::PrintTableStat; + _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::PullGeoParam; + _service_handler_map[PS_PUSH_SPARSE_PARAM] = &BrpcPsService::PushSparseParam; + _service_handler_map[PS_BARRIER] = &BrpcPsService::Barrier; + _service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler; + _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler; + _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep; auto &profiler = CostProfiler::instance(); profiler.register_profiler("pserver_server_pull_dense"); profiler.register_profiler("pserver_server_push_dense"); @@ -125,7 +124,7 @@ int32_t BrpcPsService::initialize() { profiler.register_profiler("pserver_server_push_sparse"); // shard初始化,server启动后才可从env获取到server_list的shard信息 - initialize_shard_info(); + InitializeShardInfo(); return 0; } @@ -138,16 +137,16 @@ int32_t BrpcPsService::initialize() { return -1; \ } -int32_t BrpcPsService::initialize_shard_info() { +int32_t BrpcPsService::InitializeShardInfo() { if (!_is_initialize_shard_info) { std::lock_guard guard(_initialize_shard_mutex); if (_is_initialize_shard_info) { return 0; } - size_t shard_num = _server->environment()->get_ps_servers().size(); - auto &table_map = *(_server->table()); + size_t shard_num = _server->Environment()->GetPsServers().size(); + auto &table_map = *(_server->GetTable()); for (auto itr : table_map) { - itr.second->set_shard(_rank, shard_num); + itr.second->SetShard(_rank, shard_num); } _is_initialize_shard_info = true; } @@ -167,7 +166,7 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, response->set_err_code(0); response->set_err_msg(""); - auto *table = _server->table(request->table_id()); + auto *table = _server->GetTable(request->table_id()); brpc::Controller *cntl = static_cast(cntl_base); auto itr = _service_handler_map.find(request->cmd_id()); if (itr == _service_handler_map.end()) { @@ -185,11 +184,11 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, } } -int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( - "PsService->pull_dense", platform::TracerEventType::Communication, 1); + "PsService->PullDense", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 1) { set_response_code( @@ -206,14 +205,15 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, } auto res_data = butil::get_object>(); - res_data->resize(num * table->value_accesor()->GetTableInfo(SELECT_SIZE) / + res_data->resize(num * table->ValueAccesor()->GetTableInfo(SELECT_SIZE) / sizeof(float)); + TableContext table_context; table_context.value_type = Dense; table_context.pull_context.values = res_data->data(); table_context.num = num; table->Pull(table_context); - // table->pull_dense(res_data->data(), num); + // table->PullDense(res_data->data(), num); cntl->response_attachment().append((char *)(res_data->data()), res_data->size() * sizeof(float)); @@ -222,13 +222,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, return 0; } -int32_t BrpcPsService::push_dense_param(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - platform::RecordEvent record_event("PsService->push_dense_param", - platform::TracerEventType::Communication, - 1); +int32_t BrpcPsService::PushDenseParam(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event( + "PsService->PushDenseParam", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) thread_local std::string push_buffer; auto &req_io_buffer = cntl->request_attachment(); @@ -245,17 +244,17 @@ int32_t BrpcPsService::push_dense_param(Table *table, uint32_t num = *(const uint32_t *)data; const float *values = (const float *)(data + sizeof(uint32_t)); - if (table->push_dense_param(values, num) != 0) { - set_response_code(response, -1, "push_dense_param failed"); + if (table->PushDenseParam(values, num) != 0) { + set_response_code(response, -1, "PushDenseParam failed"); } return 0; } -int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( - "PsService->push_dense", platform::TracerEventType::Communication, 1); + "PsService->PushDense", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) auto req_buffer_size = request.data().size(); if (req_buffer_size < 1) { @@ -278,14 +277,14 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, // const float *values = (const float *)(request.data().data() + // sizeof(uint32_t)); if (table->Push(table_context) != 0) { - // if (table->push_dense(values, num) != 0) { - set_response_code(response, -1, "push_dense failed"); + // if (table->PushDense(values, num) != 0) { + set_response_code(response, -1, "PushDense failed"); } return 0; } -int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request, +int32_t BrpcPsService::Barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) @@ -299,15 +298,15 @@ int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request, auto trainer_id = request.client_id(); auto barrier_type = request.params(0); - table->barrier(trainer_id, barrier_type); + table->Barrier(trainer_id, barrier_type); return 0; } -int32_t BrpcPsService::push_sparse_param(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - platform::RecordEvent record_event("PsService->push_sparse_param", +int32_t BrpcPsService::PushSparseParam(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->PushSparseParam", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) @@ -331,16 +330,16 @@ int32_t BrpcPsService::push_sparse_param(Table *table, const uint64_t *keys = (const uint64_t *)push_data.data(); const float *values = (const float *)(push_data.data() + sizeof(uint64_t) * num); - if (table->push_sparse_param(keys, values, num) != 0) { - set_response_code(response, -1, "push_sparse_param error"); + if (table->PushSparseParam(keys, values, num) != 0) { + set_response_code(response, -1, "PushSparseParam error"); } return 0; } -int32_t BrpcPsService::pull_geo_param(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PullGeoParam(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( "PsService->pull_geo_param", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) @@ -350,7 +349,7 @@ int32_t BrpcPsService::pull_geo_param(Table *table, std::vector values; std::vector ids; - table->pull_geo_param(trainer_id, &values, &ids); + table->PullGeoParam(trainer_id, &values, &ids); uint32_t num = ids.size(); cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); @@ -361,12 +360,11 @@ int32_t BrpcPsService::pull_geo_param(Table *table, return 0; } -int32_t BrpcPsService::pull_sparse(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( - "PsService->pull_sparse", platform::TracerEventType::Communication, 1); + "PsService->PullSparse", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) auto &req_io_buffer = cntl->request_attachment(); @@ -386,7 +384,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, CostTimer timer("pserver_server_pull_sparse"); uint32_t num = *(uint32_t *)(request.params(0).c_str()); - auto dim = table->value_accesor()->GetTableInfo(SELECT_DIM); + auto dim = table->ValueAccesor()->GetTableInfo(SELECT_DIM); thread_local std::string req_buffer; req_buffer.reserve(req_buffer_size); @@ -405,7 +403,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, table_context.pull_context.pull_value = value; table_context.pull_context.values = res_data->data(); table->Pull(table_context); - // table->pull_sparse(res_data->data(), value); + // table->PullSparse(res_data->data(), value); cntl->response_attachment().append((char *)(res_data->data()), res_data->size() * sizeof(float)); @@ -413,12 +411,11 @@ int32_t BrpcPsService::pull_sparse(Table *table, return 0; } -int32_t BrpcPsService::push_sparse(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( - "PsService->push_sparse", platform::TracerEventType::Communication, 1); + "PsService->PushSparse", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) auto &push_data = request.data(); if (push_data.size() < 1) { @@ -448,18 +445,18 @@ int32_t BrpcPsService::push_sparse(Table *table, // const float *values = (const float *)(push_data.data() + sizeof(uint64_t) * // num); if (table->Push(table_context) != 0) { - // if (table->push_sparse(keys, values, num) != 0) { - set_response_code(response, -1, "push_sparse error"); + // if (table->PushSparse(keys, values, num) != 0) { + set_response_code(response, -1, "PushSparse error"); } return 0; } -int32_t BrpcPsService::print_table_stat(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PrintTableStat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) - std::pair ret = table->print_table_stat(); + std::pair ret = table->PrintTableStat(); paddle::framework::BinaryArchive ar; ar << ret.first << ret.second; std::string table_info(ar.Buffer(), ar.Length()); @@ -468,10 +465,10 @@ int32_t BrpcPsService::print_table_stat(Table *table, return 0; } -int32_t BrpcPsService::load_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::LoadOneTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 2) { set_response_code( @@ -479,20 +476,20 @@ int32_t BrpcPsService::load_one_table(Table *table, "PsRequestMessage.datas is requeired at least 2 for path & load_param"); return -1; } - if (table->load(request.params(0), request.params(1)) != 0) { + if (table->Load(request.params(0), request.params(1)) != 0) { set_response_code(response, -1, "table load failed"); return -1; } return 0; } -int32_t BrpcPsService::load_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - auto &table_map = *(_server->table()); +int32_t BrpcPsService::LoadAllTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); for (auto &itr : table_map) { - if (load_one_table(itr.second.get(), request, response, cntl) != 0) { + if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) { LOG(ERROR) << "load table[" << itr.first << "] failed"; return -1; } @@ -500,10 +497,10 @@ int32_t BrpcPsService::load_all_table(Table *table, return 0; } -int32_t BrpcPsService::save_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::SaveOneTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 2) { set_response_code( @@ -511,12 +508,12 @@ int32_t BrpcPsService::save_one_table(Table *table, "PsRequestMessage.datas is requeired at least 2, path&mode"); return -1; } - table->flush(); + table->Flush(); int32_t feasign_size = 0; VLOG(3) << "save table " << request.params(0) << " " << request.params(1); - feasign_size = table->save(request.params(0), request.params(1)); + feasign_size = table->Save(request.params(0), request.params(1)); if (feasign_size < 0) { set_response_code(response, -1, "table save failed"); return -1; @@ -524,16 +521,16 @@ int32_t BrpcPsService::save_one_table(Table *table, return feasign_size; } -int32_t BrpcPsService::save_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - auto &table_map = *(_server->table()); +int32_t BrpcPsService::SaveAllTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); int32_t all_feasign_size = 0; int32_t feasign_size = 0; for (auto &itr : table_map) { - feasign_size = save_one_table(itr.second.get(), request, response, cntl); + feasign_size = SaveOneTable(itr.second.get(), request, response, cntl); if (feasign_size < 0) { LOG(ERROR) << "save table[" << itr.first << "] failed"; return -1; @@ -542,10 +539,10 @@ int32_t BrpcPsService::save_all_table(Table *table, return 0; } -int32_t BrpcPsService::shrink_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::ShrinkTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 1) { set_response_code( @@ -553,8 +550,8 @@ int32_t BrpcPsService::shrink_table(Table *table, "PsRequestMessage.datas is requeired at least 1, threshold"); return -1; } - table->flush(); - if (table->shrink(request.params(0)) != 0) { + table->Flush(); + if (table->Shrink(request.params(0)) != 0) { set_response_code(response, -1, "table shrink failed"); return -1; } @@ -562,63 +559,62 @@ int32_t BrpcPsService::shrink_table(Table *table, return 0; } -int32_t BrpcPsService::clear_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::ClearOneTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) - table->flush(); - table->clear(); + table->Flush(); + table->Clear(); return 0; } -int32_t BrpcPsService::clear_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - auto &table_map = *(_server->table()); +int32_t BrpcPsService::ClearAllTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); for (auto &itr : table_map) { - if (clear_one_table(itr.second.get(), request, response, cntl) != 0) { + if (ClearOneTable(itr.second.get(), request, response, cntl) != 0) { return -1; } } return 0; } -int32_t BrpcPsService::stop_server(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::StopServer(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { auto *p_server = _server; std::thread t_stop([p_server]() { - p_server->stop(); + p_server->Stop(); VLOG(3) << "Server Stoped"; }); t_stop.detach(); return 0; } -int32_t BrpcPsService::stop_profiler(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::StopProfiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::DisableProfiler(platform::EventSortingKey::kDefault, string::Sprintf("server_%s_profile", _rank)); return 0; } -int32_t BrpcPsService::start_profiler(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::StartProfiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::EnableProfiler(platform::ProfilerState::kCPU); return 0; } -int32_t BrpcPsService::push_global_step(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PushGlobalStep(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response); auto req_buffer_size = request.data().size(); if (req_buffer_size < 1) { @@ -629,7 +625,7 @@ int32_t BrpcPsService::push_global_step(Table *table, const int64_t *values = (const int64_t *)(request.data().data() + sizeof(uint32_t)); auto trainer_id = request.client_id(); - if (table->push_dense(values, trainer_id) != 0) { + if (table->PushDense(values, trainer_id) != 0) { set_response_code(response, -1, "run_program failed"); } diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.h b/paddle/fluid/distributed/ps/service/brpc_ps_server.h index d81a3a5df07f1de534cd646138fecc4dc2c970e1..250f465d84253731df3198ca92baca022864974b 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.h @@ -41,8 +41,8 @@ class BrpcPsServer : public PSServer { public: BrpcPsServer() {} virtual ~BrpcPsServer() {} - virtual uint64_t start(const std::string &ip, uint32_t port); - virtual int32_t stop() { + virtual uint64_t Start(const std::string &ip, uint32_t port); + virtual int32_t Stop() { std::unique_lock lock(mutex_); stoped_ = true; cv_.notify_all(); @@ -51,10 +51,10 @@ class BrpcPsServer : public PSServer { _server.Join(); return 0; } - int32_t port(); + int32_t Port(); private: - virtual int32_t initialize(); + virtual int32_t Initialize(); mutable std::mutex mutex_; std::condition_variable cv_; bool stoped_ = false; @@ -71,7 +71,7 @@ typedef int32_t (BrpcPsService::*serviceHandlerFunc)( class BrpcPsService : public PsBaseService { public: - virtual int32_t initialize() override; + virtual int32_t Initialize() override; virtual void service(::google::protobuf::RpcController *controller, const PsRequestMessage *request, @@ -79,50 +79,49 @@ class BrpcPsService : public PsBaseService { ::google::protobuf::Closure *done) override; private: - int32_t initialize_shard_info(); - int32_t pull_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_dense_param(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_sparse_param(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl); - int32_t pull_sparse(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t pull_geo_param(Table *table, const PsRequestMessage &request, + int32_t InitializeShardInfo(); + int32_t PullDense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t PushDense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t PushDenseParam(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t barrier(Table *table, const PsRequestMessage &request, + int32_t PushSparseParam(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t PullSparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t PullGeoParam(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t Barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_sparse(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t load_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t load_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t save_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t save_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t shrink_table(Table *table, const PsRequestMessage &request, + int32_t PushSparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t LoadOneTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t clear_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t clear_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t stop_server(Table *table, const PsRequestMessage &request, + int32_t LoadAllTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t SaveOneTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t SaveAllTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t ShrinkTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t start_profiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t stop_profiler(Table *table, const PsRequestMessage &request, + int32_t ClearOneTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t ClearAllTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StopServer(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StartProfiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StopProfiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); - int32_t print_table_stat(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); + int32_t PrintTableStat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_global_step(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); + int32_t PushGlobalStep(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); bool _is_initialize_shard_info; std::mutex _initialize_shard_mutex; diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.cc b/paddle/fluid/distributed/ps/service/communicator/communicator.cc index 50c34bd319253aedeab7c51014db98bd655f88d7..c4b833f294e177f13fcd7e99a086f14260502010 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.cc +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.cc @@ -39,7 +39,7 @@ inline double GetCurrentUS() { Communicator::Communicator() {} -void Communicator::init_gflag(const std::string &gflags) { +void Communicator::InitGFlag(const std::string &gflags) { VLOG(3) << "Init With Gflags:" << gflags; std::vector flags = paddle::string::split_string(gflags); if (flags.size() < 1) { @@ -73,7 +73,7 @@ void Communicator::InitBrpcClient( } std::vector Communicator::GetClientInfo() { - std::vector res = _ps_env.get_client_info(); + std::vector res = _ps_env.GetClientInfo(); for (auto rr : res) { VLOG(2) << "Communicator::GetClientInfo " << rr; } @@ -82,7 +82,7 @@ std::vector Communicator::GetClientInfo() { int Communicator::SetClients(std::vector &host_sign_list) { int node = host_sign_list.size(); - return _ps_env.set_ps_clients(host_sign_list.data(), node); + return _ps_env.SetPsClients(host_sign_list.data(), node); } void Communicator::RpcRecvDense(const std::vector &varnames, @@ -114,7 +114,7 @@ void Communicator::RpcRecvDense(const std::vector &varnames, } } auto status = - _worker_ptr->pull_dense(regions.data(), regions.size(), table_id); + _worker_ptr->PullDense(regions.data(), regions.size(), table_id); status.wait(); for (auto &t : varnames) { @@ -177,7 +177,7 @@ void Communicator::RpcSendDenseParam(const std::vector &varnames, } } auto status = - _worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); + _worker_ptr->PushDenseParam(regions.data(), regions.size(), table_id); status.wait(); VLOG(4) << "RPC Send Dense Param " << table_id << " done!"; return; @@ -190,9 +190,9 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { auto &var_names = ctx.origin_varnames; auto &table_id = ctx.table_id; auto dense_data = std::make_shared>(); - size_t request_call_num = _worker_ptr->get_server_nums(); + size_t request_call_num = _worker_ptr->GetServerNums(); uint32_t num_per_shard = - dense_dim_per_shard(ctx.height_sections[0], request_call_num); + DenseDimPerShard(ctx.height_sections[0], request_call_num); dense_data->resize(num_per_shard * request_call_num); // accessor->update_dim() = 1 float *data = dense_data->data(); @@ -222,8 +222,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { closure->set_promise_value(ret); --_async_call_num; }); - auto status = _worker_ptr->push_dense_raw_gradient( - table_id, data, dense_data->size(), closure); + auto status = _worker_ptr->PushDenseRawGradient(table_id, data, + dense_data->size(), closure); status.wait(); return; } @@ -233,7 +233,7 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, platform::RecordEvent record_event("Communicator->RpcSendSparseParam", platform::TracerEventType::Communication, 1); - size_t request_call_num = _worker_ptr->get_server_nums(); + size_t request_call_num = _worker_ptr->GetServerNums(); std::vector push_g_vec; auto *send_var = scope.FindVar(varname); @@ -260,9 +260,9 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, } closure->set_promise_value(ret); }); - auto status = _worker_ptr->push_sparse_param( - table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), - sparse_push_keys.size(), closure); + auto status = _worker_ptr->PushSparseParam(table_id, sparse_push_keys.data(), + (const float **)push_g_vec.data(), + sparse_push_keys.size(), closure); status.wait(); return; } @@ -272,7 +272,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, platform::RecordEvent record_event("Communicator->RpcSendSparse", platform::TracerEventType::Communication, 1); - size_t request_call_num = _worker_ptr->get_server_nums(); + size_t request_call_num = _worker_ptr->GetServerNums(); std::vector sparse_push_keys; std::vector push_g_vec; @@ -313,7 +313,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, closure->set_promise_value(ret); --_async_call_num; }); - auto status = _worker_ptr->push_sparse_raw_gradient( + auto status = _worker_ptr->PushSparseRawGradient( table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), sparse_push_keys.size(), closure); status.wait(); @@ -340,7 +340,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id, bool training = true; - auto status = _worker_ptr->pull_sparse_param( + auto status = _worker_ptr->PullSparseParam( (float **)push_g_vec.data(), table_id, // NOLINT sparse_push_keys.data(), sparse_push_keys.size(), training); status.wait(); @@ -376,11 +376,11 @@ void Communicator::RpcProfilerControl() { if (!do_server_profiler_ && platform::IsProfileEnabled()) { // send profiler start flag do_server_profiler_ = true; - auto start_status = _worker_ptr->start_profiler(); + auto start_status = _worker_ptr->StartProfiler(); start_status.wait(); } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { // send profiler end flag - auto stop_status = _worker_ptr->stop_profiler(); + auto stop_status = _worker_ptr->StopProfiler(); stop_status.wait(); do_server_profiler_ = false; } @@ -396,7 +396,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, platform::TracerEventType::Communication, 1); auto &table_id = ctx.table_id; - size_t request_call_num = _worker_ptr->get_server_nums(); + size_t request_call_num = _worker_ptr->GetServerNums(); auto &var_name = STEP_COUNTER; auto *out_var = send_scope->Var(var_name); @@ -416,7 +416,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, } closure->set_promise_value(ret); }); - auto status = _worker_ptr->push_global_step(table_id, data, closure); + auto status = _worker_ptr->PushGlobalStep(table_id, data, closure); status.wait(); return; } @@ -605,8 +605,8 @@ void AsyncCommunicator::PullSparseToTensorSync( } } auto status = - _worker_ptr->pull_sparse(pull_result_ptr.data(), table_id, - fea_keys.data(), fea_keys.size(), is_training); + _worker_ptr->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(), + fea_keys.size(), is_training); status.wait(); auto ret = status.get(); if (ret != 0) { @@ -738,9 +738,9 @@ void AsyncCommunicator::PushSparseFromTensorAsync( this->Check(table_id), true, platform::errors::InvalidArgument( "can not find table: %s, please check your config", table_id)); - auto status = _worker_ptr->push_sparse(table_id, push_keys.data(), - (const float **)push_g_vec.data(), - push_keys.size()); + auto status = _worker_ptr->PushSparse(table_id, push_keys.data(), + (const float **)push_g_vec.data(), + push_keys.size()); } void HalfAsyncCommunicator::MainThread() { @@ -813,7 +813,7 @@ void AsyncCommunicator::Stop() { if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; } else { - // _worker_ptr->finalize_worker(); + // _worker_ptr->FinalizeWorker(); VLOG(1) << "client finalize_worker done"; if (recv_thread_) { VLOG(1) << "stop recv thread"; @@ -1327,7 +1327,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, closure->set_promise_value(ret); --_async_call_num; }); - auto status = _worker_ptr->push_sparse_raw_gradient_partial( + auto status = _worker_ptr->PushSparseRawGradientPartial( table_id, (const uint64_t *)sparse_ids.data(), (const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx); status.wait(); @@ -1345,7 +1345,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, // 1. recv from pserver std::vector keys; std::vector values; - auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx); + auto status = _worker_ptr->PullGeoParam(table_id, &values, &keys, ep_idx); status.wait(); std::string param = SplitedGradToParam(varname); diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.h b/paddle/fluid/distributed/ps/service/communicator/communicator.h index da4b46928d55c827a6fd2ed1e6801cd85b1098a2..75676c392435cc2fc010d736556eabce85189790 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.h +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.h @@ -299,7 +299,7 @@ class Communicator { virtual void Barrier() {} virtual void BarrierWithTable(uint32_t barrier_type) { - auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); + auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type); rets.wait(); int status = rets.get(); PADDLE_ENFORCE_EQ(status, 0, @@ -310,7 +310,7 @@ class Communicator { virtual void CreateC2CConnection(int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { - _worker_ptr->create_client2client_connection( + _worker_ptr->CreateClient2ClientConnection( pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); } @@ -379,12 +379,12 @@ class Communicator { std::unordered_map envs; // 计算每个shard 对 dense的存储量 - inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, - uint32_t shard_num) { + inline uint32_t DenseDimPerShard(uint32_t dense_dim_total, + uint32_t shard_num) { return dense_dim_total / shard_num + 1; } - void init_gflag(const std::string &gflags); + void InitGFlag(const std::string &gflags); paddle::distributed::PSParameter _ps_param; paddle::distributed::PaddlePSEnvironment _ps_env; int servers_ = 0; diff --git a/paddle/fluid/distributed/ps/service/env.h b/paddle/fluid/distributed/ps/service/env.h index 0cc57229b7a82ddeb7eca040335aa83b18810eb9..162ee6f0984223baf10a95a2853c41e0c6821910 100644 --- a/paddle/fluid/distributed/ps/service/env.h +++ b/paddle/fluid/distributed/ps/service/env.h @@ -40,7 +40,7 @@ struct PSHost { // |---ip---|---port---|--rank--| // |-32bit--|--20bit---|--12bit-| - uint64_t serialize_to_uint64() { + uint64_t SerializeToUint64() { uint64_t host_label = 0; host_label = inet_addr(ip.c_str()); host_label = host_label << 32; @@ -49,7 +49,7 @@ struct PSHost { return host_label; } - void parse_from_uint64(uint64_t host_label) { + void ParseFromUint64(uint64_t host_label) { static uint64_t rank_label_mask = (1L << 12) - 1; static uint64_t port_label_mask = (1L << 20) - 1; rank = host_label & rank_label_mask; @@ -58,17 +58,17 @@ struct PSHost { ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT } - std::string to_string() { + std::string ToString() { std::stringstream s; s << "host: " << ip; s << " port: " << port; s << " rank: " << rank; - s << " uint: " << serialize_to_uint64(); + s << " uint: " << SerializeToUint64(); return s.str(); } // for open source parameter server - std::string serialize_to_string() { + std::string SerializeToString() { std::stringstream s; s << ip << ":"; s << port << ":"; @@ -76,16 +76,16 @@ struct PSHost { return s.str(); } - void parse_from_string(std::string endpoint) { + void ParseFromString(std::string endpoint) { std::vector endpoint_info; - string_split(endpoint, ':', &endpoint_info); + StringSplit(endpoint, ':', &endpoint_info); ip = endpoint_info[0]; port = std::stoi(endpoint_info[1]); rank = std::stoi(endpoint_info[2]); } - void string_split(const std::string &str, char sep, - std::vector *pieces, bool ignore_null = true) { + void StringSplit(const std::string &str, char sep, + std::vector *pieces, bool ignore_null = true) { pieces->clear(); if (str.empty()) { if (!ignore_null) { @@ -111,63 +111,60 @@ class PSEnvironment { explicit PSEnvironment() {} // NOLINT virtual ~PSEnvironment() {} - virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) { return 0; } - virtual int32_t set_ps_servers( + virtual int32_t SetPsServers( const std::vector *host_endpoint_list, int node_num) { return 0; } - virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) { return 0; } - virtual int32_t set_ps_clients(std::string *host_endpoint_list, - int node_num) { + virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) { return 0; } - virtual uint64_t get_local_host_sign() { return 0; } - virtual std::vector get_ps_servers() const { return _ps_server_list; } - virtual int32_t registe_ps_server(const std::string &ip, uint32_t port, - int32_t rank) { - return registe_ps_host(ip, port, rank, _ps_server_list, - _ps_server_sign_set); + virtual uint64_t GetLocalHostSign() { return 0; } + virtual std::vector GetPsServers() const { return _ps_server_list; } + virtual int32_t RegistePsServer(const std::string &ip, uint32_t port, + int32_t rank) { + return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set); } - virtual std::vector get_ps_clients() const { return _ps_client_list; } - virtual int32_t registe_ps_client(const std::string &ip, uint32_t port, - int32_t rank) { - return registe_ps_host(ip, port, rank, _ps_client_list, - _ps_client_sign_set); + virtual std::vector GetPsClients() const { return _ps_client_list; } + virtual int32_t RegistePsClient(const std::string &ip, uint32_t port, + int32_t rank) { + return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set); } - virtual std::vector get_client_info() { + virtual std::vector GetClientInfo() { std::vector client_info; for (auto &i : _ps_client_list) { - client_info.push_back(i.serialize_to_uint64()); + client_info.push_back(i.SerializeToUint64()); } return client_info; } - virtual std::vector get_client_info(bool use_string_endpoint) { + virtual std::vector GetClientInfo(bool use_string_endpoint) { if (use_string_endpoint) { std::vector client_info; for (auto &i : _ps_client_list) { - client_info.push_back(i.serialize_to_string()); + client_info.push_back(i.SerializeToString()); } return client_info; } return {}; } - virtual void set_trainers(int trainers) { trainers_ = trainers; } + virtual void SetTrainers(int trainers) { trainers_ = trainers; } - virtual int get_trainers() { return trainers_; } + virtual int GetTrainers() { return trainers_; } protected: //注册一个host // NOLINT - virtual int32_t registe_ps_host( + virtual int32_t RegistePsHost( const std::string &ip, uint32_t port, int32_t rank, std::vector &host_list, // NOLINT std::unordered_set &sign_set) { // NOLINT @@ -198,15 +195,15 @@ class PaddlePSEnvironment : public PSEnvironment { explicit PaddlePSEnvironment() {} // NOLINT virtual ~PaddlePSEnvironment() {} - virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) { _ps_server_list.clear(); _ps_server_sign_set.clear(); for (int i = 0; i < node_num; ++i) { if (host_sign_list[i] > 0) { PSHost host; - host.parse_from_uint64(host_sign_list[i]); + host.ParseFromUint64(host_sign_list[i]); _ps_server_list.push_back(host); - _ps_server_sign_set.insert(host.serialize_to_uint64()); + _ps_server_sign_set.insert(host.SerializeToUint64()); } } std::sort( @@ -215,14 +212,14 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual int32_t set_ps_servers(const std::vector *host_sign_list, - int node_num) { + virtual int32_t SetPsServers(const std::vector *host_sign_list, + int node_num) { _ps_server_list.clear(); _ps_server_sign_set.clear(); for (int i = 0; i < node_num; ++i) { if (host_sign_list->at(i) != "") { PSHost host; - host.parse_from_string(host_sign_list->at(i)); + host.ParseFromString(host_sign_list->at(i)); _ps_server_list.push_back(host); _ps_server_sign_set.insert(host.rank); } @@ -233,15 +230,15 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) { _ps_client_list.clear(); _ps_client_sign_set.clear(); for (int i = 0; i < node_num; ++i) { if (host_sign_list[i] > 0) { PSHost host; - host.parse_from_uint64(host_sign_list[i]); + host.ParseFromUint64(host_sign_list[i]); _ps_client_list.push_back(host); - _ps_client_sign_set.insert(host.serialize_to_uint64()); + _ps_client_sign_set.insert(host.SerializeToUint64()); } } std::sort( @@ -250,14 +247,14 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual int32_t set_ps_clients(const std::vector *host_sign_list, - int node_num) { + virtual int32_t SetPsClients(const std::vector *host_sign_list, + int node_num) { _ps_client_list.clear(); _ps_client_sign_set.clear(); for (int i = 0; i < node_num; ++i) { if (host_sign_list->at(i) != "") { PSHost host; - host.parse_from_string(host_sign_list->at(i)); + host.ParseFromString(host_sign_list->at(i)); _ps_client_list.push_back(host); _ps_client_sign_set.insert(host.rank); } @@ -269,9 +266,9 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual uint64_t get_local_host_sign() { + virtual uint64_t GetLocalHostSign() { if (_ps_client_list.size() > 0) { - return _ps_client_list[0].serialize_to_uint64(); + return _ps_client_list[0].SerializeToUint64(); } else { return 0; } diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_client.cc b/paddle/fluid/distributed/ps/service/graph_brpc_client.cc index a3db88e3b679da63a9b205cc013d579cf9a4be2f..827a643ee50d682603bf27fb4e66b26d3c2c928f 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/ps/service/graph_brpc_client.cc @@ -135,8 +135,7 @@ std::future GraphBrpcClient::get_node_feat( closure->request(request_idx) ->add_params(joint_feature_name.c_str(), joint_feature_name.size()); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -169,8 +168,7 @@ std::future GraphBrpcClient::clear_nodes(uint32_t table_id) { closure->request(server_index)->set_table_id(table_id); closure->request(server_index)->set_client_id(_client_id); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(server_index), closure->request(server_index), @@ -238,9 +236,8 @@ std::future GraphBrpcClient::add_graph_node( ->add_params((char *)weighted, sizeof(bool) * is_weighted_bucket[request_idx].size()); } - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -292,9 +289,8 @@ std::future GraphBrpcClient::remove_graph_node( closure->request(request_idx) ->add_params((char *)request_bucket[request_idx].data(), sizeof(int64_t) * node_num); - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -362,9 +358,8 @@ std::future GraphBrpcClient::batch_sample_neighbors( closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&need_weight, sizeof(bool)); ; - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); @@ -464,9 +459,8 @@ std::future GraphBrpcClient::batch_sample_neighbors( ->add_params((char *)&sample_size, sizeof(int)); closure->request(request_idx) ->add_params((char *)&need_weight, sizeof(bool)); - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -506,8 +500,8 @@ std::future GraphBrpcClient::random_sample_nodes( closure->request(0)->set_client_id(_client_id); closure->request(0)->add_params((char *)&sample_size, sizeof(int)); ; - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); @@ -541,8 +535,7 @@ std::future GraphBrpcClient::load_graph_split_config( closure->request(server_index)->set_table_id(table_id); closure->request(server_index)->set_client_id(_client_id); closure->request(server_index)->add_params(path); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(server_index), closure->request(server_index), @@ -581,8 +574,7 @@ std::future GraphBrpcClient::use_neighbors_sample_cache( closure->request(server_index) ->add_params((char *)&size_limit, sizeof(size_t)); closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(server_index), closure->request(server_index), @@ -624,8 +616,8 @@ std::future GraphBrpcClient::pull_graph_list( closure->request(0)->add_params((char *)&start, sizeof(int)); closure->request(0)->add_params((char *)&size, sizeof(int)); closure->request(0)->add_params((char *)&step, sizeof(int)); - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); @@ -717,8 +709,7 @@ std::future GraphBrpcClient::set_node_feat( closure->request(request_idx) ->add_params(set_feature.c_str(), set_feature.size()); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -727,10 +718,10 @@ std::future GraphBrpcClient::set_node_feat( return fut; } -int32_t GraphBrpcClient::initialize() { +int32_t GraphBrpcClient::Initialize() { // set_shard_num(_config.shard_num()); - BrpcPsClient::initialize(); - server_size = get_server_nums(); + BrpcPsClient::Initialize(); + server_size = GetServerNums(); graph_service = NULL; local_channel = NULL; return 0; diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_client.h b/paddle/fluid/distributed/ps/service/graph_brpc_client.h index e2b8a518615dc511a726c4be104cb03900dd2e9a..d1d3c95260df4849f301d31d44185a3e755f7428 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/ps/service/graph_brpc_client.h @@ -97,12 +97,12 @@ class GraphBrpcClient : public BrpcPsClient { std::string path); virtual std::future remove_graph_node( uint32_t table_id, std::vector& node_id_list); - virtual int32_t initialize(); + virtual int32_t Initialize(); int get_shard_num() { return shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; } int get_server_index_by_id(int64_t id); void set_local_channel(int index) { - this->local_channel = get_cmd_channel(index); + this->local_channel = GetCmdChannel(index); } void set_local_graph_service(GraphBrpcService* graph_service) { this->graph_service = graph_service; diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_server.cc b/paddle/fluid/distributed/ps/service/graph_brpc_server.cc index 20a55e4d11983dad37b9e2e7845923dded881d3b..21e590997b178832178732f620cee64c0e68dbc0 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/ps/service/graph_brpc_server.cc @@ -33,7 +33,7 @@ namespace distributed { return -1; \ } -int32_t GraphBrpcServer::initialize() { +int32_t GraphBrpcServer::Initialize() { auto &service_config = _config.downpour_server_param().service_param(); if (!service_config.has_service_class()) { LOG(ERROR) << "miss service_class in ServerServiceParameter"; @@ -48,7 +48,7 @@ int32_t GraphBrpcServer::initialize() { } _service.reset(service); - if (service->configure(this) != 0 || service->initialize() != 0) { + if (service->Configure(this) != 0 || service->Initialize() != 0) { LOG(ERROR) << "service initialize failed, service_name:" << service_config.service_class(); return -1; @@ -61,11 +61,11 @@ int32_t GraphBrpcServer::initialize() { return 0; } -brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) { +brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) { return _pserver_channels[server_index].get(); } -uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { +uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) { std::unique_lock lock(mutex_); std::string ip_port = ip + ":" + std::to_string(port); @@ -73,20 +73,20 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { brpc::ServerOptions options; int num_threads = std::thread::hardware_concurrency(); - auto trainers = _environment->get_trainers(); + auto trainers = _environment->GetTrainers(); options.num_threads = trainers > num_threads ? trainers : num_threads; if (_server.Start(ip_port.c_str(), &options) != 0) { LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port; return 0; } - _environment->registe_ps_server(ip, port, _rank); + _environment->RegistePsServer(ip, port, _rank); return 0; } int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { this->rank = rank; - auto _env = environment(); + auto _env = Environment(); brpc::ChannelOptions options; options.protocol = "baidu_std"; options.timeout_ms = 500000; @@ -94,7 +94,7 @@ int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { options.connect_timeout_ms = 10000; options.max_retry = 3; - std::vector server_list = _env->get_ps_servers(); + std::vector server_list = _env->GetPsServers(); _pserver_channels.resize(server_list.size()); std::ostringstream os; std::string server_ip_port; @@ -172,19 +172,18 @@ int32_t GraphBrpcService::remove_graph_node(Table *table, ((GraphTable *)table)->remove_graph_node(node_ids); return 0; } -int32_t GraphBrpcServer::port() { return _server.listen_address().port; } +int32_t GraphBrpcServer::Port() { return _server.listen_address().port; } -int32_t GraphBrpcService::initialize() { +int32_t GraphBrpcService::Initialize() { _is_initialize_shard_info = false; - _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server; - _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table; - _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table; + _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer; + _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable; + _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable; - _service_handler_map[PS_PRINT_TABLE_STAT] = - &GraphBrpcService::print_table_stat; - _service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier; - _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler; - _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler; + _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat; + _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier; + _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler; + _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler; _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list; _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] = @@ -207,21 +206,21 @@ int32_t GraphBrpcService::initialize() { _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] = &GraphBrpcService::load_graph_split_config; // shard初始化,server启动后才可从env获取到server_list的shard信息 - initialize_shard_info(); + InitializeShardInfo(); return 0; } -int32_t GraphBrpcService::initialize_shard_info() { +int32_t GraphBrpcService::InitializeShardInfo() { if (!_is_initialize_shard_info) { std::lock_guard guard(_initialize_shard_mutex); if (_is_initialize_shard_info) { return 0; } - server_size = _server->environment()->get_ps_servers().size(); - auto &table_map = *(_server->table()); + server_size = _server->Environment()->GetPsServers().size(); + auto &table_map = *(_server->GetTable()); for (auto itr : table_map) { - itr.second->set_shard(_rank, server_size); + itr.second->SetShard(_rank, server_size); } _is_initialize_shard_info = true; } @@ -241,7 +240,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, response->set_err_code(0); response->set_err_msg(""); - auto *table = _server->table(request->table_id()); + auto *table = _server->GetTable(request->table_id()); brpc::Controller *cntl = static_cast(cntl_base); auto itr = _service_handler_map.find(request->cmd_id()); if (itr == _service_handler_map.end()) { @@ -261,7 +260,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, } } -int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, +int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) @@ -275,16 +274,16 @@ int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, auto trainer_id = request.client_id(); auto barrier_type = request.params(0); - table->barrier(trainer_id, barrier_type); + table->Barrier(trainer_id, barrier_type); return 0; } -int32_t GraphBrpcService::print_table_stat(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::PrintTableStat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) - std::pair ret = table->print_table_stat(); + std::pair ret = table->PrintTableStat(); paddle::framework::BinaryArchive ar; ar << ret.first << ret.second; std::string table_info(ar.Buffer(), ar.Length()); @@ -293,10 +292,10 @@ int32_t GraphBrpcService::print_table_stat(Table *table, return 0; } -int32_t GraphBrpcService::load_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::LoadOneTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 2) { set_response_code( @@ -304,20 +303,20 @@ int32_t GraphBrpcService::load_one_table(Table *table, "PsRequestMessage.datas is requeired at least 2 for path & load_param"); return -1; } - if (table->load(request.params(0), request.params(1)) != 0) { + if (table->Load(request.params(0), request.params(1)) != 0) { set_response_code(response, -1, "table load failed"); return -1; } return 0; } -int32_t GraphBrpcService::load_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - auto &table_map = *(_server->table()); +int32_t GraphBrpcService::LoadAllTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); for (auto &itr : table_map) { - if (load_one_table(itr.second.get(), request, response, cntl) != 0) { + if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) { LOG(ERROR) << "load table[" << itr.first << "] failed"; return -1; } @@ -325,13 +324,13 @@ int32_t GraphBrpcService::load_all_table(Table *table, return 0; } -int32_t GraphBrpcService::stop_server(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::StopServer(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { GraphBrpcServer *p_server = (GraphBrpcServer *)_server; std::thread t_stop([p_server]() { - p_server->stop(); + p_server->Stop(); LOG(INFO) << "Server Stoped"; }); p_server->export_cv()->notify_all(); @@ -339,19 +338,19 @@ int32_t GraphBrpcService::stop_server(Table *table, return 0; } -int32_t GraphBrpcService::stop_profiler(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::StopProfiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::DisableProfiler(platform::EventSortingKey::kDefault, string::Sprintf("server_%s_profile", _rank)); return 0; } -int32_t GraphBrpcService::start_profiler(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::StartProfiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::EnableProfiler(platform::ProfilerState::kCPU); return 0; } @@ -475,7 +474,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( std::vector server2request(server_size, -1); std::vector local_id; std::vector local_query_idx; - size_t rank = get_rank(); + size_t rank = GetRank(); for (int query_idx = 0; query_idx < node_num; ++query_idx) { int server_index = ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); @@ -589,9 +588,9 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( closure->request(request_idx) ->add_params((char *)&need_weight, sizeof(bool)); PsService_Stub rpc_stub( - ((GraphBrpcServer *)get_server())->get_cmd_channel(server_index)); + ((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index)); // GraphPsService_Stub rpc_stub = - // getServiceStub(get_cmd_channel(server_index)); + // getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_server.h b/paddle/fluid/distributed/ps/service/graph_brpc_server.h index a978d97b296b0a529a121fcfb9723639421d1e5e..caf728701b289e9629a726aaee84d1f4744ff8c0 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/ps/service/graph_brpc_server.h @@ -31,10 +31,10 @@ class GraphBrpcServer : public PSServer { GraphBrpcServer() {} virtual ~GraphBrpcServer() {} PsBaseService *get_service() { return _service.get(); } - virtual uint64_t start(const std::string &ip, uint32_t port); + virtual uint64_t Start(const std::string &ip, uint32_t port); virtual int32_t build_peer2peer_connection(int rank); - virtual brpc::Channel *get_cmd_channel(size_t server_index); - virtual int32_t stop() { + virtual brpc::Channel *GetCmdChannel(size_t server_index); + virtual int32_t Stop() { std::unique_lock lock(mutex_); if (stoped_) return 0; stoped_ = true; @@ -43,12 +43,12 @@ class GraphBrpcServer : public PSServer { _server.Join(); return 0; } - int32_t port(); + int32_t Port(); std::condition_variable *export_cv() { return &cv_; } private: - virtual int32_t initialize(); + virtual int32_t Initialize(); mutable std::mutex mutex_; std::condition_variable cv_; bool stoped_ = false; @@ -66,7 +66,7 @@ typedef int32_t (GraphBrpcService::*serviceFunc)( class GraphBrpcService : public PsBaseService { public: - virtual int32_t initialize() override; + virtual int32_t Initialize() override; virtual void service(::google::protobuf::RpcController *controller, const PsRequestMessage *request, @@ -75,7 +75,7 @@ class GraphBrpcService : public PsBaseService { protected: std::unordered_map _service_handler_map; - int32_t initialize_shard_info(); + int32_t InitializeShardInfo(); int32_t pull_graph_list(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); int32_t graph_random_sample_neighbors(Table *table, @@ -100,21 +100,21 @@ class GraphBrpcService : public PsBaseService { int32_t remove_graph_node(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t barrier(Table *table, const PsRequestMessage &request, + int32_t Barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t load_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t load_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t stop_server(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t start_profiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t stop_profiler(Table *table, const PsRequestMessage &request, + int32_t LoadOneTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t LoadAllTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StopServer(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StartProfiler(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t StopProfiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); - int32_t print_table_stat(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); + int32_t PrintTableStat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); int32_t sample_neighbors_across_multi_servers(Table *table, const PsRequestMessage &request, diff --git a/paddle/fluid/distributed/ps/service/ps_client.cc b/paddle/fluid/distributed/ps/service/ps_client.cc index 27f2d88fdd9fa02262c3e89a618bbe697d0a542e..f7df99ec13cdf128e22e5ed9a702f9e1983186ad 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_client.cc @@ -25,7 +25,7 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient); REGISTER_PSCORE_CLASS(PSClient, PsLocalClient); REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient); -int32_t PSClient::configure( +int32_t PSClient::Configure( const PSParameter &config, const std::map> ®ions, PSEnvironment &env, size_t client_id) { @@ -51,10 +51,10 @@ int32_t PSClient::configure( _table_accessors[work_param.downpour_table_param(i).table_id()].reset( accessor); } - return initialize(); + return Initialize(); } -PSClient *PSClientFactory::create(const PSParameter &ps_config) { +PSClient *PSClientFactory::Create(const PSParameter &ps_config) { const auto &config = ps_config.server_param(); if (!config.has_downpour_server_param()) { LOG(ERROR) << "miss downpour_server_param in ServerParameter"; @@ -81,7 +81,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) { return NULL; } - TableManager::instance().initialize(); + TableManager::Instance().Initialize(); VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success"; return client; } diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index 83d2aba1db44564a3314e6d1f9b07ebd2730b85e..6f27b0eb046245c722100bcfdb2e6b89d92ec488 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -26,7 +26,6 @@ #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h" -#include "paddle/fluid/distributed/ps/table/table.h" #include "paddle/fluid/platform/timer.h" namespace paddle { @@ -60,41 +59,6 @@ class PSClientClosure : public google::protobuf::Closure { std::vector>> _promises; }; -struct LoadSaveContext { - int table_id; - std::string epoch; - std::string mode; -}; - -enum TrainingMode { Async = 0, Sync = 1, Geo = 3 }; - -enum TrainingPhase { Init = 0, Train = 1, Save = 2 }; - -// enum ValueType { -// Sparse = 0, -// Dense = 1 -// }; - -struct PushContext { - const uint64_t *keys; - const float **push_values; - const Region *push_dense_values; -}; - -struct RequestContext { - int table; - TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync - TrainingPhase training_phase; // 1 for init, 2 for train - ValueType value_type; // 1 for sparse, 2 for dense - uint64_t *keys; - float **sparse_values; // for sparse values - Region *dense_values; // for dense values - PushContext push_context; - size_t num; - bool is_training; - void *callback; -}; - class PSClient { public: PSClient() {} @@ -102,41 +66,37 @@ class PSClient { PSClient(PSClient &&) = delete; PSClient(const PSClient &) = delete; - virtual int32_t configure( // NOLINT + virtual int32_t Configure( // NOLINT const PSParameter &config, const std::map> ®ions, PSEnvironment &_env, size_t client_id) final; // NOLINT - virtual int32_t create_client2client_connection( - int pserver_timeout_ms, int pserver_connect_timeout_ms, - int max_retry) = 0; + virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry) = 0; // 触发table数据退场 - virtual std::future shrink(uint32_t table_id, + virtual std::future Shrink(uint32_t table_id, const std::string threshold) = 0; // 全量table进行数据load - virtual std::future load(const std::string &epoch, + virtual std::future Load(const std::string &epoch, const std::string &mode) = 0; // 指定table数据load - virtual std::future load(uint32_t table_id, const std::string &epoch, + virtual std::future Load(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; - // context配置load选项 - virtual std::future Load(const LoadSaveContext &load_context) = 0; // 全量table数据save value_accessor根据mode,可能有不同的save条件 - virtual std::future save(const std::string &epoch, + virtual std::future Save(const std::string &epoch, const std::string &mode) = 0; // 指定table数据save value_accessor根据mode,可能有不同的save条件 - virtual std::future save(uint32_t table_id, const std::string &epoch, + virtual std::future Save(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; - virtual std::future Save(const LoadSaveContext &save_context) = 0; - // 清空table数据 - virtual std::future clear() = 0; - virtual std::future clear(uint32_t table_id) = 0; + virtual std::future Clear() = 0; + virtual std::future Clear(uint32_t table_id) = 0; // pull dense的参数部分,并分块填充到本地网络参数中 // start和num用于拉取部分参数 @@ -145,23 +105,19 @@ class PSClient { // sender聚集同一区块的请求,累计多个填充buffer // server将参数区块中配置的某一维提取返回 // 返回数据解包后填充到累计的多个buffer中 - virtual std::future pull_dense(Region *regions, size_t region_num, - size_t table_id) = 0; // 保留 - - virtual std::future Push(RequestContext &push_context) = 0; + virtual std::future PullDense(Region *regions, size_t region_num, + size_t table_id) = 0; // 保留 // firstly push dense param for parameter server // this is neccessary because dense weight initialized in trainer on cold // start - virtual std::future push_dense_param(const Region *regions, - size_t region_num, - size_t table_id) = 0; - - virtual std::future push_dense(const Region *regions, - size_t region_num, - size_t table_id) = 0; + virtual std::future PushDenseParam(const Region *regions, + size_t region_num, + size_t table_id) = 0; - virtual std::future Pull(RequestContext &pull_context) = 0; + virtual std::future PushDense(const Region *regions, + size_t region_num, + size_t table_id) = 0; // 使用keys进行pull请求,结果填充values // keys和values的个数均为num个,每个value占用select_size空间 @@ -169,15 +125,14 @@ class PSClient { // 整合多个线程请求的keys,聚集并分散发送到server // 返回结果后,遍历buffer并对values赋值 // is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理. - virtual std::future pull_sparse(float **select_values, - size_t table_id, - const uint64_t *keys, size_t num, - bool is_training) = 0; - - virtual std::future pull_sparse_param(float **select_values, - size_t table_id, - const uint64_t *keys, - size_t num, bool is_training) { + virtual std::future PullSparse(float **select_values, + size_t table_id, const uint64_t *keys, + size_t num, bool is_training) = 0; + + virtual std::future PullSparseParam(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num, + bool is_training) { VLOG(0) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); @@ -185,10 +140,10 @@ class PSClient { return fut; } - virtual ::std::future pull_sparse_ptr(char **select_values, - size_t table_id, - const uint64_t *keys, - size_t num) { + virtual ::std::future PullSparsePtr(char **select_values, + size_t table_id, + const uint64_t *keys, + size_t num) { VLOG(0) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); @@ -196,38 +151,38 @@ class PSClient { return fut; } - virtual std::future print_table_stat(uint32_t table_id) = 0; + virtual std::future PrintTableStat(uint32_t table_id) = 0; // 确保所有积攒中的请求都发起发送 - virtual std::future flush() = 0; + virtual std::future Flush() = 0; // server优雅退出 - virtual std::future stop_server() = 0; + virtual std::future StopServer() = 0; // server profilera - virtual std::future start_profiler() = 0; - virtual std::future stop_profiler() = 0; + virtual std::future StartProfiler() = 0; + virtual std::future StopProfiler() = 0; - virtual std::future barrier(size_t table_id, + virtual std::future Barrier(size_t table_id, uint32_t barrier_type) = 0; - virtual std::future pull_geo_param(size_t table_id, - std::vector *values, - std::vector *keys, - int pserver_idx) = 0; + virtual std::future PullGeoParam(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) = 0; - virtual std::future push_global_step(int table_id, - int64_t *total_send_data, - void *done) = 0; + virtual std::future PushGlobalStep(int table_id, + int64_t *total_send_data, + void *done) = 0; // recv table from server and save it in LodTensor - virtual int32_t recv_and_save_table(const uint64_t table_id, - const std::string &path) = 0; + virtual int32_t RecvAndSaveTable(const uint64_t table_id, + const std::string &path) = 0; - virtual void finalize_worker() = 0; + virtual void FinalizeWorker() = 0; // client to client, 消息发送 - virtual std::future send_client2client_msg(int msg_type, - int to_client_id, - const std::string &msg) { + virtual std::future SendClient2ClientMsg(int msg_type, + int to_client_id, + const std::string &msg) { VLOG(0) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); @@ -238,13 +193,13 @@ class PSClient { // client2client消息处理,std::function ret (msg_type, from_client_id, msg) typedef std::function MsgHandlerFunc; - virtual int registe_client2client_msg_handler(int msg_type, - MsgHandlerFunc handler) { + virtual int RegisteClient2ClientMsgHandler(int msg_type, + MsgHandlerFunc handler) { _msg_handler_map[msg_type] = handler; return 0; } - virtual int handle_client2client_msg(int msg_type, int from_client_id, - const std::string &msg) { + virtual int HandleClient2ClientMsg(int msg_type, int from_client_id, + const std::string &msg) { auto itr = _msg_handler_map.find(msg_type); if (itr == _msg_handler_map.end()) { LOG(WARNING) << "unknown client2client_msg type:" << msg_type; @@ -253,7 +208,7 @@ class PSClient { return itr->second(msg_type, from_client_id, msg); } - virtual ValueAccessor *table_accessor(size_t table_id) { + virtual ValueAccessor *GetTableAccessor(size_t table_id) { auto itr = _table_accessors.find(table_id); if (itr == _table_accessors.end()) { return NULL; @@ -261,31 +216,31 @@ class PSClient { return itr->second.get(); } - virtual size_t get_server_nums() = 0; + virtual size_t GetServerNums() = 0; - virtual std::future push_dense_raw_gradient( - int table_id, float *total_send_data, size_t total_send_data_size, - void *done) = 0; + virtual std::future PushDenseRawGradient(int table_id, + float *total_send_data, + size_t total_send_data_size, + void *done) = 0; - virtual std::future push_sparse_raw_gradient( + virtual std::future PushSparseRawGradient( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) = 0; - virtual std::future push_sparse_raw_gradient_partial( + virtual std::future PushSparseRawGradientPartial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) = 0; - virtual std::future push_sparse_param(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num, void *done) = 0; - virtual std::future push_sparse(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num) = 0; + virtual std::future PushSparseParam(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) = 0; + virtual std::future PushSparse(size_t table_id, const uint64_t *keys, + const float **update_values, + size_t num) = 0; protected: - virtual int32_t initialize() = 0; + virtual int32_t Initialize() = 0; size_t _client_id; PSParameter _config; std::map> @@ -333,7 +288,7 @@ REGISTER_PSCORE_REGISTERER(PSClient); class PSClientFactory { public: - static PSClient *create(const PSParameter &config); + static PSClient *Create(const PSParameter &config); }; } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index dbf47f0df41161069da4b430e1aedd26bf30bca3..bb8ba223d828eacfecae674a4d73040077e11df0 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -19,166 +19,91 @@ namespace paddle { namespace distributed { -int32_t PsLocalClient::initialize() { +int32_t PsLocalClient::Initialize() { const auto& downpour_param = _config.server_param().downpour_server_param(); - TableManager::instance().initialize(); + TableManager::Instance().Initialize(); for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { auto* table = CREATE_PSCORE_CLASS( Table, downpour_param.downpour_table_param(i).table_class()); - table->set_shard(0, 1); - table->initialize(downpour_param.downpour_table_param(i), + table->SetShard(0, 1); + table->Initialize(downpour_param.downpour_table_param(i), _config.fs_client_param()); _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); } return 0; } -::std::future PsLocalClient::shrink(uint32_t table_id, +::std::future PsLocalClient::Shrink(uint32_t table_id, const std::string threshold) { // TODO return done(); } -::std::future PsLocalClient::load(const std::string& epoch, +::std::future PsLocalClient::Load(const std::string& epoch, const std::string& mode) { // TODO for (auto& it : _table_map) { - load(it.first, epoch, mode); + Load(it.first, epoch, mode); } return done(); } -::std::future PsLocalClient::load(uint32_t table_id, +::std::future PsLocalClient::Load(uint32_t table_id, const std::string& epoch, const std::string& mode) { // TODO - auto* table_ptr = table(table_id); - table_ptr->load(epoch, mode); + auto* table_ptr = GetTable(table_id); + table_ptr->Load(epoch, mode); return done(); } -std::future PsLocalClient::Load(const LoadSaveContext& load_context) { - if (load_context.table_id < 0) { - for (auto& it : _table_map) { - load(it.first, load_context.epoch, load_context.mode); - } - return done(); - } else { - auto* table_ptr = table(load_context.table_id); - table_ptr->load(load_context.epoch, load_context.mode); - return done(); - } -} - -::std::future PsLocalClient::save(const std::string& epoch, +::std::future PsLocalClient::Save(const std::string& epoch, const std::string& mode) { // TODO for (auto& it : _table_map) { - save(it.first, epoch, mode); + Save(it.first, epoch, mode); } return done(); } -::std::future PsLocalClient::save(uint32_t table_id, +::std::future PsLocalClient::Save(uint32_t table_id, const std::string& epoch, const std::string& mode) { // TODO - auto* table_ptr = table(table_id); - table_ptr->flush(); - table_ptr->save(epoch, mode); + auto* table_ptr = GetTable(table_id); + table_ptr->Flush(); + table_ptr->Save(epoch, mode); return done(); } -::std::future PsLocalClient::Save( - const LoadSaveContext& save_context) { - if (save_context.table_id < 0) { - for (auto& it : _table_map) { - save(it.first, save_context.epoch, save_context.mode); - } - return done(); - } else { - auto* table_ptr = table(save_context.table_id); - table_ptr->flush(); - table_ptr->save(save_context.epoch, save_context.mode); - return done(); - } -} - -::std::future PsLocalClient::clear() { +::std::future PsLocalClient::Clear() { // TODO return done(); } -::std::future PsLocalClient::clear(uint32_t table_id) { +::std::future PsLocalClient::Clear(uint32_t table_id) { // TODO return done(); } -::std::future PsLocalClient::flush() { +::std::future PsLocalClient::Flush() { // no need return done(); } -::std::future PsLocalClient::stop_server() { +::std::future PsLocalClient::StopServer() { // no need return done(); } -::std::future PsLocalClient::Pull(RequestContext& pull_context) { - if (pull_context.value_type == Dense) { // pull dense - Region* dense_region = reinterpret_cast(pull_context.dense_values); - pull_dense(dense_region, pull_context.num, pull_context.table); - } else { // pull sparse - // uint64_t* keys = reinterpret_cast(pull_context.keys); - // char** select_values = - // reinterpret_cast(pull_context.sparse_values); - size_t table_id = pull_context.table; - size_t num = pull_context.num; - pull_sparse_ptr(reinterpret_cast(pull_context.sparse_values), - table_id, pull_context.keys, num); - } -} +::std::future PsLocalClient::PullDense(Region* regions, + size_t region_num, + size_t table_id) { + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); -::std::future PsLocalClient::Push(RequestContext& push_context) { - if (push_context.value_type == Dense) { // push dense - if (push_context.training_phase == Init) { - const Region* regions = push_context.push_context.push_dense_values; - size_t region_num = push_context.num; - push_dense_param(regions, region_num, push_context.table); - } else { - if (push_context.training_mode == Geo) { // geo - float* total_send_data = - reinterpret_cast(push_context.dense_values); - size_t total_send_data_size = push_context.num; - push_dense_raw_gradient(push_context.table, total_send_data, - total_send_data_size, push_context.callback); - } else { // async and sync - const Region* regions = push_context.push_context.push_dense_values; - size_t region_num = push_context.num; - push_dense(regions, region_num, push_context.table); - } - } - } else { // push sparse - if (push_context.training_mode == Async) { - const uint64_t* keys = push_context.push_context.keys; - const float** update_values = push_context.push_context.push_values; - size_t table_id = push_context.table; - size_t num = push_context.num; - push_sparse(table_id, keys, update_values, num); - } else { - // TODO - } - } -} - -::std::future PsLocalClient::pull_dense(Region* regions, - size_t region_num, - size_t table_id) { - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); + uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1); - uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1); std::vector region_buffer; region_buffer.resize(num_per_shard); - table_ptr->pull_dense(region_buffer.data(), region_buffer.size()); + table_ptr->PullDense(region_buffer.data(), region_buffer.size()); size_t region_idx = 0; size_t region_data_idx = 0; @@ -213,48 +138,49 @@ std::future PsLocalClient::Load(const LoadSaveContext& load_context) { return done(); } -::std::future PsLocalClient::push_dense_param(const Region* regions, - size_t region_num, - size_t table_id) { - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); +::std::future PsLocalClient::PushDenseParam(const Region* regions, + size_t region_num, + size_t table_id) { + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); std::vector region_buffer; - region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1), - 0); + region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0); + for (size_t i = 0, offset = 0; i < region_num; ++i) { uint32_t data_num = regions[i].size / sizeof(float); memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size); offset += data_num; } - // table_ptr->push_dense_param(region_buffer.data(), region_buffer.size()); + // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size()); return done(); } -::std::future PsLocalClient::push_dense_raw_gradient( +::std::future PsLocalClient::PushDenseRawGradient( int table_id, float* total_send_data, size_t total_send_data_size, void* callback) { VLOG(1) << "wxx push_dense_raw_gradient"; PSClientClosure* closure = reinterpret_cast(callback); - auto* table_ptr = table(table_id); + auto* table_ptr = GetTable(table_id); - table_ptr->push_dense(total_send_data, total_send_data_size); + table_ptr->PushDense(total_send_data, total_send_data_size); delete closure; return done(); } -::std::future PsLocalClient::push_dense(const Region* regions, - size_t region_num, - size_t table_id) { - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); +::std::future PsLocalClient::PushDense(const Region* regions, + size_t region_num, + size_t table_id) { + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); std::vector region_buffer; - region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1)); + region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1)); + size_t data_size = region_buffer.size(); for (size_t i = 0, offset = 0; i < region_num; ++i) { uint32_t data_num = regions[i].size / sizeof(float); @@ -267,12 +193,12 @@ std::future PsLocalClient::Load(const LoadSaveContext& load_context) { offset += data_num; } - table_ptr->push_dense(region_buffer.data(), region_buffer.size()); + table_ptr->PushDense(region_buffer.data(), region_buffer.size()); return done(); } -//::std::future PsLocalClient::pull_sparse(float** select_values, +//::std::future PsLocalClient::PullSparse(float** select_values, // size_t table_id, // const uint64_t* keys, // size_t num) { @@ -282,14 +208,14 @@ std::future PsLocalClient::Load(const LoadSaveContext& load_context) { // // auto local_timer = // // std::make_shared("pslib_downpour_client_pull_sparse_local"); // //将key拆分到各shard请求,并记录原始对应value指针 -// auto* accessor = table_accessor(table_id); -// auto* table_ptr = table(table_id); +// auto* accessor = GetTableAccessor(table_id); +// auto* table_ptr = GetTable(table_id); // size_t value_size = accessor->select_size(); // -// // table_ptr->pull_sparse(keys, num); +// // table_ptr->PullSparse(keys, num); // std::vector res_data; // res_data.resize(num * value_size / sizeof(float)); -// table_ptr->pull_sparse(res_data.data(), keys, num); +// table_ptr->PullSparse(res_data.data(), keys, num); // // memcpy(select_values[0], res_data->data(), res_data->size() * // // sizeof(float)); // size_t offset = 0; @@ -302,43 +228,43 @@ std::future PsLocalClient::Load(const LoadSaveContext& load_context) { // return done(); //} -::std::future PsLocalClient::pull_sparse_ptr(char** select_values, - size_t table_id, - const uint64_t* keys, - size_t num) { +::std::future PsLocalClient::PullSparsePtr(char** select_values, + size_t table_id, + const uint64_t* keys, + size_t num) { // FIXME // auto timer = // std::make_shared("pslib_downpour_client_pull_sparse"); // auto local_timer = // std::make_shared("pslib_downpour_client_pull_sparse_local"); //将key拆分到各shard请求,并记录原始对应value指针 - auto* table_ptr = table(table_id); + auto* table_ptr = GetTable(table_id); - table_ptr->pull_sparse_ptr(select_values, keys, num); + table_ptr->PullSparsePtr(select_values, keys, num); return done(); } -::std::future PsLocalClient::push_sparse_raw_gradient( +::std::future PsLocalClient::PushSparseRawGradient( size_t table_id, const uint64_t* keys, const float** update_values, size_t num, void* callback) { PSClientClosure* closure = reinterpret_cast(callback); - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); - table_ptr->push_sparse(keys, update_values, num); + table_ptr->PushSparse(keys, update_values, num); delete closure; return done(); } -::std::future PsLocalClient::push_sparse(size_t table_id, - const uint64_t* keys, - const float** update_values, - size_t num) { - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); +::std::future PsLocalClient::PushSparse(size_t table_id, + const uint64_t* keys, + const float** update_values, + size_t num) { + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); - table_ptr->push_sparse(keys, update_values, num); + table_ptr->PushSparse(keys, update_values, num); return done(); } } diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.h b/paddle/fluid/distributed/ps/service/ps_local_client.h index 83ca558e3d2cb1f62235cda06c221b0d9367b043..439ecf79f2f808c98edc9ca1c0ea8403f9266bc8 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.h +++ b/paddle/fluid/distributed/ps/service/ps_local_client.h @@ -26,54 +26,46 @@ class PsLocalClient : public PSClient { public: PsLocalClient() {} virtual ~PsLocalClient() { _running = false; } - virtual int32_t create_client2client_connection(int pslib_timeout_ms, - int pslib_connect_timeout_ms, - int max_retry) { + virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms, + int pslib_connect_timeout_ms, + int max_retry) { return 0; } - virtual ::std::future shrink(uint32_t table_id, + virtual ::std::future Shrink(uint32_t table_id, const std::string threshold) override; - virtual ::std::future load(const std::string& epoch, + virtual ::std::future Load(const std::string& epoch, const std::string& mode) override; - virtual ::std::future load(uint32_t table_id, + virtual ::std::future Load(uint32_t table_id, const std::string& epoch, const std::string& mode) override; - virtual std::future Load( - const LoadSaveContext& load_context) override; - virtual ::std::future save(const std::string& epoch, + virtual ::std::future Save(const std::string& epoch, const std::string& mode) override; - virtual ::std::future save(uint32_t table_id, + virtual ::std::future Save(uint32_t table_id, const std::string& epoch, const std::string& mode) override; - virtual std::future Save( - const LoadSaveContext& save_context) override; - virtual ::std::future clear() override; - virtual ::std::future clear(uint32_t table_id) override; + virtual ::std::future Clear() override; + virtual ::std::future Clear(uint32_t table_id) override; - virtual ::std::future stop_server() override; + virtual ::std::future StopServer() override; - virtual void finalize_worker() override {} - virtual ::std::future pull_dense(Region* regions, size_t region_num, - size_t table_id); + virtual void FinalizeWorker() override {} + virtual ::std::future PullDense(Region* regions, size_t region_num, + size_t table_id); - virtual ::std::future Pull(RequestContext& pull_context) override; + virtual ::std::future PushDense(const Region* regions, + size_t region_num, size_t table_id); - virtual ::std::future Push(RequestContext& push_context) override; + virtual ::std::future PushDenseParam(const Region* regions, + size_t region_num, + size_t table_id); - virtual ::std::future push_dense(const Region* regions, - size_t region_num, size_t table_id); - - virtual ::std::future push_dense_param(const Region* regions, - size_t region_num, - size_t table_id); - - virtual ::std::future pull_sparse(float** select_values, - size_t table_id, - const uint64_t* keys, size_t num, - bool is_training) { + virtual ::std::future PullSparse(float** select_values, + size_t table_id, + const uint64_t* keys, size_t num, + bool is_training) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -81,26 +73,26 @@ class PsLocalClient : public PSClient { return fut; } - virtual ::std::future pull_sparse_ptr(char** select_values, - size_t table_id, - const uint64_t* keys, - size_t num); + virtual ::std::future PullSparsePtr(char** select_values, + size_t table_id, + const uint64_t* keys, + size_t num); - virtual ::std::future print_table_stat(uint32_t table_id) { + virtual ::std::future PrintTableStat(uint32_t table_id) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } - virtual ::std::future push_sparse(size_t table_id, - const uint64_t* keys, - const float** update_values, - size_t num); + virtual ::std::future PushSparse(size_t table_id, + const uint64_t* keys, + const float** update_values, + size_t num); - virtual ::std::future flush(); + virtual ::std::future Flush(); // server profilera - virtual std::future start_profiler() { + virtual std::future StartProfiler() { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -108,7 +100,7 @@ class PsLocalClient : public PSClient { return fut; }; - virtual std::future stop_profiler() { + virtual std::future StopProfiler() { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -116,7 +108,7 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future barrier(size_t table_id, uint32_t barrier_type) { + virtual std::future Barrier(size_t table_id, uint32_t barrier_type) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -124,10 +116,10 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future pull_geo_param(size_t table_id, - std::vector* values, - std::vector* keys, - int pserver_idx) { + virtual std::future PullGeoParam(size_t table_id, + std::vector* values, + std::vector* keys, + int pserver_idx) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -135,9 +127,9 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future push_global_step(int table_id, - int64_t* total_send_data, - void* done) { + virtual std::future PushGlobalStep(int table_id, + int64_t* total_send_data, + void* done) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -146,12 +138,12 @@ class PsLocalClient : public PSClient { } // recv table from server and save it in LodTensor - virtual int32_t recv_and_save_table(const uint64_t table_id, - const std::string& path) { + virtual int32_t RecvAndSaveTable(const uint64_t table_id, + const std::string& path) { return 0; } - virtual ::std::future send_client2client_msg( + virtual ::std::future SendClient2ClientMsg( int msg_type, int to_client_id, const std::string& msg) override { std::promise prom; std::future fut = prom.get_future(); @@ -159,17 +151,18 @@ class PsLocalClient : public PSClient { return fut; } - virtual size_t get_server_nums() { return 1; } + virtual size_t GetServerNums() { return 1; } - virtual std::future push_dense_raw_gradient( - int table_id, float* total_send_data, size_t total_send_data_size, - void* callback) override; + virtual std::future PushDenseRawGradient(int table_id, + float* total_send_data, + size_t total_send_data_size, + void* callback) override; - virtual std::future push_sparse_raw_gradient( + virtual std::future PushSparseRawGradient( size_t table_id, const uint64_t* keys, const float** update_values, size_t num, void* callback) override; - virtual std::future push_sparse_raw_gradient_partial( + virtual std::future PushSparseRawGradientPartial( size_t table_id, const uint64_t* keys, const float** update_values, uint32_t num, void* done, int pserver_idx) override { std::promise prom; @@ -179,11 +172,11 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future push_sparse_param(size_t table_id, - const uint64_t* keys, - const float** update_values, - size_t num, - void* done) override { + virtual std::future PushSparseParam(size_t table_id, + const uint64_t* keys, + const float** update_values, + size_t num, + void* done) override { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -192,7 +185,7 @@ class PsLocalClient : public PSClient { } private: - virtual int32_t initialize() override; + virtual int32_t Initialize() override; std::future done() { std::shared_ptr> prom = @@ -202,16 +195,16 @@ class PsLocalClient : public PSClient { return fut; } - inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, - uint32_t shard_num) { + inline uint32_t DenseDimPerShard(uint32_t dense_dim_total, + uint32_t shard_num) { return dense_dim_total / shard_num + 1; } - inline std::unordered_map>* table() { + inline std::unordered_map>* GetTable() { return &_table_map; } - inline Table* table(size_t table_id) { + inline Table* GetTable(size_t table_id) { auto itr = _table_map.find(table_id); if (itr != _table_map.end()) { return itr->second.get(); diff --git a/paddle/fluid/distributed/ps/service/ps_local_server.h b/paddle/fluid/distributed/ps/service/ps_local_server.h index 31b52126fc5767b445dfb605ff46b3fbc63c620c..c09f8585b659d64951ee6c522e82a8d83f43c6e6 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_server.h +++ b/paddle/fluid/distributed/ps/service/ps_local_server.h @@ -25,17 +25,17 @@ class PsLocalServer : public PSServer { public: PsLocalServer() {} virtual ~PsLocalServer() {} - virtual uint64_t start() { return 0; } - virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } - virtual int32_t stop() { return 0; } - virtual int32_t configure( + virtual uint64_t Start() { return 0; } + virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; } + virtual int32_t Stop() { return 0; } + virtual int32_t Configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, const std::vector &server_sub_program = {}) { return 0; } private: - virtual int32_t initialize() { return 0; } + virtual int32_t Initialize() { return 0; } }; } } diff --git a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc index c8be0f797109078509eeced53920845ac4c51684..92dfeb6818a2872e201bf5f4b30d584dee3cee9f 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc +++ b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc @@ -70,7 +70,7 @@ void GraphPyService::set_up(std::string ips_str, int shard_num, port_list.push_back(ip_and_port[1]); uint32_t port = stoul(ip_and_port[1]); auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index); - host_sign_list.push_back(ph_host.serialize_to_string()); + host_sign_list.push_back(ph_host.SerializeToString()); index++; } } @@ -83,11 +83,11 @@ void GraphPyClient::start_client() { paddle::distributed::PaddlePSEnvironment _ps_env; auto servers_ = host_sign_list.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list, servers_); + _ps_env.SetPsServers(&host_sign_list, servers_); worker_ptr = std::shared_ptr( (paddle::distributed::GraphBrpcClient*) - paddle::distributed::PSClientFactory::create(worker_proto)); - worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id); + paddle::distributed::PSClientFactory::Create(worker_proto)); + worker_ptr->Configure(worker_proto, dense_regions, _ps_env, client_id); worker_ptr->set_shard_num(get_shard_num()); } void GraphPyServer::start_server(bool block) { @@ -96,17 +96,17 @@ void GraphPyServer::start_server(bool block) { ::paddle::distributed::PSParameter server_proto = this->GetServerProto(); auto _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&this->host_sign_list, - this->host_sign_list.size()); // test + _ps_env.SetPsServers(&this->host_sign_list, + this->host_sign_list.size()); // test pserver_ptr = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); VLOG(0) << "pserver-ptr created "; std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec); - pserver_ptr->start(ip, port); + pserver_ptr->Configure(server_proto, _ps_env, rank, empty_vec); + pserver_ptr->Start(ip, port); pserver_ptr->build_peer2peer_connection(rank); std::condition_variable* cv_ = pserver_ptr->export_cv(); if (block) { @@ -246,7 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath, VLOG(0) << "loadding data with type " << name << " from " << filepath; uint32_t table_id = this->table_id_map[name]; auto status = - get_ps_client()->load(table_id, std::string(filepath), params); + get_ps_client()->Load(table_id, std::string(filepath), params); status.wait(); } } @@ -285,7 +285,7 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; auto status = - get_ps_client()->load(table_id, std::string(filepath), params); + get_ps_client()->Load(table_id, std::string(filepath), params); status.wait(); } } @@ -396,13 +396,13 @@ std::vector GraphPyClient::pull_graph_list(std::string name, return res; } -void GraphPyClient::stop_server() { +void GraphPyClient::StopServer() { VLOG(0) << "going to stop server"; std::unique_lock lock(mutex_); if (stoped_) return; - auto status = this->worker_ptr->stop_server(); + auto status = this->worker_ptr->StopServer(); if (status.get() == 0) stoped_ = true; } -void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); } +void GraphPyClient::FinalizeWorker() { this->worker_ptr->FinalizeWorker(); } } } diff --git a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h index 85707137c1800ed9486148584ce22a78c52a47fd..19f34dad80745715dc9a43a8a6962f2b5a7897d2 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h +++ b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h @@ -123,7 +123,7 @@ class GraphPyServer : public GraphPyService { set_rank(rank); GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); } - int get_rank() { return rank; } + int GetRank() { return rank; } void set_rank(int rank) { this->rank = rank; } void start_server(bool block = true); @@ -154,8 +154,8 @@ class GraphPyClient : public GraphPyService { (paddle::distributed::GraphBrpcService*)server.get_ps_server() ->get_service()); } - void stop_server(); - void finalize_worker(); + void StopServer(); + void FinalizeWorker(); void load_edge_file(std::string name, std::string filepath, bool reverse); void load_node_file(std::string name, std::string filepath); void clear_nodes(std::string name); diff --git a/paddle/fluid/distributed/ps/service/ps_service/service.cc b/paddle/fluid/distributed/ps/service/ps_service/service.cc index 73793d2f9bd0ec8c5b485830059a730bb8d8559a..9c3a06c2212e6a797132a25ebec775abb05aeed8 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/service.cc +++ b/paddle/fluid/distributed/ps/service/ps_service/service.cc @@ -46,7 +46,7 @@ paddle::distributed::PSParameter load_from_prototxt( return param; } -void PSCore::init_gflag(const std::string& gflags) { +void PSCore::InitGFlag(const std::string& gflags) { VLOG(3) << "Init With Gflags:" << gflags; std::vector flags = paddle::string::split_string(gflags); if (flags.size() < 1) { @@ -65,67 +65,67 @@ void PSCore::init_gflag(const std::string& gflags) { ::GFLAGS_NAMESPACE::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true); } -int PSCore::init_server( +int PSCore::InitServer( const std::string& dist_desc, const std::vector* host_sign_list, int node_num, int index, int trainers, const std::vector& server_sub_program) { google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); - init_gflag(_ps_param.init_gflags()); + InitGFlag(_ps_param.init_gflags()); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(host_sign_list, node_num); - _ps_env.set_trainers(trainers); + _ps_env.SetPsServers(host_sign_list, node_num); + _ps_env.SetTrainers(trainers); int ret = 0; _server_ptr = std::shared_ptr( - paddle::distributed::PSServerFactory::create(_ps_param)); - ret = _server_ptr->configure(_ps_param, _ps_env, index, server_sub_program); + paddle::distributed::PSServerFactory::Create(_ps_param)); + ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program); CHECK(ret == 0) << "failed to configure server"; return ret; } -int PSCore::init_worker( +int PSCore::InitWorker( const std::string& dist_desc, const std::map>& regions, const std::vector* host_sign_list, int node_num, int index) { google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); - init_gflag(_ps_param.init_gflags()); + InitGFlag(_ps_param.init_gflags()); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(host_sign_list, node_num); + _ps_env.SetPsServers(host_sign_list, node_num); int ret = 0; - VLOG(1) << "PSCore::init_worker"; + VLOG(1) << "PSCore::InitWorker"; auto* communicator = Communicator::GetInstance(); - ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env, + ret = communicator->GetPsClient()->Configure(_ps_param, regions, _ps_env, index); communicator->Start(); return ret; } -std::vector PSCore::get_client_info() { - return _ps_env.get_client_info(); +std::vector PSCore::GetClientInfo() { + return _ps_env.GetClientInfo(); } -int PSCore::create_client2client_connection(int pserver_timeout_ms, - int pserver_connect_timeout_ms, - int max_retry) { - int ret = _worker_ptr->create_client2client_connection( +int PSCore::CreateClient2ClientConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry) { + int ret = _worker_ptr->CreateClient2ClientConnection( pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); return ret; } -uint64_t PSCore::run_server(const std::string& ip, uint32_t port) { - return _server_ptr->start(ip, port); +uint64_t PSCore::RunServer(const std::string& ip, uint32_t port) { + return _server_ptr->Start(ip, port); } -int PSCore::finalize_worker() { - _worker_ptr->finalize_worker(); +int PSCore::FinalizeWorker() { + _worker_ptr->FinalizeWorker(); return 0; } -int PSCore::stop_server() { - auto stop_status = _worker_ptr->stop_server(); +int PSCore::StopServer() { + auto stop_status = _worker_ptr->StopServer(); stop_status.wait(); return 0; } -paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; } +paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/ps_service/service.h b/paddle/fluid/distributed/ps/service/ps_service/service.h index 202c2407f15ae9fbf5087b55a65f6acd2957ddc5..112fdc3e141838d5a3d20d8e43bebc3a04dc5d6b 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/service.h +++ b/paddle/fluid/distributed/ps/service/ps_service/service.h @@ -42,31 +42,31 @@ class PSCore { explicit PSCore() {} virtual ~PSCore() {} - virtual int init_server( + virtual int InitServer( const std::string& dist_desc, const std::vector* host_sign_list, int node_num, int index, int trainers, const std::vector& server_sub_program = {}); - virtual int init_worker( + virtual int InitWorker( const std::string& dist_desc, const std::map>& regions, const std::vector* host_sign_list, int node_num, int index); - virtual uint64_t run_server(const std::string& ip, uint32_t port); - virtual int stop_server(); - virtual int finalize_worker(); - virtual std::vector get_client_info(); - virtual int create_client2client_connection(int pserver_timeout_ms, - int pserver_connect_timeout_ms, - int max_retry); + virtual uint64_t RunServer(const std::string& ip, uint32_t port); + virtual int StopServer(); + virtual int FinalizeWorker(); + virtual std::vector GetClientInfo(); + virtual int CreateClient2ClientConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry); std::shared_ptr _server_ptr; // pointer to server std::shared_ptr _worker_ptr; // pointer to worker - virtual paddle::distributed::PSParameter* get_param(); + virtual paddle::distributed::PSParameter* GetParam(); private: - void init_gflag(const std::string& gflags); + void InitGFlag(const std::string& gflags); paddle::distributed::PSParameter _ps_param; paddle::distributed::PaddlePSEnvironment _ps_env; }; diff --git a/paddle/fluid/distributed/ps/service/server.cc b/paddle/fluid/distributed/ps/service/server.cc index 893f671359e40ce632185c78bade16404d23afc0..65f7ae821cef1ace041711cb1bab9794935c6dfa 100644 --- a/paddle/fluid/distributed/ps/service/server.cc +++ b/paddle/fluid/distributed/ps/service/server.cc @@ -29,7 +29,7 @@ REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService); REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer); REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService); -PSServer *PSServerFactory::create(const PSParameter &ps_config) { +PSServer *PSServerFactory::Create(const PSParameter &ps_config) { const auto &config = ps_config.server_param(); if (!config.has_downpour_server_param()) { @@ -56,18 +56,18 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) { << service_param.server_class(); return NULL; } - TableManager::instance().initialize(); + TableManager::Instance().Initialize(); return server; } -int32_t PSServer::configure( +int32_t PSServer::Configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, const std::vector &server_sub_program) { scope_.reset(new framework::Scope()); _config = config.server_param(); _rank = server_rank; _environment = &env; - size_t shard_num = env.get_ps_servers().size(); + size_t shard_num = env.GetPsServers().size(); const auto &downpour_param = _config.downpour_server_param(); @@ -87,21 +87,21 @@ int32_t PSServer::configure( global_step_table = downpour_param.downpour_table_param(i).table_id(); } - table->set_program_env(scope_.get(), place_, &server_sub_program); - table->set_shard(_rank, shard_num); - table->initialize(downpour_param.downpour_table_param(i), + table->SetProgramEnv(scope_.get(), place_, &server_sub_program); + table->SetShard(_rank, shard_num); + table->Initialize(downpour_param.downpour_table_param(i), config.fs_client_param()); _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); } if (barrier_table != UINT32_MAX) { - _table_map[barrier_table]->set_table_map(&_table_map); + _table_map[barrier_table]->SetTableMap(&_table_map); } if (global_step_table != UINT32_MAX) { - _table_map[global_step_table]->set_table_map(&_table_map); + _table_map[global_step_table]->SetTableMap(&_table_map); } - return initialize(); + return Initialize(); } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/server.h b/paddle/fluid/distributed/ps/service/server.h index d2804405b41989cbd9b5bed0afaf6d481d0658db..5da819326b05260630c22d73f074d75915130211 100644 --- a/paddle/fluid/distributed/ps/service/server.h +++ b/paddle/fluid/distributed/ps/service/server.h @@ -65,19 +65,19 @@ class PSServer { PSServer(PSServer &&) = delete; PSServer(const PSServer &) = delete; - virtual int32_t configure( + virtual int32_t Configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, const std::vector &server_sub_program = {}); - virtual uint64_t start(const std::string &ip, uint32_t port) = 0; - virtual int32_t stop() = 0; + virtual uint64_t Start(const std::string &ip, uint32_t port) = 0; + virtual int32_t Stop() = 0; - inline size_t rank() const { return _rank; } + inline size_t Rank() const { return _rank; } - inline PSEnvironment *environment() { return _environment; } + inline PSEnvironment *Environment() { return _environment; } - inline const ServerParameter *config() const { return &_config; } - inline Table *table(size_t table_id) { + inline const ServerParameter *Config() const { return &_config; } + inline Table *GetTable(size_t table_id) { auto itr = _table_map.find(table_id); if (itr != _table_map.end()) { return itr->second.get(); @@ -85,12 +85,12 @@ class PSServer { return NULL; } - inline std::unordered_map> *table() { + inline std::unordered_map> *GetTable() { return &_table_map; } protected: - virtual int32_t initialize() = 0; + virtual int32_t Initialize() = 0; protected: size_t _rank; @@ -129,11 +129,11 @@ class PsBaseService : public PsService { public: PsBaseService() : _rank(0), _server(NULL), _config(NULL) {} virtual ~PsBaseService() {} - virtual size_t get_rank() { return _rank; } - virtual int32_t configure(PSServer *server) { + virtual size_t GetRank() { return _rank; } + virtual int32_t Configure(PSServer *server) { _server = server; - _rank = _server->rank(); - _config = _server->config(); + _rank = _server->Rank(); + _config = _server->Config(); return 0; } virtual void service(::google::protobuf::RpcController *controller, @@ -148,8 +148,8 @@ class PsBaseService : public PsService { LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg; } - virtual int32_t initialize() = 0; - PSServer *get_server() { return _server; } + virtual int32_t Initialize() = 0; + PSServer *GetServer() { return _server; } protected: size_t _rank; @@ -160,7 +160,7 @@ REGISTER_PSCORE_REGISTERER(PsBaseService); class PSServerFactory { public: - static PSServer *create(const PSParameter &config); + static PSServer *Create(const PSParameter &config); }; } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/barrier_table.cc b/paddle/fluid/distributed/ps/table/barrier_table.cc index 25838e7ac2f047d9ff7bf20705459c6b1d60d26f..b9d0345313cc3728f828ffecd439f3890a436cb8 100644 --- a/paddle/fluid/distributed/ps/table/barrier_table.cc +++ b/paddle/fluid/distributed/ps/table/barrier_table.cc @@ -17,7 +17,7 @@ namespace paddle { namespace distributed { -int32_t BarrierTable::initialize() { +int32_t BarrierTable::Initialize() { auto trainers = _config.common().trainer_num(); trigger_.store(trainers); @@ -29,7 +29,7 @@ int32_t BarrierTable::initialize() { } // 0: send_barrier 1: recv_barrier 2: complete -int32_t BarrierTable::barrier(const uint32_t trainer_id, +int32_t BarrierTable::Barrier(const uint32_t trainer_id, const std::string barrier_type) { std::unique_lock lock(mutex_); @@ -56,7 +56,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id, VLOG(1) << "barrier table optimize begin"; for (auto& x : *table_map_) { auto table = x.second; - table->pour(); + table->Pour(); } VLOG(1) << "barrier table optimize done"; @@ -66,7 +66,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id, return 0; } -int32_t BarrierTable::set_table_map( +int32_t BarrierTable::SetTableMap( std::unordered_map>* table_map) { table_map_ = table_map; return 0; diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.cc b/paddle/fluid/distributed/ps/table/common_dense_table.cc index caec575e33eef16384df34e08b305ceac0619af8..f0cb586e45190660c5180f6ffeaa371487ae7ee6 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/common_dense_table.cc @@ -21,8 +21,8 @@ namespace distributed { int FLAGS_pslib_table_save_max_retry_dense = 3; -void CommonDenseTable::create_initializer(const std::string& attr, - const std::string& name) { +void CommonDenseTable::CreateInitializer(const std::string& attr, + const std::string& name) { auto slices = string::split_string(attr, "&"); if (slices[0] == "gaussian_random") { @@ -39,7 +39,7 @@ void CommonDenseTable::create_initializer(const std::string& attr, } } -int32_t CommonDenseTable::initialize() { +int32_t CommonDenseTable::Initialize() { _shards_task_pool.resize(task_pool_size_); for (int i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); @@ -49,12 +49,12 @@ int32_t CommonDenseTable::initialize() { VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync; _global_lr = new float(1.0); - initialize_value(); - initialize_optimizer(); + InitializeValue(); + InitializeOptimizer(); return 0; } -int32_t CommonDenseTable::initialize_value() { +int32_t CommonDenseTable::InitializeValue() { auto common = _config.common(); int size = static_cast(common.params().size()); values_.resize(size); @@ -70,7 +70,7 @@ int32_t CommonDenseTable::initialize_value() { auto& initializer = common.initializers()[x]; total_dim_ += dim; - create_initializer(initializer, varname); + CreateInitializer(initializer, varname); values_[x].resize(dim); names_index_[varname] = x; @@ -92,27 +92,27 @@ int32_t CommonDenseTable::initialize_value() { param_col_ids_.insert(param_col_ids_.begin() + 1, -1); } - VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_ + VLOG(1) << "CommonDenseTable::InitializeValue total dim: " << total_dim_ << " fixed_len_params_dim: " << fixed_len_params_dim_; pull_reservoir_ = ReservoirValue(param_dim_); return 0; } -int32_t CommonDenseTable::initialize_optimizer() { +int32_t CommonDenseTable::InitializeOptimizer() { auto common = _config.common(); auto name = common.name(); auto attrs = common.attributes(); if (name == "sgd") { optimizer_ = std::make_shared(common, &values_); - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); } else if (name == "adam") { optimizer_ = std::make_shared(common, &values_); - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); } else if (name == "adam_d2sum") { optimizer_ = std::make_shared(common, &values_); - // optimizer_->set_global_lr(_global_lr); //no use + // optimizer_->SetGlobalLR(_global_lr); //no use } else if (name == "sum") { optimizer_ = std::make_shared(common, &values_); } else if (name == "summary") { @@ -124,34 +124,34 @@ int32_t CommonDenseTable::initialize_optimizer() { return 0; } -int32_t CommonDenseTable::set_global_lr(float* lr) { +int32_t CommonDenseTable::SetGlobalLR(float* lr) { _global_lr = lr; - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); return 0; } int32_t CommonDenseTable::Pull(TableContext& context) { CHECK(context.value_type == Dense); float* pull_values = context.pull_context.values; - return pull_dense(pull_values, context.num); + return PullDense(pull_values, context.num); } int32_t CommonDenseTable::Push(TableContext& context) { CHECK(context.value_type == Dense); if (context.push_context.values != nullptr) { const float* values = context.push_context.values; - return push_dense(values, context.num); + return PushDense(values, context.num); } return 0; } -int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { +int32_t CommonDenseTable::PullDense(float* pull_values, size_t num) { std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), pull_values); return 0; } -int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { +int32_t CommonDenseTable::PushDenseParam(const float* values, size_t num) { PADDLE_ENFORCE_GE( num, param_dim_, paddle::platform::errors::InvalidArgument( @@ -160,14 +160,14 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { return 0; } -int32_t CommonDenseTable::pour() { +int32_t CommonDenseTable::Pour() { pull_reservoir_.avg(); - _push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); + _PushDense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); pull_reservoir_.reset(); return 0; } -int32_t CommonDenseTable::push_dense(const float* values, size_t num) { +int32_t CommonDenseTable::PushDense(const float* values, size_t num) { if (sync) { std::future task = _shards_task_pool[0]->enqueue([this, &values]() -> int { @@ -176,12 +176,12 @@ int32_t CommonDenseTable::push_dense(const float* values, size_t num) { }); task.wait(); } else { - _push_dense(values, num); + _PushDense(values, num); } return 0; } -int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { +int32_t CommonDenseTable::_PushDense(const float* values, size_t num) { PADDLE_ENFORCE_GE( num, param_dim_, paddle::platform::errors::InvalidArgument( @@ -195,7 +195,7 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { [this, shard_id, &buckets, &values]() -> int { auto begin = buckets[shard_id]; auto end = buckets[shard_id + 1]; - optimizer_->update(values, param_dim_, begin, end); + optimizer_->Update(values, param_dim_, begin, end); return 0; }); } @@ -207,12 +207,12 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { return 0; } -int32_t CommonDenseTable::load(const std::string& path, +int32_t CommonDenseTable::Load(const std::string& path, const std::string& param) { if (param_dim_ <= 0) { return 0; } - std::string table_path = table_dir(path); + std::string table_path = TableDir(path); auto file_list = _afs_client.list(table_path); std::sort(file_list.begin(), file_list.end()); for (auto ff : file_list) { @@ -314,7 +314,7 @@ int32_t CommonDenseTable::load(const std::string& path, return 0; } -int32_t CommonDenseTable::save(const std::string& path, +int32_t CommonDenseTable::Save(const std::string& path, const std::string& param) { int save_param = atoi(param.c_str()); uint32_t feasign_size; @@ -323,10 +323,10 @@ int32_t CommonDenseTable::save(const std::string& path, FsChannelConfig channel_config; if (_config.compress_in_save()) { channel_config.path = paddle::string::format_string( - "%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx); + "%s/part-%03d.gz", TableDir(path).c_str(), _shard_idx); } else { channel_config.path = paddle::string::format_string( - "%s/part-%03d", table_dir(path).c_str(), _shard_idx); + "%s/part-%03d", TableDir(path).c_str(), _shard_idx); } _afs_client.remove(channel_config.path); channel_config.converter = _value_accesor->Converter(save_param).converter; diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.h b/paddle/fluid/distributed/ps/table/common_dense_table.h index cad49a0a449c4735a74261574436a78789694d9b..8e4ff1ecaf487df8a8c86c1fc38b6e7bc441bb81 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.h +++ b/paddle/fluid/distributed/ps/table/common_dense_table.h @@ -34,29 +34,29 @@ class CommonDenseTable : public DenseTable { public: CommonDenseTable() {} virtual ~CommonDenseTable() {} - int32_t initialize() override; - int32_t initialize_shard() override { return 0; } - virtual void create_initializer(const std::string& attr, - const std::string& name); - virtual int32_t initialize_value(); - virtual int32_t initialize_optimizer(); + int32_t Initialize() override; + int32_t InitializeShard() override { return 0; } + virtual void CreateInitializer(const std::string& attr, + const std::string& name); + virtual int32_t InitializeValue(); + virtual int32_t InitializeOptimizer(); virtual int32_t Pull(TableContext& context); virtual int32_t Push(TableContext& context); - int32_t pull_dense(float* pull_values, size_t num) override; - int32_t push_dense_param(const float* values, size_t num) override; - int32_t push_dense(const float* values, size_t num) override; - int32_t pour() override; - int32_t set_global_lr(float* lr) override; + int32_t PullDense(float* pull_values, size_t num) override; + int32_t PushDenseParam(const float* values, size_t num) override; + int32_t PushDense(const float* values, size_t num) override; + int32_t Pour() override; + int32_t SetGlobalLR(float* lr) override; - int32_t load(const std::string& path, const std::string& param) override; - int32_t save(const std::string& path, const std::string& param) override; + int32_t Load(const std::string& path, const std::string& param) override; + int32_t Save(const std::string& path, const std::string& param) override; - int32_t flush() override { return 0; } - int32_t shrink(const std::string& param) override { return 0; } - void clear() override { return; } + int32_t Flush() override { return 0; } + int32_t Shrink(const std::string& param) override { return 0; } + void Clear() override { return; } protected: - int32_t _push_dense(const float* values, size_t num); + int32_t _PushDense(const float* values, size_t num); private: const int task_pool_size_ = 10; diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index dcce46270d02671d2f26ab20e5ac127a72b1a636..7aab679954709b774abe426e79cebfbc0af43b3f 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -448,7 +448,7 @@ int32_t GraphTable::load_graph_split_config(const std::string &path) { return 0; } -int32_t GraphTable::load(const std::string &path, const std::string ¶m) { +int32_t GraphTable::Load(const std::string &path, const std::string ¶m) { bool load_edge = (param[0] == 'e'); bool load_node = (param[0] == 'n'); if (load_edge) { @@ -1066,11 +1066,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, int32_t GraphTable::get_server_index_by_id(int64_t id) { return id % shard_num / shard_num_per_server; } -int32_t GraphTable::initialize(const TableParameter &config, +int32_t GraphTable::Initialize(const TableParameter &config, const FsClientParameter &fs_config) { LOG(INFO) << "in graphTable initialize"; _config = config; - if (initialize_accessor() != 0) { + if (InitializeAccessor() != 0) { LOG(WARNING) << "Table accessor initialize failed"; return -1; } @@ -1082,9 +1082,9 @@ int32_t GraphTable::initialize(const TableParameter &config, auto graph = config.graph_parameter(); shard_num = _config.shard_num(); LOG(INFO) << "in graphTable initialize over"; - return initialize(graph); + return Initialize(graph); } -int32_t GraphTable::initialize(const GraphParameter &graph) { +int32_t GraphTable::Initialize(const GraphParameter &graph) { #ifdef PADDLE_WITH_HETERPS if (graph.gpups_mode()) { gpups_mode = true; diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 72600b42b828247a84a344a606cbc08fe8e2b3ef..035a3de3eba6320bb3703a7dd402db8109d9efbe 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -280,7 +280,7 @@ class ScaledLRU { } } auto status = - thread_pool->enqueue([this]() -> int { return shrink(); }); + thread_pool->enqueue([this]() -> int { return Shrink(); }); status.wait(); } }); @@ -298,7 +298,7 @@ class ScaledLRU { LRUResponse insert(size_t index, K *keys, V *data, size_t length) { return lru_pool[index].insert(keys, data, length); } - int shrink() { + int Shrink() { int node_size = 0; for (size_t i = 0; i < lru_pool.size(); i++) { node_size += lru_pool[i].node_size - lru_pool[i].remove_count; @@ -329,7 +329,7 @@ class ScaledLRU { if (diff != 0) { __sync_fetch_and_add(&global_count, diff); if (global_count > int(1.25 * size_limit)) { - thread_pool->enqueue([this]() -> int { return shrink(); }); + thread_pool->enqueue([this]() -> int { return Shrink(); }); } } } @@ -430,11 +430,11 @@ class GraphTable : public SparseTable { virtual int32_t get_nodes_ids_by_ranges( std::vector> ranges, std::vector &res); - virtual int32_t initialize() { return 0; } - virtual int32_t initialize(const TableParameter &config, + virtual int32_t Initialize() { return 0; } + virtual int32_t Initialize(const TableParameter &config, const FsClientParameter &fs_config); - virtual int32_t initialize(const GraphParameter &config); - int32_t load(const std::string &path, const std::string ¶m); + virtual int32_t Initialize(const GraphParameter &config); + int32_t Load(const std::string &path, const std::string ¶m); int32_t load_graph_split_config(const std::string &path); int32_t load_edges(const std::string &path, bool reverse); @@ -452,26 +452,25 @@ class GraphTable : public SparseTable { virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - virtual int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) { + virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) { return 0; } - virtual int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) { + virtual int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) { return 0; } virtual int32_t clear_nodes(); - virtual void clear() {} - virtual int32_t flush() { return 0; } - virtual int32_t shrink(const std::string ¶m) { return 0; } + virtual void Clear() {} + virtual int32_t Flush() { return 0; } + virtual int32_t Shrink(const std::string ¶m) { return 0; } //指定保存路径 - virtual int32_t save(const std::string &path, const std::string &converter) { + virtual int32_t Save(const std::string &path, const std::string &converter) { return 0; } - virtual int32_t initialize_shard() { return 0; } - virtual int32_t set_shard(size_t shard_idx, size_t server_num) { + virtual int32_t InitializeShard() { return 0; } + virtual int32_t SetShard(size_t shard_idx, size_t server_num) { _shard_idx = shard_idx; /* _shard_num is not used in graph_table, this following operation is for the diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.cc b/paddle/fluid/distributed/ps/table/common_sparse_table.cc index 1fc8adc2b92ebd79544ba518382faa59989337d3..6b3d3a6ea1584d10fd950c3260a770daf894b8b0 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.cc @@ -167,7 +167,7 @@ int64_t CommonSparseTable::LoadFromText( return 0; } -int32_t CommonSparseTable::initialize() { +int32_t CommonSparseTable::Initialize() { _shards_task_pool.resize(task_pool_size_); for (int i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); @@ -200,15 +200,15 @@ int32_t CommonSparseTable::initialize() { offset += dim; } - initialize_value(); - initialize_optimizer(); - initialize_recorder(); + InitializeValue(); + InitializeOptimizer(); + InitializeRecorder(); return 0; } -int32_t CommonSparseTable::initialize_recorder() { return 0; } +int32_t CommonSparseTable::InitializeRecorder() { return 0; } -int32_t CommonSparseTable::initialize_value() { +int32_t CommonSparseTable::InitializeValue() { auto common = _config.common(); shard_values_.reserve(task_pool_size_); @@ -223,18 +223,18 @@ int32_t CommonSparseTable::initialize_value() { return 0; } -int32_t CommonSparseTable::initialize_optimizer() { +int32_t CommonSparseTable::InitializeOptimizer() { auto common = _config.common(); auto name = common.name(); if (name == "sgd") { optimizer_ = std::make_shared(value_names_, value_dims_, value_offsets_, value_idx_); - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); } else if (name == "adam") { optimizer_ = std::make_shared(value_names_, value_dims_, value_offsets_, value_idx_); - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); } else if (name == "sum") { optimizer_ = std::make_shared(value_names_, value_dims_, value_offsets_, value_idx_); @@ -246,13 +246,13 @@ int32_t CommonSparseTable::initialize_optimizer() { return 0; } -int32_t CommonSparseTable::set_global_lr(float* lr) { +int32_t CommonSparseTable::SetGlobalLR(float* lr) { _global_lr = lr; - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); return 0; } -int32_t CommonSparseTable::load(const std::string& dirname, +int32_t CommonSparseTable::Load(const std::string& dirname, const std::string& param) { auto begin = GetCurrentUS(); rwlock_->WRLock(); @@ -276,7 +276,7 @@ int32_t CommonSparseTable::load(const std::string& dirname, return 0; } -int32_t CommonSparseTable::save(const std::string& dirname, +int32_t CommonSparseTable::Save(const std::string& dirname, const std::string& param) { auto begin = GetCurrentUS(); rwlock_->WRLock(); @@ -322,7 +322,7 @@ int32_t CommonSparseTable::save(const std::string& dirname, return 0; } -std::pair CommonSparseTable::print_table_stat() { +std::pair CommonSparseTable::PrintTableStat() { int64_t feasign_size = 0; int64_t mf_size = 0; @@ -335,7 +335,7 @@ std::pair CommonSparseTable::print_table_stat() { return {feasign_size, mf_size}; } -int32_t CommonSparseTable::pour() { +int32_t CommonSparseTable::Pour() { std::vector values; std::vector keys; @@ -349,7 +349,7 @@ int32_t CommonSparseTable::pour() { std::copy(reservoir.values.begin(), reservoir.values.end(), std::back_inserter(values)); } - _push_sparse(keys.data(), values.data(), pull_reservoir_.size()); + _PushSparse(keys.data(), values.data(), pull_reservoir_.size()); pull_reservoir_.clear(); return 0; @@ -360,11 +360,11 @@ int32_t CommonSparseTable::Pull(TableContext& context) { if (context.use_ptr) { char** pull_values = context.pull_context.ptr_values; const uint64_t* keys = context.pull_context.keys; - return pull_sparse_ptr(pull_values, keys, context.num); + return PullSparsePtr(pull_values, keys, context.num); } else { float* pull_values = context.pull_context.values; const PullSparseValue& pull_value = context.pull_context.pull_value; - return pull_sparse(pull_values, pull_value); + return PullSparse(pull_values, pull_value); } } @@ -373,16 +373,16 @@ int32_t CommonSparseTable::Push(TableContext& context) { if (context.push_context.values != nullptr) { const float* values = context.push_context.values; const uint64_t* keys = context.push_context.keys; - return push_sparse(keys, values, context.num); + return PushSparse(keys, values, context.num); } else { const float** values = context.push_context.ptr_values; const uint64_t* keys = context.push_context.keys; - return push_sparse(keys, values, context.num); + return PushSparse(keys, values, context.num); } } -int32_t CommonSparseTable::pull_sparse(float* pull_values, - const PullSparseValue& pull_value) { +int32_t CommonSparseTable::PullSparse(float* pull_values, + const PullSparseValue& pull_value) { auto shard_num = task_pool_size_; std::vector> tasks(shard_num); @@ -421,8 +421,8 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, return 0; } -int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, - const uint64_t* keys, size_t num) { +int32_t CommonSparseTable::PullSparsePtr(char** pull_values, + const uint64_t* keys, size_t num) { std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -458,8 +458,8 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, return 0; } -int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, - const float* values, size_t num) { +int32_t CommonSparseTable::_PushSparse(const uint64_t* keys, + const float* values, size_t num) { std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -474,7 +474,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( [this, shard_id, &keys, &values, num, &offset_bucket]() -> int { auto& offsets = offset_bucket[shard_id]; - optimizer_->update(keys, values, num, offsets, + optimizer_->Update(keys, values, num, offsets, shard_values_[shard_id].get()); return 0; }); @@ -486,8 +486,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, return 0; } -int32_t CommonSparseTable::push_sparse(const uint64_t* keys, - const float* values, size_t num) { +int32_t CommonSparseTable::PushSparse(const uint64_t* keys, const float* values, + size_t num) { if (sync) { std::future task = _shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int { @@ -506,20 +506,20 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys, }); task.wait(); } else { - _push_sparse(keys, values, num); + _PushSparse(keys, values, num); } return 0; } -int32_t CommonSparseTable::push_sparse(const uint64_t* keys, - const float** values, size_t num) { - _push_sparse(keys, values, num); +int32_t CommonSparseTable::PushSparse(const uint64_t* keys, + const float** values, size_t num) { + _PushSparse(keys, values, num); return 0; } -int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, - const float** values, size_t num) { +int32_t CommonSparseTable::_PushSparse(const uint64_t* keys, + const float** values, size_t num) { std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -536,7 +536,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, auto& offsets = offset_bucket[shard_id]; for (size_t i = 0; i < offsets.size(); ++i) { std::vector tmp_off = {0}; - optimizer_->update(keys + offsets[i], values[offsets[i]], num, + optimizer_->Update(keys + offsets[i], values[offsets[i]], num, tmp_off, shard_values_[shard_id].get()); } return 0; @@ -549,8 +549,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, return 0; } -int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, - const float* values, size_t num) { +int32_t CommonSparseTable::PushSparseParam(const uint64_t* keys, + const float* values, size_t num) { std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -585,21 +585,21 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, return 0; } -int32_t CommonSparseTable::flush() { return 0; } +int32_t CommonSparseTable::Flush() { return 0; } -int32_t CommonSparseTable::shrink(const std::string& param) { +int32_t CommonSparseTable::Shrink(const std::string& param) { int threshold = std::stoi(param); - VLOG(3) << "sparse table shrink: " << threshold; + VLOG(3) << "sparse table Shrink: " << threshold; for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { - // shrink - VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink"; + // Shrink + VLOG(4) << shard_id << " " << task_pool_size_ << " begin Shrink"; shard_values_[shard_id]->Shrink(threshold); } return 0; } -void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; } +void CommonSparseTable::Clear() { VLOG(0) << "clear coming soon"; } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.h b/paddle/fluid/distributed/ps/table/common_sparse_table.h index 138c5447420663eae5ad94ea03a84360a46f8b3d..f6deaf0a82b138cebff4ffed48992c58bd54b25f 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.h @@ -114,25 +114,23 @@ class CommonSparseTable : public SparseTable { virtual ~CommonSparseTable() {} // unused method begin - virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } - virtual int32_t push_dense_param(const float* values, size_t num) { - return 0; - } - virtual int32_t push_dense(const float* values, size_t num) { return 0; } + virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } + virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } + virtual int32_t PushDense(const float* values, size_t num) { return 0; } // unused method end virtual int32_t Pull(TableContext& context); virtual int32_t Push(TableContext& context); - virtual int32_t initialize(); - virtual int32_t initialize_shard() { return 0; } - virtual int32_t initialize_value(); - virtual int32_t initialize_optimizer(); - virtual int32_t initialize_recorder(); + virtual int32_t Initialize(); + virtual int32_t InitializeShard() { return 0; } + virtual int32_t InitializeValue(); + virtual int32_t InitializeOptimizer(); + virtual int32_t InitializeRecorder(); - virtual int32_t load(const std::string& path, const std::string& param); + virtual int32_t Load(const std::string& path, const std::string& param); - virtual int32_t save(const std::string& path, const std::string& param); + virtual int32_t Save(const std::string& path, const std::string& param); void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, const size_t shard_idx, const int64_t total); @@ -150,34 +148,34 @@ class CommonSparseTable : public SparseTable { const int pserver_id, const int pserver_num, const int local_shard_num, std::vector>* blocks); - virtual std::pair print_table_stat(); - virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + virtual std::pair PrintTableStat(); + virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); - virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, - size_t num); + virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, + size_t num); - virtual int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num); - virtual int32_t push_sparse(const uint64_t* keys, const float** values, - size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float** values, + size_t num); // only for sparse geo table - virtual int32_t push_sparse_param(const uint64_t* keys, const float* values, - size_t num); + virtual int32_t PushSparseParam(const uint64_t* keys, const float* values, + size_t num); - virtual int32_t set_global_lr(float* lr) override; + virtual int32_t SetGlobalLR(float* lr) override; - virtual int32_t pour(); - virtual int32_t flush(); - virtual int32_t shrink(const std::string& param); - virtual void clear(); + virtual int32_t Pour(); + virtual int32_t Flush(); + virtual int32_t Shrink(const std::string& param); + virtual void Clear(); protected: - virtual int32_t _push_sparse(const uint64_t* keys, const float* values, - size_t num); - virtual int32_t _push_sparse(const uint64_t* keys, const float** values, - size_t num); + virtual int32_t _PushSparse(const uint64_t* keys, const float* values, + size_t num); + virtual int32_t _PushSparse(const uint64_t* keys, const float** values, + size_t num); protected: const int task_pool_size_ = 11; diff --git a/paddle/fluid/distributed/ps/table/common_table.h b/paddle/fluid/distributed/ps/table/common_table.h index 3d291c0152246bffa748ea57cf1c96eff6f2f343..f5e263e8e71891b0cd2ced4942f7d914dff178d0 100644 --- a/paddle/fluid/distributed/ps/table/common_table.h +++ b/paddle/fluid/distributed/ps/table/common_table.h @@ -71,11 +71,11 @@ class SparseTable : public Table { SparseTable() {} virtual ~SparseTable() {} - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } static int32_t sparse_local_shard_num(uint32_t shard_num, uint32_t server_num) { @@ -97,19 +97,17 @@ class DenseTable : public Table { DenseTable() {} virtual ~DenseTable() {} - virtual void *get_shard(size_t shard_idx) { return 0; } - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { + virtual void *GetShard(size_t shard_idx) { return 0; } + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t push_dense_param(const float *values, size_t num) override { - return 0; - } - int32_t shrink(const std::string ¶m) override { return 0; } + int32_t PushDenseParam(const float *values, size_t num) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } }; class BarrierTable : public Table { @@ -117,44 +115,42 @@ class BarrierTable : public Table { BarrierTable() {} virtual ~BarrierTable() {} - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_dense_param(const float *values, size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t shrink(const std::string ¶m) override { return 0; } - virtual void clear() {} - virtual int32_t flush() { return 0; } - virtual int32_t load(const std::string &path, const std::string ¶m) { + int32_t PushDenseParam(const float *values, size_t num) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } + virtual void Clear() {} + virtual int32_t Flush() { return 0; } + virtual int32_t Load(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t save(const std::string &path, const std::string ¶m) { + virtual int32_t Save(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t initialize_shard() { return 0; } + virtual int32_t InitializeShard() { return 0; } - virtual int32_t initialize() override; + virtual int32_t Initialize() override; // only for barrier // 0: send_barrier 1: recv_barrier 2: complete - virtual int32_t barrier(const uint32_t trainer_id, + virtual int32_t Barrier(const uint32_t trainer_id, const std::string barrier_type) override; - virtual int32_t set_table_map( + virtual int32_t SetTableMap( std::unordered_map> *table_map) override; private: diff --git a/paddle/fluid/distributed/ps/table/depends/dense.h b/paddle/fluid/distributed/ps/table/depends/dense.h index 8661eb1feecc83cc3d58c71c9bba8874e63d093d..258c0f4b6a4e6b96586254275711b579b80ea006 100644 --- a/paddle/fluid/distributed/ps/table/depends/dense.h +++ b/paddle/fluid/distributed/ps/table/depends/dense.h @@ -34,9 +34,9 @@ class DenseOptimizer { DenseOptimizer() {} explicit DenseOptimizer(const CommonAccessorParameter& accessor, std::vector>* values) {} - virtual void update(const float* update_values, size_t num, int begin, + virtual void Update(const float* update_values, size_t num, int begin, int end) = 0; - virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; } + virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; } protected: float* global_learning_rate_; @@ -55,7 +55,7 @@ class DSUM : public DenseOptimizer { } } - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; GetBlas().VADD(update_numel, update_values + begin, param + begin, @@ -81,7 +81,7 @@ class DSGD : public DenseOptimizer { } } - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; std::vector grads; @@ -134,7 +134,7 @@ class DAdam : public DenseOptimizer { // make sure common_dense_table.task_pool_size_ == 1; // otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; std::vector grad, grad2, tmp; @@ -214,7 +214,7 @@ class DAdamD2Sum : public DenseOptimizer { } } - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; Eigen::Map mat_ada_g2sum(ada_g2sum + begin, 1, @@ -276,7 +276,7 @@ class DSummary : public DenseOptimizer { } } - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; Eigen::Map mat_w(param + begin, 1, update_numel); diff --git a/paddle/fluid/distributed/ps/table/depends/sparse.h b/paddle/fluid/distributed/ps/table/depends/sparse.h index d4ea7829e45f8326fdbe33ebb1c7c9cfa3d35f6f..7eed5ab6c794bf2cdba0f34fd6911a32471d5fcb 100644 --- a/paddle/fluid/distributed/ps/table/depends/sparse.h +++ b/paddle/fluid/distributed/ps/table/depends/sparse.h @@ -40,11 +40,11 @@ class SparseOptimizer { value_offsets_(value_offsets), value_idx_(value_idx) {} - virtual void update(const uint64_t* keys, const float* update_values, + virtual void Update(const uint64_t* keys, const float* update_values, size_t num, const std::vector& offsets, ValueBlock* block) = 0; - virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; } + virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; } const std::vector& value_names_; const std::vector& value_dims_; @@ -70,7 +70,7 @@ class SSUM : public SparseOptimizer { update_numel = value_dims.at(idx); } - void update(const uint64_t* keys, const float* update_values, size_t num, + void Update(const uint64_t* keys, const float* update_values, size_t num, const std::vector& offsets, ValueBlock* block) override { auto blas = GetBlas(); @@ -100,7 +100,7 @@ class SSGD : public SparseOptimizer { lr_offset = value_offsets.at(idx); } - void update(const uint64_t* keys, const float* update_values, size_t num, + void Update(const uint64_t* keys, const float* update_values, size_t num, const std::vector& offsets, ValueBlock* block) override { auto blas = GetBlas(); @@ -156,7 +156,7 @@ class SAdam : public SparseOptimizer { epsilon = 1.0e-8; } - void update(const uint64_t* keys, const float* update_values, size_t num, + void Update(const uint64_t* keys, const float* update_values, size_t num, const std::vector& offsets, ValueBlock* block) override { auto blas = GetBlas(); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc index f16f4fc7f34a533ec9539e9f68a49aace1b27c7b..979e1c482547c67f66cfa1d3ea82ec1a1a8e78a4 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc @@ -17,11 +17,10 @@ namespace paddle { namespace distributed { -int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, - const float* values, - size_t num) { - VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param begin " - "push_sparse_param " +int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys, + const float* values, size_t num) { + VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin " + "PushSparseParam " << num; auto shard_num = _task_pool_size; std::vector> offset_bucket; @@ -31,8 +30,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, auto y = keys[x] % shard_num; offset_bucket[y].push_back(x); if (x < 10) { - VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param key: " - << keys[x] << " shard: " << y; + VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam key: " << keys[x] + << " shard: " << y; } } @@ -51,8 +50,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, feature_value.resize(_dim); std::copy_n(values + _dim * offset, _dim, feature_value.data()); if (i < 10) { - VLOG(5) << "MemorySparseGeoTable::push_sparse_param " - "push_sparse_param key " + VLOG(5) << "MemorySparseGeoTable::PushSparseParam " + "PushSparseParam key " << id << " value[0]: " << (values + _dim * offset)[0] << " data: " << feature_value.data()[0] << " value[-1]: " << (values + _dim * offset)[_dim - 1] @@ -69,9 +68,9 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, return 0; } -int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, - std::vector* values, - std::vector* ids) { +int32_t MemorySparseGeoTable::PullGeoParam(const uint32_t trainer_id, + std::vector* values, + std::vector* ids) { _geo_recorder->GetAndClear(trainer_id, ids); VLOG(5) << "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id " @@ -86,23 +85,23 @@ int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, pull_value.frequencies_ = frequencies.data(); values->resize(ids->size() * _dim); - pull_sparse(values->data(), pull_value); + PullSparse(values->data(), pull_value); return 0; } -int32_t MemorySparseGeoTable::push_sparse(const uint64_t* keys, - const float* values, size_t num) { - VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse keys[0]" << keys[0] +int32_t MemorySparseGeoTable::PushSparse(const uint64_t* keys, + const float* values, size_t num) { + VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparse keys[0]" << keys[0] << " key_num: " << num; std::vector ids; ids.resize(num); std::copy_n(keys, num, ids.begin()); _geo_recorder->Update(ids); - _push_sparse(keys, values, num); + _PushSparse(keys, values, num); return 0; } -int32_t MemorySparseGeoTable::initialize() { +int32_t MemorySparseGeoTable::Initialize() { if (!_geo_recorder) { auto trainers = _config.common().trainer_num(); _geo_recorder = std::make_shared(trainers); @@ -118,8 +117,8 @@ int32_t MemorySparseGeoTable::initialize() { return 0; } -int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, - const PullSparseValue& pull_value) { +int32_t MemorySparseGeoTable::PullSparse(float* pull_values, + const PullSparseValue& pull_value) { auto shard_num = _task_pool_size; std::vector> tasks(shard_num); @@ -146,13 +145,13 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, auto& feature_value = local_shard[key]; feature_value.resize(_dim); memset(feature_value.data(), 0, sizeof(float) * _dim); - VLOG(0) << "MemorySparseGeoTable pull_sparse key not found!!! " + VLOG(0) << "MemorySparseGeoTable PullSparse key not found!!! " << key; itr = local_shard.find(key); } memcpy(select_data, itr.value().data(), _dim * sizeof(float)); - VLOG(5) << "DEBUG MemorySparseGeoTable::pull_sparse key: " << key + VLOG(5) << "DEBUG MemorySparseGeoTable::PullSparse key: " << key << " select_data[0] " << select_data[0] << " value[0]: " << itr.value().data()[0]; } @@ -167,8 +166,8 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, return 0; } -int32_t MemorySparseGeoTable::_push_sparse(const uint64_t* keys, - const float* values, size_t num) { +int32_t MemorySparseGeoTable::_PushSparse(const uint64_t* keys, + const float* values, size_t num) { auto shard_num = _task_pool_size; std::vector> tasks(shard_num); std::vector>> task_keys(shard_num); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h index 3b43f99543fddabfaa24fc7da562203fc3f0d633..1a74df32db8e72473e53b7ec882b92e909ab754d 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h @@ -40,31 +40,31 @@ class MemorySparseGeoTable : public SparseTable { MemorySparseGeoTable() { _geo_recorder = nullptr; } virtual ~MemorySparseGeoTable() {} - virtual int32_t initialize(); - virtual int32_t initialize_shard() { return 0; } - virtual int32_t load(const std::string& path, const std::string& param) { + virtual int32_t Initialize(); + virtual int32_t InitializeShard() { return 0; } + virtual int32_t Load(const std::string& path, const std::string& param) { return 0; } - virtual int32_t save(const std::string& path, const std::string& param) { + virtual int32_t Save(const std::string& path, const std::string& param) { return 0; } virtual int32_t Pull(TableContext& context) { return 0; } virtual int32_t Push(TableContext& context) { return 0; } - virtual int32_t flush() { return 0; } - virtual int32_t shrink(const std::string& param) { return 0; } - virtual void clear() { return; } - virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + virtual int32_t Flush() { return 0; } + virtual int32_t Shrink(const std::string& param) { return 0; } + virtual void Clear() { return; } + virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); - int32_t push_sparse_param(const uint64_t* keys, const float* values, - size_t num); + int32_t PushSparseParam(const uint64_t* keys, const float* values, + size_t num); // TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse - int32_t pull_geo_param(const uint32_t trainer_id, std::vector* values, - std::vector* keys); + int32_t PullGeoParam(const uint32_t trainer_id, std::vector* values, + std::vector* keys); - int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num) override; + int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num) override; - int32_t _push_sparse(const uint64_t* keys, const float* values, size_t num); + int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num); // int32_t _pull_sparse(float* pull_values, const PullSparseValue& // pull_value); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index 61ea2f8f2007e7d44ac2176e08639a6022e532bf..97e3c008d9478330d91da333958ee1fb8842e9f8 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -31,7 +31,7 @@ bool FLAGS_pserver_create_value_when_push = true; int FLAGS_pserver_table_save_max_retry = 3; bool FLAGS_pserver_enable_create_feasign_randomly = false; -int32_t MemorySparseTable::initialize() { +int32_t MemorySparseTable::Initialize() { _shards_task_pool.resize(_task_pool_size); for (int i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); @@ -39,12 +39,12 @@ int32_t MemorySparseTable::initialize() { auto& profiler = CostProfiler::instance(); profiler.register_profiler("pserver_sparse_update_all"); profiler.register_profiler("pserver_sparse_select_all"); - initialize_value(); + InitializeValue(); VLOG(0) << "initalize MemorySparseTable succ"; return 0; } -int32_t MemorySparseTable::initialize_value() { +int32_t MemorySparseTable::InitializeValue() { _sparse_table_shard_num = static_cast(_config.shard_num()); _avg_local_shard_num = SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num); @@ -64,14 +64,14 @@ int32_t MemorySparseTable::initialize_value() { return 0; } -int32_t MemorySparseTable::load(const std::string& path, +int32_t MemorySparseTable::Load(const std::string& path, const std::string& param) { - std::string table_path = table_dir(path); + std::string table_path = TableDir(path); auto file_list = _afs_client.list(table_path); std::sort(file_list.begin(), file_list.end()); for (auto file : file_list) { - VLOG(1) << "MemorySparseTable::load() file list: " << file; + VLOG(1) << "MemorySparseTable::Load() file list: " << file; } int load_param = atoi(param.c_str()); @@ -154,9 +154,9 @@ int32_t MemorySparseTable::load(const std::string& path, return 0; } -int32_t MemorySparseTable::load_local_fs(const std::string& path, - const std::string& param) { - std::string table_path = table_dir(path); +int32_t MemorySparseTable::LoadLocalFS(const std::string& path, + const std::string& param) { + std::string table_path = TableDir(path); auto file_list = paddle::framework::localfs_list(table_path); int load_param = atoi(param.c_str()); @@ -225,12 +225,12 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path, return 0; } -int32_t MemorySparseTable::save(const std::string& dirname, +int32_t MemorySparseTable::Save(const std::string& dirname, const std::string& param) { VLOG(0) << "MemorySparseTable::save dirname: " << dirname; int save_param = atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 - std::string table_path = table_dir(dirname); + std::string table_path = TableDir(dirname); _afs_client.remove(paddle::string::format_string( "%s/part-%03d-*", table_path.c_str(), _shard_idx)); std::atomic feasign_size_all{0}; @@ -309,12 +309,12 @@ int32_t MemorySparseTable::save(const std::string& dirname, return 0; } -int32_t MemorySparseTable::save_local_fs(const std::string& dirname, - const std::string& param, - const std::string& prefix) { +int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname, + const std::string& param, + const std::string& prefix) { int save_param = atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 - std::string table_path = table_dir(dirname); + std::string table_path = TableDir(dirname); int feasign_cnt = 0; size_t file_start_idx = _avg_local_shard_num * _shard_idx; @@ -349,7 +349,7 @@ int32_t MemorySparseTable::save_local_fs(const std::string& dirname, return 0; } -int64_t MemorySparseTable::local_size() { +int64_t MemorySparseTable::LocalSize() { int64_t local_size = 0; for (size_t i = 0; i < _real_local_shard_num; ++i) { local_size += _local_shards[i].size(); @@ -357,7 +357,7 @@ int64_t MemorySparseTable::local_size() { return local_size; } -int64_t MemorySparseTable::local_mf_size() { +int64_t MemorySparseTable::LocalMFSize() { std::vector size_arr(_real_local_shard_num, 0); std::vector> tasks(_real_local_shard_num); int64_t ret_size = 0; @@ -384,9 +384,9 @@ int64_t MemorySparseTable::local_mf_size() { return ret_size; } -std::pair MemorySparseTable::print_table_stat() { - int64_t feasign_size = local_size(); - int64_t mf_size = local_mf_size(); +std::pair MemorySparseTable::PrintTableStat() { + int64_t feasign_size = LocalSize(); + int64_t mf_size = LocalMFSize(); return {feasign_size, mf_size}; } @@ -395,11 +395,11 @@ int32_t MemorySparseTable::Pull(TableContext& context) { if (context.use_ptr) { char** pull_values = context.pull_context.ptr_values; const uint64_t* keys = context.pull_context.keys; - return pull_sparse_ptr(pull_values, keys, context.num); + return PullSparsePtr(pull_values, keys, context.num); } else { float* pull_values = context.pull_context.values; const PullSparseValue& pull_value = context.pull_context.pull_value; - return pull_sparse(pull_values, pull_value); + return PullSparse(pull_values, pull_value); } } @@ -407,11 +407,11 @@ int32_t MemorySparseTable::Push(TableContext& context) { CHECK(context.value_type == Sparse); const uint64_t* keys = context.push_context.keys; - return push_sparse(keys, context.push_context.values, context.num); + return PushSparse(keys, context.push_context.values, context.num); } -int32_t MemorySparseTable::pull_sparse(float* pull_values, - const PullSparseValue& pull_value) { +int32_t MemorySparseTable::PullSparse(float* pull_values, + const PullSparseValue& pull_value) { CostTimer timer("pserver_sparse_select_all"); std::vector> tasks(_real_local_shard_num); @@ -479,8 +479,8 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, return 0; } -int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, - const uint64_t* keys, size_t num) { +int32_t MemorySparseTable::PullSparsePtr(char** pull_values, + const uint64_t* keys, size_t num) { CostTimer timer("pscore_sparse_select_all"); size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float); size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); @@ -530,8 +530,8 @@ int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, return 0; } -int32_t MemorySparseTable::push_sparse(const uint64_t* keys, - const float* values, size_t num) { +int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, + size_t num) { CostTimer timer("pserver_sparse_update_all"); std::vector> tasks(_real_local_shard_num); std::vector>> task_keys( @@ -603,14 +603,14 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys, return 0; } -int32_t MemorySparseTable::push_sparse(const uint64_t* keys, - const float** values, size_t num) { - _push_sparse(keys, values, num); +int32_t MemorySparseTable::PushSparse(const uint64_t* keys, + const float** values, size_t num) { + _PushSparse(keys, values, num); return 0; } -int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, - const float** values, size_t num) { +int32_t MemorySparseTable::_PushSparse(const uint64_t* keys, + const float** values, size_t num) { std::vector> tasks(_real_local_shard_num); std::vector>> task_keys( _real_local_shard_num); @@ -677,13 +677,13 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, return 0; } -int32_t MemorySparseTable::flush() { return 0; } +int32_t MemorySparseTable::Flush() { return 0; } -int32_t MemorySparseTable::shrink(const std::string& param) { - VLOG(0) << "MemorySparseTable::shrink"; +int32_t MemorySparseTable::Shrink(const std::string& param) { + VLOG(0) << "MemorySparseTable::Shrink"; // TODO(zhaocaibei123): implement with multi-thread for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { - // shrink + // Shrink auto& shard = _local_shards[shard_id]; for (auto it = shard.begin(); it != shard.end();) { if (_value_accesor->Shrink(it.value().data())) { @@ -696,7 +696,7 @@ int32_t MemorySparseTable::shrink(const std::string& param) { return 0; } -void MemorySparseTable::clear() { VLOG(0) << "clear coming soon"; } +void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_table.h index d26c67319760da0496ae8a1c164adf0d5b63b1f2..a4af4caa472d75a59d0e0a4ae7d313437861f41e 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.h @@ -41,50 +41,48 @@ class MemorySparseTable : public SparseTable { virtual ~MemorySparseTable() {} // unused method begin - virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } - virtual int32_t push_dense_param(const float* values, size_t num) { - return 0; - } - virtual int32_t push_dense(const float* values, size_t num) { return 0; } + virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } + virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } + virtual int32_t PushDense(const float* values, size_t num) { return 0; } // unused method end virtual int32_t Pull(TableContext& context); virtual int32_t Push(TableContext& context); - virtual int32_t initialize(); - virtual int32_t initialize_shard() { return 0; } - virtual int32_t initialize_value(); + virtual int32_t Initialize(); + virtual int32_t InitializeShard() { return 0; } + virtual int32_t InitializeValue(); - virtual int32_t load(const std::string& path, const std::string& param); + virtual int32_t Load(const std::string& path, const std::string& param); - virtual int32_t save(const std::string& path, const std::string& param); + virtual int32_t Save(const std::string& path, const std::string& param); - int32_t load_local_fs(const std::string& path, const std::string& param); - int32_t save_local_fs(const std::string& path, const std::string& param, - const std::string& prefix); + int32_t LoadLocalFS(const std::string& path, const std::string& param); + int32_t SaveLocalFS(const std::string& path, const std::string& param, + const std::string& prefix); - int64_t local_size(); - int64_t local_mf_size(); + int64_t LocalSize(); + int64_t LocalMFSize(); - virtual std::pair print_table_stat(); - virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + virtual std::pair PrintTableStat(); + virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); - virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, - size_t num); + virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, + size_t num); - virtual int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num); - virtual int32_t push_sparse(const uint64_t* keys, const float** values, - size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float** values, + size_t num); - virtual int32_t flush(); - virtual int32_t shrink(const std::string& param); - virtual void clear(); + virtual int32_t Flush(); + virtual int32_t Shrink(const std::string& param); + virtual void Clear(); protected: - virtual int32_t _push_sparse(const uint64_t* keys, const float** values, - size_t num); + virtual int32_t _PushSparse(const uint64_t* keys, const float** values, + size_t num); protected: const int _task_pool_size = 24; diff --git a/paddle/fluid/distributed/ps/table/sparse_geo_table.cc b/paddle/fluid/distributed/ps/table/sparse_geo_table.cc index 6ef4330113e8fee3d2cb0d3e541194ca7b600a82..de9628a5b52357aae4def8e9d30abeed4a9a0da0 100644 --- a/paddle/fluid/distributed/ps/table/sparse_geo_table.cc +++ b/paddle/fluid/distributed/ps/table/sparse_geo_table.cc @@ -17,9 +17,9 @@ namespace paddle { namespace distributed { -int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id, - std::vector* values, - std::vector* ids) { +int32_t SparseGeoTable::PullGeoParam(const uint32_t trainer_id, + std::vector* values, + std::vector* ids) { geo_recorder->GetAndClear(trainer_id, ids); auto dim = _config.common().dims()[0]; @@ -32,21 +32,21 @@ int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id, pull_value.frequencies_ = frequencies.data(); values->resize(ids->size() * dim); - CommonSparseTable::pull_sparse(values->data(), pull_value); + CommonSparseTable::PullSparse(values->data(), pull_value); return 0; } -int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values, - size_t num) { +int32_t SparseGeoTable::PushSparse(const uint64_t* keys, const float* values, + size_t num) { std::vector ids; ids.resize(num); std::copy_n(keys, num, ids.begin()); geo_recorder->Update(ids); - CommonSparseTable::push_sparse(keys, values, num); + CommonSparseTable::PushSparse(keys, values, num); return 0; } -int32_t SparseGeoTable::initialize_value() { +int32_t SparseGeoTable::InitializeValue() { auto common = _config.common(); shard_values_.reserve(task_pool_size_); @@ -82,7 +82,7 @@ int32_t SparseGeoTable::initialize_value() { auto pull_value = PullSparseValue(ids, fres, param_dim_); std::vector pulls; pulls.resize(bucket_feasigns * param_dim_); - pull_sparse(pulls.data(), pull_value); + PullSparse(pulls.data(), pull_value); } return 0; } diff --git a/paddle/fluid/distributed/ps/table/sparse_geo_table.h b/paddle/fluid/distributed/ps/table/sparse_geo_table.h index 1151c9f81ac978ce44e0d2dcd7bc388a43fa3f53..261338c2ba7b1800356131c4be7c3ff0dee7ca25 100644 --- a/paddle/fluid/distributed/ps/table/sparse_geo_table.h +++ b/paddle/fluid/distributed/ps/table/sparse_geo_table.h @@ -44,15 +44,15 @@ class SparseGeoTable : public CommonSparseTable { explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; } virtual ~SparseGeoTable() {} - virtual int32_t initialize_value(); + virtual int32_t InitializeValue(); - int32_t pull_geo_param(const uint32_t trainer_id, std::vector* values, - std::vector* keys); + int32_t PullGeoParam(const uint32_t trainer_id, std::vector* values, + std::vector* keys); - int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num) override; + int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num) override; - virtual int32_t initialize_recorder() { + virtual int32_t InitializeRecorder() { if (!geo_recorder) { auto trainers = _config.common().trainer_num(); geo_recorder = std::make_shared(trainers); diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc index 5bc58bc5a1108b5f342036d9bd72c96287458401..484fa9e1c6eea946a69654466b519c1ba1b881b5 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc @@ -20,7 +20,7 @@ DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file"); namespace paddle { namespace distributed { -int32_t SSDSparseTable::initialize() { +int32_t SSDSparseTable::Initialize() { _shards_task_pool.resize(task_pool_size_); for (int i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); @@ -53,9 +53,9 @@ int32_t SSDSparseTable::initialize() { offset += dim; } - initialize_value(); - initialize_optimizer(); - initialize_recorder(); + InitializeValue(); + InitializeOptimizer(); + InitializeRecorder(); _db = paddle::distributed::RocksDBHandler::GetInstance(); _db->initialize(FLAGS_rocksdb_path, task_pool_size_); return 0; @@ -66,18 +66,18 @@ int32_t SSDSparseTable::Pull(TableContext& context) { if (context.use_ptr) { char** pull_values = context.pull_context.ptr_values; const uint64_t* keys = context.pull_context.keys; - return pull_sparse_ptr(pull_values, keys, context.num); + return PullSparsePtr(pull_values, keys, context.num); } else { float* pull_values = context.pull_context.values; const PullSparseValue& pull_value = context.pull_context.pull_value; - return pull_sparse(pull_values, pull_value); + return PullSparse(pull_values, pull_value); } } int32_t SSDSparseTable::Push(TableContext& context) { return 0; } -int32_t SSDSparseTable::pull_sparse(float* pull_values, - const PullSparseValue& pull_value) { +int32_t SSDSparseTable::PullSparse(float* pull_values, + const PullSparseValue& pull_value) { auto shard_num = task_pool_size_; std::vector> tasks(shard_num); @@ -140,8 +140,8 @@ int32_t SSDSparseTable::pull_sparse(float* pull_values, return 0; } -int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values, - const uint64_t* keys, size_t num) { +int32_t SSDSparseTable::PullSparsePtr(char** pull_values, const uint64_t* keys, + size_t num) { auto shard_num = task_pool_size_; std::vector> tasks(shard_num); @@ -201,9 +201,9 @@ int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values, return 0; } -int32_t SSDSparseTable::shrink(const std::string& param) { return 0; } +int32_t SSDSparseTable::Shrink(const std::string& param) { return 0; } -int32_t SSDSparseTable::update_table() { +int32_t SSDSparseTable::UpdateTable() { int count = 0; int value_size = shard_values_[0]->value_length_; int db_size = 3 + value_size; @@ -299,7 +299,7 @@ int64_t SSDSparseTable::SaveValueToText(std::ostream* os, return save_num; } -int32_t SSDSparseTable::load(const std::string& path, +int32_t SSDSparseTable::Load(const std::string& path, const std::string& param) { rwlock_->WRLock(); VLOG(3) << "ssd sparse table load with " << path << " with meta " << param; diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h index 3a703d7d966d3e6026d13c0658f5979120cd2073..11a776bd9e8476750270daf2644698946f375bae 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h @@ -23,7 +23,7 @@ class SSDSparseTable : public CommonSparseTable { SSDSparseTable() {} virtual ~SSDSparseTable() {} - virtual int32_t initialize() override; + virtual int32_t Initialize() override; void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, const size_t shard_idx, const int64_t total); @@ -37,22 +37,22 @@ class SSDSparseTable : public CommonSparseTable { const int pserver_id, const int pserver_num, const int local_shard_num, std::vector>* blocks); - virtual int32_t load(const std::string& path, const std::string& param); + virtual int32_t Load(const std::string& path, const std::string& param); // exchange data - virtual int32_t update_table(); + virtual int32_t UpdateTable(); virtual int32_t Pull(TableContext& context); virtual int32_t Push(TableContext& context); - virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); - virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, - size_t num); + virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, + size_t num); - virtual int32_t flush() override { return 0; } - virtual int32_t shrink(const std::string& param) override; - virtual void clear() override {} + virtual int32_t Flush() override { return 0; } + virtual int32_t Shrink(const std::string& param) override; + virtual void Clear() override {} private: RocksDBHandler* _db; diff --git a/paddle/fluid/distributed/ps/table/table.cc b/paddle/fluid/distributed/ps/table/table.cc index 99790606f0b31b9edfb26b0bc6a03551b2bf0013..9f17a2006d232572287792445b8a6cc565f079cb 100644 --- a/paddle/fluid/distributed/ps/table/table.cc +++ b/paddle/fluid/distributed/ps/table/table.cc @@ -56,7 +56,7 @@ REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule); -int32_t TableManager::initialize() { +int32_t TableManager::Initialize() { static bool initialized = false; if (initialized) { return 0; @@ -65,10 +65,10 @@ int32_t TableManager::initialize() { return 0; } -int32_t Table::initialize(const TableParameter &config, +int32_t Table::Initialize(const TableParameter &config, const FsClientParameter &fs_config) { _config = config; - if (initialize_accessor() != 0) { + if (InitializeAccessor() != 0) { LOG(WARNING) << "Table accessor initialize failed"; return -1; } @@ -77,10 +77,10 @@ int32_t Table::initialize(const TableParameter &config, LOG(WARNING) << "Table fs_client initialize failed"; // return -1; } - return initialize(); + return Initialize(); } -int32_t Table::initialize_accessor() { +int32_t Table::InitializeAccessor() { if (!_config.has_accessor() || !_config.accessor().has_accessor_class()) { LOG(ERROR) << "missing accessor config in table, table_id:" << _config.table_id(); diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index bba34d89377a7d4050d0efa43c187bd8314fed39..c61efe769e2f8092fac24a280f8cf236d2aee5a4 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -60,101 +60,99 @@ class Table { public: Table() {} virtual ~Table() {} - virtual int32_t initialize(const TableParameter &config, + virtual int32_t Initialize(const TableParameter &config, const FsClientParameter &fs_config); virtual int32_t Pull(TableContext &context) = 0; virtual int32_t Push(TableContext &context) = 0; - virtual int32_t pull_dense(float *values, size_t num) = 0; - virtual int32_t push_dense(const float *values, size_t num) = 0; + virtual int32_t PullDense(float *values, size_t num) = 0; + virtual int32_t PushDense(const float *values, size_t num) = 0; // for push global_step - virtual int32_t push_dense(const int64_t *values, const int32_t trainer_id) { - return 0; - } - virtual int32_t push_dense_param(const float *values, size_t num) { + virtual int32_t PushDense(const int64_t *values, const int32_t trainer_id) { return 0; } + virtual int32_t PushDenseParam(const float *values, size_t num) { return 0; } - virtual int32_t pull_sparse_ptr(char **pull_values, const uint64_t *keys, - size_t num) { + virtual int32_t PullSparsePtr(char **pull_values, const uint64_t *keys, + size_t num) { VLOG(0) << "NOT IMPLEMENT"; return 0; } - virtual int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) = 0; - virtual int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) = 0; - virtual int32_t push_sparse(const uint64_t *keys, const float **values, - size_t num) { + virtual int32_t PullSparse(float *values, + const PullSparseValue &pull_value) = 0; + virtual int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) = 0; + virtual int32_t PushSparse(const uint64_t *keys, const float **values, + size_t num) { return 0; } - virtual int32_t push_sparse_param(const uint64_t *keys, const float *values, - size_t num) { + virtual int32_t PushSparseParam(const uint64_t *keys, const float *values, + size_t num) { return 0; } // only for sparse geo table - virtual int32_t pull_geo_param(const uint32_t trainer_id, - std::vector *values, - std::vector *keys) { + virtual int32_t PullGeoParam(const uint32_t trainer_id, + std::vector *values, + std::vector *keys) { return 0; } // only for barrier - virtual int32_t barrier(const uint32_t trainer_id, + virtual int32_t Barrier(const uint32_t trainer_id, const std::string barrier_type) { return 0; } // only for barrier table - virtual int32_t set_table_map( + virtual int32_t SetTableMap( std::unordered_map> *table_map) { return 0; } // only for tensor table - virtual int32_t set_program_env( + virtual int32_t SetProgramEnv( framework::Scope *scope, platform::Place place, const std::vector *sub_program) { return 0; } - virtual int32_t set_global_lr(float *lr) { + virtual int32_t SetGlobalLR(float *lr) { _global_lr = lr; return 0; } - virtual int32_t pour() { return 0; } + virtual int32_t Pour() { return 0; } - virtual void clear() = 0; - virtual int32_t flush() = 0; - virtual int32_t shrink(const std::string ¶m) = 0; + virtual void Clear() = 0; + virtual int32_t Flush() = 0; + virtual int32_t Shrink(const std::string ¶m) = 0; // 指定加载路径 - virtual int32_t load(const std::string &path, + virtual int32_t Load(const std::string &path, const std::string &converter) = 0; // 指定保存路径 - virtual int32_t save(const std::string &path, + virtual int32_t Save(const std::string &path, const std::string &converter) = 0; - virtual int32_t set_shard(size_t shard_idx, size_t shard_num) { + virtual int32_t SetShard(size_t shard_idx, size_t shard_num) { _shard_idx = shard_idx; _shard_num = shard_num; - return initialize_shard(); + return InitializeShard(); } - inline std::shared_ptr value_accesor() { + inline std::shared_ptr ValueAccesor() { return _value_accesor; } - virtual void *get_shard(size_t shard_idx) = 0; - virtual std::pair print_table_stat() { return {0, 0}; } + virtual void *GetShard(size_t shard_idx) = 0; + virtual std::pair PrintTableStat() { return {0, 0}; } protected: - virtual int32_t initialize() = 0; - virtual int32_t initialize_accessor(); - virtual int32_t initialize_shard() = 0; - virtual std::string table_dir(const std::string &model_dir) { + virtual int32_t Initialize() = 0; + virtual int32_t InitializeAccessor(); + virtual int32_t InitializeShard() = 0; + virtual std::string TableDir(const std::string &model_dir) { return paddle::string::format_string("%s/%03d/", model_dir.c_str(), _config.table_id()); } @@ -171,11 +169,11 @@ REGISTER_PSCORE_REGISTERER(Table); class TableManager { public: - static TableManager &instance() { + static TableManager &Instance() { static TableManager manager; return manager; } - int32_t initialize(); + int32_t Initialize(); private: TableManager() {} diff --git a/paddle/fluid/distributed/ps/table/tensor_table.h b/paddle/fluid/distributed/ps/table/tensor_table.h index e59314923cdbc90222e0e4c66fa76755417c9453..175aa194fb80f6317a84ba4f34105e172259293c 100644 --- a/paddle/fluid/distributed/ps/table/tensor_table.h +++ b/paddle/fluid/distributed/ps/table/tensor_table.h @@ -52,42 +52,42 @@ class TensorTable : public Table { virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t shrink(const std::string ¶m) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } - virtual int32_t initialize_shard() { return 0; } + virtual int32_t InitializeShard() { return 0; } - virtual int32_t flush() { return 0; } + virtual int32_t Flush() { return 0; } - virtual int32_t load(const std::string &path, const std::string ¶m) { + virtual int32_t Load(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t save(const std::string &path, const std::string ¶m) { + virtual int32_t Save(const std::string &path, const std::string ¶m) { return 0; } - virtual void clear() {} + virtual void Clear() {} - int32_t initialize() override { return 0; } + int32_t Initialize() override { return 0; } - int32_t push_dense(const int64_t *values, const int32_t trainer_id) override { + int32_t PushDense(const int64_t *values, const int32_t trainer_id) override { return 0; } - int32_t set_program_env( + int32_t SetProgramEnv( framework::Scope *scope, platform::Place place, const std::vector *sub_program) override { scope_ = scope; @@ -111,48 +111,48 @@ class DenseTensorTable : public TensorTable { DenseTensorTable() {} virtual ~DenseTensorTable() {} - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t shrink(const std::string ¶m) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } - virtual int32_t initialize_shard() { return 0; } + virtual int32_t InitializeShard() { return 0; } - virtual int32_t flush() { return 0; } + virtual int32_t Flush() { return 0; } - virtual void clear() {} + virtual void Clear() {} // Todo: Support program Load & Save - virtual int32_t load(const std::string &path, const std::string ¶m) { + virtual int32_t Load(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t save(const std::string &path, const std::string ¶m) { + virtual int32_t Save(const std::string &path, const std::string ¶m) { return 0; } // Todo: Support pull dense - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } /*----------------------------------------------------------------------*/ - int32_t initialize() override { return 0; } + int32_t Initialize() override { return 0; } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t push_dense(const int64_t *values, const int32_t trainer_id) { + int32_t PushDense(const int64_t *values, const int32_t trainer_id) { return 0; } protected: - virtual int32_t _run_program(const float *values, size_t num, - const uint32_t trainer_id) { + virtual int32_t _RunProgram(const float *values, size_t num, + const uint32_t trainer_id) { return 0; } @@ -167,36 +167,36 @@ class GlobalStepTable : public DenseTensorTable { GlobalStepTable() {} virtual ~GlobalStepTable() {} - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t shrink(const std::string ¶m) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } - virtual int32_t initialize_shard() { return 0; } + virtual int32_t InitializeShard() { return 0; } - virtual int32_t flush() { return 0; } + virtual int32_t Flush() { return 0; } - virtual void clear() {} + virtual void Clear() {} - virtual int32_t load(const std::string &path, const std::string ¶m) { + virtual int32_t Load(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t save(const std::string &path, const std::string ¶m) { + virtual int32_t Save(const std::string &path, const std::string ¶m) { return 0; } - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } /*----------------------------------------------------------------------*/ - int32_t initialize() override { + int32_t Initialize() override { auto _program_config = _config.tensor(); auto trainers_ = _config.common().trainer_num(); FLAGS_eager_delete_tensor_gb = -1; @@ -237,14 +237,14 @@ class GlobalStepTable : public DenseTensorTable { } } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t push_dense(const int64_t *values, const int32_t trainer_id) { - return _run_program(values, trainer_id); + int32_t PushDense(const int64_t *values, const int32_t trainer_id) { + return _RunProgram(values, trainer_id); } - int32_t set_table_map(std::unordered_map> - *table_map) override { + int32_t SetTableMap(std::unordered_map> + *table_map) override { auto *lr_var = scope_->FindVar(fetch_var_name_); auto *lr_tensor = lr_var->GetMutable(); auto *lr_value = lr_tensor->mutable_data(platform::CPUPlace()); @@ -255,14 +255,14 @@ class GlobalStepTable : public DenseTensorTable { if (table_id == _config.table_id()) { continue; } - iter->second->set_global_lr(lr_value); + iter->second->SetGlobalLR(lr_value); } return 0; } private: - virtual int32_t _run_program(const int64_t *values, - const uint32_t trainer_id) { + virtual int32_t _RunProgram(const int64_t *values, + const uint32_t trainer_id) { FLAGS_eager_delete_tensor_gb = -1; auto counter = decay_counters_.at(trainer_id); counter += int(values[0]); diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index c9093368c693e774657e4e1f2b688774df24ebd2..7bc50a868104a0d8d459a96b768e4f426630afe7 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -51,32 +51,6 @@ int32_t FleetWrapper::CopyTableByFeasign( return 0; } -void FleetWrapper::Stop() { StopServer(); } - -void FleetWrapper::Load(WrapperContext& context) { - auto table_id = context.table_id; - if (table_id >= 0 && context.meta != "") { - LoadSparseOnServer(context.path, context.meta, context.table_id); - return; - } - if (table_id < 0) { // laod all - LoadModel(context.path, context.mode); - } else { // load one table - LoadModelOneTable(table_id, context.path, context.mode); - } - return; -} - -void FleetWrapper::Save(WrapperContext& context) { - auto table_id = context.table_id; - if (table_id < 0) { - SaveModel(context.path, context.mode); - } else { - SaveModelOneTable(table_id, context.path, context.mode); - } - return; -} - void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, int max_retry) { @@ -90,7 +64,7 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path, uint32_t table_id) { VLOG(3) << "load sparse table " << table_id << " with " << path << " meta " << meta; - pserver_ptr_->_server_ptr->table(table_id)->load(path, meta); + pserver_ptr_->_server_ptr->GetTable(table_id)->Load(path, meta); } void FleetWrapper::InitServer( @@ -101,8 +75,8 @@ void FleetWrapper::InitServer( VLOG(3) << "Going to init server"; pserver_ptr_ = std::shared_ptr( new paddle::distributed::PSCore()); - pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), - index, trainers, server_sub_program); + pserver_ptr_->InitServer(dist_desc, &host_sign_list, host_sign_list.size(), + index, trainers, server_sub_program); is_initialized_ = true; } else { VLOG(3) << "Server can be initialized only once"; @@ -143,10 +117,10 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param); InitGFlag(ps_param.init_gflags()); int servers = host_sign_list.size(); - ps_env_.set_ps_servers(&host_sign_list, servers); + ps_env_.SetPsServers(&host_sign_list, servers); worker_ptr_ = std::shared_ptr( - paddle::distributed::PSClientFactory::create(ps_param)); - worker_ptr_->configure(ps_param, dense_pull_regions, ps_env_, index); + paddle::distributed::PSClientFactory::Create(ps_param)); + worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index); } } else { VLOG(3) << "Client can be initialized only once"; @@ -155,13 +129,13 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, void FleetWrapper::StopServer() { VLOG(3) << "Going to stop server"; - auto status = worker_ptr_->stop_server(); + auto status = worker_ptr_->StopServer(); status.wait(); } void FleetWrapper::FinalizeWorker() { VLOG(3) << "Going to finalize worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); } void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { @@ -172,13 +146,13 @@ void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { VLOG(3) << "Going to run server with ip " << ip << " port " << port; - auto ret = pserver_ptr_->run_server(ip, port); + auto ret = pserver_ptr_->RunServer(ip, port); return ret; } std::vector FleetWrapper::GetClientsInfo() { VLOG(3) << "Going to get client info"; - std::vector res = ps_env_.get_client_info(); + std::vector res = ps_env_.GetClientInfo(); for (auto rr : res) { VLOG(2) << "FleetWrapper::GetClientInfo " << rr; } @@ -187,14 +161,14 @@ std::vector FleetWrapper::GetClientsInfo() { int FleetWrapper::SetClients(std::vector& host_sign_list) { int node = host_sign_list.size(); - return ps_env_.set_ps_clients(host_sign_list.data(), node); + return ps_env_.SetPsClients(host_sign_list.data(), node); } void FleetWrapper::CreateClient2ClientConnection() { VLOG(1) << "Going to create client2client connection"; - worker_ptr_->create_client2client_connection( - client2client_request_timeout_ms_, client2client_connect_timeout_ms_, - client2client_max_retry_); + worker_ptr_->CreateClient2ClientConnection(client2client_request_timeout_ms_, + client2client_connect_timeout_ms_, + client2client_max_retry_); } std::future FleetWrapper::PullSparseVarsAsync( @@ -230,9 +204,9 @@ std::future FleetWrapper::PullSparseVarsAsync( } bool training = true; - return pserver_ptr_->_worker_ptr->pull_sparse(pull_result_ptr.data(), - table_id, fea_keys->data(), - fea_keys->size(), training); + return pserver_ptr_->_worker_ptr->PullSparse(pull_result_ptr.data(), table_id, + fea_keys->data(), + fea_keys->size(), training); } void FleetWrapper::PullSparseVarsSync( @@ -279,7 +253,7 @@ void FleetWrapper::PullSparseVarsSync( pull_result_ptr.push_back(t.data()); } bool training = true; - auto status = pserver_ptr_->_worker_ptr->pull_sparse( + auto status = pserver_ptr_->_worker_ptr->PullSparse( pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size(), training); pull_sparse_status.push_back(std::move(status)); @@ -337,21 +311,10 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, pull_result_ptr.push_back(output_data + output_len); } } - // ps client pull sparse - // construct client request context - RequestContext req_context; - req_context.value_type = Sparse; - req_context.training_mode = Async; - req_context.table = table_id; - req_context.sparse_values = pull_result_ptr.data(); - req_context.keys = fea_keys.data(); - req_context.num = fea_keys.size(); - req_context.is_training = is_training; - auto status = worker_ptr_->Pull(req_context); - // auto status = - // worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id, - // fea_keys.data(), fea_keys.size(), - // is_training); + + auto status = + worker_ptr_->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(), + fea_keys.size(), is_training); status.wait(); auto ret = status.get(); if (ret != 0) { @@ -364,7 +327,7 @@ void FleetWrapper::PullDenseVarsAsync( const Scope& scope, const uint64_t tid, const std::vector& var_names, std::vector>* pull_dense_status, bool in_cpu) { - auto& regions = _regions[tid]; + auto& regions = regions_[tid]; regions.clear(); regions.resize(var_names.size()); for (auto i = 0u; i < var_names.size(); ++i) { @@ -378,21 +341,15 @@ void FleetWrapper::PullDenseVarsAsync( paddle::distributed::Region reg(w, tensor->numel()); regions[i] = std::move(reg); } - RequestContext req_context; - req_context.value_type = Dense; - req_context.training_mode = Async; - req_context.table = tid; - req_context.dense_values = regions.data(); - req_context.num = regions.size(); - auto status = worker_ptr_->Pull(req_context); - // auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid); + + auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid); pull_dense_status->push_back(std::move(status)); } void FleetWrapper::PullDenseVarsSync( const Scope& scope, const uint64_t tid, const std::vector& var_names) { - auto& regions = _regions[tid]; + auto& regions = regions_[tid]; regions.clear(); regions.reserve(var_names.size()); for (auto& t : var_names) { @@ -404,7 +361,7 @@ void FleetWrapper::PullDenseVarsSync( regions.emplace_back(std::move(reg)); } } - auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid); + auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid); status.wait(); } @@ -424,7 +381,7 @@ void FleetWrapper::PushDenseParamSync( } } auto push_status = - worker_ptr_->push_dense_param(regions.data(), regions.size(), table_id); + worker_ptr_->PushDenseParam(regions.data(), regions.size(), table_id); push_status.wait(); auto status = push_status.get(); CHECK(status == 0) << "push dense param failed, status[" << status << "]"; @@ -470,15 +427,8 @@ void FleetWrapper::PushDenseVarsAsync( << g[tensor->numel() - 1]; } - RequestContext req_context; - req_context.value_type = Dense; - req_context.training_mode = Async; - req_context.table = table_id; - req_context.push_context.push_dense_values = regions.data(); - req_context.num = regions.size(); - // auto push_status = - // worker_ptr_->push_dense(regions.data(), regions.size(), table_id); - auto push_status = worker_ptr_->Push(req_context); + auto push_status = + worker_ptr_->PushDense(regions.data(), regions.size(), table_id); } void FleetWrapper::PushSparseVarsAsync( @@ -650,23 +600,13 @@ void FleetWrapper::PushSparseFromTensorAsync( push_g_vec[i] = push_values.at(i).data(); } - // ps client push sparse - // construct request context - RequestContext req_context; - req_context.value_type = Sparse; - req_context.training_mode = Async; - req_context.table = table_id; - req_context.push_context.push_values = (const float**)push_g_vec.data(); - req_context.push_context.keys = push_keys.data(); - req_context.num = push_keys.size(); - auto status = worker_ptr_->Push(req_context); - // auto status = worker_ptr_->push_sparse(table_id, push_keys.data(), - // (const float**)push_g_vec.data(), - // push_keys.size()); + auto status = worker_ptr_->PushSparse(table_id, push_keys.data(), + (const float**)push_g_vec.data(), + push_keys.size()); } void FleetWrapper::LoadModel(const std::string& path, const int mode) { - auto ret = worker_ptr_->load(path, std::to_string(mode)); + auto ret = worker_ptr_->Load(path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model from path:" << path << " failed"; @@ -675,7 +615,7 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModelOneTable(const uint64_t table_id, const std::string& path, const int mode) { - auto ret = worker_ptr_->load(table_id, path, std::to_string(mode)); + auto ret = worker_ptr_->Load(table_id, path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model of table id: " << table_id @@ -684,7 +624,7 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id, } void FleetWrapper::SaveModel(const std::string& path, const int mode) { - auto ret = worker_ptr_->save(path, std::to_string(mode)); + auto ret = worker_ptr_->Save(path, std::to_string(mode)); ret.wait(); int32_t feasign_cnt = ret.get(); if (feasign_cnt == -1) { @@ -694,7 +634,7 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { void FleetWrapper::SaveModelOneTable(const uint64_t table_id, const std::string& path, const int mode) { - auto ret = worker_ptr_->save(table_id, path, std::to_string(mode)); + auto ret = worker_ptr_->Save(table_id, path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "save model of table id: " << table_id @@ -704,7 +644,7 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id, void FleetWrapper::RecvAndSaveTable(const uint64_t table_id, const std::string& path) { - auto ret = worker_ptr_->recv_and_save_table(table_id, path); + auto ret = worker_ptr_->RecvAndSaveTable(table_id, path); if (ret != 0) { LOG(ERROR) << "save model of table id: " << table_id << ", to path: " << path << " failed"; @@ -712,7 +652,7 @@ void FleetWrapper::RecvAndSaveTable(const uint64_t table_id, } void FleetWrapper::PrintTableStat(const uint64_t table_id) { - auto ret = worker_ptr_->print_table_stat(table_id); + auto ret = worker_ptr_->PrintTableStat(table_id); ret.wait(); int32_t err_code = ret.get(); if (err_code == -1) { @@ -721,7 +661,7 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) { } void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) { - auto ret = worker_ptr_->shrink(table_id, std::to_string(threshold)); + auto ret = worker_ptr_->Shrink(table_id, std::to_string(threshold)); ret.wait(); int32_t err_code = ret.get(); if (err_code == -1) { @@ -730,12 +670,12 @@ void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) { } void FleetWrapper::ClearModel() { - auto ret = pserver_ptr_->_worker_ptr->clear(); + auto ret = pserver_ptr_->_worker_ptr->Clear(); ret.wait(); } void FleetWrapper::ClearOneTable(const uint64_t table_id) { - auto ret = pserver_ptr_->_worker_ptr->clear(table_id); + auto ret = pserver_ptr_->_worker_ptr->Clear(table_id); ret.wait(); } @@ -774,7 +714,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, regions.emplace_back(std::move(reg)); } } - auto push_status = pserver_ptr_->_worker_ptr->push_dense_param( + auto push_status = pserver_ptr_->_worker_ptr->PushDenseParam( regions.data(), regions.size(), table_id); push_status.wait(); auto status = push_status.get(); @@ -791,7 +731,7 @@ void FleetWrapper::ClientFlush() { VLOG(0) << "worker_ptr null, do nothing"; return; } - auto ret = worker_ptr_->flush(); + auto ret = worker_ptr_->Flush(); ret.wait(); int32_t err_code = ret.get(); if (err_code == -1) { @@ -805,13 +745,13 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, VLOG(0) << "FleetWrapper::Client is null"; return -1; } else { - return worker_ptr_->registe_client2client_msg_handler(msg_type, handler); + return worker_ptr_->RegisteClient2ClientMsgHandler(msg_type, handler); } } std::future FleetWrapper::SendClientToClientMsg( int msg_type, int to_client_id, const std::string& msg) { - return worker_ptr_->send_client2client_msg(msg_type, to_client_id, msg); + return worker_ptr_->SendClient2ClientMsg(msg_type, to_client_id, msg); } std::default_random_engine& FleetWrapper::LocalRandomEngine() { diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.h b/paddle/fluid/distributed/ps/wrapper/fleet.h index 13b7ea7609ee6a90df67756d921409359b348ade..e6ec09a12637d9d8d6da18ce45d0fd70dd45db7c 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.h +++ b/paddle/fluid/distributed/ps/wrapper/fleet.h @@ -25,7 +25,6 @@ limitations under the License. */ #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" #include "paddle/fluid/distributed/ps/service/ps_service/service.h" -#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h" #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/shell.h" @@ -55,7 +54,7 @@ using framework::Variable; using RpcCtxMap = std::unordered_map; -class FleetWrapper : public PSWrapper { +class FleetWrapper { public: virtual ~FleetWrapper() {} FleetWrapper() { @@ -69,7 +68,6 @@ class FleetWrapper : public PSWrapper { // pserver request max retry client2client_max_retry_ = 3; } - virtual int32_t Initialize(InitContext& context) { return 0; } // TODO(zhaocaibei123: later) int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id); @@ -81,12 +79,6 @@ class FleetWrapper : public PSWrapper { typedef std::function HeterCallBackFunc; int RegisterHeterCallback(HeterCallBackFunc handler); - virtual void Stop() override; - - virtual void Load(WrapperContext& context) override; - - virtual void Save(WrapperContext& context) override; - // set client to client communication config void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, int max_retry); @@ -278,7 +270,7 @@ class FleetWrapper : public PSWrapper { protected: static bool is_initialized_; - std::map> _regions; + std::map> regions_; bool scale_sparse_gradient_with_batch_size_; int32_t sleep_seconds_before_fail_exit_; int client2client_request_timeout_ms_; diff --git a/paddle/fluid/distributed/test/barrier_table_test.cc b/paddle/fluid/distributed/test/barrier_table_test.cc index 0715f777fa5cb286ff393190a3d94dd86e74518a..c4c5b229928049d14f494ce941aac4b2ed775415 100644 --- a/paddle/fluid/distributed/test/barrier_table_test.cc +++ b/paddle/fluid/distributed/test/barrier_table_test.cc @@ -39,19 +39,19 @@ TEST(BarrierTable, Barrier) { common_config->set_trainer_num(trainers); common_config->set_sync(sync); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); std::unordered_map> maps = std::unordered_map>(); - table->set_table_map(&maps); + table->SetTableMap(&maps); std::shared_ptr<::ThreadPool> pool_ = std::make_shared<::ThreadPool>(trainers); std::vector> task_status; for (auto x = 0; x < trainers; x++) { - auto task = [table, x] { table->barrier(x, 0); }; + auto task = [table, x] { table->Barrier(x, 0); }; task_status.push_back(pool_->enqueue(std::move(task))); } diff --git a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc index 19ff50ec2a43ba5d4d69f34c2a368b3b9720cae6..d5e196ff3219f1dcc5d26a41b48a0b8446437d37 100644 --- a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc +++ b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc @@ -155,16 +155,16 @@ void RunServer() { auto _ps_env = paddle::distributed::PaddlePSEnvironment(); LOG(INFO) << "RUN set_ps_servers"; - _ps_env.set_ps_servers(&host_sign_list_, 1); + _ps_env.SetPsServers(&host_sign_list_, 1); pserver_ptr_ = std::shared_ptr( - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); LOG(INFO) << "RUN configure"; std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); + pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec); LOG(INFO) << "RUN start"; - pserver_ptr_->start(ip_, port_); + pserver_ptr_->Start(ip_, port_); LOG(INFO) << "End start"; } @@ -175,19 +175,19 @@ void RunClient(std::map>& auto servers_ = host_sign_list_.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); LOG(INFO) << "Run set_ps_servers"; - _ps_env.set_ps_servers(&host_sign_list_, servers_); + _ps_env.SetPsServers(&host_sign_list_, servers_); LOG(INFO) << "Run Create PSClient"; worker_ptr_ = std::shared_ptr( - paddle::distributed::PSClientFactory::create(worker_proto)); + paddle::distributed::PSClientFactory::Create(worker_proto)); LOG(INFO) << "Run configure"; - worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0); } void RunBrpcPushDense() { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); - host_sign_list_.push_back(ph_host.serialize_to_string()); + host_sign_list_.push_back(ph_host.SerializeToString()); // Srart Server std::thread server_thread(RunServer); @@ -218,7 +218,7 @@ void RunBrpcPushDense() { paddle::distributed::Region temp_reg(temp, tensor->numel()); temp_region.emplace_back(std::move(temp_reg)); auto pull_status = - worker_ptr_->pull_dense(temp_region.data(), temp_region.size(), 0); + worker_ptr_->PullDense(temp_region.data(), temp_region.size(), 0); pull_status.wait(); for (size_t idx = 0; idx < tensor->numel(); ++idx) { @@ -229,10 +229,10 @@ void RunBrpcPushDense() { LOG(INFO) << "Run push_dense_param"; auto push_status = - worker_ptr_->push_dense_param(regions.data(), regions.size(), 0); + worker_ptr_->PushDenseParam(regions.data(), regions.size(), 0); push_status.wait(); - pull_status = worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + pull_status = worker_ptr_->PullDense(regions.data(), regions.size(), 0); pull_status.wait(); for (size_t idx = 0; idx < tensor->numel(); ++idx) { @@ -257,11 +257,11 @@ void RunBrpcPushDense() { LOG(INFO) << "Run pull_dense_grad"; auto push_grad_status = - worker_ptr_->push_dense_raw_gradient(0, temp, tensor->numel(), closure); + worker_ptr_->PushDenseRawGradient(0, temp, tensor->numel(), closure); push_grad_status.wait(); auto pull_update_status = - worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + worker_ptr_->PullDense(regions.data(), regions.size(), 0); pull_update_status.wait(); for (size_t idx = 0; idx < tensor->numel(); ++idx) { @@ -269,9 +269,9 @@ void RunBrpcPushDense() { } LOG(INFO) << "Run stop_server"; - worker_ptr_->stop_server(); + worker_ptr_->StopServer(); LOG(INFO) << "Run finalize_worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); server_thread.join(); } diff --git a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc index 633f3b2f3c55006ddf0501d4288d149c535629e0..f7d287af8447296dbdf676c4fbf8a676a9f417e9 100644 --- a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc +++ b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc @@ -156,14 +156,14 @@ void RunServer() { ::paddle::distributed::PSParameter server_proto = GetServerProto(); auto _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, 1); + _ps_env.SetPsServers(&host_sign_list_, 1); pserver_ptr_ = std::shared_ptr( - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); - pserver_ptr_->start(ip_, port_); + pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec); + pserver_ptr_->Start(ip_, port_); } void RunClient(std::map>& @@ -172,17 +172,17 @@ void RunClient(std::map>& paddle::distributed::PaddlePSEnvironment _ps_env; auto servers_ = host_sign_list_.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, servers_); + _ps_env.SetPsServers(&host_sign_list_, servers_); worker_ptr_ = std::shared_ptr( - paddle::distributed::PSClientFactory::create(worker_proto)); - worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + paddle::distributed::PSClientFactory::Create(worker_proto)); + worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0); } void RunBrpcPushSparse() { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); - host_sign_list_.push_back(ph_host.serialize_to_string()); + host_sign_list_.push_back(ph_host.SerializeToString()); // Srart Server std::thread server_thread(RunServer); @@ -214,7 +214,7 @@ void RunBrpcPushSparse() { /*-----------------------Test Server Init----------------------------------*/ LOG(INFO) << "Run pull_sparse_param"; - auto pull_status = worker_ptr_->pull_sparse( + auto pull_status = worker_ptr_->PullSparse( fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_status.wait(); for (size_t idx = 0; idx < tensor->numel(); ++idx) { @@ -237,12 +237,12 @@ void RunBrpcPushSparse() { } closure->set_promise_value(ret); }); - auto push_status = worker_ptr_->push_sparse_param( + auto push_status = worker_ptr_->PushSparseParam( 0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(), closure_push_param); push_status.wait(); - auto pull_param_status = worker_ptr_->pull_sparse( + auto pull_param_status = worker_ptr_->PullSparse( fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_param_status.wait(); @@ -271,12 +271,12 @@ void RunBrpcPushSparse() { for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { push_g_vec.push_back(tensor->data() + i * 10); } - auto push_grad_status = worker_ptr_->push_sparse_raw_gradient( + auto push_grad_status = worker_ptr_->PushSparseRawGradient( 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), closure_push_grad); push_grad_status.wait(); - auto pull_update_status = worker_ptr_->pull_sparse( + auto pull_update_status = worker_ptr_->PullSparse( fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_update_status.wait(); @@ -285,9 +285,9 @@ void RunBrpcPushSparse() { } LOG(INFO) << "Run stop_server"; - worker_ptr_->stop_server(); + worker_ptr_->StopServer(); LOG(INFO) << "Run finalize_worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); server_thread.join(); } diff --git a/paddle/fluid/distributed/test/dense_table_test.cc b/paddle/fluid/distributed/test/dense_table_test.cc index c9a038e000e149f354db2bab72b48c04a721a5f6..49346c2898fc6ca32e943454036d38ffba3be1de 100644 --- a/paddle/fluid/distributed/test/dense_table_test.cc +++ b/paddle/fluid/distributed/test/dense_table_test.cc @@ -63,13 +63,13 @@ TEST(CommonDenseTable, Adam) { common_config->add_params("LearningRate"); common_config->add_dims(1); common_config->add_initializers("fill_constant&5e-6"); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, 0); // pull parameters for create and check std::vector init_values; init_values.resize(fea_dim); - table->pull_dense(init_values.data(), fea_dim); + table->PullDense(init_values.data(), fea_dim); // push gradient std::vector> trainer_gradient_values; @@ -85,12 +85,12 @@ TEST(CommonDenseTable, Adam) { // for adam for (int i = 0; i < trainers; i++) { auto &push_values = trainer_gradient_values[i]; - table->push_dense(push_values.data(), push_values.size()); + table->PushDense(push_values.data(), push_values.size()); } std::vector pull_values; pull_values.resize(fea_dim); - table->pull_dense(pull_values.data(), fea_dim); + table->PullDense(pull_values.data(), fea_dim); float mom_rate = 0.99; float decay_rate = 0.9999; @@ -118,6 +118,7 @@ TEST(CommonDenseTable, Adam) { } } for (int j = 0; j < fea_dim; j++) { + VLOG(0) << param[j] << " " << pull_values[j]; ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5); } } @@ -143,13 +144,13 @@ TEST(CommonDenseTable, SGD) { common_config->add_params("LearningRate"); common_config->add_dims(1); common_config->add_initializers("fill_constant&1.0"); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, 0); // pull parameters for create and check std::vector init_values; init_values.resize(fea_dim); - table->pull_dense(init_values.data(), fea_dim); + table->PullDense(init_values.data(), fea_dim); std::vector total_gradients; total_gradients.resize(fea_dim); @@ -172,7 +173,7 @@ TEST(CommonDenseTable, SGD) { for (int i = 0; i < trainers; i++) { auto &push_values = trainer_gradient_values[i]; auto task = [table, &push_values] { - table->push_dense(push_values.data(), push_values.size()); + table->PushDense(push_values.data(), push_values.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -182,7 +183,7 @@ TEST(CommonDenseTable, SGD) { std::vector pull_values; pull_values.resize(fea_dim); - table->pull_dense(pull_values.data(), fea_dim); + table->PullDense(pull_values.data(), fea_dim); for (int j = 0; j < fea_dim; j++) { auto update_val = init_values[j] - 1.0 * total_gradients[j]; ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5); diff --git a/paddle/fluid/distributed/test/graph_node_split_test.cc b/paddle/fluid/distributed/test/graph_node_split_test.cc index a2f495de3c953a418f6e9c57a0535264eb401e65..ce4f38f6cec9f5bfdf8ddd45ec5414d9173b70bc 100644 --- a/paddle/fluid/distributed/test/graph_node_split_test.cc +++ b/paddle/fluid/distributed/test/graph_node_split_test.cc @@ -166,16 +166,16 @@ void RunServer() { ::paddle::distributed::PSParameter server_proto = GetServerProto(); auto _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, 2); // test + _ps_env.SetPsServers(&host_sign_list_, 2); // test pserver_ptr_ = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); + pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec); LOG(INFO) << "first server, run start(ip,port)"; - pserver_ptr_->start(ip_, port_); + pserver_ptr_->Start(ip_, port_); pserver_ptr_->build_peer2peer_connection(0); LOG(INFO) << "init first server Done"; } @@ -185,15 +185,15 @@ void RunServer2() { ::paddle::distributed::PSParameter server_proto2 = GetServerProto(); auto _ps_env2 = paddle::distributed::PaddlePSEnvironment(); - _ps_env2.set_ps_servers(&host_sign_list_, 2); // test + _ps_env2.SetPsServers(&host_sign_list_, 2); // test pserver_ptr2 = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto2)); + paddle::distributed::PSServerFactory::Create(server_proto2)); std::vector empty_vec2; framework::ProgramDesc empty_prog2; empty_vec2.push_back(empty_prog2); - pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); - pserver_ptr2->start(ip2, port2); + pserver_ptr2->Configure(server_proto2, _ps_env2, 1, empty_vec2); + pserver_ptr2->Start(ip2, port2); pserver_ptr2->build_peer2peer_connection(1); } @@ -204,11 +204,11 @@ void RunClient( paddle::distributed::PaddlePSEnvironment _ps_env; auto servers_ = host_sign_list_.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, servers_); + _ps_env.SetPsServers(&host_sign_list_, servers_); worker_ptr_ = std::shared_ptr( (paddle::distributed::GraphBrpcClient*) - paddle::distributed::PSClientFactory::create(worker_proto)); - worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + paddle::distributed::PSClientFactory::Create(worker_proto)); + worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0); worker_ptr_->set_shard_num(127); worker_ptr_->set_local_channel(index); worker_ptr_->set_local_graph_service( @@ -222,11 +222,11 @@ void RunGraphSplit() { prepare_file(node_file_name, nodes); prepare_file(graph_split_file_name, graph_split); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); - host_sign_list_.push_back(ph_host.serialize_to_string()); + host_sign_list_.push_back(ph_host.SerializeToString()); // test-start auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1); - host_sign_list_.push_back(ph_host2.serialize_to_string()); + host_sign_list_.push_back(ph_host2.SerializeToString()); // test-end // Srart Server std::thread* server_thread = new std::thread(RunServer); @@ -247,7 +247,7 @@ void RunGraphSplit() { 0, std::string(graph_split_file_name)); pull_status.wait(); pull_status = - worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); + worker_ptr_->Load(0, std::string(edge_file_name), std::string("e>")); srand(time(0)); pull_status.wait(); std::vector> _vs; @@ -266,9 +266,9 @@ void RunGraphSplit() { std::remove(node_file_name); std::remove(graph_split_file_name); LOG(INFO) << "Run stop_server"; - worker_ptr_->stop_server(); + worker_ptr_->StopServer(); LOG(INFO) << "Run finalize_worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); } TEST(RunGraphSplit, Run) { RunGraphSplit(); } diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index e55d39cd4834d425025a8084eb88982ef543a6f1..b2c741df7a5ddd6ac668d22fd8a160455b3222fd 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -348,16 +348,16 @@ void RunServer() { ::paddle::distributed::PSParameter server_proto = GetServerProto(); auto _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, 2); // test + _ps_env.SetPsServers(&host_sign_list_, 2); // test pserver_ptr_ = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); + pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec); LOG(INFO) << "first server, run start(ip,port)"; - pserver_ptr_->start(ip_, port_); + pserver_ptr_->Start(ip_, port_); pserver_ptr_->build_peer2peer_connection(0); LOG(INFO) << "init first server Done"; } @@ -367,15 +367,15 @@ void RunServer2() { ::paddle::distributed::PSParameter server_proto2 = GetServerProto(); auto _ps_env2 = paddle::distributed::PaddlePSEnvironment(); - _ps_env2.set_ps_servers(&host_sign_list_, 2); // test + _ps_env2.SetPsServers(&host_sign_list_, 2); // test pserver_ptr2 = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto2)); + paddle::distributed::PSServerFactory::Create(server_proto2)); std::vector empty_vec2; framework::ProgramDesc empty_prog2; empty_vec2.push_back(empty_prog2); - pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); - pserver_ptr2->start(ip2, port2); + pserver_ptr2->Configure(server_proto2, _ps_env2, 1, empty_vec2); + pserver_ptr2->Start(ip2, port2); pserver_ptr2->build_peer2peer_connection(1); } @@ -386,11 +386,11 @@ void RunClient( paddle::distributed::PaddlePSEnvironment _ps_env; auto servers_ = host_sign_list_.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, servers_); + _ps_env.SetPsServers(&host_sign_list_, servers_); worker_ptr_ = std::shared_ptr( (paddle::distributed::GraphBrpcClient*) - paddle::distributed::PSClientFactory::create(worker_proto)); - worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + paddle::distributed::PSClientFactory::Create(worker_proto)); + worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0); worker_ptr_->set_shard_num(127); worker_ptr_->set_local_channel(index); worker_ptr_->set_local_graph_service( @@ -404,11 +404,11 @@ void RunBrpcPushSparse() { prepare_file(edge_file_name, 1); prepare_file(node_file_name, 0); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); - host_sign_list_.push_back(ph_host.serialize_to_string()); + host_sign_list_.push_back(ph_host.SerializeToString()); // test-start auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1); - host_sign_list_.push_back(ph_host2.serialize_to_string()); + host_sign_list_.push_back(ph_host2.SerializeToString()); // test-end // Srart Server std::thread* server_thread = new std::thread(RunServer); @@ -424,7 +424,7 @@ void RunBrpcPushSparse() { /*-----------------------Test Server Init----------------------------------*/ auto pull_status = - worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); + worker_ptr_->Load(0, std::string(edge_file_name), std::string("e>")); srand(time(0)); pull_status.wait(); std::vector> _vs; @@ -438,7 +438,7 @@ void RunBrpcPushSparse() { pull_status.wait(); ASSERT_EQ(0, _vs[0].size()); paddle::distributed::GraphTable* g = - (paddle::distributed::GraphTable*)pserver_ptr_->table(0); + (paddle::distributed::GraphTable*)pserver_ptr_->GetTable(0); size_t ttl = 6; g->make_neighbor_sample_cache(4, ttl); int round = 5; @@ -622,15 +622,15 @@ void RunBrpcPushSparse() { std::remove(node_file_name); testAddNode(worker_ptr_); LOG(INFO) << "Run stop_server"; - worker_ptr_->stop_server(); + worker_ptr_->StopServer(); LOG(INFO) << "Run finalize_worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); testFeatureNodeSerializeInt(); testFeatureNodeSerializeInt64(); testFeatureNodeSerializeFloat32(); testFeatureNodeSerializeFloat64(); testGraphToBuffer(); - client1.stop_server(); + client1.StopServer(); } void testCache() { @@ -700,4 +700,4 @@ void testGraphToBuffer() { VLOG(0) << s1.get_feature(0); } -TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } \ No newline at end of file +TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } diff --git a/paddle/fluid/distributed/test/memory_geo_table_test.cc b/paddle/fluid/distributed/test/memory_geo_table_test.cc index fb48b38c76a28d67a914493c06b1865dffa988e5..965f67992d0008620e96c64deb4c54401ce75b4b 100644 --- a/paddle/fluid/distributed/test/memory_geo_table_test.cc +++ b/paddle/fluid/distributed/test/memory_geo_table_test.cc @@ -48,7 +48,7 @@ TEST(MemorySparseGeoTable, SSUM) { common_config->add_dims(emb_dim); common_config->add_initializers("fill_constant&1.0"); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, 0); // test push_sparse_param, and create params @@ -58,12 +58,12 @@ TEST(MemorySparseGeoTable, SSUM) { for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { init_values.push_back(0.0); } - table->push_sparse_param(init_keys.data(), init_values.data(), - init_keys.size()); + table->PushSparseParam(init_keys.data(), init_values.data(), + init_keys.size()); std::vector pull_values(init_values.size()); auto value = PullSparseValue(init_keys, init_fres, emb_dim); - table->pull_sparse(pull_values.data(), value); + table->PullSparse(pull_values.data(), value); for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5); @@ -93,8 +93,7 @@ TEST(MemorySparseGeoTable, SSUM) { auto &push_keys = trainer_keys[i]; auto &push_values = trainer_values[i]; auto task = [table, &push_keys, &push_values] { - table->push_sparse(push_keys.data(), push_values.data(), - push_keys.size()); + table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -107,7 +106,7 @@ TEST(MemorySparseGeoTable, SSUM) { geo_pull_ids.resize(trainers); geo_pull_values.resize(trainers); for (int i = 0; i < trainers; i++) { - table->pull_geo_param(i, &geo_pull_values[i], &geo_pull_ids[i]); + table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]); ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim); for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) { auto id = geo_pull_ids[i][j]; diff --git a/paddle/fluid/distributed/test/memory_sparse_table_test.cc b/paddle/fluid/distributed/test/memory_sparse_table_test.cc index aec02e8aec55872b734932b27994289df68de416..73fa7272280b2ce82d57c114de7027426500aeab 100644 --- a/paddle/fluid/distributed/test/memory_sparse_table_test.cc +++ b/paddle/fluid/distributed/test/memory_sparse_table_test.cc @@ -36,7 +36,7 @@ TEST(MemorySparseTable, SGD) { table_config.set_shard_num(10); FsClientParameter fs_config; Table *table = new MemorySparseTable(); - table->set_shard(0, 1); + table->SetShard(0, 1); TableAccessorParameter *accessor_config = table_config.mutable_accessor(); accessor_config->set_accessor_class("CtrCommonAccessor"); @@ -66,7 +66,7 @@ TEST(MemorySparseTable, SGD) { naive_param->add_weight_bounds(-10.0); naive_param->add_weight_bounds(10.0); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, 0); // pull parameters for create and check @@ -76,7 +76,7 @@ TEST(MemorySparseTable, SGD) { std::vector init_values; init_values.resize(init_keys.size() * (emb_dim + 3)); auto value = PullSparseValue(init_keys, init_fres, emb_dim); - table->pull_sparse(init_values.data(), value); + table->PullSparse(init_values.data(), value); // for check std::vector total_gradients; @@ -109,8 +109,7 @@ TEST(MemorySparseTable, SGD) { auto &push_keys = trainer_keys[i]; auto &push_values = trainer_gradient_values[i]; auto task = [table, &push_keys, &push_values] { - table->push_sparse(push_keys.data(), push_values.data(), - push_keys.size()); + table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -120,7 +119,7 @@ TEST(MemorySparseTable, SGD) { std::vector pull_values; pull_values.resize(init_keys.size() * (emb_dim + 3)); - table->pull_sparse(pull_values.data(), value); + table->PullSparse(pull_values.data(), value); for (size_t i = 0; i < init_keys.size(); ++i) { for (size_t j = 2; j < emb_dim + 3; ++j) { @@ -133,7 +132,7 @@ TEST(MemorySparseTable, SGD) { } MemorySparseTable *ctr_table = dynamic_cast(table); - ctr_table->save_local_fs("./work/table.save", "0", "test"); + ctr_table->SaveLocalFS("./work/table.save", "0", "test"); } } // namespace distributed diff --git a/paddle/fluid/distributed/test/table_test.cc b/paddle/fluid/distributed/test/table_test.cc index 6a29781158b838378468b1789b9eed0408c3435d..8690aee39f69c5dc2c05d74d2dc002de07d3894f 100644 --- a/paddle/fluid/distributed/test/table_test.cc +++ b/paddle/fluid/distributed/test/table_test.cc @@ -26,7 +26,7 @@ TEST(Table, Initialize) { FsClientParameter fs_config; // case 1. no accessor Table *table = new SparseGeoTable(); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, -1); } } // namespace distributed diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 72f998a772764b83eb81e9809b2ec9d48297f366..75f5c24af5a9961bafbc7296299c55bd46bce3ef 100755 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -343,7 +343,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { #ifdef PADDLE_WITH_PSCORE int32_t cnt = 0; while (true) { - auto tt = fleet_ptr->worker_ptr_->pull_sparse_ptr( + auto tt = fleet_ptr->worker_ptr_->PullSparsePtr( reinterpret_cast(local_ptr[i].data()), this->table_id_, local_keys[i].data(), key_size); bool flag = true; diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 83926336cbec8306b65d8f37814bfc97a3a71b72..61cd7ad01696e1a34891b490f6bddd4713384cd0 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -276,7 +276,7 @@ void MultiTrainer::Finalize() { if (communicator == nullptr) { VLOG(0) << "MultiTrainer::Finalize communicator is null!"; } else { - communicator->_worker_ptr->flush(); + communicator->_worker_ptr->Flush(); VLOG(1) << "MultiTrainer::Finalize ps client flush done"; } #endif diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index befcf36b41c24df29a11061de11db5111744f775..330719762ae08789b3ded2b0ad4bb7ab9a91f04e 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -86,11 +86,11 @@ void BindDistFleetWrapper(py::module* m) { void BindPSHost(py::module* m) { py::class_(*m, "PSHost") .def(py::init()) - .def("serialize_to_string", &distributed::PSHost::serialize_to_string) - .def("parse_from_string", &distributed::PSHost::parse_from_string) - .def("to_uint64", &distributed::PSHost::serialize_to_uint64) - .def("from_uint64", &distributed::PSHost::parse_from_uint64) - .def("to_string", &distributed::PSHost::to_string); + .def("serialize_to_string", &distributed::PSHost::SerializeToString) + .def("parse_from_string", &distributed::PSHost::ParseFromString) + .def("to_uint64", &distributed::PSHost::SerializeToUint64) + .def("from_uint64", &distributed::PSHost::ParseFromUint64) + .def("to_string", &distributed::PSHost::ToString); } void BindSparseShardingTools(py::module* m) { @@ -224,7 +224,7 @@ void BindGraphPyClient(py::module* m) { &GraphPyClient::use_neighbors_sample_cache) .def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes) - .def("stop_server", &GraphPyClient::stop_server) + .def("stop_server", &GraphPyClient::StopServer) .def("get_node_feat", [](GraphPyClient& self, std::string node_type, std::vector node_ids,