From c202a613a991161d35cfb8c218c04905b5c5ede8 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Tue, 12 Apr 2022 13:04:18 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90heterps=E3=80=91datafeed=20puttofeedve?= =?UTF-8?q?c=20performance=20(#40168)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perform SlotRecordInMemoryDataFeed feedvec;test=develop --- paddle/fluid/framework/CMakeLists.txt | 10 +- paddle/fluid/framework/data_feed.cc | 331 +++++++++++++++++- paddle/fluid/framework/data_feed.cu | 149 ++++++++ paddle/fluid/framework/data_feed.h | 293 +++++++++++++++- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 18 +- paddle/fluid/framework/ps_gpu_worker.cc | 14 +- 6 files changed, 793 insertions(+), 22 deletions(-) create mode 100644 paddle/fluid/framework/data_feed.cu diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index fb4c9937611..1b9943df1b0 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -295,7 +295,7 @@ if(WITH_DISTRIBUTE) dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc - ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc + ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc data_feed.cu pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer @@ -316,7 +316,7 @@ if(WITH_DISTRIBUTE) dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc heter_pipeline_trainer.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc - downpour_worker.cc downpour_lite_worker.cc downpour_worker_opt.cc + downpour_worker.cc downpour_lite_worker.cc downpour_worker_opt.cc data_feed.cu pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog index_sampler index_wrapper sampler index_dataset_proto @@ -339,7 +339,7 @@ if(WITH_DISTRIBUTE) dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc - ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc + ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc data_feed.cu pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method @@ -359,7 +359,7 @@ elseif(WITH_PSLIB) dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc - ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc + ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc data_feed.cu pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method @@ -369,7 +369,7 @@ else() dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc - ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc + ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc data_feed.cu pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 330f5ea5295..3b6370e1185 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -2394,9 +2394,6 @@ bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line, for (int j = 0; j < num; ++j) { uint64_t feasign = static_cast(strtoull(endptr, &endptr, 10)); - if (feasign == 0 && !used_slots_info_[info.used_idx].dense) { - continue; - } slot_fea.push_back(feasign); ++uint64_total_slot_num; } @@ -2419,8 +2416,21 @@ bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line, return (uint64_total_slot_num > 0); } +void SlotRecordInMemoryDataFeed::AssignFeedVar(const Scope& scope) { + CheckInit(); + for (int i = 0; i < use_slot_size_; ++i) { + feed_vec_[i] = + scope.FindVar(used_slots_info_[i].slot)->GetMutable(); + } +} + void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec, int num) { +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) + paddle::platform::SetDeviceId(place_.GetDeviceId()); + pack_->pack_instance(ins_vec, num); + BuildSlotBatchGPU(pack_->ins_num()); +#else for (int j = 0; j < use_slot_size_; ++j) { auto& feed = feed_vec_[j]; if (feed == nullptr) { @@ -2497,6 +2507,7 @@ void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec, feed_vec_[j]->set_lod(data_lod); } } +#endif } void SlotRecordInMemoryDataFeed::ExpandSlotRecord(SlotRecord* rec) { @@ -2573,6 +2584,10 @@ bool SlotRecordInMemoryDataFeed::Start() { this->offset_index_ = 0; } this->finish_start_ = true; +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) + CHECK(paddle::platform::is_gpu_place(this->place_)); + pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_); +#endif return true; } @@ -2607,5 +2622,315 @@ int SlotRecordInMemoryDataFeed::Next() { #endif } +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) +void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) { + int offset_cols_size = (ins_num + 1); + size_t slot_total_num = (use_slot_size_ * offset_cols_size); + pack_->resize_gpu_slot_offsets(slot_total_num * sizeof(size_t)); + + auto& value = pack_->value(); + const UsedSlotGpuType* used_slot_gpu_types = + static_cast(pack_->get_gpu_slots()); + FillSlotValueOffset(ins_num, use_slot_size_, + reinterpret_cast(pack_->gpu_slot_offsets()), + value.d_uint64_offset.data(), uint64_use_slot_size_, + value.d_float_offset.data(), float_use_slot_size_, + used_slot_gpu_types); + size_t* d_slot_offsets = reinterpret_cast(pack_->gpu_slot_offsets()); + + HostBuffer& offsets = pack_->offsets(); + offsets.resize(slot_total_num); + HostBuffer& h_tensor_ptrs = pack_->h_tensor_ptrs(); + h_tensor_ptrs.resize(use_slot_size_); + // alloc gpu memory + pack_->resize_tensor(); + + LoDTensor& float_tensor = pack_->float_tensor(); + LoDTensor& uint64_tensor = pack_->uint64_tensor(); + + int64_t float_offset = 0; + int64_t uint64_offset = 0; + + // copy index + CUDA_CHECK(cudaMemcpy(offsets.data(), d_slot_offsets, + slot_total_num * sizeof(size_t), + cudaMemcpyDeviceToHost)); + for (int j = 0; j < use_slot_size_; ++j) { + auto& feed = feed_vec_[j]; + if (feed == nullptr) { + h_tensor_ptrs[j] = nullptr; + continue; + } + + size_t* off_start_ptr = &offsets[j * offset_cols_size]; + + int total_instance = static_cast(off_start_ptr[offset_cols_size - 1]); + CHECK(total_instance >= 0) << "slot idx:" << j + << ", total instance:" << total_instance; + auto& info = used_slots_info_[j]; + + // fill slot value with default value 0 + if (info.type[0] == 'f') { // float + if (total_instance > 0) { + feed->ShareDataWith(float_tensor.Slice( + static_cast(float_offset), + static_cast(float_offset + total_instance))); + feed->Resize({total_instance, 1}); + float_offset += total_instance; + h_tensor_ptrs[j] = feed->mutable_data(this->place_); + } else { + h_tensor_ptrs[j] = + feed->mutable_data({total_instance, 1}, this->place_); + } + } else if (info.type[0] == 'u') { // uint64 + if (total_instance > 0) { + feed->ShareDataWith(uint64_tensor.Slice( + static_cast(uint64_offset), + static_cast(uint64_offset + total_instance))); + feed->Resize({total_instance, 1}); + uint64_offset += total_instance; + h_tensor_ptrs[j] = feed->mutable_data(this->place_); + } else { + h_tensor_ptrs[j] = + feed->mutable_data({total_instance, 1}, this->place_); + } + } + + if (info.dense) { + if (info.inductive_shape_index != -1) { + info.local_shape[info.inductive_shape_index] = + total_instance / info.total_dims_without_inductive; + } + feed->Resize(phi::make_ddim(info.local_shape)); + } else { + LoD& lod = (*feed->mutable_lod()); + lod.resize(1); + lod[0].resize(offset_cols_size); + paddle::framework::MixVector mixv_lod(&lod[0]); + memcpy(mixv_lod.MutableData(platform::CPUPlace()), off_start_ptr, + offset_cols_size * sizeof(size_t)); + } + } + void** dest_gpu_p = reinterpret_cast(pack_->slot_buf_ptr()); + CUDA_CHECK(cudaMemcpy(dest_gpu_p, h_tensor_ptrs.data(), + use_slot_size_ * sizeof(void*), + cudaMemcpyHostToDevice)); + + CopyForTensor(ins_num, use_slot_size_, dest_gpu_p, + (const size_t*)pack_->gpu_slot_offsets(), + (const uint64_t*)value.d_uint64_keys.data(), + (const int*)value.d_uint64_offset.data(), + (const int*)value.d_uint64_lens.data(), uint64_use_slot_size_, + (const float*)value.d_float_keys.data(), + (const int*)value.d_float_offset.data(), + (const int*)value.d_float_lens.data(), float_use_slot_size_, + used_slot_gpu_types); +} + +MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place, + const std::vector& infos) { + place_ = place; + stream_ = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + + ins_num_ = 0; + pv_num_ = 0; + used_float_num_ = 0; + used_uint64_num_ = 0; + + used_slot_size_ = static_cast(infos.size()); + for (int i = 0; i < used_slot_size_; ++i) { + auto& info = infos[i]; + if (info.type[0] == 'u') { + gpu_used_slots_.push_back({1, info.slot_value_idx}); + ++used_uint64_num_; + } else { + gpu_used_slots_.push_back({0, info.slot_value_idx}); + ++used_float_num_; + } + } + copy_host2device(&gpu_slots_, gpu_used_slots_.data(), gpu_used_slots_.size()); + + slot_buf_ptr_ = memory::AllocShared(place_, used_slot_size_ * sizeof(void*)); + + int device_id = place_.GetDeviceId(); + VLOG(3) << "begin get batch pack device id: " << device_id; + // sync + CUDA_CHECK(cudaStreamSynchronize(stream_)); +} + +MiniBatchGpuPack::~MiniBatchGpuPack() {} + +void MiniBatchGpuPack::reset(const paddle::platform::Place& place) { + place_ = place; + stream_ = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + ins_num_ = 0; + pv_num_ = 0; +} + +void MiniBatchGpuPack::pack_all_data(const SlotRecord* ins_vec, int num) { + int uint64_total_num = 0; + int float_total_num = 0; + + buf_.h_uint64_lens.resize(num + 1); + buf_.h_uint64_lens[0] = 0; + buf_.h_float_lens.resize(num + 1); + buf_.h_float_lens[0] = 0; + + for (int i = 0; i < num; ++i) { + auto r = ins_vec[i]; + uint64_total_num += r->slot_uint64_feasigns_.slot_values.size(); + buf_.h_uint64_lens[i + 1] = uint64_total_num; + float_total_num += r->slot_float_feasigns_.slot_values.size(); + buf_.h_float_lens[i + 1] = float_total_num; + } + + int uint64_cols = (used_uint64_num_ + 1); + buf_.h_uint64_offset.resize(uint64_cols * num); + buf_.h_uint64_keys.resize(uint64_total_num); + + int float_cols = (used_float_num_ + 1); + buf_.h_float_offset.resize(float_cols * num); + buf_.h_float_keys.resize(float_total_num); + + size_t fea_num = 0; + uint64_total_num = 0; + float_total_num = 0; + for (int i = 0; i < num; ++i) { + auto r = ins_vec[i]; + auto& uint64_feasigns = r->slot_uint64_feasigns_; + fea_num = uint64_feasigns.slot_values.size(); + if (fea_num > 0) { + memcpy(&buf_.h_uint64_keys[uint64_total_num], + uint64_feasigns.slot_values.data(), fea_num * sizeof(uint64_t)); + } + uint64_total_num += fea_num; + // copy uint64 offset + memcpy(&buf_.h_uint64_offset[i * uint64_cols], + uint64_feasigns.slot_offsets.data(), sizeof(int) * uint64_cols); + + auto& float_feasigns = r->slot_float_feasigns_; + fea_num = float_feasigns.slot_values.size(); + memcpy(&buf_.h_float_keys[float_total_num], + float_feasigns.slot_values.data(), fea_num * sizeof(float)); + float_total_num += fea_num; + + // copy float offset + memcpy(&buf_.h_float_offset[i * float_cols], + float_feasigns.slot_offsets.data(), sizeof(int) * float_cols); + } + + CHECK(uint64_total_num == static_cast(buf_.h_uint64_lens.back())) + << "uint64 value length error"; + CHECK(float_total_num == static_cast(buf_.h_float_lens.back())) + << "float value length error"; +} +void MiniBatchGpuPack::pack_uint64_data(const SlotRecord* ins_vec, int num) { + int uint64_total_num = 0; + + buf_.h_float_lens.clear(); + buf_.h_float_keys.clear(); + buf_.h_float_offset.clear(); + + buf_.h_uint64_lens.resize(num + 1); + buf_.h_uint64_lens[0] = 0; + + for (int i = 0; i < num; ++i) { + auto r = ins_vec[i]; + uint64_total_num += r->slot_uint64_feasigns_.slot_values.size(); + buf_.h_uint64_lens[i + 1] = uint64_total_num; + } + + int uint64_cols = (used_uint64_num_ + 1); + buf_.h_uint64_offset.resize(uint64_cols * num); + buf_.h_uint64_keys.resize(uint64_total_num); + + size_t fea_num = 0; + uint64_total_num = 0; + for (int i = 0; i < num; ++i) { + auto r = ins_vec[i]; + auto& uint64_feasigns = r->slot_uint64_feasigns_; + fea_num = uint64_feasigns.slot_values.size(); + if (fea_num > 0) { + memcpy(&buf_.h_uint64_keys[uint64_total_num], + uint64_feasigns.slot_values.data(), fea_num * sizeof(uint64_t)); + } + uint64_total_num += fea_num; + // copy uint64 offset + memcpy(&buf_.h_uint64_offset[i * uint64_cols], + uint64_feasigns.slot_offsets.data(), sizeof(int) * uint64_cols); + } + CHECK(uint64_total_num == static_cast(buf_.h_uint64_lens.back())) + << "uint64 value length error"; +} +void MiniBatchGpuPack::pack_float_data(const SlotRecord* ins_vec, int num) { + int float_total_num = 0; + + buf_.h_uint64_lens.clear(); + buf_.h_uint64_offset.clear(); + buf_.h_uint64_keys.clear(); + + buf_.h_float_lens.resize(num + 1); + buf_.h_float_lens[0] = 0; + + for (int i = 0; i < num; ++i) { + auto r = ins_vec[i]; + float_total_num += r->slot_float_feasigns_.slot_values.size(); + buf_.h_float_lens[i + 1] = float_total_num; + } + + int float_cols = (used_float_num_ + 1); + buf_.h_float_offset.resize(float_cols * num); + buf_.h_float_keys.resize(float_total_num); + + size_t fea_num = 0; + float_total_num = 0; + for (int i = 0; i < num; ++i) { + auto r = ins_vec[i]; + auto& float_feasigns = r->slot_float_feasigns_; + fea_num = float_feasigns.slot_values.size(); + memcpy(&buf_.h_float_keys[float_total_num], + float_feasigns.slot_values.data(), fea_num * sizeof(float)); + float_total_num += fea_num; + + // copy float offset + memcpy(&buf_.h_float_offset[i * float_cols], + float_feasigns.slot_offsets.data(), sizeof(int) * float_cols); + } + CHECK(float_total_num == static_cast(buf_.h_float_lens.back())) + << "float value length error"; +} + +void MiniBatchGpuPack::pack_instance(const SlotRecord* ins_vec, int num) { + ins_num_ = num; + batch_ins_ = ins_vec; + CHECK(used_uint64_num_ > 0 || used_float_num_ > 0); + // uint64 and float + if (used_uint64_num_ > 0 && used_float_num_ > 0) { + pack_all_data(ins_vec, num); + } else if (used_uint64_num_ > 0) { // uint64 + pack_uint64_data(ins_vec, num); + } else { // only float + pack_float_data(ins_vec, num); + } + // to gpu + transfer_to_gpu(); +} + +void MiniBatchGpuPack::transfer_to_gpu(void) { + copy_host2device(&value_.d_uint64_lens, buf_.h_uint64_lens); + copy_host2device(&value_.d_uint64_keys, buf_.h_uint64_keys); + copy_host2device(&value_.d_uint64_offset, buf_.h_uint64_offset); + + copy_host2device(&value_.d_float_lens, buf_.h_float_lens); + copy_host2device(&value_.d_float_keys, buf_.h_float_keys); + copy_host2device(&value_.d_float_offset, buf_.h_float_offset); + CUDA_CHECK(cudaStreamSynchronize(stream_)); +} +#endif + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu new file mode 100644 index 00000000000..f9435ec2a32 --- /dev/null +++ b/paddle/fluid/framework/data_feed.cu @@ -0,0 +1,149 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined _WIN32 || defined __APPLE__ +#else +#define _LINUX +#endif +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) + +#include "paddle/fluid/framework/data_feed.h" + +namespace paddle { +namespace framework { + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +// CUDA: use 512 threads per block +const int CUDA_NUM_THREADS = 512; +// CUDA: number of blocks for threads. +inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} +// fill slot values +__global__ void FillSlotValueOffsetKernel( + const int ins_num, const int used_slot_num, size_t *slot_value_offsets, + const int *uint64_offsets, const int uint64_slot_size, + const int *float_offsets, const int float_slot_size, + const UsedSlotGpuType *used_slots) { + int col_num = ins_num + 1; + int uint64_cols = uint64_slot_size + 1; + int float_cols = float_slot_size + 1; + + CUDA_KERNEL_LOOP(slot_idx, used_slot_num) { + int value_off = slot_idx * col_num; + slot_value_offsets[value_off] = 0; + + auto &info = used_slots[slot_idx]; + if (info.is_uint64_value) { + for (int k = 0; k < ins_num; ++k) { + int pos = k * uint64_cols + info.slot_value_idx; + int num = uint64_offsets[pos + 1] - uint64_offsets[pos]; + PADDLE_ENFORCE(num >= 0, "The number of slot size must be ge 0."); + slot_value_offsets[value_off + k + 1] = + slot_value_offsets[value_off + k] + num; + } + } else { + for (int k = 0; k < ins_num; ++k) { + int pos = k * float_cols + info.slot_value_idx; + int num = float_offsets[pos + 1] - float_offsets[pos]; + PADDLE_ENFORCE(num >= 0, "The number of slot size must be ge 0."); + slot_value_offsets[value_off + k + 1] = + slot_value_offsets[value_off + k] + num; + } + } + } +} + +void SlotRecordInMemoryDataFeed::FillSlotValueOffset( + const int ins_num, const int used_slot_num, size_t *slot_value_offsets, + const int *uint64_offsets, const int uint64_slot_size, + const int *float_offsets, const int float_slot_size, + const UsedSlotGpuType *used_slots) { + auto stream = + dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(this->place_)) + ->stream(); + FillSlotValueOffsetKernel<<>>( + ins_num, used_slot_num, slot_value_offsets, uint64_offsets, + uint64_slot_size, float_offsets, float_slot_size, used_slots); + cudaStreamSynchronize(stream); +} + +__global__ void CopyForTensorKernel( + const int used_slot_num, const int ins_num, void **dest, + const size_t *slot_value_offsets, const uint64_t *uint64_feas, + const int *uint64_offsets, const int *uint64_ins_lens, + const int uint64_slot_size, const float *float_feas, + const int *float_offsets, const int *float_ins_lens, + const int float_slot_size, const UsedSlotGpuType *used_slots) { + int col_num = ins_num + 1; + int uint64_cols = uint64_slot_size + 1; + int float_cols = float_slot_size + 1; + + CUDA_KERNEL_LOOP(i, ins_num * used_slot_num) { + int slot_idx = i / ins_num; + int ins_idx = i % ins_num; + + uint32_t value_offset = slot_value_offsets[slot_idx * col_num + ins_idx]; + auto &info = used_slots[slot_idx]; + if (info.is_uint64_value) { + uint64_t *up = reinterpret_cast(dest[slot_idx]); + int index = info.slot_value_idx + uint64_cols * ins_idx; + int old_off = uint64_offsets[index]; + int num = uint64_offsets[index + 1] - old_off; + PADDLE_ENFORCE(num >= 0, "The number of slot size must be ge 0."); + int uint64_value_offset = uint64_ins_lens[ins_idx]; + for (int k = 0; k < num; ++k) { + up[k + value_offset] = uint64_feas[k + old_off + uint64_value_offset]; + } + } else { + float *fp = reinterpret_cast(dest[slot_idx]); + int index = info.slot_value_idx + float_cols * ins_idx; + int old_off = float_offsets[index]; + int num = float_offsets[index + 1] - old_off; + PADDLE_ENFORCE(num >= 0, "The number of slot size must be ge 0."); + int float_value_offset = float_ins_lens[ins_idx]; + for (int k = 0; k < num; ++k) { + fp[k + value_offset] = float_feas[k + old_off + float_value_offset]; + } + } + } +} + +void SlotRecordInMemoryDataFeed::CopyForTensor( + const int ins_num, const int used_slot_num, void **dest, + const size_t *slot_value_offsets, const uint64_t *uint64_feas, + const int *uint64_offsets, const int *uint64_ins_lens, + const int uint64_slot_size, const float *float_feas, + const int *float_offsets, const int *float_ins_lens, + const int float_slot_size, const UsedSlotGpuType *used_slots) { + auto stream = + dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(this->place_)) + ->stream(); + + CopyForTensorKernel<<>>( + used_slot_num, ins_num, dest, slot_value_offsets, uint64_feas, + uint64_offsets, uint64_ins_lens, uint64_slot_size, float_feas, + float_offsets, float_ins_lens, float_slot_size, used_slots); + cudaStreamSynchronize(stream); +} + +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index eb6ed268809..6f7f1dac528 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -41,6 +41,10 @@ limitations under the License. */ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/timer.h" #include "paddle/fluid/string/string_helper.h" +#if defined(PADDLE_WITH_CUDA) +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#endif DECLARE_int32(record_pool_max_size); DECLARE_int32(slotpool_thread_num); @@ -409,6 +413,266 @@ class CustomParser { } }; +struct UsedSlotGpuType { + int is_uint64_value; + int slot_value_idx; +}; + +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) +#define CUDA_CHECK(val) CHECK(val == gpuSuccess) +template +struct CudaBuffer { + T* cu_buffer; + uint64_t buf_size; + + CudaBuffer() { + cu_buffer = NULL; + buf_size = 0; + } + ~CudaBuffer() { free(); } + T* data() { return cu_buffer; } + uint64_t size() { return buf_size; } + void malloc(uint64_t size) { + buf_size = size; + CUDA_CHECK( + cudaMalloc(reinterpret_cast(&cu_buffer), size * sizeof(T))); + } + void free() { + if (cu_buffer != NULL) { + CUDA_CHECK(cudaFree(cu_buffer)); + cu_buffer = NULL; + } + buf_size = 0; + } + void resize(uint64_t size) { + if (size <= buf_size) { + return; + } + free(); + malloc(size); + } +}; +template +struct HostBuffer { + T* host_buffer; + size_t buf_size; + size_t data_len; + + HostBuffer() { + host_buffer = NULL; + buf_size = 0; + data_len = 0; + } + ~HostBuffer() { free(); } + + T* data() { return host_buffer; } + const T* data() const { return host_buffer; } + size_t size() const { return data_len; } + void clear() { free(); } + T& back() { return host_buffer[data_len - 1]; } + + T& operator[](size_t i) { return host_buffer[i]; } + const T& operator[](size_t i) const { return host_buffer[i]; } + void malloc(size_t len) { + buf_size = len; + CUDA_CHECK(cudaHostAlloc(reinterpret_cast(&host_buffer), + buf_size * sizeof(T), cudaHostAllocDefault)); + CHECK(host_buffer != NULL); + } + void free() { + if (host_buffer != NULL) { + CUDA_CHECK(cudaFreeHost(host_buffer)); + host_buffer = NULL; + } + buf_size = 0; + } + void resize(size_t size) { + if (size <= buf_size) { + data_len = size; + return; + } + data_len = size; + free(); + malloc(size); + } +}; + +struct BatchCPUValue { + HostBuffer h_uint64_lens; + HostBuffer h_uint64_keys; + HostBuffer h_uint64_offset; + + HostBuffer h_float_lens; + HostBuffer h_float_keys; + HostBuffer h_float_offset; + + HostBuffer h_rank; + HostBuffer h_cmatch; + HostBuffer h_ad_offset; +}; + +struct BatchGPUValue { + CudaBuffer d_uint64_lens; + CudaBuffer d_uint64_keys; + CudaBuffer d_uint64_offset; + + CudaBuffer d_float_lens; + CudaBuffer d_float_keys; + CudaBuffer d_float_offset; + + CudaBuffer d_rank; + CudaBuffer d_cmatch; + CudaBuffer d_ad_offset; +}; + +class MiniBatchGpuPack { + public: + MiniBatchGpuPack(const paddle::platform::Place& place, + const std::vector& infos); + ~MiniBatchGpuPack(); + void reset(const paddle::platform::Place& place); + void pack_instance(const SlotRecord* ins_vec, int num); + int ins_num() { return ins_num_; } + int pv_num() { return pv_num_; } + BatchGPUValue& value() { return value_; } + BatchCPUValue& cpu_value() { return buf_; } + UsedSlotGpuType* get_gpu_slots(void) { + return reinterpret_cast(gpu_slots_.data()); + } + SlotRecord* get_records(void) { return &ins_vec_[0]; } + + // tensor gpu memory reused + void resize_tensor(void) { + if (used_float_num_ > 0) { + int float_total_len = buf_.h_float_lens.back(); + if (float_total_len > 0) { + float_tensor_.mutable_data({float_total_len, 1}, this->place_); + } + } + if (used_uint64_num_ > 0) { + int uint64_total_len = buf_.h_uint64_lens.back(); + if (uint64_total_len > 0) { + uint64_tensor_.mutable_data({uint64_total_len, 1}, + this->place_); + } + } + } + LoDTensor& float_tensor(void) { return float_tensor_; } + LoDTensor& uint64_tensor(void) { return uint64_tensor_; } + + HostBuffer& offsets(void) { return offsets_; } + HostBuffer& h_tensor_ptrs(void) { return h_tensor_ptrs_; } + + void* gpu_slot_offsets(void) { return gpu_slot_offsets_->ptr(); } + + void* slot_buf_ptr(void) { return slot_buf_ptr_->ptr(); } + + void resize_gpu_slot_offsets(const size_t slot_total_bytes) { + if (gpu_slot_offsets_ == nullptr) { + gpu_slot_offsets_ = memory::AllocShared(place_, slot_total_bytes); + } else if (gpu_slot_offsets_->size() < slot_total_bytes) { + auto buf = memory::AllocShared(place_, slot_total_bytes); + gpu_slot_offsets_.swap(buf); + buf = nullptr; + } + } + const std::string& get_lineid(int idx) { + if (enable_pv_) { + return ins_vec_[idx]->ins_id_; + } + return batch_ins_[idx]->ins_id_; + } + + private: + void transfer_to_gpu(void); + void pack_all_data(const SlotRecord* ins_vec, int num); + void pack_uint64_data(const SlotRecord* ins_vec, int num); + void pack_float_data(const SlotRecord* ins_vec, int num); + + public: + template + void copy_host2device(CudaBuffer* buf, const T* val, size_t size) { + if (size == 0) { + return; + } + buf->resize(size); + CUDA_CHECK(cudaMemcpyAsync(buf->data(), val, size * sizeof(T), + cudaMemcpyHostToDevice, stream_)); + } + template + void copy_host2device(CudaBuffer* buf, const HostBuffer& val) { + copy_host2device(buf, val.data(), val.size()); + } + + private: + paddle::platform::Place place_; + cudaStream_t stream_; + BatchGPUValue value_; + BatchCPUValue buf_; + int ins_num_ = 0; + int pv_num_ = 0; + + bool enable_pv_ = false; + int used_float_num_ = 0; + int used_uint64_num_ = 0; + int used_slot_size_ = 0; + + CudaBuffer gpu_slots_; + std::vector gpu_used_slots_; + std::vector ins_vec_; + const SlotRecord* batch_ins_ = nullptr; + + // uint64 tensor + LoDTensor uint64_tensor_; + // float tensor + LoDTensor float_tensor_; + // batch + HostBuffer offsets_; + HostBuffer h_tensor_ptrs_; + + std::shared_ptr gpu_slot_offsets_ = nullptr; + std::shared_ptr slot_buf_ptr_ = nullptr; +}; +class MiniBatchGpuPackMgr { + static const int MAX_DEIVCE_NUM = 16; + + public: + MiniBatchGpuPackMgr() { + for (int i = 0; i < MAX_DEIVCE_NUM; ++i) { + pack_list_[i] = nullptr; + } + } + ~MiniBatchGpuPackMgr() { + for (int i = 0; i < MAX_DEIVCE_NUM; ++i) { + if (pack_list_[i] == nullptr) { + continue; + } + delete pack_list_[i]; + pack_list_[i] = nullptr; + } + } + // one device one thread + MiniBatchGpuPack* get(const paddle::platform::Place& place, + const std::vector& infos) { + int device_id = place.GetDeviceId(); + if (pack_list_[device_id] == nullptr) { + pack_list_[device_id] = new MiniBatchGpuPack(place, infos); + } else { + pack_list_[device_id]->reset(place); + } + return pack_list_[device_id]; + } + + private: + MiniBatchGpuPack* pack_list_[MAX_DEIVCE_NUM]; +}; +// global mgr +inline MiniBatchGpuPackMgr& BatchGpuPackMgr() { + static MiniBatchGpuPackMgr mgr; + return mgr; +} +#endif + typedef paddle::framework::CustomParser* (*CreateParserObjectFunc)(); class DLManager { @@ -1126,7 +1390,13 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed { class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { public: SlotRecordInMemoryDataFeed() {} - virtual ~SlotRecordInMemoryDataFeed() {} + virtual ~SlotRecordInMemoryDataFeed() { +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) + if (pack_ != nullptr) { + pack_ = nullptr; + } +#endif + } virtual void Init(const DataFeedDesc& data_feed_desc); virtual void LoadIntoMemory(); void ExpandSlotRecord(SlotRecord* ins); @@ -1149,6 +1419,23 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { } bool ParseOneInstance(const std::string& line, SlotRecord* rec); virtual void PutToFeedVec(const SlotRecord* ins_vec, int num); + virtual void AssignFeedVar(const Scope& scope); +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) + void BuildSlotBatchGPU(const int ins_num); + void FillSlotValueOffset(const int ins_num, const int used_slot_num, + size_t* slot_value_offsets, + const int* uint64_offsets, + const int uint64_slot_size, const int* float_offsets, + const int float_slot_size, + const UsedSlotGpuType* used_slots); + void CopyForTensor(const int ins_num, const int used_slot_num, void** dest, + const size_t* slot_value_offsets, + const uint64_t* uint64_feas, const int* uint64_offsets, + const int* uint64_ins_lens, const int uint64_slot_size, + const float* float_feas, const int* float_offsets, + const int* float_ins_lens, const int float_slot_size, + const UsedSlotGpuType* used_slots); +#endif float sample_rate_ = 1.0f; int use_slot_size_ = 0; int float_use_slot_size_ = 0; @@ -1157,6 +1444,10 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { std::vector used_slots_info_; size_t float_total_dims_size_ = 0; std::vector float_total_dims_without_inductives_; + +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) + MiniBatchGpuPack* pack_ = nullptr; +#endif }; class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index c7852de00a1..e167a39caa5 100755 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -271,13 +271,13 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { } timeline.Pause(); - VLOG(1) << "GpuPs task add keys cost " << timeline.ElapsedSec() + VLOG(0) << "GpuPs task add keys cost " << timeline.ElapsedSec() << " seconds."; timeline.Start(); gpu_task->UniqueKeys(); timeline.Pause(); - VLOG(1) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; + VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; if (!multi_mf_dim_) { for (int i = 0; i < thread_keys_shard_num_; i++) { @@ -667,7 +667,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { if (!multi_mf_dim_) { for (int i = 0; i < device_num; i++) { feature_keys_count[i] = gpu_task->device_keys_[i].size(); - VLOG(1) << i << " card contains feasign nums: " << feature_keys_count[i]; + VLOG(0) << i << " card contains feasign nums: " << feature_keys_count[i]; size_max = std::max(size_max, feature_keys_count[i]); } } else { @@ -675,7 +675,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { for (int j = 0; j < multi_mf_dim_; j++) { feature_keys_count[i] += gpu_task->device_dim_ptr_[i][j].size(); } - VLOG(1) << i << " card with dynamic mf contains feasign nums: " + VLOG(0) << i << " card with dynamic mf contains feasign nums: " << feature_keys_count[i]; size_max = std::max(size_max, feature_keys_count[i]); } @@ -685,7 +685,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { HeterPs_ = nullptr; } if (size_max <= 0) { - VLOG(1) << "Skip build gpu ps cause feasign nums = " << size_max; + VLOG(0) << "Skip build gpu ps cause feasign nums = " << size_max; return; } std::vector threads(device_num); @@ -707,7 +707,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { t.join(); } timeline.Pause(); - VLOG(1) << "GpuPs build table total costs: " << timeline.ElapsedSec() + VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec() << " s."; } @@ -749,7 +749,7 @@ void PSGPUWrapper::pre_build_thread() { // build cpu ps data process PreBuildTask(gpu_task); timer.Pause(); - VLOG(1) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec() + VLOG(0) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec() << "s"; buildcpu_ready_channel_->Put(gpu_task); } @@ -768,13 +768,13 @@ void PSGPUWrapper::build_task() { return; } - VLOG(1) << "BuildPull start."; + VLOG(0) << "BuildPull start."; platform::Timer timer; timer.Start(); BuildPull(gpu_task); BuildGPUTask(gpu_task); timer.Pause(); - VLOG(1) << "BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec() + VLOG(0) << "BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec() << "s"; current_task_ = gpu_task; diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc index dc8935587e9..d98deb0f188 100644 --- a/paddle/fluid/framework/ps_gpu_worker.cc +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -119,6 +119,7 @@ void PSGPUWorker::SetChannelWriter(ChannelObject* queue) { } void PSGPUWorker::TrainFiles() { + VLOG(0) << "Begin to train files"; platform::SetNumThreads(1); platform::Timer timeline; timeline.Start(); @@ -129,6 +130,8 @@ void PSGPUWorker::TrainFiles() { device_reader_->Start(); int cur_batch; int batch_cnt = 0; + + platform::SetDeviceId(thread_id_); while ((cur_batch = device_reader_->Next()) > 0) { total_ins_num += cur_batch; for (auto& op : ops_) { @@ -190,14 +193,14 @@ void PSGPUWorker::TrainFiles() { writer_.Flush(); } timeline.Pause(); - VLOG(1) << "GpuPs worker " << thread_id_ << " train cost " + VLOG(0) << "GpuPs worker " << thread_id_ << " train cost " << timeline.ElapsedSec() << " seconds, ins_num: " << total_ins_num; return; } void PSGPUWorker::TrainFilesWithProfiler() { platform::SetNumThreads(1); - VLOG(1) << "Begin to train files with profiler"; + VLOG(0) << "Begin to train files with profiler"; device_reader_->Start(); std::vector op_total_time; std::vector op_name; @@ -225,6 +228,7 @@ void PSGPUWorker::TrainFilesWithProfiler() { int total_ins_num = 0; int cur_batch; timeline.Start(); + platform::SetDeviceId(thread_id_); while ((cur_batch = device_reader_->Next()) > 0) { total_ins_num += cur_batch; timeline.Pause(); @@ -260,13 +264,15 @@ void PSGPUWorker::TrainFilesWithProfiler() { total_time += timeline.ElapsedSec(); timeline.Start(); } - VLOG(1) << "GpuPs worker " << thread_id_ << " train cost " << total_time + VLOG(0) << "GpuPs worker " << thread_id_ << " train cost " << total_time << " seconds, ins_num: " << total_ins_num; for (size_t i = 0; i < op_name.size(); ++i) { - VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i] + VLOG(0) << "card:" << thread_id_ << ", op: " << op_name[i] << ", mean time: " << op_total_time[i] / total_ins_num << "s, totol time:" << op_total_time[i] << "sec"; } + VLOG(0) << "card: " << thread_id_ << " read time: " << read_time + << ", percent: " << read_time / total_time * 100; return; } -- GitLab