未验证 提交 449ea33d 编写于 作者: W wangguanqun 提交者: GitHub

[GPUPS]SSDSparseTable add PullSparsePtr (#44137)

* ssd pullsparseptr

* update codestyle
上级 826e2781
...@@ -39,6 +39,33 @@ int32_t SSDSparseTable::Initialize() { ...@@ -39,6 +39,33 @@ int32_t SSDSparseTable::Initialize() {
int32_t SSDSparseTable::InitializeShard() { return 0; } int32_t SSDSparseTable::InitializeShard() { return 0; }
int32_t SSDSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return PullSparsePtr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return PullSparse(pull_values, pull_value.feasigns_, pull_value.numel_);
}
}
int32_t SSDSparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
return PushSparse(context.push_context.keys,
context.push_context.ptr_values,
context.num);
} else {
const uint64_t* keys = context.push_context.keys;
const float* values = context.push_context.values;
size_t num = context.num;
return PushSparse(keys, values, num);
}
}
int32_t SSDSparseTable::PullSparse(float* pull_values, int32_t SSDSparseTable::PullSparse(float* pull_values,
const uint64_t* keys, const uint64_t* keys,
size_t num) { size_t num) {
...@@ -73,7 +100,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, ...@@ -73,7 +100,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values,
&missed_keys]() -> int { &missed_keys]() -> int {
auto& keys = task_keys[shard_id]; auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
float data_buffer[value_size]; float data_buffer[value_size]; // NOLINT
float* data_buffer_ptr = data_buffer; float* data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
...@@ -83,7 +110,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, ...@@ -83,7 +110,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values,
// pull rocksdb // pull rocksdb
std::string tmp_string(""); std::string tmp_string("");
if (_db->get(shard_id, if (_db->get(shard_id,
(char*)&key, reinterpret_cast<char*>(&key),
sizeof(uint64_t), sizeof(uint64_t),
tmp_string) > 0) { tmp_string) > 0) {
++missed_keys; ++missed_keys;
...@@ -110,7 +137,9 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, ...@@ -110,7 +137,9 @@ int32_t SSDSparseTable::PullSparse(float* pull_values,
memcpy(const_cast<float*>(feature_value.data()), memcpy(const_cast<float*>(feature_value.data()),
data_buffer_ptr, data_buffer_ptr,
data_size * sizeof(float)); data_size * sizeof(float));
_db->del_data(shard_id, (char*)&key, sizeof(uint64_t)); _db->del_data(shard_id,
reinterpret_cast<char*>(&key),
sizeof(uint64_t));
} }
} else { } else {
data_size = itr.value().size(); data_size = itr.value().size();
...@@ -142,6 +171,95 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, ...@@ -142,6 +171,95 @@ int32_t SSDSparseTable::PullSparse(float* pull_values,
return 0; return 0;
} }
int32_t SSDSparseTable::PullSparsePtr(char** pull_values,
const uint64_t* keys,
size_t num) {
CostTimer timer("pserver_ssd_sparse_select_all");
size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_size =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
{ // 从table取值 or create
std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
_real_local_shard_num);
for (size_t i = 0; i < num; ++i) {
int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
task_keys[shard_id].push_back({keys[i], i});
}
std::atomic<uint32_t> missed_keys{0};
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this,
shard_id,
&task_keys,
value_size,
mf_value_size,
pull_values,
&missed_keys]() -> int {
auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id];
float data_buffer[value_size]; // NOLINT
float* data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first;
auto itr = local_shard.find(key);
size_t data_size = value_size - mf_value_size;
FixedFeatureValue* ret = NULL;
if (itr == local_shard.end()) {
// pull rocksdb
std::string tmp_string("");
if (_db->get(shard_id,
reinterpret_cast<char*>(&key),
sizeof(uint64_t),
tmp_string) > 0) {
++missed_keys;
auto& feature_value = local_shard[key];
feature_value.resize(data_size);
float* data_ptr =
const_cast<float*>(feature_value.data());
_value_accesor->Create(&data_buffer_ptr, 1);
memcpy(
data_ptr, data_buffer_ptr, data_size * sizeof(float));
ret = &feature_value;
} else {
data_size = tmp_string.size() / sizeof(float);
memcpy(data_buffer_ptr,
paddle::string::str_to_float(tmp_string),
data_size * sizeof(float));
// from rocksdb to mem
auto& feature_value = local_shard[key];
feature_value.resize(data_size);
memcpy(const_cast<float*>(feature_value.data()),
data_buffer_ptr,
data_size * sizeof(float));
_db->del_data(shard_id,
reinterpret_cast<char*>(&key),
sizeof(uint64_t));
ret = &feature_value;
}
} else {
ret = itr.value_ptr();
}
int pull_data_idx = keys[i].second;
pull_values[pull_data_idx] = reinterpret_cast<char*>(ret);
}
return 0;
});
}
for (int i = 0; i < _real_local_shard_num; ++i) {
tasks[i].wait();
}
if (FLAGS_pserver_print_missed_key_num_every_push) {
LOG(WARNING) << "total pull keys:" << num
<< " missed_keys:" << missed_keys.load();
}
}
return 0;
}
int32_t SSDSparseTable::PushSparse(const uint64_t* keys, int32_t SSDSparseTable::PushSparse(const uint64_t* keys,
const float* values, const float* values,
size_t num) { size_t num) {
...@@ -172,7 +290,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, ...@@ -172,7 +290,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys,
&task_keys]() -> int { &task_keys]() -> int {
auto& keys = task_keys[shard_id]; auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
float data_buffer[value_col]; float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer; float* data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
...@@ -201,7 +319,8 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, ...@@ -201,7 +319,8 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys,
if (value_size == if (value_size ==
value_col) { // 已拓展到最大size, 则就地update value_col) { // 已拓展到最大size, 则就地update
_value_accesor->Update(&value_data, &update_data, 1); _value_accesor->Update(&value_data, &update_data, 1);
} else { // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 } else {
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy(data_buffer_ptr, memcpy(data_buffer_ptr,
value_data, value_data,
value_size * sizeof(float)); value_size * sizeof(float));
...@@ -247,6 +366,90 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, ...@@ -247,6 +366,90 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys,
return 0; return 0;
} }
int32_t SSDSparseTable::PushSparse(const uint64_t* keys,
const float** values,
size_t num) {
CostTimer timer("pserver_downpour_sparse_update_all");
// 构造value push_value的数据指针
size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_col =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t update_value_col =
_value_accesor->GetAccessorInfo().update_size / sizeof(float);
{
std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
_real_local_shard_num);
for (size_t i = 0; i < num; ++i) {
int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
task_keys[shard_id].push_back({keys[i], i});
}
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this,
shard_id,
value_col,
mf_value_col,
update_value_col,
values,
&task_keys]() -> int {
auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id];
float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second;
const float* update_data = values[push_data_idx];
auto itr = local_shard.find(key);
if (itr == local_shard.end()) {
if (FLAGS_pserver_enable_create_feasign_randomly &&
!_value_accesor->CreateValue(1, update_data)) {
continue;
}
auto value_size = value_col - mf_value_col;
auto& feature_value = local_shard[key];
feature_value.resize(value_size);
_value_accesor->Create(&data_buffer_ptr, 1);
memcpy(const_cast<float*>(feature_value.data()),
data_buffer_ptr,
value_size * sizeof(float));
itr = local_shard.find(key);
}
auto& feature_value = itr.value();
float* value_data = const_cast<float*>(feature_value.data());
size_t value_size = feature_value.size();
if (value_size ==
value_col) { // 已拓展到最大size, 则就地update
_value_accesor->Update(&value_data, &update_data, 1);
} else {
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy(data_buffer_ptr,
value_data,
value_size * sizeof(float));
_value_accesor->Update(&data_buffer_ptr, &update_data, 1);
if (_value_accesor->NeedExtendMF(data_buffer)) {
feature_value.resize(value_col);
value_data = const_cast<float*>(feature_value.data());
_value_accesor->Create(&value_data, 1);
}
memcpy(value_data,
data_buffer_ptr,
value_size * sizeof(float));
}
}
return 0;
});
}
for (int i = 0; i < _real_local_shard_num; ++i) {
tasks[i].wait();
}
}
return 0;
}
int32_t SSDSparseTable::Shrink(const std::string& param) { int32_t SSDSparseTable::Shrink(const std::string& param) {
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
...@@ -282,7 +485,7 @@ int32_t SSDSparseTable::Shrink(const std::string& param) { ...@@ -282,7 +485,7 @@ int32_t SSDSparseTable::Shrink(const std::string& param) {
delete it; delete it;
LOG(INFO) << "SSDSparseTable shrink success. shard:" << i << " delete MEM[" LOG(INFO) << "SSDSparseTable shrink success. shard:" << i << " delete MEM["
<< mem_count << "] SSD[" << ssd_count << "]"; << mem_count << "] SSD[" << ssd_count << "]";
//_db->flush(i); // _db->flush(i);
} }
return 0; return 0;
} }
......
...@@ -33,26 +33,14 @@ class SSDSparseTable : public MemorySparseTable { ...@@ -33,26 +33,14 @@ class SSDSparseTable : public MemorySparseTable {
// exchange data // exchange data
int32_t UpdateTable(); int32_t UpdateTable();
int32_t Pull(TableContext& context) override { int32_t Pull(TableContext& context) override;
CHECK(context.value_type == Sparse);
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return PullSparse(pull_values, pull_value.feasigns_, pull_value.numel_);
}
int32_t Push(TableContext& context) override { int32_t Push(TableContext& context) override;
const uint64_t* keys = context.push_context.keys;
const float* values = context.push_context.values;
size_t num = context.num;
return PushSparse(keys, values, num);
}
virtual int32_t PullSparse(float* pull_values, int32_t PullSparse(float* pull_values, const uint64_t* keys, size_t num);
const uint64_t* keys, int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, size_t num);
size_t num); int32_t PushSparse(const uint64_t* keys, const float* values, size_t num);
virtual int32_t PushSparse(const uint64_t* keys, int32_t PushSparse(const uint64_t* keys, const float** values, size_t num);
const float* values,
size_t num);
int32_t Flush() override { return 0; } int32_t Flush() override { return 0; }
virtual int32_t Shrink(const std::string& param) override; virtual int32_t Shrink(const std::string& param) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册