未验证 提交 effe2c11 编写于 作者: P pangengzheng 提交者: GitHub

Speedup datafeed (#51624)

上级 6f86c96b
......@@ -2001,6 +2001,20 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<Record*>& ins_vec) {
#endif
}
SlotRecordInMemoryDataFeed::~SlotRecordInMemoryDataFeed() {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
stop_token_.store(true);
for (auto& thread : pack_threads_) {
if (thread.joinable()) {
thread.join();
}
}
for (auto* pack : pack_vec_) {
pack->set_use_flag(false);
}
#endif
}
template class InMemoryDataFeed<SlotRecord>;
void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) {
finish_init_ = false;
......@@ -2513,9 +2527,7 @@ void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec,
}
}
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
paddle::platform::SetDeviceId(place_.GetDeviceId());
pack_->pack_instance(ins_vec, num);
BuildSlotBatchGPU(pack_->ins_num());
// do nothing
#else
for (int j = 0; j < use_slot_size_; ++j) {
auto& feed = feed_vec_[j];
......@@ -2658,7 +2670,7 @@ void SlotRecordInMemoryDataFeed::ExpandSlotRecord(SlotRecord* rec) {
}
bool SlotRecordInMemoryDataFeed::Start() {
VLOG(4) << "entering SlotRecordInMemoryDataFeed::Start";
VLOG(3) << "entering SlotRecordInMemoryDataFeed::Start";
#ifdef _LINUX
this->CheckSetFileList();
if (input_channel_->Size() != 0) {
......@@ -2674,7 +2686,40 @@ bool SlotRecordInMemoryDataFeed::Start() {
this->finish_start_ = true;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
CHECK(paddle::platform::is_gpu_place(this->place_));
pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_);
for (int i = 0; i < pack_thread_num_ + 1; i++) {
auto pack = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_);
pack_vec_.push_back(pack);
free_pack_queue_.Push(pack);
}
pack_offset_index_.store(0);
pack_is_end_.store(false);
thread_count_.store(pack_thread_num_);
pack_threads_.reserve(pack_thread_num_);
for (int i = 0; i < pack_thread_num_; i++) {
pack_threads_.emplace_back(std::thread([this]() -> void {
while (!stop_token_.load()) {
uint64_t offset_index = pack_offset_index_.fetch_add(1);
if (offset_index >= batch_offsets_.size()) {
int thread_num = thread_count_.fetch_sub(1);
if (thread_num == 1) {
pack_is_end_.store(true);
}
return;
}
auto* pack = free_pack_queue_.Pop();
auto& batch = batch_offsets_[offset_index];
auto offset = batch.first;
auto batch_size = batch.second;
paddle::platform::SetDeviceId(place_.GetDeviceId());
pack->pack_instance(&records_[offset], batch_size);
this->BuildSlotBatchGPU(batch_size, pack);
using_pack_queue_.Push(pack);
}
}));
}
#endif
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
gpu_graph_data_generator_.SetFeedVec(feed_vec_);
......@@ -2686,6 +2731,27 @@ int SlotRecordInMemoryDataFeed::Next() {
#ifdef _LINUX
this->CheckStart();
if (!gpu_graph_mode_) {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
while (true) {
if (last_pack_ != nullptr) {
free_pack_queue_.Push(last_pack_);
last_pack_ = nullptr;
}
if (using_pack_queue_.Size() != 0) {
auto* pack = using_pack_queue_.Pop();
PackToScope(pack);
last_pack_ = pack;
return pack->ins_num();
}
bool is_end = pack_is_end_.load();
if (is_end) {
if (using_pack_queue_.Size() == 0) {
return 0;
}
}
std::this_thread::sleep_for(std::chrono::microseconds(200));
}
#else
VLOG(3) << "enable heter next: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
if (offset_index_ >= batch_offsets_.size()) {
......@@ -2703,9 +2769,7 @@ int SlotRecordInMemoryDataFeed::Next() {
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
<< thread_id_;
}
VLOG(3) << "enable heter next: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size()
<< " baych_size: " << this->batch_size_;
#endif
} else {
VLOG(3) << "datafeed in gpu graph mode";
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
......@@ -2736,47 +2800,59 @@ void SlotRecordInMemoryDataFeed::DumpWalkPath(std::string dump_path,
}
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num,
MiniBatchGpuPack* pack) {
int offset_cols_size = (ins_num + 1);
size_t slot_total_num = (use_slot_size_ * offset_cols_size);
pack_->resize_gpu_slot_offsets(slot_total_num * sizeof(size_t));
pack->resize_gpu_slot_offsets(slot_total_num * sizeof(size_t));
auto& value = pack_->value();
auto& value = pack->value();
const UsedSlotGpuType* used_slot_gpu_types =
static_cast<const UsedSlotGpuType*>(pack_->get_gpu_slots());
static_cast<const UsedSlotGpuType*>(pack->get_gpu_slots());
FillSlotValueOffset(ins_num,
use_slot_size_,
reinterpret_cast<size_t*>(pack_->gpu_slot_offsets()),
reinterpret_cast<size_t*>(pack->gpu_slot_offsets()),
value.d_uint64_offset.data(),
uint64_use_slot_size_,
value.d_float_offset.data(),
float_use_slot_size_,
used_slot_gpu_types);
size_t* d_slot_offsets = reinterpret_cast<size_t*>(pack_->gpu_slot_offsets());
used_slot_gpu_types,
pack->get_stream());
size_t* d_slot_offsets = reinterpret_cast<size_t*>(pack->gpu_slot_offsets());
HostBuffer<size_t>& offsets = pack_->offsets();
HostBuffer<size_t>& offsets = pack->offsets();
offsets.resize(slot_total_num);
HostBuffer<void*>& h_tensor_ptrs = pack_->h_tensor_ptrs();
HostBuffer<void*>& h_tensor_ptrs = pack->h_tensor_ptrs();
h_tensor_ptrs.resize(use_slot_size_);
// alloc gpu memory
pack_->resize_tensor();
pack->resize_tensor();
phi::DenseTensor& float_tensor = pack_->float_tensor();
phi::DenseTensor& uint64_tensor = pack_->uint64_tensor();
phi::DenseTensor& float_tensor = pack->float_tensor();
phi::DenseTensor& uint64_tensor = pack->uint64_tensor();
int64_t float_offset = 0;
int64_t uint64_offset = 0;
size_t float_zero_slot_index = 0;
size_t uint64_zero_slot_index = 0;
// copy index
CUDA_CHECK(cudaMemcpy(offsets.data(),
d_slot_offsets,
slot_total_num * sizeof(size_t),
cudaMemcpyDeviceToHost));
auto* dev_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(this->place_));
for (int j = 0; j < use_slot_size_; ++j) {
auto& feed = feed_vec_[j];
if (feed == nullptr) {
h_tensor_ptrs[j] = nullptr;
continue;
if (scpoe_feed_vec_.size() > 0) {
if (scpoe_feed_vec_.begin()->second[j] == nullptr) {
h_tensor_ptrs[j] = nullptr;
continue;
}
} else {
if (feed_vec_[j] == nullptr) {
h_tensor_ptrs[j] = nullptr;
continue;
}
}
size_t* off_start_ptr = &offsets[j * offset_cols_size];
......@@ -2786,6 +2862,85 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
<< "slot idx:" << j << ", total instance:" << total_instance;
auto& info = used_slots_info_[j];
// fill slot value with default value 0
if (info.type[0] == 'f') { // float
if (total_instance > 0) {
h_tensor_ptrs[j] = float_tensor.data<float>() + float_offset;
float_offset += total_instance;
} else {
phi::DenseTensor& f_tensor =
pack->float_tensor_vec()[float_zero_slot_index];
f_tensor.Resize({total_instance, 1});
dev_ctx->Alloc<float>(&f_tensor);
h_tensor_ptrs[j] = f_tensor.data<float>();
float_zero_slot_index++;
}
} else if (info.type[0] == 'u') { // uint64
if (total_instance > 0) {
h_tensor_ptrs[j] = uint64_tensor.data<int64_t>() + uint64_offset;
uint64_offset += total_instance;
} else {
phi::DenseTensor& i_tensor =
pack->uint64_tensor_vec()[uint64_zero_slot_index];
i_tensor.Resize({total_instance, 1});
dev_ctx->Alloc<int64_t>(&i_tensor);
h_tensor_ptrs[j] = i_tensor.data<int64_t>();
uint64_zero_slot_index++;
}
}
}
void** dest_gpu_p = reinterpret_cast<void**>(pack->slot_buf_ptr());
CUDA_CHECK(cudaMemcpyAsync(dest_gpu_p,
h_tensor_ptrs.data(),
use_slot_size_ * sizeof(void*),
cudaMemcpyHostToDevice,
pack->get_stream()));
CopyForTensor(ins_num,
use_slot_size_,
dest_gpu_p,
(const size_t*)pack->gpu_slot_offsets(),
(const uint64_t*)value.d_uint64_keys.data(),
(const int*)value.d_uint64_offset.data(),
(const int*)value.d_uint64_lens.data(),
uint64_use_slot_size_,
(const float*)value.d_float_keys.data(),
(const int*)value.d_float_offset.data(),
(const int*)value.d_float_lens.data(),
float_use_slot_size_,
used_slot_gpu_types,
pack->get_stream());
}
void SlotRecordInMemoryDataFeed::PackToScope(MiniBatchGpuPack* pack,
const Scope* scope) {
int64_t float_offset = 0;
int64_t uint64_offset = 0;
size_t float_zero_slot_index = 0;
size_t uint64_zero_slot_index = 0;
int offset_cols_size = (pack->ins_num() + 1);
HostBuffer<size_t>& offsets = pack->offsets();
phi::DenseTensor& float_tensor = pack->float_tensor();
phi::DenseTensor& uint64_tensor = pack->uint64_tensor();
auto* feed_vec = &feed_vec_;
if (scope) {
CHECK(scpoe_feed_vec_.count(scope) > 0) << "scope not found.";
feed_vec = &scpoe_feed_vec_[scope];
}
CHECK(feed_vec != nullptr) << "feed_vec nullptr.";
for (int j = 0; j < use_slot_size_; ++j) {
auto& feed = (*feed_vec)[j];
if (feed == nullptr) {
continue;
}
size_t* off_start_ptr = &offsets[j * offset_cols_size];
int total_instance = static_cast<int>(off_start_ptr[offset_cols_size - 1]);
auto& info = used_slots_info_[j];
// fill slot value with default value 0
if (info.type[0] == 'f') { // float
if (total_instance > 0) {
......@@ -2794,10 +2949,9 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
static_cast<int64_t>(float_offset + total_instance)));
feed->Resize({total_instance, 1});
float_offset += total_instance;
h_tensor_ptrs[j] = feed->mutable_data<float>(this->place_);
} else {
h_tensor_ptrs[j] =
feed->mutable_data<float>({total_instance, 1}, this->place_);
feed->ShareDataWith(pack->float_tensor_vec()[float_zero_slot_index++]);
feed->Resize({total_instance, 1});
}
} else if (info.type[0] == 'u') { // uint64
if (total_instance > 0) {
......@@ -2806,10 +2960,10 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
static_cast<int64_t>(uint64_offset + total_instance)));
feed->Resize({total_instance, 1});
uint64_offset += total_instance;
h_tensor_ptrs[j] = feed->mutable_data<int64_t>(this->place_);
} else {
h_tensor_ptrs[j] =
feed->mutable_data<int64_t>({total_instance, 1}, this->place_);
feed->ShareDataWith(
pack->uint64_tensor_vec()[uint64_zero_slot_index++]);
feed->Resize({total_instance, 1});
}
}
......@@ -2829,33 +2983,14 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
offset_cols_size * sizeof(size_t));
}
}
void** dest_gpu_p = reinterpret_cast<void**>(pack_->slot_buf_ptr());
CUDA_CHECK(cudaMemcpy(dest_gpu_p,
h_tensor_ptrs.data(),
use_slot_size_ * sizeof(void*),
cudaMemcpyHostToDevice));
CopyForTensor(ins_num,
use_slot_size_,
dest_gpu_p,
(const size_t*)pack_->gpu_slot_offsets(),
(const uint64_t*)value.d_uint64_keys.data(),
(const int*)value.d_uint64_offset.data(),
(const int*)value.d_uint64_lens.data(),
uint64_use_slot_size_,
(const float*)value.d_float_keys.data(),
(const int*)value.d_float_offset.data(),
(const int*)value.d_float_lens.data(),
float_use_slot_size_,
used_slot_gpu_types);
}
MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place,
const std::vector<UsedSlotInfo>& infos) {
const std::vector<UsedSlotInfo>& infos,
phi::StreamId stream_id) {
place_ = place;
stream_ = dynamic_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
stream_holder_.reset(new phi::CUDAStream(place));
stream_ = stream_holder_->raw_stream();
ins_num_ = 0;
pv_num_ = 0;
......@@ -2881,15 +3016,16 @@ MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place,
VLOG(3) << "begin get batch pack device id: " << device_id;
// sync
CUDA_CHECK(cudaStreamSynchronize(stream_));
float_tensor_vec_.resize(used_slot_size_);
uint64_tensor_vec_.resize(used_slot_size_);
}
MiniBatchGpuPack::~MiniBatchGpuPack() {}
void MiniBatchGpuPack::reset(const paddle::platform::Place& place) {
place_ = place;
stream_ = dynamic_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
stream_holder_.reset(new phi::CUDAStream(place));
stream_ = stream_holder_->raw_stream();
ins_num_ = 0;
pv_num_ = 0;
}
......
......@@ -320,11 +320,8 @@ void SlotRecordInMemoryDataFeed::FillSlotValueOffset(
const int uint64_slot_size,
const int *float_offsets,
const int float_slot_size,
const UsedSlotGpuType *used_slots) {
auto stream =
dynamic_cast<phi::GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(this->place_))
->stream();
const UsedSlotGpuType *used_slots,
cudaStream_t stream) {
FillSlotValueOffsetKernel<<<GET_BLOCKS(used_slot_num),
CUDA_NUM_THREADS,
0,
......@@ -399,12 +396,8 @@ void SlotRecordInMemoryDataFeed::CopyForTensor(
const int *float_offsets,
const int *float_ins_lens,
const int float_slot_size,
const UsedSlotGpuType *used_slots) {
auto stream =
dynamic_cast<phi::GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(this->place_))
->stream();
const UsedSlotGpuType *used_slots,
cudaStream_t stream) {
CopyForTensorKernel<<<GET_BLOCKS(used_slot_num * ins_num),
CUDA_NUM_THREADS,
0,
......
......@@ -46,6 +46,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/phi/core/cuda_stream.h"
#endif
DECLARE_int32(record_pool_max_size);
......@@ -535,8 +536,11 @@ struct BatchGPUValue {
class MiniBatchGpuPack {
public:
MiniBatchGpuPack(const paddle::platform::Place& place,
const std::vector<UsedSlotInfo>& infos);
const std::vector<UsedSlotInfo>& infos,
phi::StreamId stream_id);
~MiniBatchGpuPack();
bool is_use() { return is_using_; }
void set_use_flag(bool is_use) { is_using_ = is_use; }
void reset(const paddle::platform::Place& place);
void pack_instance(const SlotRecord* ins_vec, int num);
int ins_num() { return ins_num_; }
......@@ -566,6 +570,12 @@ class MiniBatchGpuPack {
}
phi::DenseTensor& float_tensor(void) { return float_tensor_; }
phi::DenseTensor& uint64_tensor(void) { return uint64_tensor_; }
std::vector<phi::DenseTensor>& float_tensor_vec(void) {
return float_tensor_vec_;
}
std::vector<phi::DenseTensor>& uint64_tensor_vec(void) {
return uint64_tensor_vec_;
}
HostBuffer<size_t>& offsets(void) { return offsets_; }
HostBuffer<void*>& h_tensor_ptrs(void) { return h_tensor_ptrs_; }
......@@ -590,6 +600,8 @@ class MiniBatchGpuPack {
return batch_ins_[idx]->ins_id_;
}
cudaStream_t get_stream() { return stream_; }
private:
void transfer_to_gpu(void);
void pack_all_data(const SlotRecord* ins_vec, int num);
......@@ -612,7 +624,9 @@ class MiniBatchGpuPack {
}
private:
bool is_using_ = false;
paddle::platform::Place place_;
std::unique_ptr<phi::CUDAStream> stream_holder_;
cudaStream_t stream_;
BatchGPUValue value_;
BatchCPUValue buf_;
......@@ -631,8 +645,10 @@ class MiniBatchGpuPack {
// uint64 tensor
phi::DenseTensor uint64_tensor_;
std::vector<phi::DenseTensor> uint64_tensor_vec_;
// float tensor
phi::DenseTensor float_tensor_;
std::vector<phi::DenseTensor> float_tensor_vec_;
// batch
HostBuffer<size_t> offsets_;
HostBuffer<void*> h_tensor_ptrs_;
......@@ -645,33 +661,52 @@ class MiniBatchGpuPackMgr {
public:
MiniBatchGpuPackMgr() {
pack_list_.resize(MAX_DEIVCE_NUM);
for (int i = 0; i < MAX_DEIVCE_NUM; ++i) {
pack_list_[i] = nullptr;
pack_list_[i].clear();
}
}
~MiniBatchGpuPackMgr() {
for (int i = 0; i < MAX_DEIVCE_NUM; ++i) {
if (pack_list_[i] == nullptr) {
continue;
for (size_t j = 0; j < pack_list_[i].size(); j++) {
if (pack_list_[i][j] == nullptr) {
continue;
}
delete pack_list_[i][j];
pack_list_[i][j] = nullptr;
}
delete pack_list_[i];
pack_list_[i] = nullptr;
}
}
// one device one thread
// thread unsafe
MiniBatchGpuPack* get(const paddle::platform::Place& place,
const std::vector<UsedSlotInfo>& infos) {
int device_id = place.GetDeviceId();
if (pack_list_[device_id] == nullptr) {
pack_list_[device_id] = new MiniBatchGpuPack(place, infos);
} else {
pack_list_[device_id]->reset(place);
for (size_t i = 0; i < pack_list_[device_id].size(); i++) {
if (!pack_list_[device_id][i]->is_use()) {
pack_list_[device_id][i]->set_use_flag(true);
pack_list_[device_id][i]->reset(place);
return pack_list_[device_id][i];
}
}
return pack_list_[device_id];
{
std::lock_guard<std::mutex> lock(mutex_);
if (!alloc_stream_map_.count(device_id)) {
alloc_stream_map_.emplace(device_id, new phi::CUDAStream(place));
}
}
phi::StreamId alloc_stream_id = reinterpret_cast<phi::StreamId>(
alloc_stream_map_[device_id]->raw_stream());
auto* pack = new MiniBatchGpuPack(place, infos, alloc_stream_id);
pack->set_use_flag(true);
pack_list_[device_id].push_back(pack);
return pack;
}
private:
MiniBatchGpuPack* pack_list_[MAX_DEIVCE_NUM];
std::vector<std::vector<MiniBatchGpuPack*>> pack_list_;
std::unordered_map<int, std::unique_ptr<phi::CUDAStream>> alloc_stream_map_;
std::mutex mutex_;
};
// global mgr
inline MiniBatchGpuPackMgr& BatchGpuPackMgr() {
......@@ -1212,6 +1247,13 @@ class DataFeed {
}
virtual const paddle::platform::Place& GetPlace() const { return place_; }
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
virtual void PackToScope(MiniBatchGpuPack* pack, const Scope* scope) {
PADDLE_THROW(platform::errors::Unimplemented(
"This function(PackToScope) is not implemented."));
}
#endif
virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) {
PADDLE_THROW(platform::errors::Unimplemented(
"This function(DumpWalkPath) is not implemented."));
......@@ -1766,13 +1808,7 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
public:
SlotRecordInMemoryDataFeed() {}
virtual ~SlotRecordInMemoryDataFeed() {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
if (pack_ != nullptr) {
pack_ = nullptr;
}
#endif
}
virtual ~SlotRecordInMemoryDataFeed();
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual void LoadIntoMemory();
void ExpandSlotRecord(SlotRecord* ins);
......@@ -1797,7 +1833,11 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
virtual void AssignFeedVar(const Scope& scope);
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void BuildSlotBatchGPU(const int ins_num);
void BuildSlotBatchGPU(const int ins_num, MiniBatchGpuPack* pack);
virtual void PackToScope(MiniBatchGpuPack* pack,
const Scope* scope = nullptr);
void FillSlotValueOffset(const int ins_num,
const int used_slot_num,
size_t* slot_value_offsets,
......@@ -1805,7 +1845,8 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
const int uint64_slot_size,
const int* float_offsets,
const int float_slot_size,
const UsedSlotGpuType* used_slots);
const UsedSlotGpuType* used_slots,
cudaStream_t stream);
void CopyForTensor(const int ins_num,
const int used_slot_num,
void** dest,
......@@ -1818,7 +1859,8 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
const int* float_offsets,
const int* float_ins_lens,
const int float_slot_size,
const UsedSlotGpuType* used_slots);
const UsedSlotGpuType* used_slots,
cudaStream_t stream);
#endif
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
......@@ -1838,7 +1880,20 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
std::vector<int> float_total_dims_without_inductives_;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
MiniBatchGpuPack* pack_ = nullptr;
int pack_thread_num_{5};
std::vector<std::thread> pack_threads_;
std::vector<MiniBatchGpuPack*> pack_vec_;
BlockingQueue<MiniBatchGpuPack*> free_pack_queue_;
BlockingQueue<MiniBatchGpuPack*> using_pack_queue_;
std::atomic<bool> pack_is_end_{false};
std::atomic<uint64_t> pack_offset_index_{0};
MiniBatchGpuPack* last_pack_{nullptr};
std::atomic<bool> stop_token_{false};
std::atomic<int> thread_count_{0};
std::mutex pack_mutex_;
// async infershape
std::map<const Scope*, std::vector<phi::DenseTensor*>> scpoe_feed_vec_;
#endif
};
......
......@@ -99,11 +99,13 @@ class CommonFeatureValueAccessor {
// 根据mf_dim计算的总长度
__host__ __device__ int Dim(int mf_dim) {
int tmp_embedx_sgd_dim = 1;
int tmp_embedx_sgd_dim = 1; // shared adagrad
if (optimizer_type_ == 3) { // adam
tmp_embedx_sgd_dim = mf_dim * 2 + 2;
} else if (optimizer_type_ == 4) { // shared_adam
tmp_embedx_sgd_dim = 4;
} else if (optimizer_type_ == 2) {
tmp_embedx_sgd_dim = mf_dim;
}
return 9 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim;
}
......@@ -115,11 +117,13 @@ class CommonFeatureValueAccessor {
// 根据mf_dim 计算的 mf_size byte数
__host__ __device__ size_t MFSize(int mf_dim) {
int tmp_embedx_sgd_dim = 1;
int tmp_embedx_sgd_dim = 1; // shared adagrad
if (optimizer_type_ == 3) { // adam
tmp_embedx_sgd_dim = mf_dim * 2 + 2;
} else if (optimizer_type_ == 4) { // shared_adam
tmp_embedx_sgd_dim = 4;
} else if (optimizer_type_ = 2) { // std adagrad
tmp_embedx_sgd_dim = mf_dim;
}
return (tmp_embedx_sgd_dim + mf_dim) * sizeof(float);
}
......@@ -127,12 +131,14 @@ class CommonFeatureValueAccessor {
__host__ __device__ int EmbedxG2SumOffsetIndex() { return 0; }
__host__ __device__ int EmbedxWOffsetIndex(float* val) {
// has mf
int tmp_embedx_sgd_dim = 1;
int tmp_embedx_sgd_dim = 1; // shared adagrad
if (static_cast<int>(MfSize(val)) > 0) {
if (optimizer_type_ == 3) { // adam
tmp_embedx_sgd_dim = MfDim(val) * 2 + 2;
} else if (optimizer_type_ == 4) { // shared_adam
tmp_embedx_sgd_dim = 4;
} else if (optimizer_type_ == 2) { // std adagrad
tmp_embedx_sgd_dim = static_cast<int>(MfDim(val));
}
return EmbedxG2SumIndex() + tmp_embedx_sgd_dim;
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册