From bd79ae8a319403e519e0e3a0052a8ccd6e4ed82e Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Fri, 12 Nov 2021 11:51:09 +0800 Subject: [PATCH] brpc_ps_client upgrade (#36943) * test * rm test * add memory_sparse_table and brpc communication upgrade dependency * fix * add dense optimizer & fix dump bug & add some strategy fields * fix * fix * remove thread_pool thread_queue * add memory sparse table * update memory sparse table * update memory sparse table * update cmake * upgrade brpc_ps_client * remove show/click_const in ctr_accessor * fix deconstructor --- .../distributed/service/brpc_ps_client.cc | 706 ++++++++++++++++-- .../distributed/service/brpc_ps_client.h | 175 +++-- .../fluid/distributed/service/communicator.cc | 2 + paddle/fluid/distributed/service/ps_client.h | 14 +- .../fluid/distributed/table/ctr_accessor.cc | 8 +- paddle/fluid/distributed/table/ctr_accessor.h | 17 +- 6 files changed, 807 insertions(+), 115 deletions(-) diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc index a6ad9d08f52..f6b544d22b2 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/framework/archive.h" -const static int max_port = 65535; +static const int max_port = 65535; DEFINE_int32(pserver_push_dense_merge_limit, 12, "limit max push_dense local merge requests"); @@ -52,6 +52,9 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000, DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); +DEFINE_int32(pserver_sparse_table_shard_num, 1000, + "sparse table shard for save & load"); + namespace paddle { namespace framework { class Scope; @@ -102,6 +105,7 @@ int32_t BrpcPsClient::start_client_service() { LOG(ERROR) << "BrpcPsServer start failed"; return -1; } + _server_started = true; _env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, _client_id); return 0; @@ -117,6 +121,12 @@ int32_t BrpcPsClient::create_client2client_connection( options.max_retry = max_retry; std::vector client_list = _env->get_ps_clients(); + VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: " + << client_list.size(); + for (auto cc : client_list) { + VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: " + << cc.to_string(); + } _client_channels.resize(client_list.size()); std::ostringstream os; std::string server_ip_port; @@ -184,8 +194,34 @@ int32_t BrpcPsClient::initialize() { // 启动client探听接口, 并相互建立连接 start_client_service(); + // 异步push 请求队列初始化 + const auto &worker_param = _config.worker_param().downpour_worker_param(); + for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) { + auto type = worker_param.downpour_table_param(i).type(); + auto table_id = worker_param.downpour_table_param(i).table_id(); + if (type == PS_DENSE_TABLE) { + _push_dense_task_queue_map[table_id] = + paddle::framework::MakeChannel(); + } + if (type == PS_SPARSE_TABLE) { + _push_sparse_task_queue_map[table_id] = + paddle::framework::MakeChannel(); + _push_sparse_merge_count_map[table_id] = 0; + } + } + _running = true; _flushing = false; + // 启动异步push线程 + _async_push_sparse_thread = + std::thread(std::bind(&BrpcPsClient::push_sparse_task_consume, this)); + // _async_push_sparse_thread.detach(); + _async_push_dense_thread = + std::thread(std::bind(&BrpcPsClient::push_dense_task_consume, this)); + // for debug + // _print_thread = + // std::thread(std::bind(&BrpcPsClient::print_queue_size_thread, this)); + return 0; } @@ -238,7 +274,7 @@ std::future BrpcPsClient::print_table_stat(uint32_t table_id) { uint64_t feasign_size = 0; uint64_t mf_size = 0; paddle::framework::BinaryArchive ar; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) { ret = -1; @@ -277,7 +313,7 @@ std::future BrpcPsClient::send_cmd( DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, cmd_id](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, cmd_id) != 0) { ret = -1; @@ -298,7 +334,7 @@ std::future BrpcPsClient::send_cmd( } PsService_Stub rpc_stub(get_cmd_channel(i)); closure->cntl(i)->set_timeout_ms( - 10800000); // cmd msg don't limit timeout for save/load + 10800000 * 2); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } @@ -312,7 +348,7 @@ std::future BrpcPsClient::send_save_cmd( request_call_num, [request_call_num, cmd_id](void *done) { int ret = 0; uint32_t feasign_size = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_save_response(i, cmd_id) < 0) { ret = -1; @@ -362,11 +398,14 @@ std::future BrpcPsClient::load(uint32_t table_id, std::future BrpcPsClient::save(const std::string &epoch, const std::string &mode) { + VLOG(1) << "BrpcPsClient::save path " << epoch; return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); } std::future BrpcPsClient::save(uint32_t table_id, const std::string &epoch, const std::string &mode) { + VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id " + << table_id; return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); } @@ -378,6 +417,7 @@ std::future BrpcPsClient::clear(uint32_t table_id) { } std::future BrpcPsClient::flush() { + VLOG(0) << "BrpcPsClient::flush begin"; _flushing = true; std::promise promise; std::future fut = promise.get_future(); @@ -385,16 +425,49 @@ std::future BrpcPsClient::flush() { VLOG(3) << "wait _async_call_num:" << _async_call_num; usleep(100000); // sleep 100ms wait async end } while (_async_call_num > 0); + VLOG(1) << "flush _async_call_num = 0"; promise.set_value(0); _flushing = false; + VLOG(0) << "BrpcPsClient::flush done"; + print_queue_size(); return fut; } +void BrpcPsClient::print_queue_size() { + for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { + auto table_id = push_sparse_task_itr.first; + auto queue_size = push_sparse_task_itr.second->Size(); + VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id + << " size: " << queue_size; + } + + for (auto &task_queue_itr : _push_dense_task_queue_map) { + auto table_id = task_queue_itr.first; + auto queue_size = task_queue_itr.second->Size(); + VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id + << " size: " << queue_size; + } +} + +void BrpcPsClient::print_queue_size_thread() { + while (_running) { + usleep(1000000 * 60 * 2); + print_queue_size(); + } +} + void BrpcPsClient::finalize_worker() { flush(); + VLOG(0) << "BrpcPsClient::finalize_worker begin join thread"; _running = false; + _async_push_dense_thread.join(); + _async_push_sparse_thread.join(); + // _print_thread.join(); + VLOG(0) << "BrpcPsClient::finalize_worker begin join server"; _server.Stop(1000); _server.Join(); + _server_started = false; + VLOG(0) << "BrpcPsClient::finalize_worker done"; } std::future BrpcPsClient::stop_server() { @@ -422,19 +495,20 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); uint32_t shard_nums; if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) { ret = -1; } auto &res_io_buffer = closure->cntl(0)->response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); - io_buffer_itr.copy_and_forward((void *)(&shard_nums), sizeof(uint32_t)); + io_buffer_itr.copy_and_forward(reinterpret_cast(&shard_nums), + sizeof(uint32_t)); keys->resize(shard_nums); values->resize(shard_nums * accessor->update_dim()); - io_buffer_itr.copy_and_forward((void *)(keys->data()), + io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT sizeof(uint64_t) * shard_nums); - io_buffer_itr.copy_and_forward((void *)(values->data()), + io_buffer_itr.copy_and_forward((void *)(values->data()), // NOLINT shard_nums * accessor->update_size()); closure->set_promise_value(ret); }); @@ -466,8 +540,19 @@ std::future BrpcPsClient::push_sparse_param( std::vector> value_ptrs; ids.resize(request_call_num); value_ptrs.resize(request_call_num); + + const auto &server_param = _config.server_param().downpour_server_param(); + uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; + for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { + const auto &table_param = server_param.downpour_table_param(i); + if (table_param.table_id() == table_id) { + shard_num = table_param.shard_num(); + break; + } + } + for (size_t i = 0; i < num; ++i) { - size_t pserver_idx = keys[i] % request_call_num; + size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]); ids[pserver_idx].push_back(keys[i]); value_ptrs[pserver_idx].push_back(update_values[i]); } @@ -481,7 +566,7 @@ std::future BrpcPsClient::push_sparse_param( push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); - push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); char *push_data_ptr = const_cast(push_data->data()); @@ -514,7 +599,7 @@ std::future BrpcPsClient::pull_dense(Region *regions, int ret = 0; size_t region_idx = 0; // 当前填充的region偏移 size_t region_data_idx = 0; // 当前填充的region内data偏移 - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); size_t shard_data_size = num_per_shard * accessor->select_size(); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { @@ -537,7 +622,8 @@ std::future BrpcPsClient::pull_dense(Region *regions, if (region.size - region_data_idx >= shard_buffer_remain) { // region待填充空间 >= 分片buffer数据, 直接拷贝置入 io_buffer_itr.copy_and_forward( - (void *)(region.data + region_data_idx), shard_buffer_remain); + reinterpret_cast(region.data + region_data_idx), + shard_buffer_remain); region_data_idx += shard_buffer_remain; shard_buffer_remain = 0; } else if (region.size - region_data_idx == 0) { @@ -547,7 +633,7 @@ std::future BrpcPsClient::pull_dense(Region *regions, } else { // region不足以容纳所有数据,则能放多少 拷贝多少 io_buffer_itr.copy_and_forward( - (void *)(region.data + region_data_idx), + reinterpret_cast(region.data + region_data_idx), region.size - region_data_idx); shard_buffer_remain -= (region.size - region_data_idx); ++region_idx; @@ -564,7 +650,7 @@ std::future BrpcPsClient::pull_dense(Region *regions, closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); - closure->request(i)->add_params((char *)&num_per_shard, + closure->request(i)->add_params((char *)&num_per_shard, // NOLINT sizeof(num_per_shard)); PsService_Stub rpc_stub(get_dense_channel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), @@ -608,7 +694,7 @@ std::future BrpcPsClient::push_dense_param(const Region *regions, DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) { ret = -1; @@ -621,26 +707,28 @@ std::future BrpcPsClient::push_dense_param(const Region *regions, closure->add_promise(promise); std::future fut = promise->get_future(); static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10; - static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; //用于数据补齐 - //开始多shard并行拷贝&请求 + static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; // 用于数据补齐 + // 开始多shard并行拷贝&请求 for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); auto &request_buffer = closure->cntl(i)->request_attachment(); - request_buffer.append((void *)&num_per_shard, sizeof(uint32_t)); + request_buffer.append(reinterpret_cast(&num_per_shard), + sizeof(uint32_t)); auto ®ion_list = regions_partition[i]; size_t fill_remain_size = shard_data_size; for (auto ®ion : region_list) { fill_remain_size -= region.size; - request_buffer.append((void *)region.data, region.size); + request_buffer.append(reinterpret_cast(region.data), region.size); } - //保证各分片数据对齐 + // 保证各分片数据对齐 while (fill_remain_size > 0) { size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE ? REGION_ASSIGN_BUFFER_SIZE : fill_remain_size; - request_buffer.append((void *)region_assign_buffer, fill_num); + request_buffer.append(reinterpret_cast(region_assign_buffer), + fill_num); fill_remain_size -= fill_num; } PsService_Stub rpc_stub(get_dense_channel(i)); @@ -654,7 +742,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) { auto *accessor = table_accessor(table_id); - //发送RPC请求 + // 发送RPC请求 DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); closure->add_promise(promise); @@ -666,8 +754,18 @@ std::future BrpcPsClient::push_sparse_raw_gradient( ids.resize(request_call_num); value_ptrs.resize(request_call_num); + const auto &server_param = _config.server_param().downpour_server_param(); + uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; + for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { + const auto &table_param = server_param.downpour_table_param(i); + if (table_param.table_id() == table_id) { + shard_num = table_param.shard_num(); + break; + } + } + for (size_t i = 0; i < num; ++i) { - size_t pserver_idx = keys[i] % request_call_num; + size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]); ids[pserver_idx].push_back(keys[i]); value_ptrs[pserver_idx].push_back(update_values[i]); } @@ -684,7 +782,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); - push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); char *push_data_ptr = const_cast(push_data->data()); @@ -726,14 +824,11 @@ std::future BrpcPsClient::push_dense_raw_gradient( memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t)); memcpy(push_data_ptr + sizeof(uint32_t), total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); - VLOG(1) << "push_dense_raw_gradient finish memcpy"; // closure->cntl(i)->set_request_compress_type( // (brpc::CompressType)FLAGS_pserver_communicate_compress_type); PsService_Stub rpc_stub(get_dense_channel(i)); - VLOG(1) << "push_dense_raw_gradient get_dense_channel " << i; rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); - VLOG(1) << "push_dense_raw_gradient async service " << i; } return fut; } @@ -776,8 +871,18 @@ std::future BrpcPsClient::pull_sparse(float **select_values, std::vector>>>(); shard_sorted_kvs->resize(request_call_num); + const auto &server_param = _config.server_param().downpour_server_param(); + uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; + for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { + const auto &table_param = server_param.downpour_table_param(i); + if (table_param.table_id() == table_id) { + shard_num = table_param.shard_num(); + break; + } + } + for (size_t i = 0; i < num; ++i) { - size_t shard_id = keys[i] % request_call_num; + size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]); shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); } @@ -787,7 +892,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [shard_sorted_kvs, value_size](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) { if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) { ret = -1; @@ -803,14 +908,14 @@ std::future BrpcPsClient::pull_sparse(float **select_values, for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) { auto *kv_pair = &(request_kvs[kv_idx]); if (kv_pair->first == last_key) { - memcpy((void *)kv_pair->second, (void *)last_value_data, - value_size); + memcpy(reinterpret_cast(kv_pair->second), + reinterpret_cast(last_value_data), value_size); } else { last_key = kv_pair->first; last_value_data = kv_pair->second; if (value_size != - io_buffer_itr.copy_and_forward((void *)(last_value_data), - value_size)) { + io_buffer_itr.copy_and_forward( + reinterpret_cast(last_value_data), value_size)) { LOG(WARNING) << "res data is lack or not in format"; ret = -1; break; @@ -838,7 +943,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, size_t sorted_kv_size = sorted_kvs.size(); auto &request_buffer = closure->cntl(i)->request_attachment(); - request_buffer.append((void *)&is_training, sizeof(bool)); + request_buffer.append(reinterpret_cast(&is_training), sizeof(bool)); std::vector keys_counter; keys_counter.reserve(sorted_kv_size); @@ -846,7 +951,8 @@ std::future BrpcPsClient::pull_sparse(float **select_values, ++kv_request_count; uint32_t keys = 1; last_key = sorted_kvs[kv_idx].first; - request_buffer.append((void *)&last_key, sizeof(uint64_t)); + request_buffer.append(reinterpret_cast(&last_key), + sizeof(uint64_t)); while (kv_idx < sorted_kv_size - 1 && last_key == sorted_kvs[kv_idx + 1].first) { ++kv_idx; @@ -855,7 +961,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, keys_counter.push_back(keys); } - request_buffer.append((void *)keys_counter.data(), + request_buffer.append(reinterpret_cast(keys_counter.data()), sizeof(uint32_t) * keys_counter.size()); if (kv_request_count == 0) { @@ -864,7 +970,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); - closure->request(i)->add_params((char *)&kv_request_count, + closure->request(i)->add_params((char *)&kv_request_count, // NOLINT sizeof(uint32_t)); PsService_Stub rpc_stub(get_cmd_channel(i)); closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); @@ -886,7 +992,7 @@ std::future BrpcPsClient::send_client2client_msg( return fut; } auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) { - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); int32_t ret = closure->check_response(0, msg_type + 1000); closure->set_promise_value(ret); }); @@ -915,7 +1021,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient_partial( push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); - push_request->add_params((char *)&num, sizeof(uint32_t)); + push_request->add_params((char *)&num, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); push_data->resize(num * (sizeof(uint64_t) + value_size)); char *push_data_ptr = const_cast(push_data->data()); @@ -966,8 +1072,8 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, save_vec.push_back(save_huge_vec.data() + i * var_shape); } - auto status = pull_sparse((float **)save_vec.data(), table_id, - save_key.data(), save_key.size(), true); + auto status = pull_sparse(reinterpret_cast(save_vec.data()), + table_id, save_key.data(), save_key.size(), true); status.wait(); // create lod tensor @@ -1000,5 +1106,521 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, return 0; } +std::future BrpcPsClient::push_sparse(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num) { + auto push_timer = + std::make_shared("pserver_client_push_sparse_parse"); + int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); + while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) { + // LOG(INFO) << "push_sparse Waiting for async_call_num comsume, task_num:" + // << push_sparse_async_num << ", max_task_limit:" << + // FLAGS_pserver_max_async_call_num; + usleep(5000); // 5ms + // push_sparse_async_num = _push_sparse_task_queue_map[table_id]->size(); + push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); + } + thread_local std::vector>> + shard_sorted_kv_list; + auto *accessor = table_accessor(table_id); + size_t request_call_num = _server_channels.size(); + shard_sorted_kv_list.resize(request_call_num); + for (auto &x : shard_sorted_kv_list) { + x.clear(); + } + const auto &server_param = _config.server_param().downpour_server_param(); + uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; + for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { + const auto &table_param = server_param.downpour_table_param(i); + if (table_param.table_id() == table_id) { + shard_num = table_param.shard_num(); + break; + } + } + for (size_t i = 0; i < num; ++i) { + size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]); + shard_sorted_kv_list[shard_id].push_back({keys[i], update_values[i]}); + } + auto sparse_task_data = _sparse_task_pool.get(); + sparse_task_data->shared_data.resize(request_call_num); + auto async_task = new SparseAsyncTask(sparse_task_data, table_id, push_timer); + + for (size_t i = 0; i < request_call_num; ++i) { + auto &sorted_kv_list = shard_sorted_kv_list[i]; + size_t sorted_kv_size = sorted_kv_list.size(); + auto &shard_kv_data = async_task->data()->shared_data[i]; + shard_kv_data.key_list.resize(sorted_kv_size); + shard_kv_data.value_list.resize(sorted_kv_size); + + if (sorted_kv_size == 0) { + shard_kv_data.kv_num = 0; + continue; + } + + uint32_t value_size = accessor->update_size(); + for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { + shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first; + shard_kv_data.value_list[kv_idx].assign( + (const char *)sorted_kv_list[kv_idx].second, value_size); + } + shard_kv_data.kv_num = sorted_kv_size; + } + + std::future fut = async_task->get_future(); + _push_sparse_task_queue_map[table_id]->Put(std::move(async_task)); + return fut; +} + +void BrpcPsClient::push_sparse_task_consume() { + uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit; + std::vector> task_list; + size_t request_call_num = _server_channels.size(); + ::ThreadPool async_push_sparse_shard_threads( + FLAGS_pserver_sparse_merge_thread); + while (_running) { + platform::Timer timeline; + timeline.Start(); + // 所有sparseTable的pushTask 进行处理 + for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { + auto table_id = push_sparse_task_itr.first; + auto *accessor = table_accessor(table_id); + auto &task_queue = push_sparse_task_itr.second; + auto queue_size = task_queue->Size(); + if (queue_size == 0) { + continue; + } + if (merge_size > 0 && (queue_size <= 1 && _flushing == false)) { + continue; + } + ++_async_call_num; + + int merge_count = 0; + for (size_t i = 0; i < task_list.size(); ++i) { + if (task_list[i]->data()) { + _sparse_task_pool.push(task_list[i]->data()); + } + } + auto sparse_task_data = _sparse_task_pool.get(); + + task_list.clear(); + int cur_meger_size = task_queue->Size(); + + // task_list[0] 为一个空SparseAsyncTask, 分shard异步merge结果存入此结构。 + sparse_task_data->shared_data.resize(request_call_num); + auto push_timer = + std::make_shared("pserver_client_push_sparse"); + + auto async_task = + new SparseAsyncTask(sparse_task_data, table_id, push_timer); + + task_list.reserve(cur_meger_size + 1); + + task_list.push_back( + std::move(std::shared_ptr(async_task))); + + while (!task_queue->Empty() && merge_count < cur_meger_size) { + ++merge_count; + SparseAsyncTask *task; + task_queue->Get(task); + task_list.push_back(std::shared_ptr(task)); + } + + _push_sparse_merge_count_map[table_id] += merge_count; + + // 达到或大于 merge_size发送, 发送过程中 + std::vector request_kv_num(request_call_num, 0); + + if (_push_sparse_merge_count_map[table_id] >= merge_size || + _flushing == true) { + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = reinterpret_cast(done); + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + + for_each(task_list.begin() + 1, task_list.end(), + [&request_kv_num, request_call_num, + closure](std::shared_ptr &task) { + // closure->add_timer(task->timer()); + closure->add_promise(task->promise()); + }); + + // CostTimer merge_timer("pserver_client_push_sparse_merge"); + // auto rpc_timer = + // std::make_shared("pserver_client_push_sparse_rpc"); + // closure->add_timer(rpc_timer); + + std::vector> merge_status(request_call_num); + for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + merge_status[shard_idx] = + async_push_sparse_shard_threads.enqueue(std::bind( + &BrpcPsClient::push_sparse_async_shard_push, this, task_list, + request_kv_num, table_id, shard_idx, closure, accessor)); + } + for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + merge_status[shard_idx].wait(); + } + merge_status.clear(); + std::vector>().swap(merge_status); + _push_sparse_merge_count_map[table_id] = 0; + + auto queue_size = task_queue->Size(); + } else { // 未达到阈值 只做多路归并 + std::vector> merge_status(request_call_num); + for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + merge_status[shard_idx] = + async_push_sparse_shard_threads.enqueue(std::bind( + &BrpcPsClient::push_sparse_async_shard_merge, this, task_list, + request_kv_num, table_id, shard_idx, accessor)); + } + for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + merge_status[shard_idx].wait(); + } + + // meger到task_list[0] + auto async_task = new SparseAsyncTask(*(task_list[0].get())); + + task_queue->Put(std::move(async_task)); + --_async_call_num; + merge_status.clear(); + std::vector>().swap(merge_status); + } + } + auto wait_ms = + FLAGS_pserver_async_push_sparse_interval_ms - (timeline.ElapsedMS()); + if (wait_ms > 0) { + usleep(wait_ms * 1000); + } + } +} + +void sparse_local_merge(ValueAccessor *accessor, float *merge_data, + const float *another_data) { + size_t col_num = accessor->update_size() / sizeof(float); + float *merge_data_shell[col_num]; + const float *another_data_shell[col_num]; + for (int i = 0; i < col_num; ++i) { + merge_data_shell[i] = merge_data + i; + another_data_shell[i] = another_data + i; + } + accessor->merge(merge_data_shell, another_data_shell, 1); +} + +int BrpcPsClient::push_sparse_async_shard_merge( + std::vector> &task_list, + std::vector &request_kv_num, int table_id, int shard_idx, + ValueAccessor *accessor) { + size_t merged_kv_count = 0; + uint64_t min_key = UINT64_MAX; + uint32_t value_size = accessor->update_size(); + + thread_local std::vector> sorted_kv_list; + sorted_kv_list.clear(); + for (int i = 1; i < task_list.size(); ++i) { + size_t kv_num = task_list[i]->data()->shared_data[shard_idx].kv_num; + auto &key_list = task_list[i]->data()->shared_data[shard_idx].key_list; + auto &value_list = task_list[i]->data()->shared_data[shard_idx].value_list; + + for (int j = 0; j < kv_num; ++j) { + if (value_list[j].size() < value_size) { + LOG(WARNING) << "value_list[" << j << "]: " << value_list[j].c_str() + << "is invalid."; + continue; + } + char *task_data_ptr = const_cast(value_list[j].data()); + sorted_kv_list.push_back( + {key_list[j], reinterpret_cast(task_data_ptr)}); + } + } + + // 按key排序&去重 + std::sort(sorted_kv_list.begin(), sorted_kv_list.end(), + [](const std::pair &k1, + const std::pair &k2) { + return k1.first < k2.first; + }); + + auto &async_task = task_list[0]; + size_t sorted_kv_size = sorted_kv_list.size(); + auto &shard_kv_data = async_task->data()->shared_data[shard_idx]; + shard_kv_data.key_list.resize(sorted_kv_size); + shard_kv_data.value_list.resize(sorted_kv_size); + + // 将去重后数据写入分shard包 + if (sorted_kv_size == 0) { + shard_kv_data.kv_num = 0; + return 0; + } else if (sorted_kv_size == 1) { + shard_kv_data.kv_num = 1; + shard_kv_data.key_list[0] = sorted_kv_list[0].first; + shard_kv_data.value_list[0].assign((const char *)(sorted_kv_list[0].second), + value_size); + return 0; + } + + // 去重 本地merge + uint64_t last_key = sorted_kv_list[0].first; + const float *last_value_data = sorted_kv_list[0].second; + float *last_merge_data = NULL; + std::shared_ptr merger_buffer(new char[value_size], + array_deleter()); + for (size_t kv_idx = 1; kv_idx < sorted_kv_size; ++kv_idx) { + while (kv_idx < sorted_kv_size && + last_key == sorted_kv_list[kv_idx].first) { + if (last_merge_data == NULL) { + last_merge_data = reinterpret_cast(merger_buffer.get()); + memcpy(last_merge_data, last_value_data, value_size); + } + sparse_local_merge(accessor, last_merge_data, + sorted_kv_list[kv_idx].second); + ++kv_idx; + } + if (last_merge_data != NULL) { + shard_kv_data.value_list[merged_kv_count].assign( + (const char *)last_merge_data, value_size); + last_merge_data = NULL; + } else { + shard_kv_data.value_list[merged_kv_count].assign( + (const char *)sorted_kv_list[kv_idx - 1].second, value_size); + } + shard_kv_data.key_list[merged_kv_count++] = last_key; + if (kv_idx < sorted_kv_size) { + last_key = sorted_kv_list[kv_idx].first; + last_value_data = sorted_kv_list[kv_idx].second; + } + if (kv_idx == sorted_kv_size - 1) { + shard_kv_data.value_list[merged_kv_count].assign( + (const char *)last_value_data, value_size); + shard_kv_data.key_list[merged_kv_count++] = last_key; + } + } + shard_kv_data.kv_num = merged_kv_count; + return 0; +} + +int BrpcPsClient::push_sparse_async_shard_push( + std::vector> &task_list, + std::vector &request_kv_num, int table_id, int shard_idx, + DownpourBrpcClosure *closure, ValueAccessor *accessor) { + push_sparse_async_shard_merge(task_list, request_kv_num, table_id, shard_idx, + accessor); + size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num; + + auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list; + auto &merged_value_list = + task_list[0]->data()->shared_data[shard_idx].value_list; + + // 发送RPC请求 + auto *push_request = closure->request(shard_idx); + push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params(reinterpret_cast(&merged_kv_count), + sizeof(uint32_t)); // NOLINT + auto *push_data = push_request->mutable_data(); + push_data->resize(merged_kv_count * + (sizeof(uint64_t) + accessor->update_size())); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, merged_key_list.data(), + merged_kv_count * sizeof(uint64_t)); + push_data_ptr += merged_kv_count * sizeof(uint64_t); + for (int i = 0; i < merged_kv_count; ++i) { + const char *task_data_ptr = merged_value_list[i].data(); + + memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT + accessor->update_size()); + push_data_ptr += accessor->update_size(); + } + PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + closure->cntl(shard_idx)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), + closure->response(shard_idx), closure); + _push_sparse_merge_count_map[table_id] = 0; + return 0; +} + +std::future BrpcPsClient::push_dense(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = table_accessor(table_id); + auto push_timer = std::make_shared("pserver_client_push_dense"); + auto parse_timer = + std::make_shared("pserver_client_push_dense_parse"); + int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size(); + while (push_dense_async_num > FLAGS_pserver_max_async_call_num) { + LOG(INFO) << "push_dense Waiting for async_call_num comsume, task_num:" + << push_dense_async_num + << ", max_task_limit:" << FLAGS_pserver_max_async_call_num; + usleep(5000); // 5ms + push_dense_async_num = _push_dense_task_queue_map[table_id]->Size(); + } + // auto dense_data = _dense_matrix_obj_pool.get(); + auto dense_data = std::make_shared>(); + auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer); + size_t request_call_num = _server_channels.size(); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + + // 将region数据拷贝到转置矩阵中 + async_task->data()->resize(num_per_shard * request_call_num * + accessor->update_dim()); + float *data = async_task->data()->data(); + size_t data_size = async_task->data()->size(); + uint32_t pos = 0; + for (size_t i = 0; i < region_num; ++i) { + uint32_t data_num = regions[i].size / sizeof(float); + CHECK(pos + data_num <= data_size) + << "invalid dense size, cur pos[" << pos << "]" + << " data_num[" << data_num << "] size[" << data_size << "]"; + const float *region_data = (const float *)(regions[i].data); + memcpy(data + pos, region_data, regions[i].size); + pos += data_num; + } + std::future fut = async_task->get_future(); + _push_dense_task_queue_map[table_id]->Put(std::move(async_task)); + return fut; +} + +void BrpcPsClient::push_dense_task_consume() { + uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit; + static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge; + ::ThreadPool async_merge_dense_threads(10); + while (_running) { + platform::Timer timeline; + timeline.Start(); + for (auto &task_queue_itr : _push_dense_task_queue_map) { + auto &task_queue = task_queue_itr.second; + auto queue_size = task_queue->Size(); + if (queue_size == 0) { + continue; + } + if (queue_size <= merge_size && _flushing == false) { + continue; + } + ++_async_call_num; + DenseAsyncTask *task; + task_queue->Get(task); + auto *accessor = table_accessor(task->table_id()); + // 设置请求回调 + size_t request_call_num = _server_channels.size(); + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = reinterpret_cast(done); + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + + auto &total_send_data_vec = *(task->data()); + float *total_send_data = + reinterpret_cast(total_send_data_vec.data()); + size_t total_send_data_size = total_send_data_vec.size(); + { + CostTimer merge_timer("pserver_client_push_dense_merge"); + uint32_t merge_count = 0; + std::vector> merge_status(merge_size); + while (!task_queue->Empty() && merge_count < merge_size) { + auto *async_task = new DenseAsyncTask(); + task_queue->Get(async_task); + closure->add_timer(async_task->timer()); + closure->add_promise(async_task->promise()); + merge_status[merge_count] = async_merge_dense_threads.enqueue( + [closure, accessor, &total_send_data, total_send_data_size, + async_task]() -> int { + auto &tmp_task_vec = *(async_task->data()); + const float *merge_data = tmp_task_vec.data(); + accessor->merge(&total_send_data, &merge_data, + total_send_data_size); +#pragma optimize("", off) + auto *debug_closure = closure; + auto *debug_task = async_task; + delete async_task; +#pragma optimize("", on) + return 0; + }); + ++merge_count; + } + for (int i = 0; i < merge_count; ++i) { + merge_status[i].wait(); + } + + VLOG(3) << "BrpcPsClient::push_dense_task_consume before merge " + "total_send_data[0]" + << total_send_data[0] << " total_send_data[-2]" + << total_send_data[total_send_data_size - 2] + << total_send_data[0] << " total_send_data[-1]" + << total_send_data[total_send_data_size - 1]; + if (scale_gradient && merge_count > 1) { + Eigen::Map mat(total_send_data, 1, + total_send_data_size); + mat *= (1.0 / (merge_count + 1)); + } + + VLOG(3) << "BrpcPsClient::push_dense_task_consume after merge " + "total_send_data[0]" + << total_send_data[0] << " total_send_data[-2]" + << total_send_data[total_send_data_size - 2] + << " total_send_data[-1]" + << total_send_data[total_send_data_size - 1] << " merge_count " + << merge_count; + } + std::shared_ptr task_ptr(task); + push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size, + closure); + } + auto wait_ms = + FLAGS_pserver_async_push_dense_interval_ms - (timeline.ElapsedMS()); + if (wait_ms > 0) { + usleep(wait_ms * 1000); + } + } +} + +void BrpcPsClient::push_dense_raw_gradient( + std::shared_ptr &task, float *total_send_data, + size_t total_send_data_size, DownpourBrpcClosure *closure) { + auto *accessor = table_accessor(task->table_id()); + size_t request_call_num = _server_channels.size(); + // 将数据拷贝到请求buffer区 + auto timer = std::make_shared("pserver_client_push_dense_rpc"); + closure->add_timer(timer); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); + closure->request(i)->set_table_id(task->table_id()); + closure->request(i)->set_client_id(_client_id); + auto *push_data = closure->request(i)->mutable_data(); + push_data->clear(); + push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float)); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t)); + memcpy(push_data_ptr + sizeof(uint32_t), + total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); + closure->cntl(i)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + PsService_Stub rpc_stub(get_dense_channel(i)); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h index 5192356e4b5..d5388a5cd07 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/service/brpc_ps_client.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -23,6 +24,7 @@ #include "brpc/server.h" #include "paddle/fluid/distributed/service/brpc_utils.h" #include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor_util.h" @@ -53,10 +55,9 @@ class DownpourPsClientService : public PsService { _rank = rank_id; return 0; } - virtual void service(::google::protobuf::RpcController *controller, - const PsRequestMessage *request, - PsResponseMessage *response, - ::google::protobuf::Closure *done) override; + void service(::google::protobuf::RpcController *controller, + const PsRequestMessage *request, PsResponseMessage *response, + ::google::protobuf::Closure *done) override; protected: size_t _rank; @@ -77,7 +78,7 @@ class DownpourBrpcClosure : public PSClientClosure { } } virtual ~DownpourBrpcClosure() {} - virtual void Run() override { + void Run() override { if (_waiting_num.fetch_sub(1) == 1) { _callback(this); delete this; @@ -97,47 +98,87 @@ class DownpourBrpcClosure : public PSClientClosure { std::vector> _cntls; }; +struct SharedSparsePushData { + SharedSparsePushData() {} + ~SharedSparsePushData() noexcept {} + size_t kv_num; + std::vector key_list; + std::vector value_list; +}; +struct SparsePushTaskData { + std::vector shared_data; // sparse数据按key hash分片 +}; + +// push sparse 对象池 +struct SparseTaskPool { + std::shared_ptr get() { + std::lock_guard lock(_mutex); + if (_pool.empty()) { + return std::make_shared(); + } else { + auto ret = _pool.back(); + _pool.pop_back(); + return ret; + } + } + void push(std::shared_ptr data) { + std::lock_guard lock(_mutex); + _pool.push_back(std::move(data)); + } + std::vector> _pool; + std::mutex _mutex; +}; + template struct array_deleter { - void operator()(T *&x) const { delete[] x; } + void operator()(T *&x) const { delete[] x; } // NOLINT }; class BrpcPsClient : public PSClient { public: BrpcPsClient() {} virtual ~BrpcPsClient() { - // _running = false; - // try { - // _async_push_dense_thread.join(); - // _async_push_sparse_thread.join(); - //} catch (...) { - //} + if (_running) { + flush(); + _running = false; + } + if (_async_push_dense_thread.joinable()) { + _async_push_dense_thread.join(); + } + if (_async_push_sparse_thread.joinable()) { + _async_push_sparse_thread.join(); + } + if (_server_started) { + _server.Stop(1000); + _server.Join(); + _server_started = false; + } } virtual int32_t create_client2client_connection( int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); - virtual std::future shrink(uint32_t table_id, - const std::string threshold) override; - virtual std::future load(const std::string &epoch, - const std::string &mode) override; - virtual std::future load(uint32_t table_id, const std::string &epoch, - const std::string &mode) override; + std::future shrink(uint32_t table_id, + const std::string threshold) override; + std::future load(const std::string &epoch, + const std::string &mode) override; + std::future load(uint32_t table_id, const std::string &epoch, + const std::string &mode) override; - virtual std::future save(const std::string &epoch, - const std::string &mode) override; + std::future save(const std::string &epoch, + const std::string &mode) override; - virtual std::future save(uint32_t table_id, const std::string &epoch, - const std::string &mode) override; + std::future save(uint32_t table_id, const std::string &epoch, + const std::string &mode) override; - virtual std::future clear() override; + std::future clear() override; - virtual std::future clear(uint32_t table_id) override; + std::future clear(uint32_t table_id) override; - virtual std::future stop_server() override; + std::future stop_server() override; - virtual std::future start_profiler() override; - virtual std::future stop_profiler() override; + std::future start_profiler() override; + std::future stop_profiler() override; - virtual void finalize_worker() override; + void finalize_worker() override; virtual std::future pull_dense(Region *regions, size_t region_num, size_t table_id); @@ -146,6 +187,9 @@ class BrpcPsClient : public PSClient { size_t region_num, size_t table_id); + virtual std::future push_dense(const Region *regions, + size_t region_num, size_t table_id); + void push_dense_task_consume(); virtual std::future pull_sparse(float **select_values, size_t table_id, const uint64_t *keys, size_t num, @@ -164,13 +208,16 @@ class BrpcPsClient : public PSClient { void *done); virtual std::future flush(); - virtual std::future send_client2client_msg( - int msg_type, int to_client_id, const std::string &msg) override; + std::future send_client2client_msg(int msg_type, int to_client_id, + const std::string &msg) override; // for local save sparse virtual int32_t recv_and_save_table(const uint64_t table_id, const std::string &path); + void print_queue_size(); + void print_queue_size_thread(); + protected: virtual size_t get_server_nums() { return _server_channels.size(); } inline brpc::Channel *get_sparse_channel(size_t server_id) { @@ -182,7 +229,7 @@ class BrpcPsClient : public PSClient { inline brpc::Channel *get_cmd_channel(size_t server_id) { return _server_channels[server_id][2].get(); } - virtual int32_t initialize() override; + int32_t initialize() override; private: // virtual int32_t initialize() override; @@ -200,38 +247,74 @@ class BrpcPsClient : public PSClient { bool _running = false; bool _flushing = false; - std::atomic _async_call_num; //异步请求计数 + std::atomic _async_call_num; // 异步请求计数 + + // 异步push dense task + std::thread _async_push_dense_thread; + typedef AsyncRequestTask>> DenseAsyncTask; + std::unordered_map> + _push_dense_task_queue_map; + // 异步push sparse task + std::thread _async_push_sparse_thread; + typedef AsyncRequestTask> SparseAsyncTask; + std::unordered_map> + _push_sparse_task_queue_map; + std::unordered_map _push_sparse_merge_count_map; + + std::thread _print_thread; + + int push_sparse_async_shard_merge( + std::vector> &task_list, // NOLINT + std::vector &request_kv_num, int table_id, int shard_idx, // NOLINT + ValueAccessor *accessor); + + int push_sparse_async_shard_push( + std::vector> &task_list, // NOLINT + std::vector &request_kv_num, int table_id, int shard_idx, // NOLINT + DownpourBrpcClosure *closure, ValueAccessor *accessor); + + SparseTaskPool _sparse_task_pool; std::vector> _client_channels; // client2client std::vector, 3>> _server_channels; // client2server - virtual std::future push_dense_raw_gradient( - int table_id, float *total_send_data, size_t total_send_data_size, - void *done) override; - - virtual std::future push_sparse_raw_gradient( - size_t table_id, const uint64_t *keys, const float **update_values, - size_t num, void *done) override; - - virtual std::future push_sparse_raw_gradient_partial( + std::future push_dense_raw_gradient(int table_id, + float *total_send_data, + size_t total_send_data_size, + void *done) override; + + std::future push_sparse_raw_gradient(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, + void *done) override; + + std::future push_sparse_raw_gradient_partial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) override; - virtual std::future push_sparse_param(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num, - void *done) override; + std::future push_sparse_param(size_t table_id, const uint64_t *keys, + const float **update_values, + size_t num, void *done) override; + std::future push_sparse(size_t table_id, const uint64_t *keys, + const float **update_values, + size_t num) override; + void push_sparse_task_consume(); private: int32_t start_client_service(); + void push_dense_raw_gradient(std::shared_ptr &task, // NOLINT + float *total_send_data, + size_t total_send_data_size, + DownpourBrpcClosure *closure); float _mae = 0; float _mse = 0; uint16_t _push_times = 0; brpc::Server _server; DownpourPsClientService _service; + bool _server_started = false; std::atomic_uint grad_num_{0}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc index a016c478846..00fae6e276e 100644 --- a/paddle/fluid/distributed/service/communicator.cc +++ b/paddle/fluid/distributed/service/communicator.cc @@ -628,6 +628,8 @@ void AsyncCommunicator::Start() { void AsyncCommunicator::Stop() { VLOG(1) << "Communicator stop"; + _worker_ptr->finalize_worker(); + VLOG(0) << "Communicator finalize_worker done"; running_ = false; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index 3be83436cec..a408a0cc24f 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -114,9 +114,9 @@ class PSClient { size_t region_num, size_t table_id) = 0; - // virtual std::future push_dense(const Region *regions, - // size_t region_num, - // size_t table_id) = 0; + virtual std::future push_dense(const Region *regions, + size_t region_num, + size_t table_id) = 0; // 使用keys进行pull请求,结果填充values // keys和values的个数均为num个,每个value占用select_size空间 // future结束前keys和values缓冲区不能再次使用 @@ -222,10 +222,10 @@ class PSClient { const uint64_t *keys, const float **update_values, size_t num, void *done) = 0; - // virtual std::future push_sparse(size_t table_id, - // const uint64_t *keys, - // const float **update_values, - // size_t num) = 0; + virtual std::future push_sparse(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num) = 0; protected: virtual int32_t initialize() = 0; diff --git a/paddle/fluid/distributed/table/ctr_accessor.cc b/paddle/fluid/distributed/table/ctr_accessor.cc index 1ef8c9e1527..68bd6eb9f27 100644 --- a/paddle/fluid/distributed/table/ctr_accessor.cc +++ b/paddle/fluid/distributed/table/ctr_accessor.cc @@ -270,8 +270,8 @@ bool CtrCommonAccessor::create_value(int stage, const float* value) { return true; } else if (stage == 1) { // operation - auto show = CtrCommonPushValue::show_const(value); - auto click = CtrCommonPushValue::click_const(value); + auto show = CtrCommonPushValue::show(const_cast(value)); + auto click = CtrCommonPushValue::click(const_cast(value)); auto score = show_click_score(show, click); if (score <= 0) { return false; @@ -302,8 +302,8 @@ std::string CtrCommonAccessor::parse_to_string(const float* v, int param) { i < common_feature_value.embedx_w_index(); i++) { os << " " << v[i]; } - auto show = common_feature_value.show_const(v); - auto click = common_feature_value.click_const(v); + auto show = common_feature_value.show(const_cast(v)); + auto click = common_feature_value.click(const_cast(v)); auto score = show_click_score(show, click); if (score >= _config.embedx_threshold()) { for (auto i = common_feature_value.embedx_w_index(); diff --git a/paddle/fluid/distributed/table/ctr_accessor.h b/paddle/fluid/distributed/table/ctr_accessor.h index 3c2ac7189f7..8be672e8e0d 100644 --- a/paddle/fluid/distributed/table/ctr_accessor.h +++ b/paddle/fluid/distributed/table/ctr_accessor.h @@ -61,14 +61,7 @@ class CtrCommonAccessor : public ValueAccessor { float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; } float& embedx_w(float* val) { return val[embedx_w_index()]; } float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; } - float show_const(const float* val) { - float s = val[show_index()]; - return s; - } - float click_const(const float* val) { - float c = val[click_index()]; - return c; - } + int embed_sgd_dim; int embedx_dim; int embedx_sgd_dim; @@ -103,14 +96,6 @@ class CtrCommonAccessor : public ValueAccessor { static float& click(float* val) { return val[CtrCommonPushValue::click_index()]; } - static float show_const(const float* val) { - float s = val[show_index()]; - return s; - } - static float click_const(const float* val) { - float c = val[click_index()]; - return c; - } static float& embed_g(float* val) { return val[CtrCommonPushValue::embed_g_index()]; } -- GitLab