diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index f4eb6c222466a3e190704f4d17e9fc6d4e33f125..9674717ffc24bcf34f7c6da92f1f097cd1a8823e 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -532,18 +532,17 @@ std::future BrpcPsClient::Pull(RequestContext &pull_context) { if (pull_context.value_type == Dense) { // pull dense Region *dense_region = reinterpret_cast(pull_context.dense_values); - pull_dense(dense_region, pull_context.num, pull_context.table); + return pull_dense(dense_region, pull_context.num, pull_context.table); } else { // pull sparse - uint64_t *keys = reinterpret_cast(pull_context.keys); - float **select_values = - reinterpret_cast(pull_context.sparse_values); size_t table_id = pull_context.table; size_t num = pull_context.num; bool is_training = pull_context.is_training; if (pull_context.training_mode == Geo) { // for geo - pull_sparse_param(select_values, table_id, keys, num, is_training); + return pull_sparse_param(pull_context.sparse_values, table_id, + pull_context.keys, num, is_training); } else if (pull_context.training_mode == Async) { // for async - pull_sparse(select_values, table_id, keys, num, is_training); + return pull_sparse(pull_context.sparse_values, table_id, + pull_context.keys, num, is_training); } } } @@ -551,7 +550,7 @@ std::future BrpcPsClient::Pull(RequestContext &pull_context) { std::future BrpcPsClient::Push(RequestContext &push_context) { if (push_context.value_type == Dense) { // push dense const Region *dense_region = push_context.push_context.push_dense_values; - push_dense(dense_region, push_context.num, push_context.table); + return push_dense(dense_region, push_context.num, push_context.table); } else { // push sparse size_t table_id = push_context.table; size_t num = push_context.num; @@ -561,7 +560,7 @@ std::future BrpcPsClient::Push(RequestContext &push_context) { } else if (push_context.training_mode == Async) { // for async const uint64_t *keys = push_context.push_context.keys; const float **update_values = push_context.push_context.push_values; - push_sparse(table_id, keys, update_values, num); + return push_sparse(table_id, keys, update_values, num); } } } @@ -584,11 +583,12 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, io_buffer_itr.copy_and_forward(reinterpret_cast(&shard_nums), sizeof(uint32_t)); keys->resize(shard_nums); - values->resize(shard_nums * accessor->update_dim()); + values->resize(shard_nums * accessor->GetTableInfo(UPDATE_DIM)); io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT sizeof(uint64_t) * shard_nums); - io_buffer_itr.copy_and_forward((void *)(values->data()), // NOLINT - shard_nums * accessor->update_size()); + io_buffer_itr.copy_and_forward( + (void *)(values->data()), // NOLINT + shard_nums * accessor->GetTableInfo(UPDATE_SIZE)); closure->set_promise_value(ret); }); auto promise = std::make_shared>(); @@ -630,7 +630,7 @@ std::future BrpcPsClient::push_sparse_param( auto kvs = ids[shard_idx]; auto value_ptr = value_ptrs[shard_idx]; size_t kv_size = kvs.size(); - uint32_t value_size = accessor->update_size(); + uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); // 发送RPC请求 auto *push_request = closure->request(shard_idx); push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); @@ -638,13 +638,14 @@ std::future BrpcPsClient::push_sparse_param( push_request->set_client_id(_client_id); push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); - push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); + push_data->resize(kv_size * + (sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE))); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); push_data_ptr += kv_size * sizeof(uint64_t); for (int i = 0; i < kv_size; ++i) { - memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); - push_data_ptr += accessor->update_size(); + memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); + push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( @@ -660,9 +661,11 @@ std::future BrpcPsClient::pull_dense(Region *regions, size_t table_id) { auto timer = std::make_shared("pserver_client_pull_dense"); auto *accessor = table_accessor(table_id); + auto fea_dim = accessor->GetTableInfo(FEA_DIM); + auto select_size = accessor->GetTableInfo(SELECT_SIZE); size_t request_call_num = _server_channels.size(); uint32_t num_per_shard = - dense_dim_per_shard(accessor->fea_dim(), request_call_num); + dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); // callback 将各shard结果,顺序填入region DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, num_per_shard, regions, region_num, @@ -671,7 +674,8 @@ std::future BrpcPsClient::pull_dense(Region *regions, size_t region_idx = 0; // 当前填充的region偏移 size_t region_data_idx = 0; // 当前填充的region内data偏移 auto *closure = reinterpret_cast(done); - size_t shard_data_size = num_per_shard * accessor->select_size(); + size_t shard_data_size = + num_per_shard * accessor->GetTableInfo(SELECT_SIZE); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { ret = -1; @@ -739,8 +743,8 @@ std::future BrpcPsClient::push_dense_param(const Region *regions, // 1.拆分Region数据到shard中,后续多shard并行拷贝数据 std::vector> regions_partition(request_call_num); uint32_t num_per_shard = - dense_dim_per_shard(accessor->fea_dim(), request_call_num); - size_t shard_data_size = num_per_shard * accessor->update_size(); + dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE); size_t current_region_idx = 0; size_t current_region_data_idx = 0; for (size_t i = 0; i < request_call_num; ++i) { @@ -847,7 +851,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( auto value_ptr = value_ptrs[shard_idx]; size_t kv_size = kvs.size(); - uint32_t value_size = accessor->update_size(); + uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); // 发送RPC请求 auto *push_request = closure->request(shard_idx); @@ -856,14 +860,15 @@ std::future BrpcPsClient::push_sparse_raw_gradient( push_request->set_client_id(_client_id); push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); - push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); + push_data->resize(kv_size * + (sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE))); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); push_data_ptr += kv_size * sizeof(uint64_t); for (int i = 0; i < kv_size; ++i) { - memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); - push_data_ptr += accessor->update_size(); + memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); + push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( @@ -884,7 +889,7 @@ std::future BrpcPsClient::push_dense_raw_gradient( std::future fut = promise->get_future(); auto *accessor = table_accessor(table_id); uint32_t num_per_shard = - dense_dim_per_shard(accessor->fea_dim(), request_call_num); + dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_table_id(table_id); @@ -962,7 +967,8 @@ std::future BrpcPsClient::pull_sparse(float **select_values, } auto *accessor = table_accessor(table_id); - size_t value_size = accessor->select_size(); + + size_t value_size = accessor->GetTableInfo(SELECT_SIZE); DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [shard_sorted_kvs, value_size](void *done) { @@ -1075,7 +1081,7 @@ std::future BrpcPsClient::pull_sparse_param(float **select_values, } auto *accessor = table_accessor(table_id); - size_t value_size = accessor->select_size(); + size_t value_size = accessor->GetTableInfo(SELECT_SIZE); DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [shard_sorted_kvs, value_size](void *done) { @@ -1199,7 +1205,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient_partial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) { auto *accessor = table_accessor(table_id); - size_t value_size = accessor->update_size(); + size_t value_size = accessor->GetTableInfo(UPDATE_SIZE); DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); closure->add_promise(promise); @@ -1359,8 +1365,7 @@ std::future BrpcPsClient::push_sparse(size_t table_id, shard_kv_data.kv_num = 0; continue; } - - uint32_t value_size = accessor->update_size(); + uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first; shard_kv_data.value_list[kv_idx].assign( @@ -1506,7 +1511,7 @@ void BrpcPsClient::push_sparse_task_consume() { void sparse_local_merge(ValueAccessor *accessor, float *merge_data, const float *another_data) { - size_t col_num = accessor->update_size() / sizeof(float); + size_t col_num = accessor->GetTableInfo(UPDATE_SIZE) / sizeof(float); float *merge_data_shell[col_num]; const float *another_data_shell[col_num]; for (int i = 0; i < col_num; ++i) { @@ -1522,7 +1527,7 @@ int BrpcPsClient::push_sparse_async_shard_merge( ValueAccessor *accessor) { size_t merged_kv_count = 0; uint64_t min_key = UINT64_MAX; - uint32_t value_size = accessor->update_size(); + uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); thread_local std::vector> sorted_kv_list; sorted_kv_list.clear(); @@ -1628,8 +1633,9 @@ int BrpcPsClient::push_sparse_async_shard_push( push_request->add_params(reinterpret_cast(&merged_kv_count), sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); + int update_size = accessor->GetTableInfo(UPDATE_SIZE); push_data->resize(merged_kv_count * - (sizeof(uint64_t) + accessor->update_size())); + (sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE))); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, merged_key_list.data(), merged_kv_count * sizeof(uint64_t)); @@ -1638,8 +1644,8 @@ int BrpcPsClient::push_sparse_async_shard_push( const char *task_data_ptr = merged_value_list[i].data(); memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT - accessor->update_size()); - push_data_ptr += accessor->update_size(); + accessor->GetTableInfo(UPDATE_SIZE)); + push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( @@ -1654,6 +1660,8 @@ std::future BrpcPsClient::push_dense(const Region *regions, size_t region_num, size_t table_id) { auto *accessor = table_accessor(table_id); + int fea_dim = accessor->GetTableInfo(FEA_DIM); + int update_dim = accessor->GetTableInfo(UPDATE_DIM); auto push_timer = std::make_shared("pserver_client_push_dense"); auto parse_timer = std::make_shared("pserver_client_push_dense_parse"); @@ -1673,11 +1681,11 @@ std::future BrpcPsClient::push_dense(const Region *regions, size_t request_call_num = _server_channels.size(); uint32_t num_per_shard = - dense_dim_per_shard(accessor->fea_dim(), request_call_num); + dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); // 将region数据拷贝到转置矩阵中 async_task->data()->resize(num_per_shard * request_call_num * - accessor->update_dim()); + accessor->GetTableInfo(UPDATE_DIM)); float *data = async_task->data()->data(); size_t data_size = async_task->data()->size(); uint32_t pos = 0; @@ -1806,7 +1814,7 @@ void BrpcPsClient::push_dense_raw_gradient( auto timer = std::make_shared("pserver_client_push_dense_rpc"); closure->add_timer(timer); uint32_t num_per_shard = - dense_dim_per_shard(accessor->fea_dim(), request_call_num); + dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); auto send_timer = std::make_shared("pserver_client_push_dense_send"); for (size_t i = 0; i < request_call_num; ++i) { diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index 747b0cbb325d0d6c27808a73e8af1386f557fd04..0d7624baec5806e6bf990a382fe9bdfbe2de9690 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -207,7 +207,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, auto res_data = butil::get_object>(); res_data->resize(num * table->value_accesor()->select_size() / sizeof(float)); - table->pull_dense(res_data->data(), num); + TableContext table_context; + table_context.value_type = Dense; + table_context.pull_context.values = res_data->data(); + table_context.num = num; + table->Pull(table_context); + // table->pull_dense(res_data->data(), num); cntl->response_attachment().append((char *)(res_data->data()), res_data->size() * sizeof(float)); @@ -264,9 +269,15 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, |--4B---|----------------| */ uint32_t num = *(const uint32_t *)(request.data().data()); - const float *values = + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = (const float *)(request.data().data() + sizeof(uint32_t)); - if (table->push_dense(values, num) != 0) { + table_context.num = num; + // const float *values = (const float *)(request.data().data() + + // sizeof(uint32_t)); + if (table->Push(table_context) != 0) { + // if (table->push_dense(values, num) != 0) { set_response_code(response, -1, "push_dense failed"); } @@ -388,7 +399,12 @@ int32_t BrpcPsService::pull_sparse(Table *table, auto res_data = butil::get_object>(); res_data->resize(num * dim); - table->pull_sparse(res_data->data(), value); + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.pull_value = value; + table_context.pull_context.values = res_data->data(); + table->Pull(table_context); + // table->pull_sparse(res_data->data(), value); cntl->response_attachment().append((char *)(res_data->data()), res_data->size() * sizeof(float)); @@ -421,10 +437,17 @@ int32_t BrpcPsService::push_sparse(Table *table, |---keysData---|---valuesData---| |---8*{num}B---|----------------| */ - const uint64_t *keys = (const uint64_t *)push_data.data(); - const float *values = + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = (const uint64_t *)push_data.data(); + table_context.push_context.values = (const float *)(push_data.data() + sizeof(uint64_t) * num); - if (table->push_sparse(keys, values, num) != 0) { + table_context.num = num; + // const uint64_t *keys = (const uint64_t *)push_data.data(); + // const float *values = (const float *)(push_data.data() + sizeof(uint64_t) * + // num); + if (table->Push(table_context) != 0) { + // if (table->push_sparse(keys, values, num) != 0) { set_response_code(response, -1, "push_sparse error"); } return 0; diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index 8a2bfbe31602be299366fdcbeb264e45a5c4f703..83d2aba1db44564a3314e6d1f9b07ebd2730b85e 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -86,9 +86,9 @@ struct RequestContext { TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync TrainingPhase training_phase; // 1 for init, 2 for train ValueType value_type; // 1 for sparse, 2 for dense - void *keys; - void **sparse_values; // for sparse values - Region *dense_values; // for dense values + uint64_t *keys; + float **sparse_values; // for sparse values + Region *dense_values; // for dense values PushContext push_context; size_t num; bool is_training; diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index 55519b4f622defaf9e801144ce154c73f0f318c5..fe5cbe682ea67cffff22786f305cb50182983367 100755 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -126,11 +126,13 @@ std::future PsLocalClient::Load(const LoadSaveContext& load_context) { Region* dense_region = reinterpret_cast(pull_context.dense_values); pull_dense(dense_region, pull_context.num, pull_context.table); } else { // pull sparse - uint64_t* keys = reinterpret_cast(pull_context.keys); - char** select_values = reinterpret_cast(pull_context.sparse_values); + // uint64_t* keys = reinterpret_cast(pull_context.keys); + // char** select_values = + // reinterpret_cast(pull_context.sparse_values); size_t table_id = pull_context.table; size_t num = pull_context.num; - pull_sparse_ptr(select_values, table_id, keys, num); + pull_sparse_ptr(reinterpret_cast(pull_context.sparse_values), + table_id, pull_context.keys, num); } } diff --git a/paddle/fluid/distributed/ps/table/accessor.h b/paddle/fluid/distributed/ps/table/accessor.h index 07c211bb9c12866e3646a0dbdebfba189eb2507e..207cc94b4cb15427a231b1478891dd7185a8514b 100644 --- a/paddle/fluid/distributed/ps/table/accessor.h +++ b/paddle/fluid/distributed/ps/table/accessor.h @@ -56,6 +56,17 @@ struct AccessorInfo { size_t fea_dim; }; +enum InfoKey { + DIM = 0, + SIZE = 1, + SELECT_SIZE = 2, + SELECT_DIM = 3, + UPDATE_SIZE = 4, + UPDATE_DIM = 5, + MF_SIZE = 6, + FEA_DIM = 7 +}; + class ValueAccessor { public: ValueAccessor() {} @@ -79,7 +90,8 @@ class ValueAccessor { } virtual int initialize() = 0; - virtual void GetTableInfo(AccessorInfo& info) = 0; + virtual void SetTableInfo(AccessorInfo& info) = 0; + virtual size_t GetTableInfo(InfoKey key) = 0; // value维度 virtual size_t dim() = 0; diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.cc b/paddle/fluid/distributed/ps/table/common_dense_table.cc index b0394a4dab6dab299606e3f264b104b4af160eef..a462fc50aeb7219256890ab16469e9825701b3ca 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/common_dense_table.cc @@ -138,7 +138,7 @@ int32_t CommonDenseTable::Pull(TableContext& context) { int32_t CommonDenseTable::Push(TableContext& context) { CHECK(context.value_type == Dense); - if (context.pull_context.values != nullptr) { + if (context.push_context.values != nullptr) { const float* values = context.push_context.values; return push_dense(values, context.num); } @@ -220,7 +220,7 @@ int32_t CommonDenseTable::load(const std::string& path, } size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1; // param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1 - size_t dim_num_per_shard = _value_accesor->fea_dim() / _shard_num + 1; + size_t dim_num_per_shard = _table_info.fea_dim / _shard_num + 1; size_t start_dim_idx = dim_num_per_shard * _shard_idx; size_t start_file_idx = start_dim_idx / dim_num_per_file; size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file; diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.cc b/paddle/fluid/distributed/ps/table/common_sparse_table.cc index 45be53335e1a181f7c1e2abb7326ac6b9800703f..1fc8adc2b92ebd79544ba518382faa59989337d3 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.cc @@ -370,7 +370,7 @@ int32_t CommonSparseTable::Pull(TableContext& context) { int32_t CommonSparseTable::Push(TableContext& context) { CHECK(context.value_type == Sparse); - if (context.pull_context.values != nullptr) { + if (context.push_context.values != nullptr) { const float* values = context.push_context.values; const uint64_t* keys = context.push_context.keys; return push_sparse(keys, values, context.num); diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_accessor.cc index 4974f004caa43bb01809dd58b94f1826135e7414..ffb97914fb8c02867a26b12b5ea8c602cf679a89 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.cc @@ -38,16 +38,39 @@ int CtrCommonAccessor::initialize() { return 0; } -void CtrCommonAccessor::GetTableInfo(AccessorInfo& info) { +void CtrCommonAccessor::SetTableInfo(AccessorInfo& info) { info.dim = dim(); info.size = size(); info.select_dim = select_dim(); info.select_size = select_size(); info.update_dim = update_dim(); info.update_size = update_size(); + info.mf_size = mf_size(); info.fea_dim = fea_dim(); } +size_t CtrCommonAccessor::GetTableInfo(InfoKey key) { + switch (key) { + case DIM: + return dim(); + case SIZE: + return size(); + case SELECT_DIM: + return select_dim(); + case SELECT_SIZE: + return select_size(); + case UPDATE_DIM: + return update_dim(); + case UPDATE_SIZE: + return update_size(); + case MF_SIZE: + return mf_size(); + case FEA_DIM: + return fea_dim(); + } + return 0; +} + size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); } size_t CtrCommonAccessor::dim_size(size_t dim) { diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.h b/paddle/fluid/distributed/ps/table/ctr_accessor.h index 6cf18aa5e4632e2c82a03d1c05722f3c7b361414..a2121b21d9fe6ff47f11059436d122fa15d02f65 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.h @@ -137,7 +137,8 @@ class CtrCommonAccessor : public ValueAccessor { virtual int initialize(); virtual ~CtrCommonAccessor() {} - virtual void GetTableInfo(AccessorInfo& info); + virtual void SetTableInfo(AccessorInfo& info); + virtual size_t GetTableInfo(InfoKey key); // value维度 virtual size_t dim(); // value各个维度的size diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc index bccf1fdebafa03442047048825ef85207711b6b3..0e3df6e82521deec657764df4479b3ea8d028cd9 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc @@ -37,16 +37,39 @@ int DownpourCtrDoubleAccessor::initialize() { return 0; } -void DownpourCtrDoubleAccessor::GetTableInfo(AccessorInfo& info) { +void DownpourCtrDoubleAccessor::SetTableInfo(AccessorInfo& info) { info.dim = dim(); info.size = size(); info.select_dim = select_dim(); info.select_size = select_size(); info.update_dim = update_dim(); info.update_size = update_size(); + info.mf_size = mf_size(); info.fea_dim = fea_dim(); } +size_t DownpourCtrDoubleAccessor::GetTableInfo(InfoKey key) { + switch (key) { + case DIM: + return dim(); + case SIZE: + return size(); + case SELECT_DIM: + return select_dim(); + case SELECT_SIZE: + return select_size(); + case UPDATE_DIM: + return update_dim(); + case UPDATE_SIZE: + return update_size(); + case MF_SIZE: + return mf_size(); + case FEA_DIM: + return fea_dim(); + } + return 0; +} + size_t DownpourCtrDoubleAccessor::dim() { auto embedx_dim = _config.embedx_dim(); return DownpourCtrDoubleFeatureValue::dim(embedx_dim); diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.h b/paddle/fluid/distributed/ps/table/ctr_double_accessor.h index d7942634e86003c484710aad1d969e4d6371cb7f..fb8b27ecfd98549c175ac6e2ac2b4f96972a73f5 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.h @@ -168,7 +168,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {} virtual int initialize(); - virtual void GetTableInfo(AccessorInfo& info); + virtual void SetTableInfo(AccessorInfo& info); + virtual size_t GetTableInfo(InfoKey key); // value维度 virtual size_t dim(); // value各个维度的size diff --git a/paddle/fluid/distributed/ps/table/depends/sparse_utils.h b/paddle/fluid/distributed/ps/table/depends/sparse_utils.h index 98e0250acc4d686dbde561ffb03edeb96444c406..5aef6b8cfbc74ff01ffffd2bf5a9d20dec0b80d6 100644 --- a/paddle/fluid/distributed/ps/table/depends/sparse_utils.h +++ b/paddle/fluid/distributed/ps/table/depends/sparse_utils.h @@ -24,6 +24,7 @@ namespace paddle { namespace distributed { struct PullSparseValue { + PullSparseValue() {} explicit PullSparseValue(int numel, int dim) : numel_(numel), dim_(dim), diff --git a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc index e8ca7430351de7cbdc1e98607d6d9b884b6a376a..2fff81b1a4dc612a1e84659ee55e262542cd406c 100644 --- a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc @@ -37,16 +37,39 @@ int DownpourCtrAccessor::initialize() { return 0; } -void DownpourCtrAccessor::GetTableInfo(AccessorInfo& info) { +void DownpourCtrAccessor::SetTableInfo(AccessorInfo& info) { info.dim = dim(); info.size = size(); info.select_dim = select_dim(); info.select_size = select_size(); info.update_dim = update_dim(); info.update_size = update_size(); + info.mf_size = mf_size(); info.fea_dim = fea_dim(); } +size_t DownpourCtrAccessor::GetTableInfo(InfoKey key) { + switch (key) { + case DIM: + return dim(); + case SIZE: + return size(); + case SELECT_DIM: + return select_dim(); + case SELECT_SIZE: + return select_size(); + case UPDATE_DIM: + return update_dim(); + case UPDATE_SIZE: + return update_size(); + case MF_SIZE: + return mf_size(); + case FEA_DIM: + return fea_dim(); + } + return 0; +} + size_t DownpourCtrAccessor::dim() { auto embedx_dim = _config.embedx_dim(); return DownpourCtrFeatureValue::dim(embedx_dim); diff --git a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h index 11991ad044ff63353c9a898469ec915163c2dea9..6ff6c0438310e348ce3edfa2409aed8ddc1c083a 100644 --- a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h +++ b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h @@ -160,7 +160,8 @@ class DownpourCtrAccessor : public ValueAccessor { virtual ~DownpourCtrAccessor() {} virtual int initialize(); - virtual void GetTableInfo(AccessorInfo& info); + virtual void SetTableInfo(AccessorInfo& info); + virtual size_t GetTableInfo(InfoKey key); // value维度 virtual size_t dim(); // value各个维度的size diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index ea61ca444f7fdeb6e454f3f0f48fb856b3d671f1..3f5c484eab82525f25c4ee6b52355f0c5f063d02 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -88,7 +88,8 @@ int32_t MemorySparseTable::load(const std::string& path, size_t file_start_idx = _shard_idx * _avg_local_shard_num; - size_t feature_value_size = _value_accesor->size() / sizeof(float); + size_t feature_value_size = + _value_accesor->GetTableInfo(SIZE) / sizeof(float); int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; omp_set_num_threads(thread_num); @@ -173,7 +174,8 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path, size_t file_start_idx = _shard_idx * _avg_local_shard_num; - size_t feature_value_size = _value_accesor->size() / sizeof(float); + size_t feature_value_size = + _value_accesor->GetTableInfo(SIZE) / sizeof(float); int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; omp_set_num_threads(thread_num); @@ -407,7 +409,7 @@ int32_t MemorySparseTable::Push(TableContext& context) { CHECK(context.value_type == Sparse); const uint64_t* keys = context.push_context.keys; - return push_sparse(keys, context.push_context.ptr_values, context.num); + return push_sparse(keys, context.push_context.values, context.num); } int32_t MemorySparseTable::pull_sparse(float* pull_values, @@ -415,9 +417,10 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, CostTimer timer("pserver_sparse_select_all"); std::vector> tasks(_real_local_shard_num); - const size_t value_size = _value_accesor->size() / sizeof(float); - size_t mf_value_size = _value_accesor->mf_size() / sizeof(float); - size_t select_value_size = _value_accesor->select_size() / sizeof(float); + const size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float); + size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); + size_t select_value_size = + _value_accesor->GetTableInfo(SELECT_SIZE) / sizeof(float); // std::atomic missed_keys{0}; std::vector>> task_keys( @@ -475,7 +478,6 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { tasks[shard_id].wait(); } - return 0; } @@ -541,9 +543,10 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys, task_keys[shard_id].push_back({keys[i], i}); } - const size_t value_col = _value_accesor->size() / sizeof(float); - size_t mf_value_col = _value_accesor->mf_size() / sizeof(float); - size_t update_value_col = _value_accesor->update_size() / sizeof(float); + const size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float); + size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); + size_t update_value_col = + _value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float); for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( @@ -618,9 +621,10 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, task_keys[shard_id].push_back({keys[i], i}); } - size_t value_col = _value_accesor->size() / sizeof(float); - size_t mf_value_col = _value_accesor->mf_size() / sizeof(float); - size_t update_value_col = _value_accesor->update_size() / sizeof(float); + size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float); + size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); + size_t update_value_col = + _value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float); for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( diff --git a/paddle/fluid/distributed/ps/table/sparse_accessor.cc b/paddle/fluid/distributed/ps/table/sparse_accessor.cc index e971138c6cbf6b0cb9af891df89935f7b1416d17..651ff9d00e49ac653e606d00e2184deb0f811dc1 100644 --- a/paddle/fluid/distributed/ps/table/sparse_accessor.cc +++ b/paddle/fluid/distributed/ps/table/sparse_accessor.cc @@ -38,16 +38,39 @@ int SparseAccessor::initialize() { return 0; } -void SparseAccessor::GetTableInfo(AccessorInfo& info) { +void SparseAccessor::SetTableInfo(AccessorInfo& info) { info.dim = dim(); info.size = size(); info.select_dim = select_dim(); info.select_size = select_size(); info.update_dim = update_dim(); info.update_size = update_size(); + info.mf_size = mf_size(); info.fea_dim = fea_dim(); } +size_t SparseAccessor::GetTableInfo(InfoKey key) { + switch (key) { + case DIM: + return dim(); + case SIZE: + return size(); + case SELECT_DIM: + return select_dim(); + case SELECT_SIZE: + return select_size(); + case UPDATE_DIM: + return update_dim(); + case UPDATE_SIZE: + return update_size(); + case MF_SIZE: + return mf_size(); + case FEA_DIM: + return fea_dim(); + } + return 0; +} + size_t SparseAccessor::dim() { return sparse_feature_value.dim(); } size_t SparseAccessor::dim_size(size_t dim) { diff --git a/paddle/fluid/distributed/ps/table/sparse_accessor.h b/paddle/fluid/distributed/ps/table/sparse_accessor.h index 368e6bbcd3f5745135de480f71feef1462986826..cdc4c1dc6200e91d4b617814b3dc6595fae8e71f 100644 --- a/paddle/fluid/distributed/ps/table/sparse_accessor.h +++ b/paddle/fluid/distributed/ps/table/sparse_accessor.h @@ -123,7 +123,8 @@ class SparseAccessor : public ValueAccessor { }; SparseAccessor() {} virtual int initialize(); - virtual void GetTableInfo(AccessorInfo& info); + virtual void SetTableInfo(AccessorInfo& info); + virtual size_t GetTableInfo(InfoKey key); virtual ~SparseAccessor() {} // value维度 diff --git a/paddle/fluid/distributed/ps/table/table.cc b/paddle/fluid/distributed/ps/table/table.cc index 54e3576fd4ee0f46f09c026cd6c780d320949b1c..6faa3e2632e28cd5a60784240d97058f186a7003 100644 --- a/paddle/fluid/distributed/ps/table/table.cc +++ b/paddle/fluid/distributed/ps/table/table.cc @@ -103,6 +103,7 @@ int32_t Table::initialize_accessor() { return -1; } _value_accesor.reset(accessor); + // _value_accesor->SetTableInfo(_table_info); return 0; } diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index 2bd2a42b6c58f0753de86aa4e60ac7e0611bd7f7..bba34d89377a7d4050d0efa43c187bd8314fed39 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -37,7 +37,7 @@ enum ValueType { Sparse = 0, Dense = 1 }; struct PullContext { const uint64_t *keys; - const PullSparseValue pull_value; + PullSparseValue pull_value; float *values; char **ptr_values; }; @@ -53,7 +53,7 @@ struct TableContext { PullContext pull_context; TablePushContext push_context; size_t num; - bool use_ptr; + bool use_ptr = false; }; class Table { @@ -164,6 +164,7 @@ class Table { TableParameter _config; float *_global_lr = nullptr; std::shared_ptr _value_accesor; + AccessorInfo _table_info; AfsClient _afs_client; }; REGISTER_PSCORE_REGISTERER(Table); diff --git a/paddle/fluid/distributed/ps/table/tensor_accessor.cc b/paddle/fluid/distributed/ps/table/tensor_accessor.cc index 8c5349bff832caaa0a1b411723df8b3e9bcdcd4f..77014141783c39ac067e5065c32aeae8853bdd47 100644 --- a/paddle/fluid/distributed/ps/table/tensor_accessor.cc +++ b/paddle/fluid/distributed/ps/table/tensor_accessor.cc @@ -20,16 +20,39 @@ namespace distributed { int CommMergeAccessor::initialize() { return 0; } -void CommMergeAccessor::GetTableInfo(AccessorInfo &info) { +void CommMergeAccessor::SetTableInfo(AccessorInfo &info) { info.dim = dim(); info.size = size(); info.select_dim = select_dim(); info.select_size = select_size(); info.update_dim = update_dim(); info.update_size = update_size(); + info.mf_size = mf_size(); info.fea_dim = fea_dim(); } +size_t CommMergeAccessor::GetTableInfo(InfoKey key) { + switch (key) { + case DIM: + return dim(); + case SIZE: + return size(); + case SELECT_DIM: + return select_dim(); + case SELECT_SIZE: + return select_size(); + case UPDATE_DIM: + return update_dim(); + case UPDATE_SIZE: + return update_size(); + case MF_SIZE: + return mf_size(); + case FEA_DIM: + return fea_dim(); + } + return 0; +} + // value 维度 size_t CommMergeAccessor::dim() { return 0; } diff --git a/paddle/fluid/distributed/ps/table/tensor_accessor.h b/paddle/fluid/distributed/ps/table/tensor_accessor.h index 1873b743b44ec736f0470c3eff1f5b0280c235bf..6f5b69a392bc584701feec68821c87b826aa7db4 100644 --- a/paddle/fluid/distributed/ps/table/tensor_accessor.h +++ b/paddle/fluid/distributed/ps/table/tensor_accessor.h @@ -30,7 +30,8 @@ class CommMergeAccessor : public ValueAccessor { CommMergeAccessor() {} virtual ~CommMergeAccessor() {} virtual int initialize(); - virtual void GetTableInfo(AccessorInfo &info); + virtual void SetTableInfo(AccessorInfo &info); + virtual size_t GetTableInfo(InfoKey key); // value维度 virtual size_t dim(); // value各个维度的size diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index 8a129c6cc54dc4048abd0f3cf3b4964b5b6b2ac3..c9093368c693e774657e4e1f2b688774df24ebd2 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -337,9 +337,21 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, pull_result_ptr.push_back(output_data + output_len); } } - auto status = - worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id, - fea_keys.data(), fea_keys.size(), is_training); + // ps client pull sparse + // construct client request context + RequestContext req_context; + req_context.value_type = Sparse; + req_context.training_mode = Async; + req_context.table = table_id; + req_context.sparse_values = pull_result_ptr.data(); + req_context.keys = fea_keys.data(); + req_context.num = fea_keys.size(); + req_context.is_training = is_training; + auto status = worker_ptr_->Pull(req_context); + // auto status = + // worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id, + // fea_keys.data(), fea_keys.size(), + // is_training); status.wait(); auto ret = status.get(); if (ret != 0) { @@ -366,7 +378,14 @@ void FleetWrapper::PullDenseVarsAsync( paddle::distributed::Region reg(w, tensor->numel()); regions[i] = std::move(reg); } - auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid); + RequestContext req_context; + req_context.value_type = Dense; + req_context.training_mode = Async; + req_context.table = tid; + req_context.dense_values = regions.data(); + req_context.num = regions.size(); + auto status = worker_ptr_->Pull(req_context); + // auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid); pull_dense_status->push_back(std::move(status)); } @@ -451,8 +470,15 @@ void FleetWrapper::PushDenseVarsAsync( << g[tensor->numel() - 1]; } - auto push_status = - worker_ptr_->push_dense(regions.data(), regions.size(), table_id); + RequestContext req_context; + req_context.value_type = Dense; + req_context.training_mode = Async; + req_context.table = table_id; + req_context.push_context.push_dense_values = regions.data(); + req_context.num = regions.size(); + // auto push_status = + // worker_ptr_->push_dense(regions.data(), regions.size(), table_id); + auto push_status = worker_ptr_->Push(req_context); } void FleetWrapper::PushSparseVarsAsync( @@ -624,9 +650,19 @@ void FleetWrapper::PushSparseFromTensorAsync( push_g_vec[i] = push_values.at(i).data(); } - auto status = worker_ptr_->push_sparse(table_id, push_keys.data(), - (const float**)push_g_vec.data(), - push_keys.size()); + // ps client push sparse + // construct request context + RequestContext req_context; + req_context.value_type = Sparse; + req_context.training_mode = Async; + req_context.table = table_id; + req_context.push_context.push_values = (const float**)push_g_vec.data(); + req_context.push_context.keys = push_keys.data(); + req_context.num = push_keys.size(); + auto status = worker_ptr_->Push(req_context); + // auto status = worker_ptr_->push_sparse(table_id, push_keys.data(), + // (const float**)push_g_vec.data(), + // push_keys.size()); } void FleetWrapper::LoadModel(const std::string& path, const int mode) {