From 9b3b53ba4f7080ab012d870a84a7b0ea96cfc4c9 Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Sun, 30 Jan 2022 20:09:56 +0800 Subject: [PATCH] geo memory sparse table (#39250) * geo depends * add memory geo table * fix --- .../distributed/ps/service/brpc_ps_client.cc | 146 ++++++++++-- .../distributed/ps/service/brpc_ps_client.h | 4 + .../ps/service/communicator/communicator.cc | 32 ++- .../ps/service/communicator/communicator.h | 6 +- .../fluid/distributed/ps/service/ps_client.h | 11 + .../fluid/distributed/ps/table/CMakeLists.txt | 5 +- .../ps/table/depends/geo_recorder.h | 4 - .../ps/table/memory_sparse_geo_table.cc | 220 ++++++++++++++++++ .../ps/table/memory_sparse_geo_table.h | 78 +++++++ paddle/fluid/distributed/ps/table/table.cc | 2 + paddle/fluid/distributed/test/CMakeLists.txt | 3 + .../distributed/test/memory_geo_table_test.cc | 123 ++++++++++ .../distributed/fleet/runtime/the_one_ps.py | 3 +- 13 files changed, 604 insertions(+), 33 deletions(-) create mode 100644 paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc create mode 100644 paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h create mode 100644 paddle/fluid/distributed/test/memory_geo_table_test.cc diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index e855fcbd02..301136794d 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -213,6 +213,7 @@ int32_t BrpcPsClient::initialize() { auto &profiler = CostProfiler::instance(); profiler.register_profiler("pserver_client_pull_dense"); profiler.register_profiler("pserver_client_pull_sparse"); + profiler.register_profiler("pserver_client_pull_sparse_param"); profiler.register_profiler("pserver_client_pull_sparse_local"); profiler.register_profiler("pserver_client_push_sparse"); profiler.register_profiler("pserver_client_push_sparse_parse"); @@ -543,6 +544,7 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, return fut; } +// for GEO std::future BrpcPsClient::push_sparse_param( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) { @@ -558,18 +560,8 @@ std::future BrpcPsClient::push_sparse_param( ids.resize(request_call_num); value_ptrs.resize(request_call_num); - const auto &server_param = _config.server_param().downpour_server_param(); - uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; - for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { - const auto &table_param = server_param.downpour_table_param(i); - if (table_param.table_id() == table_id) { - shard_num = table_param.shard_num(); - break; - } - } - for (size_t i = 0; i < num; ++i) { - size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]); + size_t pserver_idx = keys[i] % request_call_num; ids[pserver_idx].push_back(keys[i]); value_ptrs[pserver_idx].push_back(update_values[i]); } @@ -1003,6 +995,120 @@ std::future BrpcPsClient::pull_sparse(float **select_values, return fut; } +// for GEO +std::future BrpcPsClient::pull_sparse_param(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num, + bool is_training) { + auto timer = std::make_shared("pserver_client_pull_sparse_param"); + size_t request_call_num = _server_channels.size(); + + auto shard_sorted_kvs = std::make_shared< + std::vector>>>(); + shard_sorted_kvs->resize(request_call_num); + + for (size_t i = 0; i < num; ++i) { + size_t shard_id = keys[i] % request_call_num; + shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); + } + + auto *accessor = table_accessor(table_id); + size_t value_size = accessor->select_size(); + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [shard_sorted_kvs, value_size](void *done) { + int ret = 0; + auto *closure = reinterpret_cast(done); + for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) { + if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + + auto &request_kvs = shard_sorted_kvs->at(i); + auto &res_io_buffer = closure->cntl(i)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + uint64_t last_key = UINT64_MAX; + float *last_value_data = NULL; + + // can remove sort&unique + for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) { + auto *kv_pair = &(request_kvs[kv_idx]); + if (kv_pair->first == last_key) { + memcpy(reinterpret_cast(kv_pair->second), + reinterpret_cast(last_value_data), value_size); + } else { + last_key = kv_pair->first; + last_value_data = kv_pair->second; + if (value_size != + io_buffer_itr.copy_and_forward( + reinterpret_cast(last_value_data), value_size)) { + LOG(WARNING) << "res data is lack or not in format"; + ret = -1; + break; + } + } + } + } + closure->set_promise_value(ret); + }); + closure->add_timer(timer); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (size_t i = 0; i < request_call_num; ++i) { + auto &sorted_kvs = shard_sorted_kvs->at(i); + std::sort(sorted_kvs.begin(), sorted_kvs.end(), + [](const std::pair &k1, + const std::pair &k2) { + return k1.first < k2.first; + }); + + uint64_t last_key = UINT64_MAX; + uint32_t kv_request_count = 0; + size_t sorted_kv_size = sorted_kvs.size(); + auto &request_buffer = closure->cntl(i)->request_attachment(); + + request_buffer.append(reinterpret_cast(&is_training), sizeof(bool)); + std::vector keys_counter; + keys_counter.reserve(sorted_kv_size); + + for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { + ++kv_request_count; + uint32_t keys = 1; + last_key = sorted_kvs[kv_idx].first; + request_buffer.append(reinterpret_cast(&last_key), + sizeof(uint64_t)); + while (kv_idx < sorted_kv_size - 1 && + last_key == sorted_kvs[kv_idx + 1].first) { + ++kv_idx; + ++keys; + } + keys_counter.push_back(keys); + } + + request_buffer.append(reinterpret_cast(keys_counter.data()), + sizeof(uint32_t) * keys_counter.size()); + + if (kv_request_count == 0) { + closure->Run(); + } else { + closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->add_params((char *)&kv_request_count, // NOLINT + sizeof(uint32_t)); + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + } + return fut; +} + std::future BrpcPsClient::send_client2client_msg( int msg_type, int to_client_id, const std::string &msg) { auto promise = std::make_shared>(); @@ -1067,12 +1173,14 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, std::string var_name = ""; int64_t var_num = 0; int64_t var_shape = 0; + std::string table_class; const auto &worker_param = _config.worker_param().downpour_worker_param(); for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) { if (worker_param.downpour_table_param(i).table_id() == table_id) { var_name = worker_param.downpour_table_param(i).common().table_name(); var_num = worker_param.downpour_table_param(i).common().table_num(); var_shape = worker_param.downpour_table_param(i).common().table_dim(); + table_class = worker_param.downpour_table_param(i).table_class(); break; } } @@ -1094,9 +1202,19 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, save_vec.push_back(save_huge_vec.data() + i * var_shape); } - auto status = pull_sparse(reinterpret_cast(save_vec.data()), - table_id, save_key.data(), save_key.size(), true); - status.wait(); + VLOG(2) << "recv_and_save_table: table_class: " << table_class; + // TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its + // recv_and_save_table + if (table_class == "MemorySparseGeoTable") { + auto status = + pull_sparse_param(reinterpret_cast(save_vec.data()), table_id, + save_key.data(), save_key.size(), true); + status.wait(); + } else { + auto status = pull_sparse(reinterpret_cast(save_vec.data()), + table_id, save_key.data(), save_key.size(), true); + status.wait(); + } // create lod tensor std::shared_ptr scope; diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.h b/paddle/fluid/distributed/ps/service/brpc_ps_client.h index 70f406ee24..59ed59933d 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.h @@ -194,6 +194,10 @@ class BrpcPsClient : public PSClient { size_t table_id, const uint64_t *keys, size_t num, bool is_training); + virtual std::future pull_sparse_param(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num, bool is_training); virtual std::future print_table_stat(uint32_t table_id); diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.cc b/paddle/fluid/distributed/ps/service/communicator/communicator.cc index a73f87c1d8..3f1667e534 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.cc +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.cc @@ -354,7 +354,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id, bool training = true; - auto status = _worker_ptr->pull_sparse( + auto status = _worker_ptr->pull_sparse_param( (float **)push_g_vec.data(), table_id, // NOLINT sparse_push_keys.data(), sparse_push_keys.size(), training); status.wait(); @@ -1029,7 +1029,7 @@ void GeoCommunicator::Send(const std::vector &var_names, auto &sparse_ids_set = iter.second; auto sparse_ids_vec = std::make_shared>(); sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end()); - sparse_id_queues_.at(key)->Push(sparse_ids_vec); + sparse_id_queues_.at(key)->Put(sparse_ids_vec); VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key << "'s queue"; } @@ -1051,7 +1051,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, for (auto &iter : send_varname_to_ctx_) { auto &ctx = iter.second; - if (!ctx.is_sparse) continue; + if (!ctx.is_sparse) { + parallel_task_nums_ += 1; + continue; + } auto &varnames = ctx.origin_varnames; PADDLE_ENFORCE_EQ( varnames.size(), 1, @@ -1060,12 +1063,11 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, for (auto &splited_var : ctx.splited_varnames) { parallel_task_nums_ += 1; sparse_id_queues_.insert( - std::pair>>>>( + std::pair>>>( splited_var, - std::make_shared< - BlockingQueue>>>( - send_queue_size_))); + paddle::framework::MakeChannel< + std::shared_ptr>>(send_queue_size_))); } } @@ -1242,8 +1244,8 @@ std::vector GeoCommunicator::MergeSparseIds( VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num; if (sparse_id_queues_.at(send_varname)->Size() > 0) { wait_times = 0; - std::shared_ptr> pop_ids = - sparse_id_queues_.at(send_varname)->Pop(); + std::shared_ptr> pop_ids = nullptr; + sparse_id_queues_.at(send_varname)->Get(pop_ids); for (size_t j = 0; j < pop_ids->size(); j++) { sparse_ids.insert(pop_ids->at(j)); } @@ -1268,6 +1270,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, std::vector &sparse_ids, int table_id, int ep_idx) { platform::RecordEvent record_event("GeoCommunicator->SendSparse"); + if (sparse_ids.size() == 0) { + return; + } std::string param_name = SplitedGradToParam(varname); VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name << ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id @@ -1313,6 +1318,10 @@ void GeoCommunicator::SendSparse(const std::string &varname, t_value + j * dims1, t_old->data() + sparse_ids[j] * dims1); push_g_vec.push_back(t_value + j * dims1); + + VLOG(5) << "DEBUG GeoCommunicator::SendSparse send sparse key " + << sparse_ids[j] << " value[0] " << push_g_vec[j][0] + << " value[-1] " << push_g_vec[j][dims1 - 1]; } ++_async_call_num; @@ -1367,6 +1376,9 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, cpu_ctx); for (auto j = 0; j < static_cast(keys.size()); ++j) { + VLOG(5) << "DEBUG GeoCommunicator::RecvSparse recv sparse key" << keys[j] + << "value[0] " << values[j * dims1] << " value[-1] " + << values[j * dims1 + dims1 - 1]; float *latest_data = t_latest->data() + keys[j] * dims1; float *old_data = t_old->data() + keys[j] * dims1; // pserver - old => delta diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.h b/paddle/fluid/distributed/ps/service/communicator/communicator.h index 570e668d9d..c63f341607 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.h +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.h @@ -30,6 +30,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable_helper.h" @@ -626,9 +627,8 @@ class GeoCommunicator : public AsyncCommunicator { // parameter on pserver std::shared_ptr pserver_scope_; - std::unordered_map< - std::string, - std::shared_ptr>>>> + std::unordered_map>>> sparse_id_queues_; }; diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index 7db8b0c124..21719fbdbf 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -128,6 +128,17 @@ class PSClient { const uint64_t *keys, size_t num, bool is_training) = 0; + virtual std::future pull_sparse_param(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num, bool is_training) { + VLOG(0) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + virtual ::std::future pull_sparse_ptr(char **select_values, size_t table_id, const uint64_t *keys, diff --git a/paddle/fluid/distributed/ps/table/CMakeLists.txt b/paddle/fluid/distributed/ps/table/CMakeLists.txt index b0a553f210..9aa9ecc2af 100644 --- a/paddle/fluid/distributed/ps/table/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/table/CMakeLists.txt @@ -47,6 +47,9 @@ cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framewo cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule) cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table) -cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) +set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table) + +cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) target_link_libraries(table -fopenmp) diff --git a/paddle/fluid/distributed/ps/table/depends/geo_recorder.h b/paddle/fluid/distributed/ps/table/depends/geo_recorder.h index ad094f0dfb..adab0ee344 100644 --- a/paddle/fluid/distributed/ps/table/depends/geo_recorder.h +++ b/paddle/fluid/distributed/ps/table/depends/geo_recorder.h @@ -15,13 +15,9 @@ #pragma once #include -#include #include // NOLINT #include -#include -#include #include -#include #include namespace paddle { diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc new file mode 100644 index 0000000000..f16f4fc7f3 --- /dev/null +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc @@ -0,0 +1,220 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h" + +namespace paddle { +namespace distributed { + +int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, + const float* values, + size_t num) { + VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param begin " + "push_sparse_param " + << num; + auto shard_num = _task_pool_size; + std::vector> offset_bucket; + offset_bucket.resize(shard_num); + + for (int x = 0; x < num; ++x) { + auto y = keys[x] % shard_num; + offset_bucket[y].push_back(x); + if (x < 10) { + VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param key: " + << keys[x] << " shard: " << y; + } + } + + std::vector> tasks(shard_num); + + for (int shard_id = 0; shard_id < shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, &keys, &offset_bucket, &values]() -> int { + auto& local_shard = _local_shards[shard_id]; + auto& offsets = offset_bucket[shard_id]; + + for (int i = 0; i < offsets.size(); ++i) { + auto offset = offsets[i]; + auto id = keys[offset]; + auto& feature_value = local_shard[id]; + feature_value.resize(_dim); + std::copy_n(values + _dim * offset, _dim, feature_value.data()); + if (i < 10) { + VLOG(5) << "MemorySparseGeoTable::push_sparse_param " + "push_sparse_param key " + << id << " value[0]: " << (values + _dim * offset)[0] + << " data: " << feature_value.data()[0] + << " value[-1]: " << (values + _dim * offset)[_dim - 1] + << " data: " << feature_value.data()[_dim - 1]; + } + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, + std::vector* values, + std::vector* ids) { + _geo_recorder->GetAndClear(trainer_id, ids); + VLOG(5) + << "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id " + << trainer_id << " id_num: " << ids->size(); + + std::vector frequencies; + frequencies.resize(ids->size(), 1); + + auto pull_value = PullSparseValue(ids->size(), _dim); + pull_value.is_training_ = true; + pull_value.feasigns_ = ids->data(); + pull_value.frequencies_ = frequencies.data(); + + values->resize(ids->size() * _dim); + pull_sparse(values->data(), pull_value); + return 0; +} + +int32_t MemorySparseGeoTable::push_sparse(const uint64_t* keys, + const float* values, size_t num) { + VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse keys[0]" << keys[0] + << " key_num: " << num; + std::vector ids; + ids.resize(num); + std::copy_n(keys, num, ids.begin()); + _geo_recorder->Update(ids); + _push_sparse(keys, values, num); + return 0; +} + +int32_t MemorySparseGeoTable::initialize() { + if (!_geo_recorder) { + auto trainers = _config.common().trainer_num(); + _geo_recorder = std::make_shared(trainers); + } + + _dim = _config.common().dims()[0]; + _shards_task_pool.resize(_task_pool_size); + for (int i = 0; i < _shards_task_pool.size(); ++i) { + _shards_task_pool[i].reset(new ::ThreadPool(1)); + } + + _local_shards.reset(new shard_type[_task_pool_size]); + return 0; +} + +int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, + const PullSparseValue& pull_value) { + auto shard_num = _task_pool_size; + std::vector> tasks(shard_num); + + std::vector>> task_keys(shard_num); + size_t num = pull_value.numel_; + for (size_t i = 0; i < num; ++i) { + int shard_id = pull_value.feasigns_[i] % shard_num; + task_keys[shard_id].push_back({pull_value.feasigns_[i], i}); + } + + for (int shard_id = 0; shard_id < shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, &task_keys, pull_values]() -> int { + auto& local_shard = _local_shards[shard_id]; + auto& keys = task_keys[shard_id]; + for (size_t i = 0; i < keys.size(); i++) { + uint64_t key = keys[i].first; + auto offset = keys[i].second; + float* select_data = pull_values + _dim * offset; + + auto itr = local_shard.find(key); + if (itr == local_shard.end()) { + // ++missed_keys; + auto& feature_value = local_shard[key]; + feature_value.resize(_dim); + memset(feature_value.data(), 0, sizeof(float) * _dim); + VLOG(0) << "MemorySparseGeoTable pull_sparse key not found!!! " + << key; + itr = local_shard.find(key); + } + memcpy(select_data, itr.value().data(), _dim * sizeof(float)); + + VLOG(5) << "DEBUG MemorySparseGeoTable::pull_sparse key: " << key + << " select_data[0] " << select_data[0] + << " value[0]: " << itr.value().data()[0]; + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + + return 0; +} + +int32_t MemorySparseGeoTable::_push_sparse(const uint64_t* keys, + const float* values, size_t num) { + auto shard_num = _task_pool_size; + std::vector> tasks(shard_num); + std::vector>> task_keys(shard_num); + for (size_t i = 0; i < num; ++i) { + int shard_id = keys[i] % shard_num; + task_keys[shard_id].push_back({keys[i], i}); + } + + for (size_t shard_id = 0; shard_id < shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, values, &task_keys]() -> int { + auto& keys = task_keys[shard_id]; + auto& local_shard = _local_shards[shard_id]; + auto blas = GetBlas(); + + for (int i = 0; i < keys.size(); ++i) { + uint64_t key = keys[i].first; + uint64_t push_data_idx = keys[i].second; + const float* update_data = values + push_data_idx * _dim; + auto itr = local_shard.find(key); + if (itr == local_shard.end()) { + VLOG(0) << "sparse geo table push not found key!!! " << key; + auto& feature_value = local_shard[key]; + feature_value.resize(_dim); + memset(feature_value.data(), 0, sizeof(float) * _dim); + itr = local_shard.find(key); + } + + auto& feature_value = itr.value(); + float* value_data = feature_value.data(); + VLOG(5) << "DEBUG MemorySparseGeoTable::_push_sparse before key: " + << key << " update_data[0] " << update_data[0] + << " value[0]: " << value_data[0]; + blas.VADD(_dim, update_data, value_data, value_data); + VLOG(5) << "DEBUG MemorySparseGeoTable::_push_sparse after key: " + << key << " value[0]: " << value_data[0]; + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h new file mode 100644 index 0000000000..89c4fc15ae --- /dev/null +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +// #include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "paddle/fluid/distributed/ps/table/accessor.h" +#include "paddle/fluid/distributed/ps/table/common_table.h" +#include "paddle/fluid/distributed/ps/table/depends/feature_value.h" +#include "paddle/fluid/distributed/ps/table/depends/geo_recorder.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace distributed { + +class GeoRecorder; + +class MemorySparseGeoTable : public SparseTable { + public: + typedef SparseTableShard shard_type; + MemorySparseGeoTable() { _geo_recorder = nullptr; } + virtual ~MemorySparseGeoTable() {} + + virtual int32_t initialize(); + virtual int32_t initialize_shard() { return 0; } + virtual int32_t load(const std::string& path, const std::string& param) { + return 0; + } + virtual int32_t save(const std::string& path, const std::string& param) { + return 0; + } + virtual int32_t flush() { return 0; } + virtual int32_t shrink(const std::string& param) { return 0; } + virtual void clear() { return; } + virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + + int32_t push_sparse_param(const uint64_t* keys, const float* values, + size_t num); + // TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse + int32_t pull_geo_param(const uint32_t trainer_id, std::vector* values, + std::vector* keys); + + int32_t push_sparse(const uint64_t* keys, const float* values, + size_t num) override; + + int32_t _push_sparse(const uint64_t* keys, const float* values, size_t num); + // int32_t _pull_sparse(float* pull_values, const PullSparseValue& + // pull_value); + + private: + std::shared_ptr _geo_recorder; + const int _task_pool_size = 10; + std::vector> _shards_task_pool; + std::unique_ptr _local_shards; + int _dim; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/table.cc b/paddle/fluid/distributed/ps/table/table.cc index b9b5ff12fc..fa8169da07 100644 --- a/paddle/fluid/distributed/ps/table/table.cc +++ b/paddle/fluid/distributed/ps/table/table.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/distributed/ps/table/common_dense_table.h" #include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/distributed/ps/table/common_sparse_table.h" +#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h" #include "paddle/fluid/distributed/ps/table/sparse_geo_table.h" #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h" @@ -43,6 +44,7 @@ REGISTER_PSCORE_CLASS(Table, TensorTable); REGISTER_PSCORE_CLASS(Table, DenseTensorTable); REGISTER_PSCORE_CLASS(Table, GlobalStepTable); REGISTER_PSCORE_CLASS(Table, MemorySparseTable); +REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable); REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor); REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor); REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule); diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt index 62de82832e..2223334ccc 100644 --- a/paddle/fluid/distributed/test/CMakeLists.txt +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -35,3 +35,6 @@ cc_test(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} boost ta set_source_files_properties(memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS ${COMMON_DEPS} boost table) + +set_source_files_properties(memory_geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(memory_sparse_geo_table_test SRCS memory_geo_table_test.cc DEPS ${COMMON_DEPS} boost table) diff --git a/paddle/fluid/distributed/test/memory_geo_table_test.cc b/paddle/fluid/distributed/test/memory_geo_table_test.cc new file mode 100644 index 0000000000..fb48b38c76 --- /dev/null +++ b/paddle/fluid/distributed/test/memory_geo_table_test.cc @@ -0,0 +1,123 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h" +#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h" +#include "paddle/fluid/distributed/ps/table/table.h" + +namespace paddle { +namespace distributed { + +// MemorySparseGeoTable +TEST(MemorySparseGeoTable, SSUM) { + int emb_dim = 10; + int trainers = 2; + + TableParameter table_config; + table_config.set_table_class("MemorySparseGeoTable"); + FsClientParameter fs_config; + Table *table = new MemorySparseGeoTable(); + TableAccessorParameter *accessor_config = table_config.mutable_accessor(); + accessor_config->set_accessor_class("CommMergeAccessor"); + accessor_config->set_fea_dim(10); + CommonAccessorParameter *common_config = table_config.mutable_common(); + common_config->set_name("sum"); + common_config->set_table_name("ssum_test_table"); + common_config->set_trainer_num(trainers); + common_config->add_params("Param"); + common_config->add_dims(emb_dim); + common_config->add_initializers("fill_constant&1.0"); + + auto ret = table->initialize(table_config, fs_config); + ASSERT_EQ(ret, 0); + + // test push_sparse_param, and create params + std::vector init_keys = {0, 1, 2, 3, 4}; + std::vector init_fres = {1, 1, 1, 1, 1}; + std::vector init_values; + for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { + init_values.push_back(0.0); + } + table->push_sparse_param(init_keys.data(), init_values.data(), + init_keys.size()); + + std::vector pull_values(init_values.size()); + auto value = PullSparseValue(init_keys, init_fres, emb_dim); + table->pull_sparse(pull_values.data(), value); + + for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { + ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5); + } + + std::vector> trainer_keys; + std::vector> trainer_values; + trainer_keys.resize(trainers); + trainer_values.resize(trainers); + float start = 0.0; + for (int i = 0; i < trainers; i++) { + trainer_keys[i] = init_keys; + for (size_t j = 0; j < trainer_keys[i].size(); j++) { + auto id = trainer_keys[i][j]; + for (int k = 0; k < emb_dim; k++) { + trainer_values[i].push_back(start); + pull_values[id * emb_dim + k] += start; + start += 0.1; + } + } + } + + std::shared_ptr<::ThreadPool> pool_ = + std::make_shared<::ThreadPool>(trainers); + std::vector> task_status; + for (int i = 0; i < trainers; i++) { + auto &push_keys = trainer_keys[i]; + auto &push_values = trainer_values[i]; + auto task = [table, &push_keys, &push_values] { + table->push_sparse(push_keys.data(), push_values.data(), + push_keys.size()); + }; + task_status.push_back(pool_->enqueue(std::move(task))); + } + for (auto &status : task_status) { + status.wait(); + } + + std::vector> geo_pull_ids; + std::vector> geo_pull_values; + geo_pull_ids.resize(trainers); + geo_pull_values.resize(trainers); + for (int i = 0; i < trainers; i++) { + table->pull_geo_param(i, &geo_pull_values[i], &geo_pull_ids[i]); + ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim); + for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) { + auto id = geo_pull_ids[i][j]; + for (int k = 0; k < emb_dim; k++) { + ASSERT_TRUE(abs(geo_pull_values[i][j * emb_dim + k] - + pull_values[id * emb_dim + k]) < 1e-5); + } + } + } +} + +} // namespace distributed +} // namespace paddle diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index c561c25067..cc81f8b3e9 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -943,7 +943,7 @@ class TheOnePSRuntime(RuntimeBase): ctx.origin_varnames()[0]] if self.compiled_strategy.is_geo_mode(): - table.table_class = "SparseGeoTable" + table.table_class = "MemorySparseGeoTable" else: all_table_proto = self.context[ "user_defined_strategy"].sparse_table_configs @@ -1306,6 +1306,7 @@ class TheOnePSRuntime(RuntimeBase): is_dense=True, split_dense_table=self.role_maker._is_heter_parameter_server_mode, use_origin_program=True) + # TODO(zhaocaibei123): for GEO: should call GeoCommunicator::RecvDense self._communicator.pull_dense(denses) generate_vars = self.context[ -- GitLab