未验证 提交 2089b485 编写于 作者: Y yaoxuefeng 提交者: GitHub

change to new api in ssync mode (#41022)

* change to new api in ssync mode

* fix

* fix

* fix

* fix
上级 60c4c9cd
......@@ -532,18 +532,17 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
return pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training);
return pull_sparse_param(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training);
return pull_sparse(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
}
}
}
......@@ -551,7 +550,7 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table);
return push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
......@@ -561,7 +560,7 @@ std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num);
return push_sparse(table_id, keys, update_values, num);
}
}
}
......@@ -584,11 +583,12 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums),
sizeof(uint32_t));
keys->resize(shard_nums);
values->resize(shard_nums * accessor->update_dim());
values->resize(shard_nums * accessor->GetTableInfo(UPDATE_DIM));
io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT
sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward((void *)(values->data()), // NOLINT
shard_nums * accessor->update_size());
io_buffer_itr.copy_and_forward(
(void *)(values->data()), // NOLINT
shard_nums * accessor->GetTableInfo(UPDATE_SIZE));
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
......@@ -630,7 +630,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->update_size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
......@@ -638,13 +638,14 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size()));
push_data->resize(kv_size *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->update_size());
push_data_ptr += accessor->update_size();
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
......@@ -660,9 +661,11 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = table_accessor(table_id);
auto fea_dim = accessor->GetTableInfo(FEA_DIM);
auto select_size = accessor->GetTableInfo(SELECT_SIZE);
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, num_per_shard, regions, region_num,
......@@ -671,7 +674,8 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t region_idx = 0; // 当前填充的region偏移
size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
size_t shard_data_size = num_per_shard * accessor->select_size();
size_t shard_data_size =
num_per_shard * accessor->GetTableInfo(SELECT_SIZE);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
ret = -1;
......@@ -739,8 +743,8 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
size_t shard_data_size = num_per_shard * accessor->update_size();
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE);
size_t current_region_idx = 0;
size_t current_region_data_idx = 0;
for (size_t i = 0; i < request_call_num; ++i) {
......@@ -847,7 +851,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->update_size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
......@@ -856,14 +860,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size()));
push_data->resize(kv_size *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->update_size());
push_data_ptr += accessor->update_size();
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
......@@ -884,7 +889,7 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
std::future<int> fut = promise->get_future();
auto *accessor = table_accessor(table_id);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
......@@ -962,7 +967,8 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
}
auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();
size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
......@@ -1075,7 +1081,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
}
auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();
size_t value_size = accessor->GetTableInfo(SELECT_SIZE);
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
......@@ -1199,7 +1205,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) {
auto *accessor = table_accessor(table_id);
size_t value_size = accessor->update_size();
size_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
......@@ -1359,8 +1365,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
shard_kv_data.kv_num = 0;
continue;
}
uint32_t value_size = accessor->update_size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first;
shard_kv_data.value_list[kv_idx].assign(
......@@ -1506,7 +1511,7 @@ void BrpcPsClient::push_sparse_task_consume() {
void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
const float *another_data) {
size_t col_num = accessor->update_size() / sizeof(float);
size_t col_num = accessor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
float *merge_data_shell[col_num];
const float *another_data_shell[col_num];
for (int i = 0; i < col_num; ++i) {
......@@ -1522,7 +1527,7 @@ int BrpcPsClient::push_sparse_async_shard_merge(
ValueAccessor *accessor) {
size_t merged_kv_count = 0;
uint64_t min_key = UINT64_MAX;
uint32_t value_size = accessor->update_size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear();
......@@ -1628,8 +1633,9 @@ int BrpcPsClient::push_sparse_async_shard_push(
push_request->add_params(reinterpret_cast<char *>(&merged_kv_count),
sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
int update_size = accessor->GetTableInfo(UPDATE_SIZE);
push_data->resize(merged_kv_count *
(sizeof(uint64_t) + accessor->update_size()));
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, merged_key_list.data(),
merged_kv_count * sizeof(uint64_t));
......@@ -1638,8 +1644,8 @@ int BrpcPsClient::push_sparse_async_shard_push(
const char *task_data_ptr = merged_value_list[i].data();
memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT
accessor->update_size());
push_data_ptr += accessor->update_size();
accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
......@@ -1654,6 +1660,8 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = table_accessor(table_id);
int fea_dim = accessor->GetTableInfo(FEA_DIM);
int update_dim = accessor->GetTableInfo(UPDATE_DIM);
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
auto parse_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_parse");
......@@ -1673,11 +1681,11 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num *
accessor->update_dim());
accessor->GetTableInfo(UPDATE_DIM));
float *data = async_task->data()->data();
size_t data_size = async_task->data()->size();
uint32_t pos = 0;
......@@ -1806,7 +1814,7 @@ void BrpcPsClient::push_dense_raw_gradient(
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) {
......
......@@ -207,7 +207,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_dense(res_data->data(), num);
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = res_data->data();
table_context.num = num;
table->Pull(table_context);
// table->pull_dense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
......@@ -264,9 +269,15 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
|--4B---|----------------|
*/
uint32_t num = *(const uint32_t *)(request.data().data());
const float *values =
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values =
(const float *)(request.data().data() + sizeof(uint32_t));
if (table->push_dense(values, num) != 0) {
table_context.num = num;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if (table->Push(table_context) != 0) {
// if (table->push_dense(values, num) != 0) {
set_response_code(response, -1, "push_dense failed");
}
......@@ -388,7 +399,12 @@ int32_t BrpcPsService::pull_sparse(Table *table,
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * dim);
table->pull_sparse(res_data->data(), value);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.pull_value = value;
table_context.pull_context.values = res_data->data();
table->Pull(table_context);
// table->pull_sparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
......@@ -421,10 +437,17 @@ int32_t BrpcPsService::push_sparse(Table *table,
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = (const uint64_t *)push_data.data();
table_context.push_context.values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse(keys, values, num) != 0) {
table_context.num = num;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if (table->Push(table_context) != 0) {
// if (table->push_sparse(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse error");
}
return 0;
......
......@@ -86,9 +86,9 @@ struct RequestContext {
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
void *keys;
void **sparse_values; // for sparse values
Region *dense_values; // for dense values
uint64_t *keys;
float **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
......
......@@ -126,11 +126,13 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values);
// uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
// char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num);
pull_sparse_ptr(reinterpret_cast<char**>(pull_context.sparse_values),
table_id, pull_context.keys, num);
}
}
......
......@@ -56,6 +56,17 @@ struct AccessorInfo {
size_t fea_dim;
};
enum InfoKey {
DIM = 0,
SIZE = 1,
SELECT_SIZE = 2,
SELECT_DIM = 3,
UPDATE_SIZE = 4,
UPDATE_DIM = 5,
MF_SIZE = 6,
FEA_DIM = 7
};
class ValueAccessor {
public:
ValueAccessor() {}
......@@ -79,7 +90,8 @@ class ValueAccessor {
}
virtual int initialize() = 0;
virtual void GetTableInfo(AccessorInfo& info) = 0;
virtual void SetTableInfo(AccessorInfo& info) = 0;
virtual size_t GetTableInfo(InfoKey key) = 0;
// value维度
virtual size_t dim() = 0;
......
......@@ -138,7 +138,7 @@ int32_t CommonDenseTable::Pull(TableContext& context) {
int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense);
if (context.pull_context.values != nullptr) {
if (context.push_context.values != nullptr) {
const float* values = context.push_context.values;
return push_dense(values, context.num);
}
......@@ -220,7 +220,7 @@ int32_t CommonDenseTable::load(const std::string& path,
}
size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t dim_num_per_shard = _value_accesor->fea_dim() / _shard_num + 1;
size_t dim_num_per_shard = _table_info.fea_dim / _shard_num + 1;
size_t start_dim_idx = dim_num_per_shard * _shard_idx;
size_t start_file_idx = start_dim_idx / dim_num_per_file;
size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file;
......
......@@ -370,7 +370,7 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
int32_t CommonSparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.pull_context.values != nullptr) {
if (context.push_context.values != nullptr) {
const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
......
......@@ -38,16 +38,39 @@ int CtrCommonAccessor::initialize() {
return 0;
}
void CtrCommonAccessor::GetTableInfo(AccessorInfo& info) {
void CtrCommonAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim();
}
size_t CtrCommonAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); }
size_t CtrCommonAccessor::dim_size(size_t dim) {
......
......@@ -137,7 +137,8 @@ class CtrCommonAccessor : public ValueAccessor {
virtual int initialize();
virtual ~CtrCommonAccessor() {}
virtual void GetTableInfo(AccessorInfo& info);
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度
virtual size_t dim();
// value各个维度的size
......
......@@ -37,16 +37,39 @@ int DownpourCtrDoubleAccessor::initialize() {
return 0;
}
void DownpourCtrDoubleAccessor::GetTableInfo(AccessorInfo& info) {
void DownpourCtrDoubleAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrDoubleAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
size_t DownpourCtrDoubleAccessor::dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::dim(embedx_dim);
......
......@@ -168,7 +168,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {}
virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度
virtual size_t dim();
// value各个维度的size
......
......@@ -24,6 +24,7 @@ namespace paddle {
namespace distributed {
struct PullSparseValue {
PullSparseValue() {}
explicit PullSparseValue(int numel, int dim)
: numel_(numel),
dim_(dim),
......
......@@ -37,16 +37,39 @@ int DownpourCtrAccessor::initialize() {
return 0;
}
void DownpourCtrAccessor::GetTableInfo(AccessorInfo& info) {
void DownpourCtrAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
size_t DownpourCtrAccessor::dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim(embedx_dim);
......
......@@ -160,7 +160,8 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual ~DownpourCtrAccessor() {}
virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
// value维度
virtual size_t dim();
// value各个维度的size
......
......@@ -88,7 +88,8 @@ int32_t MemorySparseTable::load(const std::string& path,
size_t file_start_idx = _shard_idx * _avg_local_shard_num;
size_t feature_value_size = _value_accesor->size() / sizeof(float);
size_t feature_value_size =
_value_accesor->GetTableInfo(SIZE) / sizeof(float);
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num);
......@@ -173,7 +174,8 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path,
size_t file_start_idx = _shard_idx * _avg_local_shard_num;
size_t feature_value_size = _value_accesor->size() / sizeof(float);
size_t feature_value_size =
_value_accesor->GetTableInfo(SIZE) / sizeof(float);
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num);
......@@ -407,7 +409,7 @@ int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, context.push_context.ptr_values, context.num);
return push_sparse(keys, context.push_context.values, context.num);
}
int32_t MemorySparseTable::pull_sparse(float* pull_values,
......@@ -415,9 +417,10 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
CostTimer timer("pserver_sparse_select_all");
std::vector<std::future<int>> tasks(_real_local_shard_num);
const size_t value_size = _value_accesor->size() / sizeof(float);
size_t mf_value_size = _value_accesor->mf_size() / sizeof(float);
size_t select_value_size = _value_accesor->select_size() / sizeof(float);
const size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
size_t select_value_size =
_value_accesor->GetTableInfo(SELECT_SIZE) / sizeof(float);
// std::atomic<uint32_t> missed_keys{0};
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
......@@ -475,7 +478,6 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
......@@ -541,9 +543,10 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
task_keys[shard_id].push_back({keys[i], i});
}
const size_t value_col = _value_accesor->size() / sizeof(float);
size_t mf_value_col = _value_accesor->mf_size() / sizeof(float);
size_t update_value_col = _value_accesor->update_size() / sizeof(float);
const size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
size_t update_value_col =
_value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
......@@ -618,9 +621,10 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
task_keys[shard_id].push_back({keys[i], i});
}
size_t value_col = _value_accesor->size() / sizeof(float);
size_t mf_value_col = _value_accesor->mf_size() / sizeof(float);
size_t update_value_col = _value_accesor->update_size() / sizeof(float);
size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
size_t update_value_col =
_value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
......
......@@ -38,16 +38,39 @@ int SparseAccessor::initialize() {
return 0;
}
void SparseAccessor::GetTableInfo(AccessorInfo& info) {
void SparseAccessor::SetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim();
}
size_t SparseAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
size_t SparseAccessor::dim() { return sparse_feature_value.dim(); }
size_t SparseAccessor::dim_size(size_t dim) {
......
......@@ -123,7 +123,8 @@ class SparseAccessor : public ValueAccessor {
};
SparseAccessor() {}
virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
virtual void SetTableInfo(AccessorInfo& info);
virtual size_t GetTableInfo(InfoKey key);
virtual ~SparseAccessor() {}
// value维度
......
......@@ -103,6 +103,7 @@ int32_t Table::initialize_accessor() {
return -1;
}
_value_accesor.reset(accessor);
// _value_accesor->SetTableInfo(_table_info);
return 0;
}
......
......@@ -37,7 +37,7 @@ enum ValueType { Sparse = 0, Dense = 1 };
struct PullContext {
const uint64_t *keys;
const PullSparseValue pull_value;
PullSparseValue pull_value;
float *values;
char **ptr_values;
};
......@@ -53,7 +53,7 @@ struct TableContext {
PullContext pull_context;
TablePushContext push_context;
size_t num;
bool use_ptr;
bool use_ptr = false;
};
class Table {
......@@ -164,6 +164,7 @@ class Table {
TableParameter _config;
float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor;
AccessorInfo _table_info;
AfsClient _afs_client;
};
REGISTER_PSCORE_REGISTERER(Table);
......
......@@ -20,16 +20,39 @@ namespace distributed {
int CommMergeAccessor::initialize() { return 0; }
void CommMergeAccessor::GetTableInfo(AccessorInfo &info) {
void CommMergeAccessor::SetTableInfo(AccessorInfo &info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.mf_size = mf_size();
info.fea_dim = fea_dim();
}
size_t CommMergeAccessor::GetTableInfo(InfoKey key) {
switch (key) {
case DIM:
return dim();
case SIZE:
return size();
case SELECT_DIM:
return select_dim();
case SELECT_SIZE:
return select_size();
case UPDATE_DIM:
return update_dim();
case UPDATE_SIZE:
return update_size();
case MF_SIZE:
return mf_size();
case FEA_DIM:
return fea_dim();
}
return 0;
}
// value 维度
size_t CommMergeAccessor::dim() { return 0; }
......
......@@ -30,7 +30,8 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor() {}
virtual ~CommMergeAccessor() {}
virtual int initialize();
virtual void GetTableInfo(AccessorInfo &info);
virtual void SetTableInfo(AccessorInfo &info);
virtual size_t GetTableInfo(InfoKey key);
// value维度
virtual size_t dim();
// value各个维度的size
......
......@@ -337,9 +337,21 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
pull_result_ptr.push_back(output_data + output_len);
}
}
auto status =
worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id,
fea_keys.data(), fea_keys.size(), is_training);
// ps client pull sparse
// construct client request context
RequestContext req_context;
req_context.value_type = Sparse;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.sparse_values = pull_result_ptr.data();
req_context.keys = fea_keys.data();
req_context.num = fea_keys.size();
req_context.is_training = is_training;
auto status = worker_ptr_->Pull(req_context);
// auto status =
// worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id,
// fea_keys.data(), fea_keys.size(),
// is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
......@@ -366,7 +378,14 @@ void FleetWrapper::PullDenseVarsAsync(
paddle::distributed::Region reg(w, tensor->numel());
regions[i] = std::move(reg);
}
auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
RequestContext req_context;
req_context.value_type = Dense;
req_context.training_mode = Async;
req_context.table = tid;
req_context.dense_values = regions.data();
req_context.num = regions.size();
auto status = worker_ptr_->Pull(req_context);
// auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status->push_back(std::move(status));
}
......@@ -451,8 +470,15 @@ void FleetWrapper::PushDenseVarsAsync(
<< g[tensor->numel() - 1];
}
auto push_status =
worker_ptr_->push_dense(regions.data(), regions.size(), table_id);
RequestContext req_context;
req_context.value_type = Dense;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.push_context.push_dense_values = regions.data();
req_context.num = regions.size();
// auto push_status =
// worker_ptr_->push_dense(regions.data(), regions.size(), table_id);
auto push_status = worker_ptr_->Push(req_context);
}
void FleetWrapper::PushSparseVarsAsync(
......@@ -624,9 +650,19 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_g_vec[i] = push_values.at(i).data();
}
auto status = worker_ptr_->push_sparse(table_id, push_keys.data(),
(const float**)push_g_vec.data(),
push_keys.size());
// ps client push sparse
// construct request context
RequestContext req_context;
req_context.value_type = Sparse;
req_context.training_mode = Async;
req_context.table = table_id;
req_context.push_context.push_values = (const float**)push_g_vec.data();
req_context.push_context.keys = push_keys.data();
req_context.num = push_keys.size();
auto status = worker_ptr_->Push(req_context);
// auto status = worker_ptr_->push_sparse(table_id, push_keys.data(),
// (const float**)push_g_vec.data(),
// push_keys.size());
}
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册