未验证 提交 2089b485 编写于 作者: Y yaoxuefeng 提交者: GitHub

change to new api in ssync mode (#41022)

* change to new api in ssync mode

* fix

* fix

* fix

* fix
上级 60c4c9cd
...@@ -532,18 +532,17 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) { ...@@ -532,18 +532,17 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense if (pull_context.value_type == Dense) { // pull dense
Region *dense_region = Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values); reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table); return pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse } else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table; size_t table_id = pull_context.table;
size_t num = pull_context.num; size_t num = pull_context.num;
bool is_training = pull_context.is_training; bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training); return pull_sparse_param(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async } else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training); return pull_sparse(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
} }
} }
} }
...@@ -551,7 +550,7 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) { ...@@ -551,7 +550,7 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) { std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values; const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table); return push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse } else { // push sparse
size_t table_id = push_context.table; size_t table_id = push_context.table;
size_t num = push_context.num; size_t num = push_context.num;
...@@ -561,7 +560,7 @@ std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) { ...@@ -561,7 +560,7 @@ std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
} else if (push_context.training_mode == Async) { // for async } else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys; const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values; const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num); return push_sparse(table_id, keys, update_values, num);
} }
} }
} }
...@@ -584,11 +583,12 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id, ...@@ -584,11 +583,12 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums), io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums),
sizeof(uint32_t)); sizeof(uint32_t));
keys->resize(shard_nums); keys->resize(shard_nums);
values->resize(shard_nums * accessor->update_dim()); values->resize(shard_nums * accessor->GetTableInfo(UPDATE_DIM));
io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT
sizeof(uint64_t) * shard_nums); sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward((void *)(values->data()), // NOLINT io_buffer_itr.copy_and_forward(
shard_nums * accessor->update_size()); (void *)(values->data()), // NOLINT
shard_nums * accessor->GetTableInfo(UPDATE_SIZE));
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
...@@ -630,7 +630,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param( ...@@ -630,7 +630,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
auto kvs = ids[shard_idx]; auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx]; auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size(); size_t kv_size = kvs.size();
uint32_t value_size = accessor->update_size(); uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
// 发送RPC请求 // 发送RPC请求
auto *push_request = closure->request(shard_idx); auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
...@@ -638,13 +638,14 @@ std::future<int32_t> BrpcPsClient::push_sparse_param( ...@@ -638,13 +638,14 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
push_request->set_client_id(_client_id); push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data(); auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); push_data->resize(kv_size *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data()); char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t); push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) { for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->update_size(); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
} }
PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
...@@ -660,9 +661,11 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions, ...@@ -660,9 +661,11 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t table_id) { size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense"); auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = table_accessor(table_id); auto *accessor = table_accessor(table_id);
auto fea_dim = accessor->GetTableInfo(FEA_DIM);
auto select_size = accessor->GetTableInfo(SELECT_SIZE);
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num); dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// callback 将各shard结果,顺序填入region // callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, num_per_shard, regions, region_num, request_call_num, [request_call_num, num_per_shard, regions, region_num,
...@@ -671,7 +674,8 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions, ...@@ -671,7 +674,8 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t region_idx = 0; // 当前填充的region偏移 size_t region_idx = 0; // 当前填充的region偏移
size_t region_data_idx = 0; // 当前填充的region内data偏移 size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done); auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
size_t shard_data_size = num_per_shard * accessor->select_size(); size_t shard_data_size =
num_per_shard * accessor->GetTableInfo(SELECT_SIZE);
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
ret = -1; ret = -1;
...@@ -739,8 +743,8 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions, ...@@ -739,8 +743,8 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据 // 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num); std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num); dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
size_t shard_data_size = num_per_shard * accessor->update_size(); size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE);
size_t current_region_idx = 0; size_t current_region_idx = 0;
size_t current_region_data_idx = 0; size_t current_region_data_idx = 0;
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
...@@ -847,7 +851,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient( ...@@ -847,7 +851,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
auto value_ptr = value_ptrs[shard_idx]; auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size(); size_t kv_size = kvs.size();
uint32_t value_size = accessor->update_size(); uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
// 发送RPC请求 // 发送RPC请求
auto *push_request = closure->request(shard_idx); auto *push_request = closure->request(shard_idx);
...@@ -856,14 +860,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient( ...@@ -856,14 +860,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
push_request->set_client_id(_client_id); push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data(); auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); push_data->resize(kv_size *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data()); char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t); push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) { for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->update_size(); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
} }
PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
...@@ -884,7 +889,7 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient( ...@@ -884,7 +889,7 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
auto *accessor = table_accessor(table_id); auto *accessor = table_accessor(table_id);
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num); dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id); closure->request(i)->set_table_id(table_id);
...@@ -962,7 +967,8 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values, ...@@ -962,7 +967,8 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
} }
auto *accessor = table_accessor(table_id); auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();
size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) { request_call_num, [shard_sorted_kvs, value_size](void *done) {
...@@ -1075,7 +1081,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values, ...@@ -1075,7 +1081,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
} }
auto *accessor = table_accessor(table_id); auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size(); size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
DownpourBrpcClosure *closure = new DownpourBrpcClosure( DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) { request_call_num, [shard_sorted_kvs, value_size](void *done) {
...@@ -1199,7 +1205,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial( ...@@ -1199,7 +1205,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) { uint32_t num, void *done, int pserver_idx) {
auto *accessor = table_accessor(table_id); auto *accessor = table_accessor(table_id);
size_t value_size = accessor->update_size(); size_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done); DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise); closure->add_promise(promise);
...@@ -1359,8 +1365,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id, ...@@ -1359,8 +1365,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
shard_kv_data.kv_num = 0; shard_kv_data.kv_num = 0;
continue; continue;
} }
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
uint32_t value_size = accessor->update_size();
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first; shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first;
shard_kv_data.value_list[kv_idx].assign( shard_kv_data.value_list[kv_idx].assign(
...@@ -1506,7 +1511,7 @@ void BrpcPsClient::push_sparse_task_consume() { ...@@ -1506,7 +1511,7 @@ void BrpcPsClient::push_sparse_task_consume() {
void sparse_local_merge(ValueAccessor *accessor, float *merge_data, void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
const float *another_data) { const float *another_data) {
size_t col_num = accessor->update_size() / sizeof(float); size_t col_num = accessor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
float *merge_data_shell[col_num]; float *merge_data_shell[col_num];
const float *another_data_shell[col_num]; const float *another_data_shell[col_num];
for (int i = 0; i < col_num; ++i) { for (int i = 0; i < col_num; ++i) {
...@@ -1522,7 +1527,7 @@ int BrpcPsClient::push_sparse_async_shard_merge( ...@@ -1522,7 +1527,7 @@ int BrpcPsClient::push_sparse_async_shard_merge(
ValueAccessor *accessor) { ValueAccessor *accessor) {
size_t merged_kv_count = 0; size_t merged_kv_count = 0;
uint64_t min_key = UINT64_MAX; uint64_t min_key = UINT64_MAX;
uint32_t value_size = accessor->update_size(); uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list; thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear(); sorted_kv_list.clear();
...@@ -1628,8 +1633,9 @@ int BrpcPsClient::push_sparse_async_shard_push( ...@@ -1628,8 +1633,9 @@ int BrpcPsClient::push_sparse_async_shard_push(
push_request->add_params(reinterpret_cast<char *>(&merged_kv_count), push_request->add_params(reinterpret_cast<char *>(&merged_kv_count),
sizeof(uint32_t)); // NOLINT sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data(); auto *push_data = push_request->mutable_data();
int update_size = accessor->GetTableInfo(UPDATE_SIZE);
push_data->resize(merged_kv_count * push_data->resize(merged_kv_count *
(sizeof(uint64_t) + accessor->update_size())); (sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data()); char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, merged_key_list.data(), memcpy(push_data_ptr, merged_key_list.data(),
merged_kv_count * sizeof(uint64_t)); merged_kv_count * sizeof(uint64_t));
...@@ -1638,8 +1644,8 @@ int BrpcPsClient::push_sparse_async_shard_push( ...@@ -1638,8 +1644,8 @@ int BrpcPsClient::push_sparse_async_shard_push(
const char *task_data_ptr = merged_value_list[i].data(); const char *task_data_ptr = merged_value_list[i].data();
memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT
accessor->update_size()); accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->update_size(); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
} }
PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type( closure->cntl(shard_idx)->set_request_compress_type(
...@@ -1654,6 +1660,8 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions, ...@@ -1654,6 +1660,8 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto *accessor = table_accessor(table_id); auto *accessor = table_accessor(table_id);
int fea_dim = accessor->GetTableInfo(FEA_DIM);
int update_dim = accessor->GetTableInfo(UPDATE_DIM);
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense"); auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
auto parse_timer = auto parse_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_parse"); std::make_shared<CostTimer>("pserver_client_push_dense_parse");
...@@ -1673,11 +1681,11 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions, ...@@ -1673,11 +1681,11 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num); dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// 将region数据拷贝到转置矩阵中 // 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num * async_task->data()->resize(num_per_shard * request_call_num *
accessor->update_dim()); accessor->GetTableInfo(UPDATE_DIM));
float *data = async_task->data()->data(); float *data = async_task->data()->data();
size_t data_size = async_task->data()->size(); size_t data_size = async_task->data()->size();
uint32_t pos = 0; uint32_t pos = 0;
...@@ -1806,7 +1814,7 @@ void BrpcPsClient::push_dense_raw_gradient( ...@@ -1806,7 +1814,7 @@ void BrpcPsClient::push_dense_raw_gradient(
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc"); auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer); closure->add_timer(timer);
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num); dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
auto send_timer = auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send"); std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
......
...@@ -207,7 +207,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, ...@@ -207,7 +207,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
auto res_data = butil::get_object<std::vector<float>>(); auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->value_accesor()->select_size() / sizeof(float)); res_data->resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_dense(res_data->data(), num); TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = res_data->data();
table_context.num = num;
table->Pull(table_context);
// table->pull_dense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()), cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float)); res_data->size() * sizeof(float));
...@@ -264,9 +269,15 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, ...@@ -264,9 +269,15 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
|--4B---|----------------| |--4B---|----------------|
*/ */
uint32_t num = *(const uint32_t *)(request.data().data()); uint32_t num = *(const uint32_t *)(request.data().data());
const float *values = TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values =
(const float *)(request.data().data() + sizeof(uint32_t)); (const float *)(request.data().data() + sizeof(uint32_t));
if (table->push_dense(values, num) != 0) { table_context.num = num;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if (table->Push(table_context) != 0) {
// if (table->push_dense(values, num) != 0) {
set_response_code(response, -1, "push_dense failed"); set_response_code(response, -1, "push_dense failed");
} }
...@@ -388,7 +399,12 @@ int32_t BrpcPsService::pull_sparse(Table *table, ...@@ -388,7 +399,12 @@ int32_t BrpcPsService::pull_sparse(Table *table,
auto res_data = butil::get_object<std::vector<float>>(); auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * dim); res_data->resize(num * dim);
table->pull_sparse(res_data->data(), value); TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.pull_value = value;
table_context.pull_context.values = res_data->data();
table->Pull(table_context);
// table->pull_sparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()), cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float)); res_data->size() * sizeof(float));
...@@ -421,10 +437,17 @@ int32_t BrpcPsService::push_sparse(Table *table, ...@@ -421,10 +437,17 @@ int32_t BrpcPsService::push_sparse(Table *table,
|---keysData---|---valuesData---| |---keysData---|---valuesData---|
|---8*{num}B---|----------------| |---8*{num}B---|----------------|
*/ */
const uint64_t *keys = (const uint64_t *)push_data.data(); TableContext table_context;
const float *values = table_context.value_type = Sparse;
table_context.push_context.keys = (const uint64_t *)push_data.data();
table_context.push_context.values =
(const float *)(push_data.data() + sizeof(uint64_t) * num); (const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse(keys, values, num) != 0) { table_context.num = num;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if (table->Push(table_context) != 0) {
// if (table->push_sparse(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse error"); set_response_code(response, -1, "push_sparse error");
} }
return 0; return 0;
......
...@@ -86,9 +86,9 @@ struct RequestContext { ...@@ -86,9 +86,9 @@ struct RequestContext {
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense ValueType value_type; // 1 for sparse, 2 for dense
void *keys; uint64_t *keys;
void **sparse_values; // for sparse values float **sparse_values; // for sparse values
Region *dense_values; // for dense values Region *dense_values; // for dense values
PushContext push_context; PushContext push_context;
size_t num; size_t num;
bool is_training; bool is_training;
......
...@@ -126,11 +126,13 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) { ...@@ -126,11 +126,13 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values); Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table); pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse } else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys); // uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values); // char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table; size_t table_id = pull_context.table;
size_t num = pull_context.num; size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num); pull_sparse_ptr(reinterpret_cast<char**>(pull_context.sparse_values),
table_id, pull_context.keys, num);
} }
} }
......
...@@ -56,6 +56,17 @@ struct AccessorInfo { ...@@ -56,6 +56,17 @@ struct AccessorInfo {
size_t fea_dim; size_t fea_dim;
}; };
enum InfoKey {
DIM = 0,
SIZE = 1,
SELECT_SIZE = 2,
SELECT_DIM = 3,
UPDATE_SIZE = 4,
UPDATE_DIM = 5,
MF_SIZE = 6,
FEA_DIM = 7
};
class ValueAccessor { class ValueAccessor {
public: public:
ValueAccessor() {} ValueAccessor() {}
...@@ -79,7 +90,8 @@ class ValueAccessor { ...@@ -79,7 +90,8 @@ class ValueAccessor {
} }
virtual int initialize() = 0; virtual int initialize() = 0;
virtual void GetTableInfo(AccessorInfo& info) = 0; virtual void SetTableInfo(AccessorInfo& info) = 0;
virtual size_t GetTableInfo(InfoKey key) = 0;
// value维度 // value维度
virtual size_t dim() = 0; virtual size_t dim() = 0;
......
...@@ -138,7 +138,7 @@ int32_t CommonDenseTable::Pull(TableContext& context) { ...@@ -138,7 +138,7 @@ int32_t CommonDenseTable::Pull(TableContext& context) {
int32_t CommonDenseTable::Push(TableContext& context) { int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense); CHECK(context.value_type == Dense);
if (context.pull_context.values != nullptr) { if (context.push_context.values != nullptr) {
const float* values = context.push_context.values; const float* values = context.push_context.values;
return push_dense(values, context.num); return push_dense(values, context.num);
} }
...@@ -220,7 +220,7 @@ int32_t CommonDenseTable::load(const std::string& path, ...@@ -220,7 +220,7 @@ int32_t CommonDenseTable::load(const std::string& path,
} }
size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1; size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1 // param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t dim_num_per_shard = _value_accesor->fea_dim() / _shard_num + 1; size_t dim_num_per_shard = _table_info.fea_dim / _shard_num + 1;
size_t start_dim_idx = dim_num_per_shard * _shard_idx; size_t start_dim_idx = dim_num_per_shard * _shard_idx;
size_t start_file_idx = start_dim_idx / dim_num_per_file; size_t start_file_idx = start_dim_idx / dim_num_per_file;
size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file; size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file;
......
...@@ -370,7 +370,7 @@ int32_t CommonSparseTable::Pull(TableContext& context) { ...@@ -370,7 +370,7 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
int32_t CommonSparseTable::Push(TableContext& context) { int32_t CommonSparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse); CHECK(context.value_type == Sparse);
if (context.pull_context.values != nullptr) { if (context.push_context.values != nullptr) {
const float* values = context.push_context.values; const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys; const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num); return push_sparse(keys, values, context.num);
......
...@@ -38,16 +38,39 @@ int CtrCommonAccessor::initialize() { ...@@ -38,16 +38,39 @@ int CtrCommonAccessor::initialize() {
return 0; return 0;
} }
void CtrCommonAccessor::GetTableInfo(AccessorInfo& info) { void CtrCommonAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim(); info.dim = dim();
info.size = size(); info.size = size();
info.select_dim = select_dim(); info.select_dim = select_dim();
info.select_size = select_size(); info.select_size = select_size();
info.update_dim = update_dim(); info.update_dim = update_dim();
info.update_size = update_size(); info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim(); info.fea_dim = fea_dim();
} }
size_t CtrCommonAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); } size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); }
size_t CtrCommonAccessor::dim_size(size_t dim) { size_t CtrCommonAccessor::dim_size(size_t dim) {
......
...@@ -137,7 +137,8 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -137,7 +137,8 @@ class CtrCommonAccessor : public ValueAccessor {
virtual int initialize(); virtual int initialize();
virtual ~CtrCommonAccessor() {} virtual ~CtrCommonAccessor() {}
virtual void GetTableInfo(AccessorInfo& info); virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -37,16 +37,39 @@ int DownpourCtrDoubleAccessor::initialize() { ...@@ -37,16 +37,39 @@ int DownpourCtrDoubleAccessor::initialize() {
return 0; return 0;
} }
void DownpourCtrDoubleAccessor::GetTableInfo(AccessorInfo& info) { void DownpourCtrDoubleAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim(); info.dim = dim();
info.size = size(); info.size = size();
info.select_dim = select_dim(); info.select_dim = select_dim();
info.select_size = select_size(); info.select_size = select_size();
info.update_dim = update_dim(); info.update_dim = update_dim();
info.update_size = update_size(); info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim(); info.fea_dim = fea_dim();
} }
size_t DownpourCtrDoubleAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
size_t DownpourCtrDoubleAccessor::dim() { size_t DownpourCtrDoubleAccessor::dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::dim(embedx_dim); return DownpourCtrDoubleFeatureValue::dim(embedx_dim);
......
...@@ -168,7 +168,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -168,7 +168,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor() {} DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info); virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -24,6 +24,7 @@ namespace paddle { ...@@ -24,6 +24,7 @@ namespace paddle {
namespace distributed { namespace distributed {
struct PullSparseValue { struct PullSparseValue {
PullSparseValue() {}
explicit PullSparseValue(int numel, int dim) explicit PullSparseValue(int numel, int dim)
: numel_(numel), : numel_(numel),
dim_(dim), dim_(dim),
......
...@@ -37,16 +37,39 @@ int DownpourCtrAccessor::initialize() { ...@@ -37,16 +37,39 @@ int DownpourCtrAccessor::initialize() {
return 0; return 0;
} }
void DownpourCtrAccessor::GetTableInfo(AccessorInfo& info) { void DownpourCtrAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim(); info.dim = dim();
info.size = size(); info.size = size();
info.select_dim = select_dim(); info.select_dim = select_dim();
info.select_size = select_size(); info.select_size = select_size();
info.update_dim = update_dim(); info.update_dim = update_dim();
info.update_size = update_size(); info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim(); info.fea_dim = fea_dim();
} }
size_t DownpourCtrAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
size_t DownpourCtrAccessor::dim() { size_t DownpourCtrAccessor::dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim(embedx_dim); return DownpourCtrFeatureValue::dim(embedx_dim);
......
...@@ -160,7 +160,8 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -160,7 +160,8 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual ~DownpourCtrAccessor() {} virtual ~DownpourCtrAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info); virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -88,7 +88,8 @@ int32_t MemorySparseTable::load(const std::string& path, ...@@ -88,7 +88,8 @@ int32_t MemorySparseTable::load(const std::string& path,
size_t file_start_idx = _shard_idx * _avg_local_shard_num; size_t file_start_idx = _shard_idx * _avg_local_shard_num;
size_t feature_value_size = _value_accesor->size() / sizeof(float); size_t feature_value_size =
_value_accesor->GetTableInfo(SIZE) / sizeof(float);
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
...@@ -173,7 +174,8 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path, ...@@ -173,7 +174,8 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path,
size_t file_start_idx = _shard_idx * _avg_local_shard_num; size_t file_start_idx = _shard_idx * _avg_local_shard_num;
size_t feature_value_size = _value_accesor->size() / sizeof(float); size_t feature_value_size =
_value_accesor->GetTableInfo(SIZE) / sizeof(float);
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
...@@ -407,7 +409,7 @@ int32_t MemorySparseTable::Push(TableContext& context) { ...@@ -407,7 +409,7 @@ int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse); CHECK(context.value_type == Sparse);
const uint64_t* keys = context.push_context.keys; const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, context.push_context.ptr_values, context.num); return push_sparse(keys, context.push_context.values, context.num);
} }
int32_t MemorySparseTable::pull_sparse(float* pull_values, int32_t MemorySparseTable::pull_sparse(float* pull_values,
...@@ -415,9 +417,10 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, ...@@ -415,9 +417,10 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
CostTimer timer("pserver_sparse_select_all"); CostTimer timer("pserver_sparse_select_all");
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
const size_t value_size = _value_accesor->size() / sizeof(float); const size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_size = _value_accesor->mf_size() / sizeof(float); size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
size_t select_value_size = _value_accesor->select_size() / sizeof(float); size_t select_value_size =
_value_accesor->GetTableInfo(SELECT_SIZE) / sizeof(float);
// std::atomic<uint32_t> missed_keys{0}; // std::atomic<uint32_t> missed_keys{0};
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys( std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
...@@ -475,7 +478,6 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, ...@@ -475,7 +478,6 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait(); tasks[shard_id].wait();
} }
return 0; return 0;
} }
...@@ -541,9 +543,10 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys, ...@@ -541,9 +543,10 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
task_keys[shard_id].push_back({keys[i], i}); task_keys[shard_id].push_back({keys[i], i});
} }
const size_t value_col = _value_accesor->size() / sizeof(float); const size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_col = _value_accesor->mf_size() / sizeof(float); size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
size_t update_value_col = _value_accesor->update_size() / sizeof(float); size_t update_value_col =
_value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
...@@ -618,9 +621,10 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, ...@@ -618,9 +621,10 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
task_keys[shard_id].push_back({keys[i], i}); task_keys[shard_id].push_back({keys[i], i});
} }
size_t value_col = _value_accesor->size() / sizeof(float); size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_col = _value_accesor->mf_size() / sizeof(float); size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
size_t update_value_col = _value_accesor->update_size() / sizeof(float); size_t update_value_col =
_value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
......
...@@ -38,16 +38,39 @@ int SparseAccessor::initialize() { ...@@ -38,16 +38,39 @@ int SparseAccessor::initialize() {
return 0; return 0;
} }
void SparseAccessor::GetTableInfo(AccessorInfo& info) { void SparseAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim(); info.dim = dim();
info.size = size(); info.size = size();
info.select_dim = select_dim(); info.select_dim = select_dim();
info.select_size = select_size(); info.select_size = select_size();
info.update_dim = update_dim(); info.update_dim = update_dim();
info.update_size = update_size(); info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim(); info.fea_dim = fea_dim();
} }
size_t SparseAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
size_t SparseAccessor::dim() { return sparse_feature_value.dim(); } size_t SparseAccessor::dim() { return sparse_feature_value.dim(); }
size_t SparseAccessor::dim_size(size_t dim) { size_t SparseAccessor::dim_size(size_t dim) {
......
...@@ -123,7 +123,8 @@ class SparseAccessor : public ValueAccessor { ...@@ -123,7 +123,8 @@ class SparseAccessor : public ValueAccessor {
}; };
SparseAccessor() {} SparseAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info); virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
virtual ~SparseAccessor() {} virtual ~SparseAccessor() {}
// value维度 // value维度
......
...@@ -103,6 +103,7 @@ int32_t Table::initialize_accessor() { ...@@ -103,6 +103,7 @@ int32_t Table::initialize_accessor() {
return -1; return -1;
} }
_value_accesor.reset(accessor); _value_accesor.reset(accessor);
// _value_accesor->SetTableInfo(_table_info);
return 0; return 0;
} }
......
...@@ -37,7 +37,7 @@ enum ValueType { Sparse = 0, Dense = 1 }; ...@@ -37,7 +37,7 @@ enum ValueType { Sparse = 0, Dense = 1 };
struct PullContext { struct PullContext {
const uint64_t *keys; const uint64_t *keys;
const PullSparseValue pull_value; PullSparseValue pull_value;
float *values; float *values;
char **ptr_values; char **ptr_values;
}; };
...@@ -53,7 +53,7 @@ struct TableContext { ...@@ -53,7 +53,7 @@ struct TableContext {
PullContext pull_context; PullContext pull_context;
TablePushContext push_context; TablePushContext push_context;
size_t num; size_t num;
bool use_ptr; bool use_ptr = false;
}; };
class Table { class Table {
...@@ -164,6 +164,7 @@ class Table { ...@@ -164,6 +164,7 @@ class Table {
TableParameter _config; TableParameter _config;
float *_global_lr = nullptr; float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor; std::shared_ptr<ValueAccessor> _value_accesor;
AccessorInfo _table_info;
AfsClient _afs_client; AfsClient _afs_client;
}; };
REGISTER_PSCORE_REGISTERER(Table); REGISTER_PSCORE_REGISTERER(Table);
......
...@@ -20,16 +20,39 @@ namespace distributed { ...@@ -20,16 +20,39 @@ namespace distributed {
int CommMergeAccessor::initialize() { return 0; } int CommMergeAccessor::initialize() { return 0; }
void CommMergeAccessor::GetTableInfo(AccessorInfo &info) { void CommMergeAccessor::SetTableInfo(AccessorInfo &info) {
info.dim = dim(); info.dim = dim();
info.size = size(); info.size = size();
info.select_dim = select_dim(); info.select_dim = select_dim();
info.select_size = select_size(); info.select_size = select_size();
info.update_dim = update_dim(); info.update_dim = update_dim();
info.update_size = update_size(); info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim(); info.fea_dim = fea_dim();
} }
size_t CommMergeAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
// value 维度 // value 维度
size_t CommMergeAccessor::dim() { return 0; } size_t CommMergeAccessor::dim() { return 0; }
......
...@@ -30,7 +30,8 @@ class CommMergeAccessor : public ValueAccessor { ...@@ -30,7 +30,8 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor() {} CommMergeAccessor() {}
virtual ~CommMergeAccessor() {} virtual ~CommMergeAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo &info); virtual void SetTableInfo(AccessorInfo &info);
virtual size_t GetTableInfo(InfoKey key);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -337,9 +337,21 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, ...@@ -337,9 +337,21 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
pull_result_ptr.push_back(output_data + output_len); pull_result_ptr.push_back(output_data + output_len);
} }
} }
auto status = // ps client pull sparse
worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id, // construct client request context
fea_keys.data(), fea_keys.size(), is_training); RequestContext req_context;
req_context.value_type = Sparse;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.sparse_values = pull_result_ptr.data();
req_context.keys = fea_keys.data();
req_context.num = fea_keys.size();
req_context.is_training = is_training;
auto status = worker_ptr_->Pull(req_context);
// auto status =
// worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id,
// fea_keys.data(), fea_keys.size(),
// is_training);
status.wait(); status.wait();
auto ret = status.get(); auto ret = status.get();
if (ret != 0) { if (ret != 0) {
...@@ -366,7 +378,14 @@ void FleetWrapper::PullDenseVarsAsync( ...@@ -366,7 +378,14 @@ void FleetWrapper::PullDenseVarsAsync(
paddle::distributed::Region reg(w, tensor->numel()); paddle::distributed::Region reg(w, tensor->numel());
regions[i] = std::move(reg); regions[i] = std::move(reg);
} }
auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid); RequestContext req_context;
req_context.value_type = Dense;
req_context.training_mode = Async;
req_context.table = tid;
req_context.dense_values = regions.data();
req_context.num = regions.size();
auto status = worker_ptr_->Pull(req_context);
// auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status->push_back(std::move(status)); pull_dense_status->push_back(std::move(status));
} }
...@@ -451,8 +470,15 @@ void FleetWrapper::PushDenseVarsAsync( ...@@ -451,8 +470,15 @@ void FleetWrapper::PushDenseVarsAsync(
<< g[tensor->numel() - 1]; << g[tensor->numel() - 1];
} }
auto push_status = RequestContext req_context;
worker_ptr_->push_dense(regions.data(), regions.size(), table_id); req_context.value_type = Dense;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.push_context.push_dense_values = regions.data();
req_context.num = regions.size();
// auto push_status =
// worker_ptr_->push_dense(regions.data(), regions.size(), table_id);
auto push_status = worker_ptr_->Push(req_context);
} }
void FleetWrapper::PushSparseVarsAsync( void FleetWrapper::PushSparseVarsAsync(
...@@ -624,9 +650,19 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -624,9 +650,19 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_g_vec[i] = push_values.at(i).data(); push_g_vec[i] = push_values.at(i).data();
} }
auto status = worker_ptr_->push_sparse(table_id, push_keys.data(), // ps client push sparse
(const float**)push_g_vec.data(), // construct request context
push_keys.size()); RequestContext req_context;
req_context.value_type = Sparse;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.push_context.push_values = (const float**)push_g_vec.data();
req_context.push_context.keys = push_keys.data();
req_context.num = push_keys.size();
auto status = worker_ptr_->Push(req_context);
// auto status = worker_ptr_->push_sparse(table_id, push_keys.data(),
// (const float**)push_g_vec.data(),
// push_keys.size());
} }
void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModel(const std::string& path, const int mode) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册