未验证 提交 c09d6453 编写于 作者: T Thunderbrook 提交者: GitHub

[heterps] optimize build task (#32358)

* build task cost

* return pool
上级 0dd28b8c
......@@ -77,6 +77,21 @@ class HeterContext {
mutex_[i] = new std::mutex();
}
}
void Reset() {
for (size_t i = 0; i < feature_keys_.size(); ++i) {
feature_keys_[i].clear();
}
for (size_t i = 0; i < value_ptr_.size(); ++i) {
value_ptr_[i].clear();
}
for (size_t i = 0; i < device_values_.size(); ++i) {
device_values_[i].clear();
}
for (size_t i = 0; i < device_keys_.size(); ++i) {
device_keys_[i].clear();
}
}
void batch_add_keys(
const std::vector<std::unordered_set<uint64_t>>& thread_keys) {
assert(thread_keys.size() == feature_keys_.size());
......@@ -90,6 +105,15 @@ class HeterContext {
}
}
void batch_add_keys(int shard_num,
const std::unordered_set<uint64_t>& shard_keys) {
int idx = feature_keys_[shard_num].size();
feature_keys_[shard_num].resize(feature_keys_[shard_num].size() +
shard_keys.size());
std::copy(shard_keys.begin(), shard_keys.end(),
feature_keys_[shard_num].begin() + idx);
}
void UniqueKeys() {
std::vector<std::thread> threads;
auto unique_func = [this](int i) {
......
......@@ -103,12 +103,26 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
timeline.Start();
threads.clear();
// merge thread_keys to shard_keys
for (size_t i = 0; i < thread_keys_.size(); i++) {
gpu_task->batch_add_keys(thread_keys_[i]);
for (int j = 0; j < thread_keys_thread_num_; j++) {
thread_keys_[i][j].clear();
auto merge_ins_func = [this, gpu_task](int shard_num) {
for (int i = 0; i < thread_keys_thread_num_; ++i) {
gpu_task->batch_add_keys(shard_num, thread_keys_[i][shard_num]);
thread_keys_[i][shard_num].clear();
}
};
// for (size_t i = 0; i < thread_keys_.size(); i++) {
// gpu_task->batch_add_keys(thread_keys_[i]);
// for (int j = 0; j < thread_keys_thread_num_; j++) {
// thread_keys_[i][j].clear();
// }
//}
for (int i = 0; i < thread_keys_shard_num_; ++i) {
threads.push_back(std::thread(merge_ins_func, i));
}
for (auto& t : threads) {
t.join();
}
timeline.Pause();
......@@ -261,6 +275,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
int device_num = heter_devices_.size();
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
BuildTask(gpu_task, table_id, feature_dim);
platform::Timer timeline;
timeline.Start();
......@@ -273,8 +288,8 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
size_max = std::max(size_max, feature_keys_count[i]);
}
if (HeterPs_) {
HeterPs_->show_one_table(0);
return;
delete HeterPs_;
HeterPs_ = nullptr;
}
std::vector<std::thread> threads(device_num);
HeterPs_ = HeterPsBase::get_instance(size_max, resource_);
......@@ -295,6 +310,7 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
timeline.Pause();
VLOG(1) << "GpuPs build table total costs: " << timeline.ElapsedSec()
<< " s.";
gpu_task_pool_.Push(gpu_task);
}
void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册