// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "Eigen/Dense" #include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/framework/archive.h" const static int max_port = 65535; DEFINE_int32(pserver_push_dense_merge_limit, 12, "limit max push_dense local merge requests"); DEFINE_int32(pserver_push_sparse_merge_limit, 12, "limit max push_sparse local merge requests"); DEFINE_int32(pserver_pull_dense_limit, 12, "limit max push_sparse local merge requests"); DEFINE_int32(pserver_async_push_dense_interval_ms, 10, "async push_dense to server interval"); DEFINE_int32(pserver_async_push_sparse_interval_ms, 10, "async push_sparse to server interval"); DEFINE_bool(pserver_scale_gradient_by_merge, false, "scale dense gradient when merged"); DEFINE_int32(pserver_communicate_compress_type, 0, "none:0 snappy:1 gzip:2 zlib:3 lz4:4"); DEFINE_int32(pserver_max_async_call_num, 13, "max task num in async_call_server"); DEFINE_int32(pserver_timeout_ms, 500000, "pserver request server timeout_ms"); DEFINE_int32(pserver_connect_timeout_ms, 10000, "pserver connect server timeout_ms"); DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); namespace paddle { namespace distributed { inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, uint64_t key) { size_t remind = shard_num % server_num; size_t local_shard_num = remind == 0 ? shard_num / server_num : shard_num / server_num + 1; return (key % shard_num) / local_shard_num; } void DownpourPsClientService::service( ::google::protobuf::RpcController *controller, const ::paddle::PsRequestMessage *request, ::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) { brpc::ClosureGuard done_guard(done); int ret = _client->handle_client2client_msg( request->cmd_id(), request->client_id(), request->data()); response->set_err_code(0); response->set_err_msg(""); if (ret != 0) { response->set_err_code(-1); response->set_err_msg("handle_client2client_msg failed"); } } // 启动client端RpcService 用于数据互发等操作 int32_t BrpcPsClient::start_client_service() { if (_service.configure(this, _client_id) != 0) { LOG(ERROR) << "service initialize failed, service_name:DownpourPsClientService"; return -1; } _server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE); brpc::ServerOptions options; int start_port = 8500; options.num_threads = 24; if (_server.Start(butil::my_ip_cstr(), brpc::PortRange(start_port, max_port), &options) != 0) { LOG(ERROR) << "BrpcPsServer start failed"; return -1; } _env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, _client_id); return 0; } int32_t BrpcPsClient::create_client2client_connection( int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { brpc::ChannelOptions options; options.protocol = "baidu_std"; options.timeout_ms = pserver_timeout_ms; options.connection_type = "pooled"; options.connect_timeout_ms = pserver_connect_timeout_ms; options.max_retry = max_retry; std::vector client_list = _env->get_ps_clients(); _client_channels.resize(client_list.size()); std::ostringstream os; std::string server_ip_port; for (size_t i = 0; i < client_list.size(); ++i) { server_ip_port.assign(client_list[i].ip.c_str()); server_ip_port.append(":"); server_ip_port.append(std::to_string(client_list[i].port)); _client_channels[i].reset(new brpc::Channel()); if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) { LOG(ERROR) << "psclient connect to client:" << server_ip_port << " Failed!"; } os << server_ip_port << ","; } LOG(INFO) << "Client connect success:" << os.str(); return 0; } int32_t BrpcPsClient::initialize() { _async_call_num = 0; brpc::ChannelOptions options; options.protocol = "baidu_std"; options.timeout_ms = FLAGS_pserver_timeout_ms; options.connection_type = "pooled"; options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms; options.max_retry = 3; std::ostringstream os; std::string server_ip_port; std::string client_ip(butil::my_ip_cstr()); // 获取server列表,并连接 std::vector server_list = _env->get_ps_servers(); _server_channels.resize(server_list.size()); for (size_t i = 0; i < server_list.size(); ++i) { server_ip_port.assign(server_list[i].ip.c_str()); server_ip_port.append(":"); server_ip_port.append(std::to_string(server_list[i].port)); for (size_t j = 0; j < _server_channels[i].size(); ++j) { _server_channels[i][j].reset(new brpc::Channel()); if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) != 0) { LOG(ERROR) << "psclient connect to server:" << server_ip_port << " Failed!"; return -1; } } os << server_ip_port << ","; } // 启动client探听接口, 并相互建立连接 start_client_service(); _running = true; _flushing = false; return 0; } int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) { if (_cntls[request_idx]->Failed()) { LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, " "err:" << _cntls[request_idx]->ErrorText(); return -1; } if (_responses[request_idx].err_code() != 0) { LOG(ERROR) << "response ret bad, server_idx:" << request_idx << "cmd_id:" << cmd_id << " err_code:" << _responses[request_idx].err_code() << " err_msg:" << _responses[request_idx].err_msg(); return -1; } return 0; } int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) { uint32_t feasign_size = 0; if (_cntls[request_idx]->Failed()) { LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, " "err:" << _cntls[request_idx]->ErrorText(); return -1; } feasign_size = _responses[request_idx].err_code(); if (feasign_size < 0) { LOG(ERROR) << "response ret bad, server_idx:" << request_idx << "cmd_id:" << cmd_id << " err_code:" << _responses[request_idx].err_code() << " err_msg:" << _responses[request_idx].err_msg(); return -1; } return feasign_size; } std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { std::string data = _responses[request_idx].data(); return data; } std::future BrpcPsClient::print_table_stat(uint32_t table_id) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, table_id](void *done) { int ret = 0; uint64_t feasign_size = 0; uint64_t mf_size = 0; paddle::framework::BinaryArchive ar; auto *closure = (DownpourBrpcClosure *)done; for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) { ret = -1; break; } std::string resp = closure->get_response(i, PS_PRINT_TABLE_STAT); ar.SetReadBuffer(const_cast(resp.c_str()), resp.length(), nullptr); feasign_size += ar.Get(); mf_size += ar.Get(); } closure->set_promise_value(ret); std::cout << "table id: " << table_id << ", feasign size: " << feasign_size << ", mf size: " << mf_size << std::endl; }); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); PsService_Stub rpc_stub(get_cmd_channel(i)); closure->cntl(i)->set_timeout_ms( 10800000); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } std::future BrpcPsClient::send_cmd( uint32_t table_id, int cmd_id, const std::vector ¶ms) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, cmd_id](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, cmd_id) != 0) { ret = -1; break; } } closure->set_promise_value(ret); }); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(cmd_id); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); for (const auto ¶m : params) { closure->request(i)->add_params(param); } PsService_Stub rpc_stub(get_cmd_channel(i)); closure->cntl(i)->set_timeout_ms( 10800000); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } std::future BrpcPsClient::send_save_cmd( uint32_t table_id, int cmd_id, const std::vector ¶ms) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, cmd_id](void *done) { int ret = 0; uint32_t feasign_size = 0; auto *closure = (DownpourBrpcClosure *)done; for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_save_response(i, cmd_id) < 0) { ret = -1; break; } feasign_size += closure->check_save_response(i, cmd_id); } if (ret == 0) { closure->set_promise_value(feasign_size); } else { closure->set_promise_value(ret); } }); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(cmd_id); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); for (const auto ¶m : params) { closure->request(i)->add_params(param); } PsService_Stub rpc_stub(get_cmd_channel(i)); closure->cntl(i)->set_timeout_ms( 10800000); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } std::future BrpcPsClient::shrink(uint32_t table_id) { return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")}); } std::future BrpcPsClient::load(const std::string &epoch, const std::string &mode) { return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); } std::future BrpcPsClient::load(uint32_t table_id, const std::string &epoch, const std::string &mode) { return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); } std::future BrpcPsClient::save(const std::string &epoch, const std::string &mode) { return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); } std::future BrpcPsClient::save(uint32_t table_id, const std::string &epoch, const std::string &mode) { return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); } std::future BrpcPsClient::clear() { return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); } std::future BrpcPsClient::clear(uint32_t table_id) { return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {}); } std::future BrpcPsClient::flush() { _flushing = true; std::promise promise; std::future fut = promise.get_future(); do { VLOG(3) << "wait _async_call_num:" << _async_call_num; usleep(100000); // sleep 100ms wait async end } while (_async_call_num > 0); promise.set_value(0); _flushing = false; return fut; } void BrpcPsClient::finalize_worker() { flush(); _running = false; _server.Stop(1000); _server.Join(); } std::future BrpcPsClient::stop_server() { return send_cmd(-1, PS_STOP_SERVER, {}); } std::future BrpcPsClient::start_profiler() { return send_cmd(-1, PS_START_PROFILER, {}); } std::future BrpcPsClient::stop_profiler() { return send_cmd(-1, PS_STOP_PROFILER, {}); } std::future BrpcPsClient::barrier(size_t table_id, uint32_t barrier_type) { return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); } std::future BrpcPsClient::pull_geo_param(size_t table_id, std::vector *values, std::vector *keys, int pserver_idx) { auto *accessor = table_accessor(table_id); DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; uint32_t shard_nums; if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) { ret = -1; } auto &res_io_buffer = closure->cntl(0)->response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); io_buffer_itr.copy_and_forward((void *)(&shard_nums), sizeof(uint32_t)); keys->resize(shard_nums); values->resize(shard_nums * accessor->update_dim()); io_buffer_itr.copy_and_forward((void *)(keys->data()), sizeof(uint64_t) * shard_nums); io_buffer_itr.copy_and_forward((void *)(values->data()), shard_nums * accessor->update_size()); closure->set_promise_value(ret); }); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM); closure->request(0)->set_table_id(table_id); closure->request(0)->set_client_id(_client_id); PsService_Stub rpc_stub(get_cmd_channel(pserver_idx)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); return fut; } std::future BrpcPsClient::push_sparse_param( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) { auto *accessor = table_accessor(table_id); // 发送RPC请求 DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); size_t request_call_num = _server_channels.size(); std::vector> ids; std::vector> value_ptrs; ids.resize(request_call_num); value_ptrs.resize(request_call_num); for (size_t i = 0; i < num; ++i) { size_t pserver_idx = keys[i] % request_call_num; ids[pserver_idx].push_back(keys[i]); value_ptrs[pserver_idx].push_back(update_values[i]); } for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { auto kvs = ids[shard_idx]; auto value_ptr = value_ptrs[shard_idx]; size_t kv_size = kvs.size(); uint32_t value_size = accessor->update_size(); // 发送RPC请求 auto *push_request = closure->request(shard_idx); push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); push_request->add_params((char *)&kv_size, sizeof(uint32_t)); auto *push_data = push_request->mutable_data(); push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); push_data_ptr += kv_size * sizeof(uint64_t); for (int i = 0; i < kv_size; ++i) { memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); push_data_ptr += accessor->update_size(); } PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), closure->response(shard_idx), closure); } return fut; } std::future BrpcPsClient::pull_dense(Region *regions, size_t region_num, size_t table_id) { auto *accessor = table_accessor(table_id); size_t request_call_num = _server_channels.size(); uint32_t num_per_shard = dense_dim_per_shard(accessor->fea_dim(), request_call_num); // callback 将各shard结果,顺序填入region DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, num_per_shard, regions, region_num, accessor](void *done) { int ret = 0; size_t region_idx = 0; // 当前填充的region偏移 size_t region_data_idx = 0; // 当前填充的region内data偏移 auto *closure = (DownpourBrpcClosure *)done; size_t shard_data_size = num_per_shard * accessor->select_size(); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { ret = -1; break; } auto &res_io_buffer = closure->cntl(i)->response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); size_t shard_buffer_remain = res_io_buffer.size(); if (shard_buffer_remain != shard_data_size) { LOG(ERROR) << "expect res_size:" << shard_data_size << ", but size:" << shard_buffer_remain << ", ignore this response"; ret = -1; break; } while (shard_buffer_remain > 0 && region_idx < region_num) { auto ®ion = regions[region_idx]; if (region.size - region_data_idx >= shard_buffer_remain) { // region待填充空间 >= 分片buffer数据, 直接拷贝置入 io_buffer_itr.copy_and_forward( (void *)(region.data + region_data_idx), shard_buffer_remain); region_data_idx += shard_buffer_remain; shard_buffer_remain = 0; } else if (region.size - region_data_idx == 0) { // region填满,切换到下一个region ++region_idx; region_data_idx = 0; } else { // region不足以容纳所有数据,则能放多少 拷贝多少 io_buffer_itr.copy_and_forward( (void *)(region.data + region_data_idx), region.size - region_data_idx); shard_buffer_remain -= (region.size - region_data_idx); ++region_idx; region_data_idx = 0; } } } closure->set_promise_value(ret); }); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); closure->request(i)->add_params((char *)&num_per_shard, sizeof(num_per_shard)); PsService_Stub rpc_stub(get_dense_channel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } std::future BrpcPsClient::push_dense_param(const Region *regions, size_t region_num, size_t table_id) { auto *accessor = table_accessor(table_id); size_t request_call_num = _server_channels.size(); // 1.拆分Region数据到shard中,后续多shard并行拷贝数据 std::vector> regions_partition(request_call_num); uint32_t num_per_shard = dense_dim_per_shard(accessor->fea_dim(), request_call_num); size_t shard_data_size = num_per_shard * accessor->update_size(); size_t current_region_idx = 0; size_t current_region_data_idx = 0; for (size_t i = 0; i < request_call_num; ++i) { size_t shard_data_remain_size = shard_data_size; while (shard_data_remain_size > 0 && current_region_idx < region_num) { const auto ®ion = regions[current_region_idx]; size_t region_remain_size = region.size - current_region_data_idx; if (shard_data_remain_size >= region_remain_size) { regions_partition[i].push_back( Region(region.data + current_region_data_idx, region_remain_size)); ++current_region_idx; current_region_data_idx = 0; shard_data_remain_size -= region_remain_size; } else { regions_partition[i].push_back(Region( region.data + current_region_data_idx, shard_data_remain_size)); current_region_data_idx += shard_data_remain_size; shard_data_remain_size = 0; } } } DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) { ret = -1; break; } } closure->set_promise_value(ret); }); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10; static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; //用于数据补齐 //开始多shard并行拷贝&请求 for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); auto &request_buffer = closure->cntl(i)->request_attachment(); request_buffer.append((void *)&num_per_shard, sizeof(uint32_t)); auto ®ion_list = regions_partition[i]; size_t fill_remain_size = shard_data_size; for (auto ®ion : region_list) { fill_remain_size -= region.size; request_buffer.append((void *)region.data, region.size); } //保证各分片数据对齐 while (fill_remain_size > 0) { size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE ? REGION_ASSIGN_BUFFER_SIZE : fill_remain_size; request_buffer.append((void *)region_assign_buffer, fill_num); fill_remain_size -= fill_num; } PsService_Stub rpc_stub(get_dense_channel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } std::future BrpcPsClient::push_sparse_raw_gradient( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) { auto *accessor = table_accessor(table_id); //发送RPC请求 DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); size_t request_call_num = _server_channels.size(); std::vector> ids; std::vector> value_ptrs; ids.resize(request_call_num); value_ptrs.resize(request_call_num); for (size_t i = 0; i < num; ++i) { size_t pserver_idx = keys[i] % request_call_num; ids[pserver_idx].push_back(keys[i]); value_ptrs[pserver_idx].push_back(update_values[i]); } for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { auto kvs = ids[shard_idx]; auto value_ptr = value_ptrs[shard_idx]; size_t kv_size = kvs.size(); uint32_t value_size = accessor->update_size(); // 发送RPC请求 auto *push_request = closure->request(shard_idx); push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); push_request->add_params((char *)&kv_size, sizeof(uint32_t)); auto *push_data = push_request->mutable_data(); push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); push_data_ptr += kv_size * sizeof(uint64_t); for (int i = 0; i < kv_size; ++i) { memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); push_data_ptr += accessor->update_size(); } PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), closure->response(shard_idx), closure); } return fut; } std::future BrpcPsClient::push_dense_raw_gradient( int table_id, float *total_send_data, size_t total_send_data_size, void *done) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); auto *accessor = table_accessor(table_id); uint32_t num_per_shard = dense_dim_per_shard(accessor->fea_dim(), request_call_num); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); auto *push_data = closure->request(i)->mutable_data(); push_data->clear(); push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float)); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t)); memcpy(push_data_ptr + sizeof(uint32_t), total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); VLOG(1) << "push_dense_raw_gradient finish memcpy"; // closure->cntl(i)->set_request_compress_type( // (brpc::CompressType)FLAGS_pserver_communicate_compress_type); PsService_Stub rpc_stub(get_dense_channel(i)); VLOG(1) << "push_dense_raw_gradient get_dense_channel " << i; rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); VLOG(1) << "push_dense_raw_gradient async service " << i; } return fut; } std::future BrpcPsClient::pull_sparse(float **select_values, size_t table_id, const uint64_t *keys, size_t num) { size_t request_call_num = _server_channels.size(); auto shard_sorted_kvs = std::make_shared< std::vector>>>(); shard_sorted_kvs->resize(request_call_num); for (size_t i = 0; i < num; ++i) { size_t shard_id = keys[i] % request_call_num; shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); } auto *accessor = table_accessor(table_id); size_t value_size = accessor->select_size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [shard_sorted_kvs, value_size](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; for (size_t i = 0; i < ids.size(); ++i) { if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) { ret = -1; break; } auto &request_kvs = shard_sorted_kvs->at(i); auto &res_io_buffer = closure->cntl(i)->response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); uint64_t last_key = UINT64_MAX; float *last_value_data = NULL; for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) { auto *kv_pair = &(request_kvs[kv_idx]); if (kv_pair->first == last_key) { memcpy((void *)kv_pair->second, (void *)last_value_data, value_size); } else { last_key = kv_pair->first; last_value_data = kv_pair->second; if (value_size != io_buffer_itr.copy_and_forward((void *)(last_value_data), value_size)) { LOG(WARNING) << "res data is lack or not in format"; ret = -1; break; } } } } closure->set_promise_value(ret); }); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); for (size_t i = 0; i < request_call_num; ++i) { auto &sorted_kvs = shard_sorted_kvs->at(i); std::sort(sorted_kvs.begin(), sorted_kvs.end(), [](const std::pair &k1, const std::pair &k2) { return k1.first < k2.first; }); uint64_t last_key = UINT64_MAX; uint32_t kv_request_count = 0; size_t sorted_kv_size = sorted_kvs.size(); auto &request_buffer = closure->cntl(i)->request_attachment(); for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { ++kv_request_count; last_key = sorted_kvs[kv_idx].first; request_buffer.append((void *)&last_key, sizeof(uint64_t)); while (kv_idx < sorted_kv_size - 1 && last_key == sorted_kvs[kv_idx + 1].first) { ++kv_idx; } } if (kv_request_count == 0) { closure->Run(); } else { closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); closure->request(i)->add_params((char *)&kv_request_count, sizeof(uint32_t)); PsService_Stub rpc_stub(get_cmd_channel(i)); closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } } return fut; } std::future BrpcPsClient::send_client2client_msg( int msg_type, int to_client_id, const std::string &msg) { auto promise = std::make_shared>(); std::future fut = promise->get_future(); if (to_client_id >= _client_channels.size()) { LOG(FATAL) << "to_client_id is out of range clients, which size is " << _client_channels.size(); promise->set_value(-1); return fut; } auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) { auto *closure = (DownpourBrpcClosure *)done; int32_t ret = closure->check_response(0, msg_type + 1000); closure->set_promise_value(ret); }); closure->add_promise(promise); closure->request(0)->set_cmd_id(msg_type); closure->request(0)->set_client_id(_client_id); closure->request(0)->set_data(msg); PsService_Stub rpc_stub(_client_channels[to_client_id].get()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); return fut; } std::future BrpcPsClient::push_sparse_raw_gradient_partial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) { auto *accessor = table_accessor(table_id); size_t value_size = accessor->update_size(); DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); // 发送RPC请求 auto *push_request = closure->request(0); push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); push_request->add_params((char *)&num, sizeof(uint32_t)); auto *push_data = push_request->mutable_data(); push_data->resize(num * (sizeof(uint64_t) + value_size)); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, keys, num * sizeof(uint64_t)); push_data_ptr += num * sizeof(uint64_t); for (int i = 0; i < num; ++i) { memcpy(push_data_ptr, update_values[i], value_size); push_data_ptr += value_size; } PsService_Stub rpc_stub(get_sparse_channel(pserver_idx)); closure->cntl(0)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); return fut; } } // namespace distributed } // namespace paddle