未验证 提交 c0001a24 编写于 作者: Y yaoxuefeng 提交者: GitHub

Acc name (#42906)

add dymf support of gpups
上级 3b488bae
......@@ -95,24 +95,6 @@ class HeterContext {
}
void SetShardNum(uint32_t shard_num) { shard_num_ = shard_num; }
uint32_t ShardNum() { return shard_num_; }
void init(int shard_num, int device_num) {
shard_num_ = shard_num;
feature_keys_.resize(shard_num_);
value_ptr_.resize(shard_num_);
device_task_ptr_.resize(shard_num_);
device_task_keys_.resize(shard_num_);
for (size_t i = 0; i < device_task_ptr_.size(); i++) {
device_task_ptr_[i].resize(device_num);
device_task_keys_[i].resize(device_num);
}
device_values_.resize(device_num);
device_keys_.resize(device_num);
mutex_.resize(device_num);
for (size_t i = 0; i < mutex_.size(); ++i) {
mutex_[i] = new std::mutex();
}
}
void init(int shard_num, int device_num, int dim_num) {
shard_num_ = shard_num;
......
......@@ -69,6 +69,20 @@ struct FeaturePushValue {
int mf_dim;
float mf_g[0];
__device__ __forceinline__ FeaturePushValue
operator+(const FeaturePushValue& a) const {
FeaturePushValue out;
out.slot = a.slot;
out.mf_dim = a.mf_dim;
out.show = a.show + show;
out.clk = a.clk + clk;
out.lr_g = a.lr_g + lr_g;
// out.mf_g = a.mf_g;
for (int i = 0; i < out.mf_dim; ++i) {
out.mf_g[i] = a.mf_g[i] + mf_g[i];
}
return out;
}
__device__ __forceinline__ void operator=(const FeaturePushValue& in) {
show = in.show;
clk = in.clk;
......
......@@ -86,13 +86,26 @@ __global__ void dy_mf_search_kernel(Table* table,
char* vals, size_t len,
size_t pull_feature_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
// return;
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
uint64_t offset = i * pull_feature_value_size;
FeatureValue& cur = *(FeatureValue*)(vals + offset);
FeatureValue* cur = (FeatureValue*)(vals + offset);
FeatureValue& input = *(FeatureValue*)(it->second);
cur->slot = input.slot;
cur->show = input.show;
cur->clk = input.clk;
cur->mf_dim = input.mf_dim;
cur->lr = input.lr;
cur->mf_size = input.mf_size;
cur->cpu_ptr = input.cpu_ptr;
cur->delta_score = input.delta_score;
cur->lr_g2sum = input.lr_g2sum;
for (int j = 0; j < cur->mf_dim + 1; ++j) {
cur->mf[j] = input.mf[j];
}
}
}
}
......
......@@ -26,6 +26,7 @@ namespace framework {
template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::HeterComm(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
VLOG(1) << "Construct new HeterComm";
resource_ = resource;
storage_.resize(resource_->total_device());
multi_mf_dim_ = resource->multi_mf();
......@@ -364,6 +365,10 @@ HeterComm<KeyType, ValType, GradType>::~HeterComm() {
delete table;
table = nullptr;
}
for (auto& table : tables_) {
delete table;
table = nullptr;
}
}
}
......@@ -473,17 +478,23 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
return;
}
int dev_id = resource_->dev_id(num);
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
// use hbm pool
std::vector<memory::allocation::AllocationPtr> d_key_bufs;
ppStream streams[stream_num]; // NOLINT
for (int i = 0; i < stream_num; ++i) {
create_stream(&(streams[i]));
auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType));
d_key_bufs.push_back(std::move(d_k_buf));
}
int cur_len = 0;
int cur_stream = 0;
while (cur_len < len) {
cur_stream = cur_stream % stream_num;
auto cur_use_stream = streams[cur_stream];
......@@ -491,8 +502,10 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
cur_use_stream = 0;
#endif
int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size;
auto dst_place = place;
auto src_place = platform::CPUPlace();
memory_copy(
dst_place, reinterpret_cast<char*>(d_key_bufs[cur_stream]->ptr()),
src_place, h_keys + cur_len, sizeof(KeyType) * tmp_len, cur_use_stream);
......@@ -557,14 +570,20 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
size_t temp_storage_bytes;
// VLOG(1) << "hetercomm merge_grad: max_mf_dim: " << max_mf_dim_;
size_t grad_value_size =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());
auto d_merge_grads = memory::Alloc(place, len * grad_value_size);
GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr());
auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1));
uint32_t* d_fea_num_info_ptr =
reinterpret_cast<uint32_t*>(d_fea_num_info->ptr());
......@@ -836,9 +855,16 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
GradType* d_shard_grads_ptr;
auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
if (!multi_mf_dim_) {
auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType));
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
} else {
auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
}
int uniq_len = len;
dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len);
......@@ -846,9 +872,16 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr,
dev_num);
heter_comm_kernel_->dy_mf_fill_shard_grads(
d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, uniq_len,
grad_value_size, stream);
if (!multi_mf_dim_) {
heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys,
d_shard_grads_ptr, d_grads, d_idx_ptr,
uniq_len, stream);
} else {
heter_comm_kernel_->dy_mf_fill_shard_grads(
d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr,
uniq_len, grad_value_size, stream);
}
sync_stream(stream);
......
......@@ -136,6 +136,7 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
size_t grad_value_size,
DynamicGradMerger& merger_) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
uint32_t start = offset[i];
uint32_t num = fea_num[i];
......
......@@ -106,25 +106,17 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
platform::Timer timeline;
timeline.Start();
int device_num = heter_devices_.size();
if (!multi_mf_dim_) {
gpu_task->init(thread_keys_shard_num_, device_num);
} else {
gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_);
}
gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_);
std::vector<std::thread> threads;
if (!multi_mf_dim_) {
thread_keys_.resize(thread_keys_thread_num_);
for (int i = 0; i < thread_keys_thread_num_; i++) {
thread_keys_[i].resize(thread_keys_shard_num_);
}
} else {
thread_dim_keys_.resize(thread_keys_thread_num_);
for (int i = 0; i < thread_keys_thread_num_; i++) {
thread_dim_keys_[i].resize(thread_keys_shard_num_);
for (int j = 0; j < thread_keys_shard_num_; j++) {
thread_dim_keys_[i][j].resize(multi_mf_dim_);
}
// data should be in input channel
thread_dim_keys_.resize(thread_keys_thread_num_);
for (int i = 0; i < thread_keys_thread_num_; i++) {
thread_dim_keys_[i].resize(thread_keys_shard_num_);
for (int j = 0; j < thread_keys_shard_num_; j++) {
thread_dim_keys_[i][j].resize(multi_mf_dim_);
}
}
......@@ -144,18 +136,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
len_per_thread = total_len / thread_keys_thread_num_;
remain = total_len % thread_keys_thread_num_;
VLOG(0) << "total len: " << total_len;
auto gen_func = [this](const std::deque<SlotRecord>& total_data,
int begin_index, int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
for (const auto feasign : feasign_v) {
int shard_id = feasign % thread_keys_shard_num_;
this->thread_keys_[i][shard_id].insert(feasign);
}
}
};
auto gen_dynamic_mf_func = [this](const std::deque<SlotRecord>& total_data,
int begin_index, int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
......@@ -177,17 +157,10 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
if (!multi_mf_dim_) {
VLOG(0) << "yxf::psgpu wrapper genfunc";
threads.push_back(
std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
} else {
VLOG(0) << "yxf::psgpu wrapper genfunc with dynamic mf";
threads.push_back(
std::thread(gen_dynamic_mf_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
}
threads.push_back(
std::thread(gen_dynamic_mf_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
begin += len_per_thread + (i < remain ? 1 : 0);
}
for (std::thread& t : threads) {
......@@ -235,12 +208,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
threads.clear();
// merge thread_keys to shard_keys
auto merge_ins_func = [this, gpu_task](int shard_num) {
for (int i = 0; i < thread_keys_thread_num_; ++i) {
gpu_task->batch_add_keys(shard_num, thread_keys_[i][shard_num]);
thread_keys_[i][shard_num].clear();
}
};
auto merge_ins_dynamic_mf_func = [this, gpu_task](int shard_num, int dim_id) {
for (int i = 0; i < thread_keys_thread_num_; ++i) {
gpu_task->batch_add_keys(shard_num, dim_id,
......@@ -249,12 +216,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
};
for (int i = 0; i < thread_keys_shard_num_; ++i) {
if (!multi_mf_dim_) {
threads.push_back(std::thread(merge_ins_func, i));
} else {
for (int j = 0; j < multi_mf_dim_; j++) {
threads.push_back(std::thread(merge_ins_dynamic_mf_func, i, j));
}
for (int j = 0; j < multi_mf_dim_; j++) {
threads.push_back(std::thread(merge_ins_dynamic_mf_func, i, j));
}
}
for (auto& t : threads) {
......@@ -297,12 +260,12 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
auto& device_dim_keys = gpu_task->device_dim_keys_;
auto& device_dim_ptr = gpu_task->device_dim_ptr_;
auto& device_dim_mutex = gpu_task->dim_mutex_;
if (multi_mf_dim_) {
for (size_t dev = 0; dev < device_dim_keys.size(); dev++) {
device_dim_keys[dev].resize(multi_mf_dim_);
device_dim_ptr[dev].resize(multi_mf_dim_);
}
for (size_t dev = 0; dev < device_dim_keys.size(); dev++) {
device_dim_keys[dev].resize(multi_mf_dim_);
device_dim_ptr[dev].resize(multi_mf_dim_);
}
// auto& device_mutex = gpu_task->mutex_;
std::vector<std::thread> threads(thread_keys_shard_num_);
......@@ -415,6 +378,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
task_keys[shard].push_back(local_dim_keys[i][j][k]);
task_ptrs[shard].push_back(local_dim_ptr[i][j][k]);
}
// allocate local keys to devices
for (int dev = 0; dev < device_num; dev++) {
device_dim_mutex[dev][j]->lock();
int len = task_keys[dev].size();
......@@ -619,6 +583,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
<< feature_keys_count[i];
size_max = std::max(size_max, feature_keys_count[i]);
}
if (HeterPs_) {
delete HeterPs_;
HeterPs_ = nullptr;
......@@ -665,6 +630,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::embed_g2sum_index()];
val->cpu_ptr = (uint64_t)(device_dim_ptrs[k]);
// TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor
ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
mf_dim_index()] = float(mf_dim);
val->mf_dim = mf_dim;
......@@ -681,11 +648,15 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
}
}
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool);
auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];
this->HeterPs_->build_ps(i, device_dim_keys.data(), cur_pool->mem(), len,
feature_value_size, 500000, 2);
if (device_dim_keys.size() > 0) {
VLOG(0) << "show ptr table: " << i
<< " table kv size: " << device_dim_keys.size()
......@@ -700,6 +671,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
threads[i + j * device_num] = std::thread(build_dynamic_mf_func, i, j);
}
}
for (std::thread& t : threads) {
t.join();
}
......@@ -723,7 +695,9 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
InitSlotInfo();
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
data_ready_channel_->Put(gpu_task);
VLOG(3) << "End LoadIntoMemory(), dataset[" << dataset_ << "]";
}
......@@ -805,6 +779,7 @@ void PSGPUWrapper::EndPass() {
timer.Start();
size_t keysize_max = 0;
// in case of feasign_num = 0, skip dump_to_cpu
for (size_t i = 0; i < heter_devices_.size(); i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
keysize_max =
......@@ -821,9 +796,11 @@ void PSGPUWrapper::EndPass() {
VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim;
size_t feature_value_size =
TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float)));
char* test_build_values = (char*)malloc(feature_value_size * len);
cudaMemcpy(test_build_values, hbm_pool->mem(), feature_value_size * len,
cudaMemcpyDeviceToHost);
CHECK(len == hbm_pool->capacity());
#ifdef PADDLE_WITH_PSLIB
uint64_t unuse_key = std::numeric_limits<uint64_t>::max();
......@@ -972,7 +949,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
feature_value_size = TYPEALIGN(
8, sizeof(FeatureValue) + sizeof(float) * (index_dim_vec_.back() + 1));
VLOG(0) << "yxf pull sparse feature_value_size: " << feature_value_size;
#ifdef PADDLE_WITH_CUDA
VLOG(3) << "Begine Gpu Ps PullSparse";
......@@ -1159,6 +1135,8 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
"GPUPS: PushSparseGrad Only Support CUDAPlace Now."));
}
all_timer.Pause();
time_3 += all_timer.ElapsedSec();
time_4 += push_gpups_timer.ElapsedSec();
VLOG(3) << "PushSparseGrad total cost: " << all_timer.ElapsedSec()
<< " s, of which GPUPS cost: " << push_gpups_timer.ElapsedSec()
<< " s";
......
......@@ -333,6 +333,11 @@ class PSGPUWrapper {
void SetSlotOffsetVector(const std::vector<int>& slot_offset_vector) {
slot_offset_vector_ = slot_offset_vector;
std::cout << "yxf set: ";
for (auto s : slot_offset_vector_) {
std::cout << s << " | ";
}
std::cout << " end " << std::endl;
}
#ifdef PADDLE_WITH_CUDA
......@@ -431,6 +436,12 @@ class PSGPUWrapper {
int max_mf_dim_{0};
size_t val_type_size_{0};
size_t grad_type_size_{0};
double time_1 = 0.0;
double time_2 = 0.0;
double time_3 = 0.0;
double time_4 = 0.0;
int multi_node_{0};
int node_size_;
uint64_t table_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册