From 4e5cb7d81c7d7a4a704efbc04d645ab7ae3b3f81 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Thu, 15 Jul 2021 19:51:22 +0800 Subject: [PATCH] psgpu:optimize build_cpu hashset; test=develop (#34175) --- paddle/fluid/framework/fleet/heter_context.h | 3 ++- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 1fb2f0fab4a..68868f447b5 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/fluid/distributed/table/depends/large_scale_kv.h" #endif +#include "paddle/fluid/distributed/thirdparty/round_robin.h" #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" #include "paddle/fluid/framework/scope.h" @@ -106,7 +107,7 @@ class HeterContext { } void batch_add_keys(int shard_num, - const std::unordered_set& shard_keys) { + const robin_hood::unordered_set& shard_keys) { int idx = feature_keys_[shard_num].size(); feature_keys_[shard_num].resize(feature_keys_[shard_num].size() + shard_keys.size()); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 2bbe5954190..b7e8bbb3694 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -29,6 +29,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #endif +#include "paddle/fluid/distributed/thirdparty/round_robin.h" #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/fleet/heter_context.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" @@ -270,7 +271,7 @@ class PSGPUWrapper { std::vector heter_devices_; std::unordered_set gpu_ps_config_keys_; HeterObjectPool gpu_task_pool_; - std::vector>> thread_keys_; + std::vector>> thread_keys_; int thread_keys_thread_num_ = 37; int thread_keys_shard_num_ = 37; uint64_t max_fea_num_per_pass_ = 5000000000; -- GitLab