未验证 提交 c6e0cedc 编写于 作者: Y yaoxuefeng 提交者: GitHub

support multi-node (#35396)

上级 8307b0cb
......@@ -354,10 +354,10 @@ cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor)
if(WITH_PSCORE)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
conditional_block_op executor ${RPC_DEPS})
conditional_block_op executor gloo_wrapper ${RPC_DEPS})
else()
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
conditional_block_op executor)
conditional_block_op executor gloo_wrapper)
endif()
cc_library(prune SRCS prune.cc DEPS framework_proto boost)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -257,6 +257,11 @@ bool InMemoryDataFeed<T>::Start() {
output_channel_->Write(std::move(data));
}
#endif
if (batch_offsets_.size() > 0) {
VLOG(3) << "batch_size offsets: " << batch_offsets_.size();
enable_heterps_ = true;
this->offset_index_ = 0;
}
this->finish_start_ = true;
return true;
}
......@@ -265,6 +270,7 @@ template <typename T>
int InMemoryDataFeed<T>::Next() {
#ifdef _LINUX
this->CheckStart();
if (!enable_heterps_) {
CHECK(output_channel_ != nullptr);
CHECK(consume_channel_ != nullptr);
VLOG(3) << "output_channel_ size=" << output_channel_->Size()
......@@ -294,6 +300,35 @@ int InMemoryDataFeed<T>::Next() {
<< ", consume_channel_ size=" << consume_channel_->Size()
<< ", thread_id=" << thread_id_;
}
} else {
VLOG(3) << "enable heter NEXT: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
if (offset_index_ >= batch_offsets_.size()) {
VLOG(3) << "offset_index: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
return 0;
}
auto& batch = batch_offsets_[offset_index_++];
this->batch_size_ = batch.second;
VLOG(3) << "batch_size_=" << this->batch_size_
<< ", thread_id=" << thread_id_;
if (this->batch_size_ != 0) {
PutToFeedVec(&records_[batch.first], this->batch_size_);
} else {
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
<< thread_id_;
}
/*
if (offset_index_ == batch_offsets_.size() - 1) {
std::vector<Record> data;
output_channel_->ReadAll(data);
consume_channel_->Write(std::move(data));
}
*/
VLOG(3) << "#15 enable heter NEXT: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size()
<< " baych_size: " << this->batch_size_;
}
return this->batch_size_;
#else
return 0;
......@@ -1141,6 +1176,103 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
return false;
}
void MultiSlotInMemoryDataFeed::PutToFeedVec(const Record* ins_vec, int num) {
#ifdef _LINUX
for (size_t i = 0; i < batch_float_feasigns_.size(); ++i) {
batch_float_feasigns_[i].clear();
batch_uint64_feasigns_[i].clear();
offset_[i].clear();
offset_[i].push_back(0);
}
ins_content_vec_.clear();
ins_content_vec_.reserve(num);
ins_id_vec_.clear();
ins_id_vec_.reserve(num);
for (int i = 0; i < num; ++i) {
auto& r = ins_vec[i];
ins_id_vec_.push_back(r.ins_id_);
ins_content_vec_.push_back(r.content_);
for (auto& item : r.float_feasigns_) {
batch_float_feasigns_[item.slot()].push_back(item.sign().float_feasign_);
visit_[item.slot()] = true;
}
for (auto& item : r.uint64_feasigns_) {
batch_uint64_feasigns_[item.slot()].push_back(
item.sign().uint64_feasign_);
visit_[item.slot()] = true;
}
for (size_t j = 0; j < use_slots_.size(); ++j) {
const auto& type = all_slots_type_[j];
if (visit_[j]) {
visit_[j] = false;
} else {
// fill slot value with default value 0
if (type[0] == 'f') { // float
batch_float_feasigns_[j].push_back(0.0);
} else if (type[0] == 'u') { // uint64
batch_uint64_feasigns_[j].push_back(0);
}
}
// get offset of this ins in this slot
if (type[0] == 'f') { // float
offset_[j].push_back(batch_float_feasigns_[j].size());
} else if (type[0] == 'u') { // uint64
offset_[j].push_back(batch_uint64_feasigns_[j].size());
}
}
}
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
int total_instance = offset_[i].back();
const auto& type = all_slots_type_[i];
if (type[0] == 'f') { // float
float* feasign = batch_float_feasigns_[i].data();
float* tensor_ptr =
feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
uint64_t* feasign = batch_uint64_feasigns_[i].data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
}
auto& slot_offset = offset_[i];
if (this->input_type_ == 0) {
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
} else if (this->input_type_ == 1) {
if (!use_slots_is_dense_[i]) {
std::vector<size_t> tmp_offset;
PADDLE_ENFORCE_EQ(slot_offset.size(), 2,
platform::errors::InvalidArgument(
"In batch reader, the sparse tensor lod size "
"must be 2, but received %d.",
slot_offset.size()));
const auto& max_size = slot_offset[1];
tmp_offset.reserve(max_size + 1);
for (unsigned int k = 0; k <= max_size; k++) {
tmp_offset.emplace_back(k);
}
slot_offset = tmp_offset;
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
}
}
if (use_slots_is_dense_[i]) {
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
#endif
}
void MultiSlotInMemoryDataFeed::PutToFeedVec(
const std::vector<Record>& ins_vec) {
#ifdef _LINUX
......
......@@ -167,7 +167,7 @@ class DLManager {
}
paddle::framework::CustomParser* Load(const std::string& name,
std::vector<SlotConf>& conf) {
const std::vector<SlotConf>& conf) {
#ifdef _LINUX
std::lock_guard<std::mutex> lock(mutex_);
DLHandle handle;
......@@ -195,7 +195,7 @@ class DLManager {
}
paddle::framework::CustomParser* ReLoad(const std::string& name,
std::vector<SlotConf>& conf) {
const std::vector<SlotConf>& conf) {
Close(name);
return Load(name, conf);
}
......@@ -422,6 +422,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void ParseOneInstanceFromSo(const char* str, T* instance,
CustomParser* parser) {}
virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;
virtual void PutToFeedVec(const T* ins_vec, int num) = 0;
int thread_id_;
int thread_num_;
......@@ -439,6 +440,11 @@ class InMemoryDataFeed : public DataFeed {
paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;
std::vector<std::pair<int, int>> batch_offsets_;
uint64_t offset_index_ = 0;
bool enable_heterps_ = false;
T* records_ = nullptr;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
......@@ -601,7 +607,7 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
for (size_t& x : offset) {
uint64_t t;
ar >> t;
x = (size_t)t;
x = static_cast<size_t>(t);
}
#endif
ar >> ins.MutableFloatData();
......@@ -777,6 +783,11 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc);
void SetRecord(Record* records) { records_ = records; }
int GetDefaultBatchSize() { return default_batch_size_; }
void AddBatchOffset(const std::pair<int, int>& offset) {
batch_offsets_.push_back(offset);
}
protected:
virtual bool ParseOneInstance(Record* instance);
......@@ -786,6 +797,7 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
uint32_t* cmatch, uint32_t* rank);
virtual void PutToFeedVec(const Record* ins_vec, int num);
std::vector<std::vector<float>> batch_float_feasigns_;
std::vector<std::vector<uint64_t>> batch_uint64_feasigns_;
std::vector<std::vector<size_t>> offset_;
......
......@@ -216,6 +216,180 @@ void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
});
VLOG(3) << "RegisterClientToClientMsgHandler done";
}
static void compute_left_batch_num(const int ins_num, const int thread_num,
std::vector<std::pair<int, int>>* offset,
const int start_pos) {
int cur_pos = start_pos;
int batch_size = ins_num / thread_num;
int left_num = ins_num % thread_num;
for (int i = 0; i < thread_num; ++i) {
int batch_num_size = batch_size;
if (i == 0) {
batch_num_size = batch_num_size + left_num;
}
offset->push_back(std::make_pair(cur_pos, batch_num_size));
cur_pos += batch_num_size;
}
}
static void compute_batch_num(const int64_t ins_num, const int batch_size,
const int thread_num,
std::vector<std::pair<int, int>>* offset) {
int thread_batch_num = batch_size * thread_num;
// less data
if (static_cast<int64_t>(thread_batch_num) > ins_num) {
compute_left_batch_num(ins_num, thread_num, offset, 0);
return;
}
int cur_pos = 0;
int offset_num = static_cast<int>(ins_num / thread_batch_num) * thread_num;
int left_ins_num = static_cast<int>(ins_num % thread_batch_num);
if (left_ins_num > 0 && left_ins_num < thread_num) {
offset_num = offset_num - thread_num;
left_ins_num = left_ins_num + thread_batch_num;
for (int i = 0; i < offset_num; ++i) {
offset->push_back(std::make_pair(cur_pos, batch_size));
cur_pos += batch_size;
}
// split data to thread avg two rounds
compute_left_batch_num(left_ins_num, thread_num * 2, offset, cur_pos);
} else {
for (int i = 0; i < offset_num; ++i) {
offset->push_back(std::make_pair(cur_pos, batch_size));
cur_pos += batch_size;
}
if (left_ins_num > 0) {
compute_left_batch_num(left_ins_num, thread_num, offset, cur_pos);
}
}
}
static int compute_thread_batch_nccl(
const int thr_num, const int64_t total_instance_num,
const int minibatch_size, std::vector<std::pair<int, int>>* nccl_offsets) {
int thread_avg_batch_num = 0;
if (total_instance_num < static_cast<int64_t>(thr_num)) {
LOG(WARNING) << "compute_thread_batch_nccl total ins num:["
<< total_instance_num << "], less thread num:[" << thr_num
<< "]";
return thread_avg_batch_num;
}
auto& offset = (*nccl_offsets);
// split data avg by thread num
compute_batch_num(total_instance_num, minibatch_size, thr_num, &offset);
thread_avg_batch_num = static_cast<int>(offset.size() / thr_num);
#ifdef PADDLE_WITH_GLOO
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (!gloo_wrapper->IsInitialized()) {
VLOG(0) << "GLOO is not inited";
gloo_wrapper->Init();
}
if (gloo_wrapper->Size() > 1) {
// adjust batch num per thread for NCCL
std::vector<int> thread_avg_batch_num_vec(1, thread_avg_batch_num);
std::vector<int64_t> total_instance_num_vec(1, total_instance_num);
auto thread_max_batch_num_vec =
gloo_wrapper->AllReduce(thread_avg_batch_num_vec, "max");
auto sum_total_ins_num_vec =
gloo_wrapper->AllReduce(total_instance_num_vec, "sum");
int thread_max_batch_num = thread_max_batch_num_vec[0];
int64_t sum_total_ins_num = sum_total_ins_num_vec[0];
int diff_batch_num = thread_max_batch_num - thread_avg_batch_num;
VLOG(3) << "diff batch num: " << diff_batch_num
<< " thread max batch num: " << thread_max_batch_num
<< " thread avg batch num: " << thread_avg_batch_num;
if (diff_batch_num == 0) {
LOG(WARNING) << "total sum ins " << sum_total_ins_num << ", thread_num "
<< thr_num << ", ins num " << total_instance_num
<< ", batch num " << offset.size()
<< ", thread avg batch num " << thread_avg_batch_num;
return thread_avg_batch_num;
}
int need_ins_num = thread_max_batch_num * thr_num;
// data is too less
if ((int64_t)need_ins_num > total_instance_num) {
PADDLE_THROW(platform::errors::InvalidArgument(
"error instance num:[%d] less need ins num:[%d]", total_instance_num,
need_ins_num));
return thread_avg_batch_num;
}
int need_batch_num = (diff_batch_num + 1) * thr_num;
int offset_split_index = static_cast<int>(offset.size() - thr_num);
int split_left_num = total_instance_num - offset[offset_split_index].first;
while (split_left_num < need_batch_num) {
need_batch_num += thr_num;
offset_split_index -= thr_num;
split_left_num = total_instance_num - offset[offset_split_index].first;
}
int split_start = offset[offset_split_index].first;
offset.resize(offset_split_index);
compute_left_batch_num(split_left_num, need_batch_num, &offset,
split_start);
LOG(WARNING) << "total sum ins " << sum_total_ins_num << ", thread_num "
<< thr_num << ", ins num " << total_instance_num
<< ", batch num " << offset.size() << ", thread avg batch num "
<< thread_avg_batch_num << ", thread max batch num "
<< thread_max_batch_num
<< ", need batch num: " << (need_batch_num / thr_num)
<< "split begin (" << split_start << ")" << split_start
<< ", num " << split_left_num;
thread_avg_batch_num = thread_max_batch_num;
} else {
LOG(WARNING) << "thread_num " << thr_num << ", ins num "
<< total_instance_num << ", batch num " << offset.size()
<< ", thread avg batch num " << thread_avg_batch_num;
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"dataset compute nccl batch number need compile with GLOO"));
#endif
return thread_avg_batch_num;
}
template <typename T>
void DatasetImpl<T>::SetHeterPs(bool enable_heterps) {
#ifdef PADDLE_WITH_GLOO
enable_heterps_ = enable_heterps;
if (enable_heterps_) {
if (input_records_.size() == 0 && input_channel_ != nullptr &&
input_channel_->Size() != 0) {
input_channel_->ReadAll(input_records_);
VLOG(3) << "read from channel to records with records size: "
<< input_records_.size();
}
VLOG(3) << "input records size: " << input_records_.size();
int64_t total_ins_num = input_records_.size();
std::vector<std::pair<int, int>> offset;
int default_batch_size =
reinterpret_cast<MultiSlotInMemoryDataFeed*>(readers_[0].get())
->GetDefaultBatchSize();
VLOG(3) << "thread_num: " << thread_num_
<< " memory size: " << total_ins_num
<< " default batch_size: " << default_batch_size;
compute_thread_batch_nccl(thread_num_, total_ins_num, default_batch_size,
&offset);
VLOG(3) << "offset size: " << offset.size();
for (int i = 0; i < thread_num_; i++) {
reinterpret_cast<MultiSlotInMemoryDataFeed*>(readers_[i].get())
->SetRecord(&input_records_[0]);
}
for (size_t i = 0; i < offset.size(); i++) {
reinterpret_cast<MultiSlotInMemoryDataFeed*>(
readers_[i % thread_num_].get())
->AddBatchOffset(offset[i]);
}
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"dataset set heterps need compile with GLOO"));
#endif
return;
}
// load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel
......@@ -319,6 +493,13 @@ void DatasetImpl<T>::ReleaseMemory() {
multi_pv_consume_[i]->Clear();
multi_pv_consume_[i] = nullptr;
}
if (enable_heterps_) {
input_records_.clear();
input_records_.shrink_to_fit();
std::vector<T>().swap(input_records_);
VLOG(3) << "release heterps input records records size: "
<< input_records_.size();
}
std::vector<paddle::framework::Channel<PvInstance>>().swap(multi_pv_consume_);
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
......@@ -654,6 +835,9 @@ void DatasetImpl<T>::CreateReaders() {
channel_idx = 0;
}
}
if (enable_heterps_) {
SetHeterPs(true);
}
VLOG(3) << "readers size: " << readers_.size();
}
......
......@@ -24,6 +24,10 @@
#include <unordered_set>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/fluid/framework/data_feed.h"
......@@ -145,6 +149,7 @@ class Dataset {
virtual void DynamicAdjustReadersNum(int thread_num) = 0;
// set fleet send sleep seconds
virtual void SetFleetSendSleepSeconds(int seconds) = 0;
virtual void SetHeterPs(bool enable_heterps) = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
......@@ -228,6 +233,7 @@ class DatasetImpl : public Dataset {
bool discard_remaining_ins = false);
virtual void DynamicAdjustReadersNum(int thread_num);
virtual void SetFleetSendSleepSeconds(int seconds);
virtual void SetHeterPs(bool enable_heterps);
std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
return multi_output_channel_;
......@@ -292,6 +298,7 @@ class DatasetImpl : public Dataset {
int64_t global_index_ = 0;
std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
std::vector<T> input_records_; // only for paddleboxdatafeed
bool enable_heterps_ = false;
};
// use std::vector<MultiSlotType> or Record as data type
......
......@@ -40,6 +40,7 @@ static std::unordered_set<std::string> kMultiDeviceOps{
"c_broadcast",
"c_comm_init",
"c_comm_init_all",
"c_comm_init_multitrainer",
"c_gen_nccl_id",
"c_sync_comm_stream",
"send",
......
......@@ -14,7 +14,9 @@
#include "gtest/gtest.h"
#include "paddle/fluid/framework/trainer.h"
#ifdef PADDLE_WITH_GLOO
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
......
......@@ -12,15 +12,15 @@ endif(WITH_PSLIB)
if(WITH_HETERPS)
if(WITH_NCCL)
nv_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc
DEPS heter_ps ${BRPC_DEPS})
DEPS heter_ps gloo_wrapper ${BRPC_DEPS})
add_subdirectory(heter_ps)
elseif(WITH_RCCL)
hip_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc
DEPS heter_ps ${BRPC_DEPS})
DEPS heter_ps gloo_wrapper ${BRPC_DEPS})
add_subdirectory(heter_ps)
endif(WITH_NCCL)
else()
cc_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cc)
cc_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cc DEPS gloo_wrapper)
endif(WITH_HETERPS)
if(WITH_NCCL OR WITH_RCCL)
......
......@@ -123,7 +123,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
timeline.Pause();
VLOG(1) << "GpuPs task unique11111 cost " << timeline.ElapsedSec()
VLOG(1) << "GpuPs task add keys cost " << timeline.ElapsedSec()
<< " seconds.";
timeline.Start();
gpu_task->UniqueKeys();
......@@ -138,19 +138,74 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
timeline.Start();
auto ptl_func = [this, &local_keys, &local_ptr, &fleet_ptr](int i) {
size_t key_size = local_keys[i].size();
int32_t status = -1;
#ifdef PADDLE_WITH_PSLIB
// auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
// reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
// local_keys[i].data(), key_size);
int32_t cnt = 0;
while (true) {
auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size);
bool flag = true;
tt.wait();
try {
status = tt.get();
} catch (const std::future_error& e) {
VLOG(0) << "Caught a future_error with code" << e.code()
<< ", Message:" << e.what();
}
if (status != 0) {
VLOG(0) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
flag = false;
cnt++;
}
if (cnt > 3) {
VLOG(0) << "fleet pull sparse failed, retry 3 times";
exit(-1);
}
if (flag) {
break;
}
}
#endif
#ifdef PADDLE_WITH_PSCORE
int32_t cnt = 0;
while (true) {
auto tt = fleet_ptr->_worker_ptr->pull_sparse_ptr(
reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size);
#endif
bool flag = true;
tt.wait();
auto status = tt.get();
// auto status = 0;
try {
status = tt.get();
} catch (const std::future_error& e) {
VLOG(0) << "Caught a future_error with code" << e.code()
<< ", Message:" << e.what();
}
if (status != 0) {
VLOG(0) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
flag = false;
cnt++;
}
if (cnt > 3) {
VLOG(0) << "fleet pull sparse failed, retry 3 times";
exit(-1);
}
if (flag) {
break;
}
}
#endif
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
sleep(300);
......@@ -169,10 +224,27 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
timeline.Pause();
VLOG(1) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec()
<< " seconds.";
if (multi_node_) {
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (!gloo_wrapper->IsInitialized()) {
VLOG(0) << "GLOO is not inited";
gloo_wrapper->Init();
}
gloo_wrapper->Barrier();
}
timeline.Start();
auto build_func = [device_num, &local_keys, &local_ptr, &device_keys,
&device_vals, &device_mutex](int i) {
std::vector<std::vector<std::pair<uint64_t, char*>>> pass_values;
uint16_t pass_id = 0;
bool record_status = false;
if (multi_node_) {
record_status = fleet_ptr->pslib_ptr_->_worker_ptr->take_sparse_record(
table_id_, pass_id, pass_values);
}
auto build_func = [device_num, record_status, &pass_values, &local_keys,
&local_ptr, &device_keys, &device_vals,
&device_mutex](int i) {
std::vector<std::vector<FeatureKey>> task_keys(device_num);
#ifdef PADDLE_WITH_PSLIB
std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>> task_ptrs(
......@@ -188,7 +260,21 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
task_keys[shard].push_back(local_keys[i][j]);
task_ptrs[shard].push_back(local_ptr[i][j]);
}
if (record_status) {
size_t local_keys_size = local_keys.size();
size_t pass_values_size = pass_values.size();
for (size_t j = 0; j < pass_values_size; j += local_keys_size) {
auto& shard_values = pass_values[j];
for (size_t pair_idx = 0; pair_idx < pass_values[j].size();
pair_idx++) {
auto& cur_pair = shard_values[pair_idx];
int shard = cur_pair.first % device_num;
task_keys[shard].push_back(cur_pair.first);
task_ptrs[shard].push_back(
(paddle::ps::DownpourFixedFeatureValue*)cur_pair.second);
}
}
}
for (int dev = 0; dev < device_num; dev++) {
device_mutex[dev]->lock();
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_NCCL)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
// #include "paddle/fluid/operators/distributed/distributed.h"
// #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CCommInitMultiTrainerInferShape : public framework::InferShapeBase {
public:
~CCommInitMultiTrainerInferShape() {}
void operator()(framework::InferShapeContext* ctx) const override{};
};
class CCommInitMultiTrainerOp : public framework::OperatorBase {
public:
CCommInitMultiTrainerOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input X must be provided."));
#if defined(PADDLE_WITH_NCCL)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int ntrainers = Attr<int>("ntrainers");
int train_id = Attr<int>("trainer_id");
int rid = Attr<int>("ring_id");
std::vector<int> devices = Attr<std::vector<int>>("devices");
if (devices.empty()) {
devices = platform::GetSelectedDevices();
}
platform::NCCLCommContext::Instance().CreateNCCLCommMultiTrainer(
devices, nccl_id, ntrainers, train_id, rid);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"PaddlePaddle should compile with GPU."));
#endif
}
};
class CCommInitMultiTrainerOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
CCommInitMultiTrainer operator
Initialize collective communicatoin context within this trainer
)DOC");
AddAttr<int>("ntrainers",
"(int) The number of trainers of distributed trainers");
AddAttr<int>("trainer_id",
"(int) The id of the trainer in distributed training.");
AddAttr<std::vector<int>>("devices",
"(std::vector<int>) which devices does the nccl "
"comm initialized on in each trainer")
.SetDefault({});
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_comm_init_multitrainer, ops::CCommInitMultiTrainerOp,
ops::CCommInitMultiTrainerInferShape,
ops::CCommInitMultiTrainerOpMaker);
......@@ -140,6 +140,50 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
});
}
void NCCLCommContext::CreateNCCLCommMultiTrainer(
const std::vector<int>& dev_ids, ncclUniqueId* nccl_id, int ntrainers,
int train_id, int ring_id) {
PADDLE_ENFORCE_GT(
dev_ids.size(), 0,
paddle::platform::errors::InvalidArgument(
"dev ids = [%d], it should greater than 0.", dev_ids.size()));
const int kDevices = dev_ids.size();
VLOG(3) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices
<< ", ntrainers: " << ntrainers << ", train_id: " << train_id
<< ", rind_id: " << ring_id;
ncclComm_t comms[kDevices];
{
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupStart());
for (int i = 0; i < kDevices; i++) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipSetDevice(i));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(i));
#endif
platform::dynload::ncclCommInitRank(comms + i, kDevices * ntrainers,
*nccl_id, train_id * kDevices + i);
VLOG(3) << "ncclCommInitRank: " << i;
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd());
VLOG(3) << "nccl group end seccessss";
}
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
"comm_map_ of ring_id: %s should be 0. %s is provided",
ring_id, comm_map_.count(ring_id)));
for (int i = 0; i < kDevices; ++i) {
AssignNCCLComm(comms[i], kDevices * ntrainers, train_id * kDevices + i,
dev_ids[i], ring_id);
VLOG(3) << "nccl communicator of train_id " << train_id * kDevices + i
<< " in ring " << ring_id << " has been created on device "
<< dev_ids[i];
}
std::call_once(once_flag_, []() {
std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
});
}
NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
int dev_id, int ring_id) {
std::unique_ptr<CUDADeviceContext> dev_ctx(
......
......@@ -77,6 +77,10 @@ class NCCLCommContext {
void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
void CreateNCCLCommMultiTrainer(const std::vector<int>& dev_ids,
ncclUniqueId* nccl_id, int nranks, int rank,
int ring_id);
// a latter comm with the same dev_id and the same ring_id
// will override the former
NCCLComm* AssignNCCLComm(ncclComm_t comm, int nranks, int rank, int dev_id,
......
......@@ -309,6 +309,8 @@ void BindDataset(py::module *m) {
&framework::Dataset::SetFleetSendSleepSeconds,
py::call_guard<py::gil_scoped_release>())
.def("enable_pv_merge", &framework::Dataset::EnablePvMerge,
py::call_guard<py::gil_scoped_release>())
.def("set_heter_ps", &framework::Dataset::SetHeterPs,
py::call_guard<py::gil_scoped_release>());
py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
......
......@@ -985,6 +985,13 @@ class InMemoryDataset(DatasetBase):
return global_data_size[0]
return local_data_size[0]
def _set_heter_ps(self, enable_heter_ps=False):
"""
Set heter ps mode
user no need to call this function.
"""
self.dataset.set_heter_ps(enable_heter_ps)
class QueueDataset(DatasetBase):
"""
......
......@@ -101,10 +101,11 @@ class PSLib(Fleet):
# barrier_all for init_worker
self._role_maker._barrier_all()
# prepare for client to client communication
if not self._opt_info["use_ps_gpu"]:
if self._role_maker.is_worker():
info = self._fleet_ptr.get_clients_info()
print("IIIIFO: {}".format(info))
all_info = self._role_maker._worker_gather(info[0])
print("ALL info: {}".format(all_info))
self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.set_client2client_config(
self._client2client_request_timeout_ms,
......@@ -1120,14 +1121,14 @@ class DownpourOptimizer(DistributedOptimizer):
fleet._main_programs = programs
fleet._scopes = scopes
if opt_info["use_ps_gpu"]:
from paddle.fluid.transpiler.collective import SingleProcessMultiThread
from paddle.fluid.transpiler.collective import MultiThread
# check start program
env = self.get_dist_env()
if not isinstance(losses, list):
startup_programs = [startup_programs]
for i in range(0, len(startup_programs)):
t = SingleProcessMultiThread()
t = MultiThread()
start_program = startup_programs[i]
main_program = programs[i]
t.transpile(
......
......@@ -29,7 +29,7 @@ from .. import core, unique_name
from ..framework import Program, default_main_program, default_startup_program
from .details import wait_server_ready
__all__ = ['GradAllReduce', 'LocalSGD']
__all__ = ['GradAllReduce', 'LocalSGD', 'MultiThread']
OpRole = core.op_proto_and_checker_maker.OpRole
......@@ -97,8 +97,14 @@ class Collective(object):
self.wait_port)
self._broadcast_params()
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
def _init_communicator(self,
program,
current_endpoint,
endpoints,
rank,
ring_id,
wait_port,
has_multitrainer=False):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
......@@ -150,6 +156,7 @@ class Collective(object):
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward
})
if not has_multitrainer:
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
......@@ -160,6 +167,17 @@ class Collective(object):
'ring_id': ring_id,
self.op_role_key: OpRole.Forward
})
else:
block.append_op(
type='c_comm_init_multitrainer',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'ntrainers': nranks,
'trainer_id': rank,
'ring_id': ring_id,
self.op_role_key: OpRole.Forward
})
def _broadcast_params(self):
block = self.startup_program.global_block()
......@@ -425,7 +443,7 @@ class MultiThread(GradAllReduce):
def __init__(self, nrings=1):
GradAllReduce.__init__(self, nrings)
self.mode = "box"
self.mode = "single_process_multi_thread"
def _transpile_startup_program(self):
if len(self.endpoints) > 1:
......@@ -434,9 +452,9 @@ class MultiThread(GradAllReduce):
print("total endpoints: ", self.endpoints)
print("rank: %d, ring_id: %d" % (self.rank, self.nrings))
for ring_id in range(self.nrings):
self._init_communicator(self.startup_program,
self.current_endpoint, self.endpoints,
self.rank, ring_id, self.wait_port)
self._init_communicator(
self.startup_program, self.current_endpoint, self.endpoints,
self.rank, ring_id, self.wait_port, True)
else:
print("begin to _transpile_startup_program for single-node")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册