未验证 提交 bd79ae8a 编写于 作者: Z zhaocaibei123 提交者: GitHub

brpc_ps_client upgrade (#36943)

* test

* rm test

* add memory_sparse_table and brpc communication upgrade dependency

* fix

* add dense optimizer & fix dump bug & add some strategy fields

* fix

* fix

* remove thread_pool thread_queue

* add memory sparse table

* update memory sparse table

* update memory sparse table

* update cmake

* upgrade brpc_ps_client

* remove show/click_const in ctr_accessor

* fix deconstructor
上级 abd4ab9c
......@@ -19,7 +19,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/framework/archive.h"
const static int max_port = 65535;
static const int max_port = 65535;
DEFINE_int32(pserver_push_dense_merge_limit, 12,
"limit max push_dense local merge requests");
......@@ -52,6 +52,9 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000,
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
DEFINE_int32(pserver_sparse_table_shard_num, 1000,
"sparse table shard for save & load");
namespace paddle {
namespace framework {
class Scope;
......@@ -102,6 +105,7 @@ int32_t BrpcPsClient::start_client_service() {
LOG(ERROR) << "BrpcPsServer start failed";
return -1;
}
_server_started = true;
_env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port,
_client_id);
return 0;
......@@ -117,6 +121,12 @@ int32_t BrpcPsClient::create_client2client_connection(
options.max_retry = max_retry;
std::vector<PSHost> client_list = _env->get_ps_clients();
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: "
<< client_list.size();
for (auto cc : client_list) {
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: "
<< cc.to_string();
}
_client_channels.resize(client_list.size());
std::ostringstream os;
std::string server_ip_port;
......@@ -184,8 +194,34 @@ int32_t BrpcPsClient::initialize() {
// 启动client探听接口, 并相互建立连接
start_client_service();
// 异步push 请求队列初始化
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
auto type = worker_param.downpour_table_param(i).type();
auto table_id = worker_param.downpour_table_param(i).table_id();
if (type == PS_DENSE_TABLE) {
_push_dense_task_queue_map[table_id] =
paddle::framework::MakeChannel<DenseAsyncTask *>();
}
if (type == PS_SPARSE_TABLE) {
_push_sparse_task_queue_map[table_id] =
paddle::framework::MakeChannel<SparseAsyncTask *>();
_push_sparse_merge_count_map[table_id] = 0;
}
}
_running = true;
_flushing = false;
// 启动异步push线程
_async_push_sparse_thread =
std::thread(std::bind(&BrpcPsClient::push_sparse_task_consume, this));
// _async_push_sparse_thread.detach();
_async_push_dense_thread =
std::thread(std::bind(&BrpcPsClient::push_dense_task_consume, this));
// for debug
// _print_thread =
// std::thread(std::bind(&BrpcPsClient::print_queue_size_thread, this));
return 0;
}
......@@ -238,7 +274,7 @@ std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) {
uint64_t feasign_size = 0;
uint64_t mf_size = 0;
paddle::framework::BinaryArchive ar;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) {
ret = -1;
......@@ -277,7 +313,7 @@ std::future<int32_t> BrpcPsClient::send_cmd(
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
......@@ -298,7 +334,7 @@ std::future<int32_t> BrpcPsClient::send_cmd(
}
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
10800000 * 2); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
......@@ -312,7 +348,7 @@ std::future<int32_t> BrpcPsClient::send_save_cmd(
request_call_num, [request_call_num, cmd_id](void *done) {
int ret = 0;
uint32_t feasign_size = 0;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_save_response(i, cmd_id) < 0) {
ret = -1;
......@@ -362,11 +398,14 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id "
<< table_id;
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
......@@ -378,6 +417,7 @@ std::future<int32_t> BrpcPsClient::clear(uint32_t table_id) {
}
std::future<int32_t> BrpcPsClient::flush() {
VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true;
std::promise<int> promise;
std::future<int32_t> fut = promise.get_future();
......@@ -385,16 +425,49 @@ std::future<int32_t> BrpcPsClient::flush() {
VLOG(3) << "wait _async_call_num:" << _async_call_num;
usleep(100000); // sleep 100ms wait async end
} while (_async_call_num > 0);
VLOG(1) << "flush _async_call_num = 0";
promise.set_value(0);
_flushing = false;
VLOG(0) << "BrpcPsClient::flush done";
print_queue_size();
return fut;
}
void BrpcPsClient::print_queue_size() {
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first;
auto queue_size = push_sparse_task_itr.second->Size();
VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id
<< " size: " << queue_size;
}
for (auto &task_queue_itr : _push_dense_task_queue_map) {
auto table_id = task_queue_itr.first;
auto queue_size = task_queue_itr.second->Size();
VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id
<< " size: " << queue_size;
}
}
void BrpcPsClient::print_queue_size_thread() {
while (_running) {
usleep(1000000 * 60 * 2);
print_queue_size();
}
}
void BrpcPsClient::finalize_worker() {
flush();
VLOG(0) << "BrpcPsClient::finalize_worker begin join thread";
_running = false;
_async_push_dense_thread.join();
_async_push_sparse_thread.join();
// _print_thread.join();
VLOG(0) << "BrpcPsClient::finalize_worker begin join server";
_server.Stop(1000);
_server.Join();
_server_started = false;
VLOG(0) << "BrpcPsClient::finalize_worker done";
}
std::future<int32_t> BrpcPsClient::stop_server() {
......@@ -422,19 +495,20 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(1, [keys, values, accessor](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
uint32_t shard_nums;
if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) {
ret = -1;
}
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward((void *)(&shard_nums), sizeof(uint32_t));
io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums),
sizeof(uint32_t));
keys->resize(shard_nums);
values->resize(shard_nums * accessor->update_dim());
io_buffer_itr.copy_and_forward((void *)(keys->data()),
io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT
sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward((void *)(values->data()),
io_buffer_itr.copy_and_forward((void *)(values->data()), // NOLINT
shard_nums * accessor->update_size());
closure->set_promise_value(ret);
});
......@@ -466,8 +540,19 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
std::vector<std::vector<const float *>> value_ptrs;
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = keys[i] % request_call_num;
size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]);
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
......@@ -481,7 +566,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t));
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size()));
char *push_data_ptr = const_cast<char *>(push_data->data());
......@@ -514,7 +599,7 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
int ret = 0;
size_t region_idx = 0; // 当前填充的region偏移
size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
size_t shard_data_size = num_per_shard * accessor->select_size();
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
......@@ -537,7 +622,8 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
if (region.size - region_data_idx >= shard_buffer_remain) {
// region待填充空间 >= 分片buffer数据, 直接拷贝置入
io_buffer_itr.copy_and_forward(
(void *)(region.data + region_data_idx), shard_buffer_remain);
reinterpret_cast<void *>(region.data + region_data_idx),
shard_buffer_remain);
region_data_idx += shard_buffer_remain;
shard_buffer_remain = 0;
} else if (region.size - region_data_idx == 0) {
......@@ -547,7 +633,7 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
} else {
// region不足以容纳所有数据,则能放多少 拷贝多少
io_buffer_itr.copy_and_forward(
(void *)(region.data + region_data_idx),
reinterpret_cast<void *>(region.data + region_data_idx),
region.size - region_data_idx);
shard_buffer_remain -= (region.size - region_data_idx);
++region_idx;
......@@ -564,7 +650,7 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&num_per_shard,
closure->request(i)->add_params((char *)&num_per_shard, // NOLINT
sizeof(num_per_shard));
PsService_Stub rpc_stub(get_dense_channel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
......@@ -608,7 +694,7 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) {
ret = -1;
......@@ -621,26 +707,28 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10;
static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; //用于数据补齐
//开始多shard并行拷贝&请求
static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; // 用于数据补齐
// 开始多shard并行拷贝&请求
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append((void *)&num_per_shard, sizeof(uint32_t));
request_buffer.append(reinterpret_cast<void *>(&num_per_shard),
sizeof(uint32_t));
auto &region_list = regions_partition[i];
size_t fill_remain_size = shard_data_size;
for (auto &region : region_list) {
fill_remain_size -= region.size;
request_buffer.append((void *)region.data, region.size);
request_buffer.append(reinterpret_cast<void *>(region.data), region.size);
}
//保证各分片数据对齐
// 保证各分片数据对齐
while (fill_remain_size > 0) {
size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE
? REGION_ASSIGN_BUFFER_SIZE
: fill_remain_size;
request_buffer.append((void *)region_assign_buffer, fill_num);
request_buffer.append(reinterpret_cast<void *>(region_assign_buffer),
fill_num);
fill_remain_size -= fill_num;
}
PsService_Stub rpc_stub(get_dense_channel(i));
......@@ -654,7 +742,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) {
auto *accessor = table_accessor(table_id);
//发送RPC请求
// 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
......@@ -666,8 +754,18 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = keys[i] % request_call_num;
size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]);
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
......@@ -684,7 +782,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t));
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size()));
char *push_data_ptr = const_cast<char *>(push_data->data());
......@@ -726,14 +824,11 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard, num_per_shard * sizeof(float));
VLOG(1) << "push_dense_raw_gradient finish memcpy";
// closure->cntl(i)->set_request_compress_type(
// (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i));
VLOG(1) << "push_dense_raw_gradient get_dense_channel " << i;
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
VLOG(1) << "push_dense_raw_gradient async service " << i;
}
return fut;
}
......@@ -776,8 +871,18 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t shard_id = keys[i] % request_call_num;
size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]);
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
......@@ -787,7 +892,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
......@@ -803,14 +908,14 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy((void *)kv_pair->second, (void *)last_value_data,
value_size);
memcpy(reinterpret_cast<void *>(kv_pair->second),
reinterpret_cast<void *>(last_value_data), value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward((void *)(last_value_data),
value_size)) {
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(last_value_data), value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
......@@ -838,7 +943,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append((void *)&is_training, sizeof(bool));
request_buffer.append(reinterpret_cast<void *>(&is_training), sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);
......@@ -846,7 +951,8 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append((void *)&last_key, sizeof(uint64_t));
request_buffer.append(reinterpret_cast<void *>(&last_key),
sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
......@@ -855,7 +961,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
keys_counter.push_back(keys);
}
request_buffer.append((void *)keys_counter.data(),
request_buffer.append(reinterpret_cast<void *>(keys_counter.data()),
sizeof(uint32_t) * keys_counter.size());
if (kv_request_count == 0) {
......@@ -864,7 +970,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count,
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
......@@ -886,7 +992,7 @@ std::future<int32_t> BrpcPsClient::send_client2client_msg(
return fut;
}
auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
......@@ -915,7 +1021,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&num, sizeof(uint32_t));
push_request->add_params((char *)&num, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(num * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
......@@ -966,8 +1072,8 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}
auto status = pull_sparse((float **)save_vec.data(), table_id,
save_key.data(), save_key.size(), true);
auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true);
status.wait();
// create lod tensor
......@@ -1000,5 +1106,521 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
return 0;
}
std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) {
auto push_timer =
std::make_shared<CostTimer>("pserver_client_push_sparse_parse");
int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "push_sparse Waiting for async_call_num comsume, task_num:"
// << push_sparse_async_num << ", max_task_limit:" <<
// FLAGS_pserver_max_async_call_num;
usleep(5000); // 5ms
// push_sparse_async_num = _push_sparse_task_queue_map[table_id]->size();
push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
}
thread_local std::vector<std::vector<std::pair<uint64_t, const float *>>>
shard_sorted_kv_list;
auto *accessor = table_accessor(table_id);
size_t request_call_num = _server_channels.size();
shard_sorted_kv_list.resize(request_call_num);
for (auto &x : shard_sorted_kv_list) {
x.clear();
}
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]);
shard_sorted_kv_list[shard_id].push_back({keys[i], update_values[i]});
}
auto sparse_task_data = _sparse_task_pool.get();
sparse_task_data->shared_data.resize(request_call_num);
auto async_task = new SparseAsyncTask(sparse_task_data, table_id, push_timer);
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kv_list = shard_sorted_kv_list[i];
size_t sorted_kv_size = sorted_kv_list.size();
auto &shard_kv_data = async_task->data()->shared_data[i];
shard_kv_data.key_list.resize(sorted_kv_size);
shard_kv_data.value_list.resize(sorted_kv_size);
if (sorted_kv_size == 0) {
shard_kv_data.kv_num = 0;
continue;
}
uint32_t value_size = accessor->update_size();
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first;
shard_kv_data.value_list[kv_idx].assign(
(const char *)sorted_kv_list[kv_idx].second, value_size);
}
shard_kv_data.kv_num = sorted_kv_size;
}
std::future<int> fut = async_task->get_future();
_push_sparse_task_queue_map[table_id]->Put(std::move(async_task));
return fut;
}
void BrpcPsClient::push_sparse_task_consume() {
uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit;
std::vector<std::shared_ptr<SparseAsyncTask>> task_list;
size_t request_call_num = _server_channels.size();
::ThreadPool async_push_sparse_shard_threads(
FLAGS_pserver_sparse_merge_thread);
while (_running) {
platform::Timer timeline;
timeline.Start();
// 所有sparseTable的pushTask 进行处理
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first;
auto *accessor = table_accessor(table_id);
auto &task_queue = push_sparse_task_itr.second;
auto queue_size = task_queue->Size();
if (queue_size == 0) {
continue;
}
if (merge_size > 0 && (queue_size <= 1 && _flushing == false)) {
continue;
}
++_async_call_num;
int merge_count = 0;
for (size_t i = 0; i < task_list.size(); ++i) {
if (task_list[i]->data()) {
_sparse_task_pool.push(task_list[i]->data());
}
}
auto sparse_task_data = _sparse_task_pool.get();
task_list.clear();
int cur_meger_size = task_queue->Size();
// task_list[0] 为一个空SparseAsyncTask, 分shard异步merge结果存入此结构。
sparse_task_data->shared_data.resize(request_call_num);
auto push_timer =
std::make_shared<CostTimer>("pserver_client_push_sparse");
auto async_task =
new SparseAsyncTask(sparse_task_data, table_id, push_timer);
task_list.reserve(cur_meger_size + 1);
task_list.push_back(
std::move(std::shared_ptr<SparseAsyncTask>(async_task)));
while (!task_queue->Empty() && merge_count < cur_meger_size) {
++merge_count;
SparseAsyncTask *task;
task_queue->Get(task);
task_list.push_back(std::shared_ptr<SparseAsyncTask>(task));
}
_push_sparse_merge_count_map[table_id] += merge_count;
// 达到或大于 merge_size发送, 发送过程中
std::vector<int> request_kv_num(request_call_num, 0);
if (_push_sparse_merge_count_map[table_id] >= merge_size ||
_flushing == true) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
for_each(task_list.begin() + 1, task_list.end(),
[&request_kv_num, request_call_num,
closure](std::shared_ptr<SparseAsyncTask> &task) {
// closure->add_timer(task->timer());
closure->add_promise(task->promise());
});
// CostTimer merge_timer("pserver_client_push_sparse_merge");
// auto rpc_timer =
// std::make_shared<CostTimer>("pserver_client_push_sparse_rpc");
// closure->add_timer(rpc_timer);
std::vector<std::future<int>> merge_status(request_call_num);
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] =
async_push_sparse_shard_threads.enqueue(std::bind(
&BrpcPsClient::push_sparse_async_shard_push, this, task_list,
request_kv_num, table_id, shard_idx, closure, accessor));
}
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx].wait();
}
merge_status.clear();
std::vector<std::future<int>>().swap(merge_status);
_push_sparse_merge_count_map[table_id] = 0;
auto queue_size = task_queue->Size();
} else { // 未达到阈值 只做多路归并
std::vector<std::future<int>> merge_status(request_call_num);
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] =
async_push_sparse_shard_threads.enqueue(std::bind(
&BrpcPsClient::push_sparse_async_shard_merge, this, task_list,
request_kv_num, table_id, shard_idx, accessor));
}
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx].wait();
}
// meger到task_list[0]
auto async_task = new SparseAsyncTask(*(task_list[0].get()));
task_queue->Put(std::move(async_task));
--_async_call_num;
merge_status.clear();
std::vector<std::future<int>>().swap(merge_status);
}
}
auto wait_ms =
FLAGS_pserver_async_push_sparse_interval_ms - (timeline.ElapsedMS());
if (wait_ms > 0) {
usleep(wait_ms * 1000);
}
}
}
void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
const float *another_data) {
size_t col_num = accessor->update_size() / sizeof(float);
float *merge_data_shell[col_num];
const float *another_data_shell[col_num];
for (int i = 0; i < col_num; ++i) {
merge_data_shell[i] = merge_data + i;
another_data_shell[i] = another_data + i;
}
accessor->merge(merge_data_shell, another_data_shell, 1);
}
int BrpcPsClient::push_sparse_async_shard_merge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num, int table_id, int shard_idx,
ValueAccessor *accessor) {
size_t merged_kv_count = 0;
uint64_t min_key = UINT64_MAX;
uint32_t value_size = accessor->update_size();
thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear();
for (int i = 1; i < task_list.size(); ++i) {
size_t kv_num = task_list[i]->data()->shared_data[shard_idx].kv_num;
auto &key_list = task_list[i]->data()->shared_data[shard_idx].key_list;
auto &value_list = task_list[i]->data()->shared_data[shard_idx].value_list;
for (int j = 0; j < kv_num; ++j) {
if (value_list[j].size() < value_size) {
LOG(WARNING) << "value_list[" << j << "]: " << value_list[j].c_str()
<< "is invalid.";
continue;
}
char *task_data_ptr = const_cast<char *>(value_list[j].data());
sorted_kv_list.push_back(
{key_list[j], reinterpret_cast<float *>(task_data_ptr)});
}
}
// 按key排序&去重
std::sort(sorted_kv_list.begin(), sorted_kv_list.end(),
[](const std::pair<uint64_t, const float *> &k1,
const std::pair<uint64_t, const float *> &k2) {
return k1.first < k2.first;
});
auto &async_task = task_list[0];
size_t sorted_kv_size = sorted_kv_list.size();
auto &shard_kv_data = async_task->data()->shared_data[shard_idx];
shard_kv_data.key_list.resize(sorted_kv_size);
shard_kv_data.value_list.resize(sorted_kv_size);
// 将去重后数据写入分shard包
if (sorted_kv_size == 0) {
shard_kv_data.kv_num = 0;
return 0;
} else if (sorted_kv_size == 1) {
shard_kv_data.kv_num = 1;
shard_kv_data.key_list[0] = sorted_kv_list[0].first;
shard_kv_data.value_list[0].assign((const char *)(sorted_kv_list[0].second),
value_size);
return 0;
}
// 去重 本地merge
uint64_t last_key = sorted_kv_list[0].first;
const float *last_value_data = sorted_kv_list[0].second;
float *last_merge_data = NULL;
std::shared_ptr<char> merger_buffer(new char[value_size],
array_deleter<char>());
for (size_t kv_idx = 1; kv_idx < sorted_kv_size; ++kv_idx) {
while (kv_idx < sorted_kv_size &&
last_key == sorted_kv_list[kv_idx].first) {
if (last_merge_data == NULL) {
last_merge_data = reinterpret_cast<float *>(merger_buffer.get());
memcpy(last_merge_data, last_value_data, value_size);
}
sparse_local_merge(accessor, last_merge_data,
sorted_kv_list[kv_idx].second);
++kv_idx;
}
if (last_merge_data != NULL) {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)last_merge_data, value_size);
last_merge_data = NULL;
} else {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)sorted_kv_list[kv_idx - 1].second, value_size);
}
shard_kv_data.key_list[merged_kv_count++] = last_key;
if (kv_idx < sorted_kv_size) {
last_key = sorted_kv_list[kv_idx].first;
last_value_data = sorted_kv_list[kv_idx].second;
}
if (kv_idx == sorted_kv_size - 1) {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)last_value_data, value_size);
shard_kv_data.key_list[merged_kv_count++] = last_key;
}
}
shard_kv_data.kv_num = merged_kv_count;
return 0;
}
int BrpcPsClient::push_sparse_async_shard_push(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num, int table_id, int shard_idx,
DownpourBrpcClosure *closure, ValueAccessor *accessor) {
push_sparse_async_shard_merge(task_list, request_kv_num, table_id, shard_idx,
accessor);
size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num;
auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list;
auto &merged_value_list =
task_list[0]->data()->shared_data[shard_idx].value_list;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params(reinterpret_cast<char *>(&merged_kv_count),
sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(merged_kv_count *
(sizeof(uint64_t) + accessor->update_size()));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, merged_key_list.data(),
merged_kv_count * sizeof(uint64_t));
push_data_ptr += merged_kv_count * sizeof(uint64_t);
for (int i = 0; i < merged_kv_count; ++i) {
const char *task_data_ptr = merged_value_list[i].data();
memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT
accessor->update_size());
push_data_ptr += accessor->update_size();
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
closure->response(shard_idx), closure);
_push_sparse_merge_count_map[table_id] = 0;
return 0;
}
std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = table_accessor(table_id);
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
auto parse_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_parse");
int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
while (push_dense_async_num > FLAGS_pserver_max_async_call_num) {
LOG(INFO) << "push_dense Waiting for async_call_num comsume, task_num:"
<< push_dense_async_num
<< ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep(5000); // 5ms
push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
}
// auto dense_data = _dense_matrix_obj_pool.get();
auto dense_data = std::make_shared<std::vector<float>>();
auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer);
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
// 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num *
accessor->update_dim());
float *data = async_task->data()->data();
size_t data_size = async_task->data()->size();
uint32_t pos = 0;
for (size_t i = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
CHECK(pos + data_num <= data_size)
<< "invalid dense size, cur pos[" << pos << "]"
<< " data_num[" << data_num << "] size[" << data_size << "]";
const float *region_data = (const float *)(regions[i].data);
memcpy(data + pos, region_data, regions[i].size);
pos += data_num;
}
std::future<int> fut = async_task->get_future();
_push_dense_task_queue_map[table_id]->Put(std::move(async_task));
return fut;
}
void BrpcPsClient::push_dense_task_consume() {
uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit;
static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge;
::ThreadPool async_merge_dense_threads(10);
while (_running) {
platform::Timer timeline;
timeline.Start();
for (auto &task_queue_itr : _push_dense_task_queue_map) {
auto &task_queue = task_queue_itr.second;
auto queue_size = task_queue->Size();
if (queue_size == 0) {
continue;
}
if (queue_size <= merge_size && _flushing == false) {
continue;
}
++_async_call_num;
DenseAsyncTask *task;
task_queue->Get(task);
auto *accessor = table_accessor(task->table_id());
// 设置请求回调
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto &total_send_data_vec = *(task->data());
float *total_send_data =
reinterpret_cast<float *>(total_send_data_vec.data());
size_t total_send_data_size = total_send_data_vec.size();
{
CostTimer merge_timer("pserver_client_push_dense_merge");
uint32_t merge_count = 0;
std::vector<std::future<int>> merge_status(merge_size);
while (!task_queue->Empty() && merge_count < merge_size) {
auto *async_task = new DenseAsyncTask();
task_queue->Get(async_task);
closure->add_timer(async_task->timer());
closure->add_promise(async_task->promise());
merge_status[merge_count] = async_merge_dense_threads.enqueue(
[closure, accessor, &total_send_data, total_send_data_size,
async_task]() -> int {
auto &tmp_task_vec = *(async_task->data());
const float *merge_data = tmp_task_vec.data();
accessor->merge(&total_send_data, &merge_data,
total_send_data_size);
#pragma optimize("", off)
auto *debug_closure = closure;
auto *debug_task = async_task;
delete async_task;
#pragma optimize("", on)
return 0;
});
++merge_count;
}
for (int i = 0; i < merge_count; ++i) {
merge_status[i].wait();
}
VLOG(3) << "BrpcPsClient::push_dense_task_consume before merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
<< total_send_data[0] << " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1];
if (scale_gradient && merge_count > 1) {
Eigen::Map<Eigen::MatrixXf> mat(total_send_data, 1,
total_send_data_size);
mat *= (1.0 / (merge_count + 1));
}
VLOG(3) << "BrpcPsClient::push_dense_task_consume after merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
<< " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1] << " merge_count "
<< merge_count;
}
std::shared_ptr<DenseAsyncTask> task_ptr(task);
push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size,
closure);
}
auto wait_ms =
FLAGS_pserver_async_push_dense_interval_ms - (timeline.ElapsedMS());
if (wait_ms > 0) {
usleep(wait_ms * 1000);
}
}
}
void BrpcPsClient::push_dense_raw_gradient(
std::shared_ptr<DenseAsyncTask> &task, float *total_send_data,
size_t total_send_data_size, DownpourBrpcClosure *closure) {
auto *accessor = table_accessor(task->table_id());
size_t request_call_num = _server_channels.size();
// 将数据拷贝到请求buffer区
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(task->table_id());
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard, num_per_shard * sizeof(float));
closure->cntl(i)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
}
} // namespace distributed
} // namespace paddle
......@@ -14,6 +14,7 @@
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
......@@ -23,6 +24,7 @@
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
......@@ -53,9 +55,8 @@ class DownpourPsClientService : public PsService {
_rank = rank_id;
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
......@@ -77,7 +78,7 @@ class DownpourBrpcClosure : public PSClientClosure {
}
}
virtual ~DownpourBrpcClosure() {}
virtual void Run() override {
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
......@@ -97,47 +98,87 @@ class DownpourBrpcClosure : public PSClientClosure {
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
struct SharedSparsePushData {
SharedSparsePushData() {}
~SharedSparsePushData() noexcept {}
size_t kv_num;
std::vector<uint64_t> key_list;
std::vector<std::string> value_list;
};
struct SparsePushTaskData {
std::vector<SharedSparsePushData> shared_data; // sparse数据按key hash分片
};
// push sparse 对象池
struct SparseTaskPool {
std::shared_ptr<SparsePushTaskData> get() {
std::lock_guard<std::mutex> lock(_mutex);
if (_pool.empty()) {
return std::make_shared<SparsePushTaskData>();
} else {
auto ret = _pool.back();
_pool.pop_back();
return ret;
}
}
void push(std::shared_ptr<SparsePushTaskData> data) {
std::lock_guard<std::mutex> lock(_mutex);
_pool.push_back(std::move(data));
}
std::vector<std::shared_ptr<SparsePushTaskData>> _pool;
std::mutex _mutex;
};
template <class T>
struct array_deleter {
void operator()(T *&x) const { delete[] x; }
void operator()(T *&x) const { delete[] x; } // NOLINT
};
class BrpcPsClient : public PSClient {
public:
BrpcPsClient() {}
virtual ~BrpcPsClient() {
// _running = false;
// try {
// _async_push_dense_thread.join();
// _async_push_sparse_thread.join();
//} catch (...) {
//}
if (_running) {
flush();
_running = false;
}
if (_async_push_dense_thread.joinable()) {
_async_push_dense_thread.join();
}
if (_async_push_sparse_thread.joinable()) {
_async_push_sparse_thread.join();
}
if (_server_started) {
_server.Stop(1000);
_server.Join();
_server_started = false;
}
}
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
virtual std::future<int32_t> shrink(uint32_t table_id,
std::future<int32_t> shrink(uint32_t table_id,
const std::string threshold) override;
virtual std::future<int32_t> load(const std::string &epoch,
std::future<int32_t> load(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(const std::string &epoch,
std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> clear() override;
std::future<int32_t> clear() override;
virtual std::future<int32_t> clear(uint32_t table_id) override;
std::future<int32_t> clear(uint32_t table_id) override;
virtual std::future<int32_t> stop_server() override;
std::future<int32_t> stop_server() override;
virtual std::future<int32_t> start_profiler() override;
virtual std::future<int32_t> stop_profiler() override;
std::future<int32_t> start_profiler() override;
std::future<int32_t> stop_profiler() override;
virtual void finalize_worker() override;
void finalize_worker() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id);
......@@ -146,6 +187,9 @@ class BrpcPsClient : public PSClient {
size_t region_num,
size_t table_id);
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num, size_t table_id);
void push_dense_task_consume();
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
......@@ -164,13 +208,16 @@ class BrpcPsClient : public PSClient {
void *done);
virtual std::future<int32_t> flush();
virtual std::future<int32_t> send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) override;
std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
const std::string &msg) override;
// for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
void print_queue_size();
void print_queue_size_thread();
protected:
virtual size_t get_server_nums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) {
......@@ -182,7 +229,7 @@ class BrpcPsClient : public PSClient {
inline brpc::Channel *get_cmd_channel(size_t server_id) {
return _server_channels[server_id][2].get();
}
virtual int32_t initialize() override;
int32_t initialize() override;
private:
// virtual int32_t initialize() override;
......@@ -200,38 +247,74 @@ class BrpcPsClient : public PSClient {
bool _running = false;
bool _flushing = false;
std::atomic<uint32_t> _async_call_num; //异步请求计数
std::atomic<uint32_t> _async_call_num; // 异步请求计数
// 异步push dense task
std::thread _async_push_dense_thread;
typedef AsyncRequestTask<std::shared_ptr<std::vector<float>>> DenseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<DenseAsyncTask *>>
_push_dense_task_queue_map;
// 异步push sparse task
std::thread _async_push_sparse_thread;
typedef AsyncRequestTask<std::shared_ptr<SparsePushTaskData>> SparseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<SparseAsyncTask *>>
_push_sparse_task_queue_map;
std::unordered_map<uint32_t, uint32_t> _push_sparse_merge_count_map;
std::thread _print_thread;
int push_sparse_async_shard_merge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
ValueAccessor *accessor);
int push_sparse_async_shard_push(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor);
SparseTaskPool _sparse_task_pool;
std::vector<std::shared_ptr<brpc::Channel>>
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
std::future<int32_t> push_dense_raw_gradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) override;
std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t *keys,
std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
size_t num, void *done) override;
std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) override;
void push_sparse_task_consume();
private:
int32_t start_client_service();
void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
DownpourPsClientService _service;
bool _server_started = false;
std::atomic_uint grad_num_{0};
};
} // namespace distributed
......
......@@ -628,6 +628,8 @@ void AsyncCommunicator::Start() {
void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop";
_worker_ptr->finalize_worker();
VLOG(0) << "Communicator finalize_worker done";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
......
......@@ -114,9 +114,9 @@ class PSClient {
size_t region_num,
size_t table_id) = 0;
// virtual std::future<int32_t> push_dense(const Region *regions,
// size_t region_num,
// size_t table_id) = 0;
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
......@@ -222,10 +222,10 @@ class PSClient {
const uint64_t *keys,
const float **update_values,
size_t num, void *done) = 0;
// virtual std::future<int32_t> push_sparse(size_t table_id,
// const uint64_t *keys,
// const float **update_values,
// size_t num) = 0;
virtual std::future<int32_t> push_sparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) = 0;
protected:
virtual int32_t initialize() = 0;
......
......@@ -270,8 +270,8 @@ bool CtrCommonAccessor::create_value(int stage, const float* value) {
return true;
} else if (stage == 1) {
// operation
auto show = CtrCommonPushValue::show_const(value);
auto click = CtrCommonPushValue::click_const(value);
auto show = CtrCommonPushValue::show(const_cast<float*>(value));
auto click = CtrCommonPushValue::click(const_cast<float*>(value));
auto score = show_click_score(show, click);
if (score <= 0) {
return false;
......@@ -302,8 +302,8 @@ std::string CtrCommonAccessor::parse_to_string(const float* v, int param) {
i < common_feature_value.embedx_w_index(); i++) {
os << " " << v[i];
}
auto show = common_feature_value.show_const(v);
auto click = common_feature_value.click_const(v);
auto show = common_feature_value.show(const_cast<float*>(v));
auto click = common_feature_value.click(const_cast<float*>(v));
auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold()) {
for (auto i = common_feature_value.embedx_w_index();
......
......@@ -61,14 +61,7 @@ class CtrCommonAccessor : public ValueAccessor {
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
float& embedx_w(float* val) { return val[embedx_w_index()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
float show_const(const float* val) {
float s = val[show_index()];
return s;
}
float click_const(const float* val) {
float c = val[click_index()];
return c;
}
int embed_sgd_dim;
int embedx_dim;
int embedx_sgd_dim;
......@@ -103,14 +96,6 @@ class CtrCommonAccessor : public ValueAccessor {
static float& click(float* val) {
return val[CtrCommonPushValue::click_index()];
}
static float show_const(const float* val) {
float s = val[show_index()];
return s;
}
static float click_const(const float* val) {
float c = val[click_index()];
return c;
}
static float& embed_g(float* val) {
return val[CtrCommonPushValue::embed_g_index()];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册