diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc index 071a1703e2a6d807c7eca4466fd79f8cc6eb5f7a..3e0f631ed41bcf2582b32d9fe834aac26aaf14ec 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc @@ -39,6 +39,33 @@ int32_t SSDSparseTable::Initialize() { int32_t SSDSparseTable::InitializeShard() { return 0; } +int32_t SSDSparseTable::Pull(TableContext& context) { + CHECK(context.value_type == Sparse); + if (context.use_ptr) { + char** pull_values = context.pull_context.ptr_values; + const uint64_t* keys = context.pull_context.keys; + return PullSparsePtr(pull_values, keys, context.num); + } else { + float* pull_values = context.pull_context.values; + const PullSparseValue& pull_value = context.pull_context.pull_value; + return PullSparse(pull_values, pull_value.feasigns_, pull_value.numel_); + } +} + +int32_t SSDSparseTable::Push(TableContext& context) { + CHECK(context.value_type == Sparse); + if (context.use_ptr) { + return PushSparse(context.push_context.keys, + context.push_context.ptr_values, + context.num); + } else { + const uint64_t* keys = context.push_context.keys; + const float* values = context.push_context.values; + size_t num = context.num; + return PushSparse(keys, values, num); + } +} + int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys, size_t num) { @@ -73,7 +100,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, &missed_keys]() -> int { auto& keys = task_keys[shard_id]; auto& local_shard = _local_shards[shard_id]; - float data_buffer[value_size]; + float data_buffer[value_size]; // NOLINT float* data_buffer_ptr = data_buffer; for (size_t i = 0; i < keys.size(); ++i) { uint64_t key = keys[i].first; @@ -83,7 +110,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, // pull rocksdb std::string tmp_string(""); if (_db->get(shard_id, - (char*)&key, + reinterpret_cast(&key), sizeof(uint64_t), tmp_string) > 0) { ++missed_keys; @@ -110,7 +137,9 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, memcpy(const_cast(feature_value.data()), data_buffer_ptr, data_size * sizeof(float)); - _db->del_data(shard_id, (char*)&key, sizeof(uint64_t)); + _db->del_data(shard_id, + reinterpret_cast(&key), + sizeof(uint64_t)); } } else { data_size = itr.value().size(); @@ -142,6 +171,95 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, return 0; } +int32_t SSDSparseTable::PullSparsePtr(char** pull_values, + const uint64_t* keys, + size_t num) { + CostTimer timer("pserver_ssd_sparse_select_all"); + size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_size = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); + + { // 从table取值 or create + std::vector> tasks(_real_local_shard_num); + std::vector>> task_keys( + _real_local_shard_num); + for (size_t i = 0; i < num; ++i) { + int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num; + task_keys[shard_id].push_back({keys[i], i}); + } + + std::atomic missed_keys{0}; + for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + tasks[shard_id] = + _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( + [this, + shard_id, + &task_keys, + value_size, + mf_value_size, + pull_values, + &missed_keys]() -> int { + auto& keys = task_keys[shard_id]; + auto& local_shard = _local_shards[shard_id]; + float data_buffer[value_size]; // NOLINT + float* data_buffer_ptr = data_buffer; + for (size_t i = 0; i < keys.size(); ++i) { + uint64_t key = keys[i].first; + auto itr = local_shard.find(key); + size_t data_size = value_size - mf_value_size; + FixedFeatureValue* ret = NULL; + if (itr == local_shard.end()) { + // pull rocksdb + std::string tmp_string(""); + if (_db->get(shard_id, + reinterpret_cast(&key), + sizeof(uint64_t), + tmp_string) > 0) { + ++missed_keys; + auto& feature_value = local_shard[key]; + feature_value.resize(data_size); + float* data_ptr = + const_cast(feature_value.data()); + _value_accesor->Create(&data_buffer_ptr, 1); + memcpy( + data_ptr, data_buffer_ptr, data_size * sizeof(float)); + ret = &feature_value; + } else { + data_size = tmp_string.size() / sizeof(float); + memcpy(data_buffer_ptr, + paddle::string::str_to_float(tmp_string), + data_size * sizeof(float)); + // from rocksdb to mem + auto& feature_value = local_shard[key]; + feature_value.resize(data_size); + memcpy(const_cast(feature_value.data()), + data_buffer_ptr, + data_size * sizeof(float)); + _db->del_data(shard_id, + reinterpret_cast(&key), + sizeof(uint64_t)); + ret = &feature_value; + } + } else { + ret = itr.value_ptr(); + } + int pull_data_idx = keys[i].second; + pull_values[pull_data_idx] = reinterpret_cast(ret); + } + return 0; + }); + } + for (int i = 0; i < _real_local_shard_num; ++i) { + tasks[i].wait(); + } + if (FLAGS_pserver_print_missed_key_num_every_push) { + LOG(WARNING) << "total pull keys:" << num + << " missed_keys:" << missed_keys.load(); + } + } + return 0; +} + int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values, size_t num) { @@ -172,7 +290,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, &task_keys]() -> int { auto& keys = task_keys[shard_id]; auto& local_shard = _local_shards[shard_id]; - float data_buffer[value_col]; + float data_buffer[value_col]; // NOLINT float* data_buffer_ptr = data_buffer; for (size_t i = 0; i < keys.size(); ++i) { uint64_t key = keys[i].first; @@ -201,7 +319,8 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, if (value_size == value_col) { // 已拓展到最大size, 则就地update _value_accesor->Update(&value_data, &update_data, 1); - } else { // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 + } else { + // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 memcpy(data_buffer_ptr, value_data, value_size * sizeof(float)); @@ -247,6 +366,90 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, return 0; } +int32_t SSDSparseTable::PushSparse(const uint64_t* keys, + const float** values, + size_t num) { + CostTimer timer("pserver_downpour_sparse_update_all"); + // 构造value push_value的数据指针 + size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_col = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); + size_t update_value_col = + _value_accesor->GetAccessorInfo().update_size / sizeof(float); + { + std::vector> tasks(_real_local_shard_num); + std::vector>> task_keys( + _real_local_shard_num); + for (size_t i = 0; i < num; ++i) { + int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num; + task_keys[shard_id].push_back({keys[i], i}); + } + for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + tasks[shard_id] = + _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( + [this, + shard_id, + value_col, + mf_value_col, + update_value_col, + values, + &task_keys]() -> int { + auto& keys = task_keys[shard_id]; + auto& local_shard = _local_shards[shard_id]; + float data_buffer[value_col]; // NOLINT + float* data_buffer_ptr = data_buffer; + for (size_t i = 0; i < keys.size(); ++i) { + uint64_t key = keys[i].first; + uint64_t push_data_idx = keys[i].second; + const float* update_data = values[push_data_idx]; + auto itr = local_shard.find(key); + if (itr == local_shard.end()) { + if (FLAGS_pserver_enable_create_feasign_randomly && + !_value_accesor->CreateValue(1, update_data)) { + continue; + } + auto value_size = value_col - mf_value_col; + auto& feature_value = local_shard[key]; + feature_value.resize(value_size); + _value_accesor->Create(&data_buffer_ptr, 1); + memcpy(const_cast(feature_value.data()), + data_buffer_ptr, + value_size * sizeof(float)); + itr = local_shard.find(key); + } + auto& feature_value = itr.value(); + float* value_data = const_cast(feature_value.data()); + size_t value_size = feature_value.size(); + + if (value_size == + value_col) { // 已拓展到最大size, 则就地update + _value_accesor->Update(&value_data, &update_data, 1); + } else { + // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 + memcpy(data_buffer_ptr, + value_data, + value_size * sizeof(float)); + _value_accesor->Update(&data_buffer_ptr, &update_data, 1); + if (_value_accesor->NeedExtendMF(data_buffer)) { + feature_value.resize(value_col); + value_data = const_cast(feature_value.data()); + _value_accesor->Create(&value_data, 1); + } + memcpy(value_data, + data_buffer_ptr, + value_size * sizeof(float)); + } + } + return 0; + }); + } + for (int i = 0; i < _real_local_shard_num; ++i) { + tasks[i].wait(); + } + } + return 0; +} + int32_t SSDSparseTable::Shrink(const std::string& param) { int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; omp_set_num_threads(thread_num); @@ -282,7 +485,7 @@ int32_t SSDSparseTable::Shrink(const std::string& param) { delete it; LOG(INFO) << "SSDSparseTable shrink success. shard:" << i << " delete MEM[" << mem_count << "] SSD[" << ssd_count << "]"; - //_db->flush(i); + // _db->flush(i); } return 0; } diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h index 5b38e4b3d73f75c28fe6e20ebaa01fd04d28dc1e..55a05bbab5ec2491af9393b50680e8789921dc97 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h @@ -33,26 +33,14 @@ class SSDSparseTable : public MemorySparseTable { // exchange data int32_t UpdateTable(); - int32_t Pull(TableContext& context) override { - CHECK(context.value_type == Sparse); - float* pull_values = context.pull_context.values; - const PullSparseValue& pull_value = context.pull_context.pull_value; - return PullSparse(pull_values, pull_value.feasigns_, pull_value.numel_); - } + int32_t Pull(TableContext& context) override; - int32_t Push(TableContext& context) override { - const uint64_t* keys = context.push_context.keys; - const float* values = context.push_context.values; - size_t num = context.num; - return PushSparse(keys, values, num); - } + int32_t Push(TableContext& context) override; - virtual int32_t PullSparse(float* pull_values, - const uint64_t* keys, - size_t num); - virtual int32_t PushSparse(const uint64_t* keys, - const float* values, - size_t num); + int32_t PullSparse(float* pull_values, const uint64_t* keys, size_t num); + int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, size_t num); + int32_t PushSparse(const uint64_t* keys, const float* values, size_t num); + int32_t PushSparse(const uint64_t* keys, const float** values, size_t num); int32_t Flush() override { return 0; } virtual int32_t Shrink(const std::string& param) override;