未验证 提交 52329f6f 编写于 作者: Z zmxdream 提交者: GitHub

[heterps]move pre-init id logic from common_sparse_table to sparse_geo_table (#38173)

* remove pre-init id in common_sparse_tabl.cc
上级 33185000
...@@ -300,6 +300,10 @@ class Communicator { ...@@ -300,6 +300,10 @@ class Communicator {
virtual void BarrierWithTable(uint32_t barrier_type) { virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type);
rets.wait(); rets.wait();
int status = rets.get();
PADDLE_ENFORCE_EQ(status, 0,
platform::errors::InvalidArgument(
"The ret status must be 0 when barrier with table"));
} }
virtual void CreateC2CConnection(int pserver_timeout_ms, virtual void CreateC2CConnection(int pserver_timeout_ms,
......
...@@ -220,33 +220,6 @@ int32_t CommonSparseTable::initialize_value() { ...@@ -220,33 +220,6 @@ int32_t CommonSparseTable::initialize_value() {
shard_values_.emplace_back(shard); shard_values_.emplace_back(shard);
} }
auto accessor = _config.accessor();
std::vector<uint64_t> feasigns;
for (size_t x = 0; x < accessor.fea_dim(); ++x) {
if (x % _shard_num == _shard_idx) {
feasigns.push_back(x);
}
}
VLOG(3) << "has " << feasigns.size() << " ids need to be pre inited";
auto buckets = bucket(feasigns.size(), 10);
for (int x = 0; x < 10; ++x) {
auto bucket_feasigns = buckets[x + 1] - buckets[x];
std::vector<uint64_t> ids(bucket_feasigns);
std::copy(feasigns.begin() + buckets[x], feasigns.begin() + buckets[x + 1],
ids.begin());
std::vector<uint32_t> fres;
fres.resize(ids.size(), 1);
auto pull_value = PullSparseValue(ids, fres, param_dim_);
std::vector<float> pulls;
pulls.resize(bucket_feasigns * param_dim_);
pull_sparse(pulls.data(), pull_value);
}
return 0; return 0;
} }
......
...@@ -46,5 +46,46 @@ int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values, ...@@ -46,5 +46,46 @@ int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values,
return 0; return 0;
} }
int32_t SparseGeoTable::initialize_value() {
auto common = _config.common();
shard_values_.reserve(task_pool_size_);
for (int x = 0; x < task_pool_size_; ++x) {
auto shard = std::make_shared<ValueBlock>(
value_names_, value_dims_, value_offsets_, value_idx_,
initializer_attrs_, common.entry());
shard_values_.emplace_back(shard);
}
auto accessor = _config.accessor();
std::vector<uint64_t> feasigns;
for (size_t x = 0; x < accessor.fea_dim(); ++x) {
if (x % _shard_num == _shard_idx) {
feasigns.push_back(x);
}
}
VLOG(3) << "has " << feasigns.size() << " ids need to be pre inited";
auto buckets = bucket(feasigns.size(), 10);
for (int x = 0; x < 10; ++x) {
auto bucket_feasigns = buckets[x + 1] - buckets[x];
std::vector<uint64_t> ids(bucket_feasigns);
std::copy(feasigns.begin() + buckets[x], feasigns.begin() + buckets[x + 1],
ids.begin());
std::vector<uint32_t> fres;
fres.resize(ids.size(), 1);
auto pull_value = PullSparseValue(ids, fres, param_dim_);
std::vector<float> pulls;
pulls.resize(bucket_feasigns * param_dim_);
pull_sparse(pulls.data(), pull_value);
}
return 0;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -44,10 +44,12 @@ class SparseGeoTable : public CommonSparseTable { ...@@ -44,10 +44,12 @@ class SparseGeoTable : public CommonSparseTable {
explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; } explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; }
virtual ~SparseGeoTable() {} virtual ~SparseGeoTable() {}
virtual int32_t initialize_value();
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values, int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys); std::vector<uint64_t>* keys);
virtual int32_t push_sparse(const uint64_t* keys, const float* values, int32_t push_sparse(const uint64_t* keys, const float* values,
size_t num) override; size_t num) override;
virtual int32_t initialize_recorder() { virtual int32_t initialize_recorder() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册