From 52329f6fde7a630ccb1d7076ab688cb552b60adc Mon Sep 17 00:00:00 2001 From: zmxdream Date: Fri, 24 Dec 2021 16:44:25 +0800 Subject: [PATCH] [heterps]move pre-init id logic from common_sparse_table to sparse_geo_table (#38173) * remove pre-init id in common_sparse_tabl.cc --- .../fluid/distributed/service/communicator.h | 4 ++ .../distributed/table/common_sparse_table.cc | 27 ------------ .../distributed/table/sparse_geo_table.cc | 41 +++++++++++++++++++ .../distributed/table/sparse_geo_table.h | 6 ++- 4 files changed, 49 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/distributed/service/communicator.h b/paddle/fluid/distributed/service/communicator.h index 8714918dc8..9ea44310f3 100644 --- a/paddle/fluid/distributed/service/communicator.h +++ b/paddle/fluid/distributed/service/communicator.h @@ -300,6 +300,10 @@ class Communicator { virtual void BarrierWithTable(uint32_t barrier_type) { auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); rets.wait(); + int status = rets.get(); + PADDLE_ENFORCE_EQ(status, 0, + platform::errors::InvalidArgument( + "The ret status must be 0 when barrier with table")); } virtual void CreateC2CConnection(int pserver_timeout_ms, diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc index e124160e71..143b24cf32 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -220,33 +220,6 @@ int32_t CommonSparseTable::initialize_value() { shard_values_.emplace_back(shard); } - auto accessor = _config.accessor(); - std::vector feasigns; - - for (size_t x = 0; x < accessor.fea_dim(); ++x) { - if (x % _shard_num == _shard_idx) { - feasigns.push_back(x); - } - } - - VLOG(3) << "has " << feasigns.size() << " ids need to be pre inited"; - - auto buckets = bucket(feasigns.size(), 10); - for (int x = 0; x < 10; ++x) { - auto bucket_feasigns = buckets[x + 1] - buckets[x]; - std::vector ids(bucket_feasigns); - std::copy(feasigns.begin() + buckets[x], feasigns.begin() + buckets[x + 1], - ids.begin()); - - std::vector fres; - fres.resize(ids.size(), 1); - - auto pull_value = PullSparseValue(ids, fres, param_dim_); - std::vector pulls; - pulls.resize(bucket_feasigns * param_dim_); - pull_sparse(pulls.data(), pull_value); - } - return 0; } diff --git a/paddle/fluid/distributed/table/sparse_geo_table.cc b/paddle/fluid/distributed/table/sparse_geo_table.cc index 04cd113638..655c478415 100644 --- a/paddle/fluid/distributed/table/sparse_geo_table.cc +++ b/paddle/fluid/distributed/table/sparse_geo_table.cc @@ -46,5 +46,46 @@ int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values, return 0; } +int32_t SparseGeoTable::initialize_value() { + auto common = _config.common(); + shard_values_.reserve(task_pool_size_); + + for (int x = 0; x < task_pool_size_; ++x) { + auto shard = std::make_shared( + value_names_, value_dims_, value_offsets_, value_idx_, + initializer_attrs_, common.entry()); + + shard_values_.emplace_back(shard); + } + + auto accessor = _config.accessor(); + std::vector feasigns; + + for (size_t x = 0; x < accessor.fea_dim(); ++x) { + if (x % _shard_num == _shard_idx) { + feasigns.push_back(x); + } + } + + VLOG(3) << "has " << feasigns.size() << " ids need to be pre inited"; + + auto buckets = bucket(feasigns.size(), 10); + for (int x = 0; x < 10; ++x) { + auto bucket_feasigns = buckets[x + 1] - buckets[x]; + std::vector ids(bucket_feasigns); + std::copy(feasigns.begin() + buckets[x], feasigns.begin() + buckets[x + 1], + ids.begin()); + + std::vector fres; + fres.resize(ids.size(), 1); + + auto pull_value = PullSparseValue(ids, fres, param_dim_); + std::vector pulls; + pulls.resize(bucket_feasigns * param_dim_); + pull_sparse(pulls.data(), pull_value); + } + return 0; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/table/sparse_geo_table.h b/paddle/fluid/distributed/table/sparse_geo_table.h index 01870615af..4ddb1fd706 100644 --- a/paddle/fluid/distributed/table/sparse_geo_table.h +++ b/paddle/fluid/distributed/table/sparse_geo_table.h @@ -44,11 +44,13 @@ class SparseGeoTable : public CommonSparseTable { explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; } virtual ~SparseGeoTable() {} + virtual int32_t initialize_value(); + int32_t pull_geo_param(const uint32_t trainer_id, std::vector* values, std::vector* keys); - virtual int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num) override; + int32_t push_sparse(const uint64_t* keys, const float* values, + size_t num) override; virtual int32_t initialize_recorder() { if (!geo_recorder) { -- GitLab