From a5b1a0e12b673ecc2c67199b4f6520c545f7c91e Mon Sep 17 00:00:00 2001 From: xujiaqi01 Date: Wed, 20 Mar 2019 00:28:11 +0800 Subject: [PATCH] support multi dataset && add init model && fix bug --- paddle/fluid/framework/async_executor.cc | 3 +- paddle/fluid/framework/data_feed.cc | 155 ++++++++++++------ paddle/fluid/framework/data_feed.h | 37 +++-- paddle/fluid/framework/data_set.cc | 82 +++++++-- paddle/fluid/framework/data_set.h | 8 +- paddle/fluid/framework/dist_multi_trainer.cc | 4 +- paddle/fluid/framework/fleet/fleet_wrapper.cc | 97 +++++++++-- paddle/fluid/framework/fleet/fleet_wrapper.h | 16 +- paddle/fluid/framework/multi_trainer.cc | 6 +- paddle/fluid/pybind/async_executor_py.cc | 2 +- paddle/fluid/pybind/data_set_py.cc | 1 + paddle/fluid/pybind/fleet_wrapper_py.cc | 1 + python/paddle/fluid/dataset.py | 17 +- .../fluid/incubate/fleet/base/role_maker.py | 4 +- .../fleet/parameter_server/__init__.py | 21 ++- 15 files changed, 341 insertions(+), 113 deletions(-) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index b13eefba2e..b2423694d0 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -155,7 +155,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, } #ifdef PADDLE_WITH_PSLIB if (mode == "mpi") { - _pull_dense_thread->stop(); + // todo ? + //_pull_dense_thread->stop(); } #endif VLOG(3) << "start to run from files in async_executor"; diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 62f35f205b..62e391a3d2 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -23,15 +23,11 @@ limitations under the License. */ #include "io/shell.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/platform/timer.h" namespace paddle { namespace framework { -std::vector DataFeed::filelist_; -size_t DataFeed::file_idx_; -std::mutex DataFeed::mutex_for_pick_file_; -bool DataFeed::finish_set_filelist_; - void DataFeed::AddFeedVar(Variable* var, const std::string& name) { CheckInit(); for (size_t i = 0; i < use_slots_.size(); ++i) { @@ -42,7 +38,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) { } bool DataFeed::SetFileList(const std::vector& files) { - std::unique_lock lock(mutex_for_pick_file_); + std::unique_lock lock(*mutex_for_pick_file_); CheckInit(); // Do not set finish_set_filelist_ flag, // since a user may set file many times after init reader @@ -52,9 +48,8 @@ bool DataFeed::SetFileList(const std::vector& files) { return false; } */ - PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); + //PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); filelist_.assign(files.begin(), files.end()); - file_idx_ = 0; finish_set_filelist_ = true; return true; @@ -66,13 +61,17 @@ void DataFeed::SetBatchSize(int batch_size) { } bool DataFeed::PickOneFile(std::string* filename) { - std::unique_lock lock(mutex_for_pick_file_); - if (file_idx_ == filelist_.size()) { + PADDLE_ENFORCE(mutex_for_pick_file_ != nullptr, + "should call SetFileListMutex before PickOneFile"); + PADDLE_ENFORCE(file_idx_ != nullptr, + "should call SetFileListIndex before PickOneFile"); + std::unique_lock lock(*mutex_for_pick_file_); + if (*file_idx_ == filelist_.size()) { VLOG(3) << "DataFeed::PickOneFile no more file to pick"; return false; } - VLOG(3) << "file_idx_=" << file_idx_; - *filename = filelist_[file_idx_++]; + VLOG(3) << "file_idx_=" << *file_idx_; + *filename = filelist_[(*file_idx_)++]; // LOG(ERROR) << "pick file:" << *filename; return true; } @@ -150,7 +149,11 @@ InMemoryDataFeed::InMemoryDataFeed() { cur_channel_ = 0; shuffled_ins_ = std::make_shared>(); shuffled_ins_out_ = std::make_shared>(); - fleet_send_batch_size_ = 10000; + fleet_send_batch_size_ = 80000; + memory_data_ = nullptr; + mutex_for_update_memory_data_ = nullptr; + this->file_idx_ = nullptr; + this->mutex_for_pick_file_ = nullptr; } template @@ -192,6 +195,8 @@ int InMemoryDataFeed::Next() { out_channel->Push(std::move(instance)); } DataFeed::batch_size_ = index; + VLOG(3) << "batch_size_=" << DataFeed::batch_size_ + << ", thread_id=" << thread_id_; if (DataFeed::batch_size_ != 0) { PutToFeedVec(ins_vec); } else { @@ -227,25 +232,22 @@ void InMemoryDataFeed::SetTrainerNum(int trainer_num) { template void InMemoryDataFeed::PutInsToChannel(const std::string& ins_str) { - T ins; + std::vector ins; DeserializeIns(&ins, ins_str); - shuffled_ins_->Push(std::move(ins)); + shuffled_ins_->Extend(std::move(ins)); + VLOG(3) << "PutInsToChannel put ins num=" << ins.size() + << " to channel, channel size=" << shuffled_ins_->Size() + << " thread_id=" << thread_id_; } template void InMemoryDataFeed::FillMemoryDataToChannel() { VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_; - int64_t start = 0; - int64_t end = 0; - int64_t size = memory_data_->size(); - VLOG(3) << "memory_data size=" << size; - for (int64_t i = 0; i <= static_cast(thread_id_); ++i) { - int64_t len = size / static_cast(thread_num_) + - (i < (size % static_cast(thread_num_))); - start = end; - end += len; - } - for (int64_t i = start; i < end; ++i) { + auto interval = GetMemoryDataInterval(); + VLOG(3) << "memory data size=" << memory_data_->size() + << ", fill data from [" << interval.first << ", " + << interval.second << "), thread_id=" << thread_id_; + for (int64_t i = interval.first; i < interval.second; ++i) { T& t = (*memory_data_)[i]; shuffled_ins_->Push(std::move(t)); } @@ -256,14 +258,19 @@ void InMemoryDataFeed::FillChannelToMemoryData() { VLOG(3) << "FillChannelToMemoryData, thread_id=" << thread_id_; std::vector local_vec; std::shared_ptr> channel = nullptr; + std::shared_ptr> pre_channel = nullptr; if (cur_channel_ == 0) { channel = shuffled_ins_; + pre_channel = shuffled_ins_out_; } else { channel = shuffled_ins_out_; + pre_channel = shuffled_ins_; } CHECK(channel != nullptr); + CHECK(pre_channel != nullptr); + CHECK(pre_channel->Size() == 0); local_vec.resize(channel->Size()); - for (int64_t i = 0; i < channel->Size(); ++i) { + for (int64_t i = 0; i < local_vec.size(); ++i) { channel->Pop(local_vec[i]); } VLOG(3) << "local_vec size=" << local_vec.size() <<", thread_id=" << thread_id_; @@ -289,20 +296,32 @@ void InMemoryDataFeed::LoadIntoMemory() { int err_no = 0; PrivateQueueDataFeed::fp_ = fs_open_read(filename, &err_no, PrivateQueueDataFeed::pipe_command_); + CHECK(PrivateQueueDataFeed::fp_ != nullptr); __fsetlocking(&*PrivateQueueDataFeed::fp_, FSETLOCKING_BYCALLER); T instance; + platform::Timer timeline; + timeline.Start(); while (ParseOneInstanceFromPipe(&instance)) { local_vec.push_back(instance); } + timeline.Pause(); VLOG(3) << "LoadIntoMemory() read all lines, file=" - << filename <<", thread_id=" << thread_id_; + << filename << ", cost time=" << timeline.ElapsedSec() + << " seconds, thread_id=" << thread_id_; { std::lock_guard lock(*mutex_for_update_memory_data_); + timeline.Start(); memory_data_->insert(memory_data_->end(), - local_vec.begin(), local_vec.end()); + std::make_move_iterator(local_vec.begin()), + std::make_move_iterator(local_vec.end())); + timeline.Pause(); + VLOG(3) << "LoadIntoMemory() memory_data insert, cost time=" + << timeline.ElapsedSec() << " seconds, thread_id=" + << thread_id_; } - std::vector().swap(local_vec); + local_vec.clear(); } + std::vector().swap(local_vec); VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_; } @@ -315,30 +334,66 @@ void InMemoryDataFeed::LocalShuffle() { template void InMemoryDataFeed::GlobalShuffle() { - VLOG(3) << "GlobalShuffle(), thread_id=" << thread_id_; + VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_; auto fleet_ptr = FleetWrapper::GetInstance(); - std::vector send_str_vec(trainer_num_); - for (int64_t i = 0; i < memory_data_->size(); ++i) { - // todo get ins id + std::vector> send_vec(trainer_num_); + for (auto& vec : send_vec) { + vec.reserve(fleet_send_batch_size_); + } + std::vector> total_status; + auto interval = GetMemoryDataInterval(); + VLOG(3) << "global shuffle data from [" << interval.first << ", " + << interval.second << "), thread_id=" << thread_id_; + for (int64_t i = interval.first; i < interval.second; ++i) { + // if get ins id, can also use hash // std::string ins_id = memory_data_[i].ins_id; - // todo hash int64_t random_num = fleet_ptr->LocalRandomEngine()(); int64_t node_id = random_num % trainer_num_; - std::string str; - SerializeIns((*memory_data_)[i], &str); - send_str_vec[node_id] += str; + send_vec[node_id].push_back(&((*memory_data_)[i])); if (i % fleet_send_batch_size_ == 0 && i != 0) { - for (int j = 0; j < send_str_vec.size(); ++j) { - fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]); - send_str_vec[j] = ""; + for (int j = 0; j < send_vec.size(); ++j) { + std::string send_str; + SerializeIns(send_vec[j], &send_str); + VLOG(3) << "send str_length=" << send_str.length() + << ", ins num=" << send_vec[j].size() << " to node_id=" + << j << ", thread_id=" << thread_id_; + auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str); + VLOG(3) << "end send, thread_id=" << thread_id_; + send_vec[j].clear(); + total_status.push_back(std::move(ret)); } } } - for (int j = 0; j < send_str_vec.size(); ++j) { - if (send_str_vec[j].length() != 0) { - fleet_ptr->SendClientToClientMsg(0, j, send_str_vec[j]); + for (int j = 0; j < send_vec.size(); ++j) { + if (send_vec[j].size() != 0) { + std::string send_str; + SerializeIns(send_vec[j], &send_str); + VLOG(3) << "send str_length=" << send_str.length() + << " to node_id=" << j << ", thread_id=" << thread_id_; + auto ret = fleet_ptr->SendClientToClientMsg(0, j, send_str); + VLOG(3) << "end send, thread_id=" << thread_id_; + total_status.push_back(std::move(ret)); } + std::vector().swap(send_vec[j]); + } + for (auto& t : total_status) { + t.wait(); } + VLOG(3) << "GlobalShuffle() end, thread_id=" << thread_id_; +} + +template +std::pair InMemoryDataFeed::GetMemoryDataInterval() { + int64_t start = 0; + int64_t end = 0; + int64_t size = memory_data_->size(); + for (int64_t i = 0; i <= static_cast(thread_id_); ++i) { + int64_t len = size / static_cast(thread_num_) + + (i < (size % static_cast(thread_num_))); + start = end; + end += len; + } + return std::make_pair(start, end); } // explicit instantiation @@ -519,7 +574,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( const char* str = reader.get(); std::string line = std::string(str); - VLOG(3) << line; + //VLOG(3) << line; char* endptr = const_cast(str); int pos = 0; for (size_t i = 0; i < use_slots_index_.size(); ++i) { @@ -695,7 +750,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe( const char* str = reader.get(); std::string line = std::string(str); - VLOG(3) << line; + //VLOG(3) << line; char* endptr = const_cast(str); int pos = 0; for (size_t i = 0; i < use_slots_index_.size(); ++i) { @@ -830,13 +885,15 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( // todo serialize ins in global shuffle void MultiSlotInMemoryDataFeed::SerializeIns( - const std::vector& ins, std::string* str) { + const std::vector*>& ins, + std::string* str) { auto fleet_ptr = FleetWrapper::GetInstance(); fleet_ptr->Serialize(ins, str); } // todo deserialize ins in global shuffle -void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector* ins, - const std::string& str) { +void MultiSlotInMemoryDataFeed::DeserializeIns( + std::vector>* ins, + const std::string& str) { auto fleet_ptr = FleetWrapper::GetInstance(); fleet_ptr->Deserialize(ins, str); } diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 8458f9e95e..cab0b431b5 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -21,6 +21,7 @@ limitations under the License. */ #include // NOLINT #include #include +#include #include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -52,7 +53,10 @@ namespace framework { // } class DataFeed { public: - DataFeed() {} + DataFeed() { + mutex_for_pick_file_ = nullptr; + file_idx_ = nullptr; + } virtual ~DataFeed() {} virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual bool CheckFile(const char* filename) { @@ -89,6 +93,12 @@ class DataFeed { virtual void SetThreadNum(int thread_num) { } // This function will do nothing at default virtual void SetTrainerNum(int trainer_num) { } + virtual void SetFileListMutex(std::mutex* mutex) { + mutex_for_pick_file_ = mutex; + } + virtual void SetFileListIndex(size_t* file_index) { + file_idx_ = file_index; + } virtual void LoadIntoMemory() { PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); } @@ -100,7 +110,9 @@ class DataFeed { } // This function will do nothing at default virtual void FillMemoryDataToChannel() { } + // This function will do nothing at default virtual void FillChannelToMemoryData() { } + // This function will do nothing at default virtual void PutInsToChannel(const std::string& ins_str) { } protected: @@ -116,9 +128,9 @@ class DataFeed { // safe). virtual bool PickOneFile(std::string* filename); - static std::vector filelist_; - static size_t file_idx_; - static std::mutex mutex_for_pick_file_; + std::vector filelist_; + size_t* file_idx_; + std::mutex* mutex_for_pick_file_; // the alias of used slots, and its order is determined by // data_feed_desc(proto object) @@ -141,7 +153,7 @@ class DataFeed { int batch_size_; bool finish_init_; - static bool finish_set_filelist_; + bool finish_set_filelist_; bool finish_start_; std::string pipe_command_; }; @@ -215,8 +227,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed { virtual bool ParseOneInstance(T* instance) = 0; virtual bool ParseOneInstanceFromPipe(T* instance) = 0; virtual void PutToFeedVec(const T& ins_vec) = 0; - virtual void SerializeIns(const T& ins, std::string* str) = 0; - virtual void DeserializeIns(T* ins, const std::string& str) = 0; + virtual void SerializeIns(const std::vector& ins, std::string* str) = 0; + virtual void DeserializeIns(std::vector* ins, const std::string& str) = 0; + virtual std::pair GetMemoryDataInterval(); int thread_id_; int thread_num_; @@ -284,13 +297,13 @@ class MultiSlotType { std::string DebugString() { std::stringstream ss; - ss << "type: " << type_ << "\n"; - ss << "offset:\n"; + ss << "\ntype: " << type_ << "\n"; + ss << "offset: "; ss << "["; for (const size_t& i : offset_) { ss << offset_[i] << ","; } - ss << "]\ndata:\n["; + ss << "]\ndata: ["; if (type_[0] == 'f') { for (const float& i : float_feasign_) { ss << i << ","; @@ -356,9 +369,9 @@ class MultiSlotInMemoryDataFeed virtual bool ParseOneInstance(std::vector* instance); virtual bool ParseOneInstanceFromPipe(std::vector* instance); virtual void PutToFeedVec(const std::vector& ins_vec); - virtual void SerializeIns(const std::vector& ins, + virtual void SerializeIns(const std::vector*>& ins, std::string* str); - virtual void DeserializeIns(std::vector* ins, + virtual void DeserializeIns(std::vector>* ins, const std::string& str); }; diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 755c858bc7..b0f5d1867a 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -18,6 +18,8 @@ #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" #include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/platform/timer.h" +#include "paddle/fluid/framework/io/fs.h" namespace paddle { namespace framework { @@ -25,12 +27,15 @@ namespace framework { template DatasetImpl::DatasetImpl() { thread_num_ = 1; + trainer_num_ = 1; + file_idx_ = 0; } template void DatasetImpl::SetFileList(const std::vector& filelist) { VLOG(3) << "filelist size: " << filelist.size(); filelist_ = filelist; + file_idx_ = 0; /* int file_cnt = filelist_.size(); if (thread_num_ > file_cnt) { @@ -45,19 +50,34 @@ void DatasetImpl::SetFileList(const std::vector& filelist) { // not user friendly template void DatasetImpl::SetThreadNum(int thread_num) { - int file_cnt = filelist_.size(); + VLOG(3) << "SetThreadNum thread_num=" << thread_num; + //int file_cnt = filelist_.size(); + /* if (file_cnt != 0 && thread_num > file_cnt) { VLOG(3) << "DataSet thread num = " << thread_num << ", file num = " << file_cnt << ". Changing DataSet thread num = " << file_cnt; thread_num = file_cnt; - } + }*/ thread_num_ = thread_num; } template void DatasetImpl::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; + // should inform reader of trainer_num directly + for (auto reader : readers_) { + reader->SetTrainerNum(trainer_num); + } +} + +template +void DatasetImpl::SetHdfsConfig(const std::string& fs_name, + const std::string& fs_ugi) { + std::string cmd = std::string("hadoop fs"); + cmd += " -D fs.default.name=" + fs_name; + cmd += " -D hadoop.job.ugi=" + fs_ugi; + paddle::framework::hdfs_set_command(cmd); } template @@ -75,6 +95,8 @@ DatasetImpl::GetReaders() { template void DatasetImpl::LoadIntoMemory() { VLOG(3) << "DatasetImpl::LoadIntoMemory() begin"; + platform::Timer timeline; + timeline.Start(); if (readers_.size() == 0) { CreateReaders(); } @@ -86,12 +108,17 @@ void DatasetImpl::LoadIntoMemory() { for (std::thread& t : load_threads) { t.join(); } - VLOG(3) << "DatasetImpl::LoadIntoMemory() end"; + timeline.Pause(); + VLOG(3) << "DatasetImpl::LoadIntoMemory() end" + << ", memory data size=" << memory_data_.size() + << ", cost time=" << timeline.ElapsedSec() << " seconds"; } template void DatasetImpl::LocalShuffle() { VLOG(3) << "DatasetImpl::LocalShuffle() begin"; + platform::Timer timeline; + timeline.Start(); if (readers_.size() == 0) { CreateReaders(); } @@ -107,23 +134,27 @@ void DatasetImpl::LocalShuffle() { t.join(); } std::vector().swap(memory_data_); - VLOG(3) << "DatasetImpl::LocalShuffle() end"; + timeline.Pause(); + VLOG(3) << "DatasetImpl::LocalShuffle() end, cost time=" + << timeline.ElapsedSec() << " seconds"; } template void DatasetImpl::GlobalShuffle() { VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; - if (readers_.size() == 0) { - CreateReaders(); - } - // if it is not InMemory, memory_data_ is empty - std::random_shuffle(memory_data_.begin(), memory_data_.end()); + platform::Timer timeline; + timeline.Start(); auto fleet_ptr = FleetWrapper::GetInstance(); VLOG(3) << "RegisterClientToClientMsgHandler"; fleet_ptr->RegisterClientToClientMsgHandler( 0, [this](int msg_type, int client_id, const std::string& msg) -> int { return this->ReceiveFromClient(msg_type, client_id, msg); }); + if (readers_.size() == 0) { + CreateReaders(); + } + // if it is not InMemory, memory_data_ is empty + std::random_shuffle(memory_data_.begin(), memory_data_.end()); VLOG(3) << "start global shuffle threads"; std::vector global_shuffle_threads; for (int i = 0; i < thread_num_; ++i) { @@ -133,15 +164,32 @@ void DatasetImpl::GlobalShuffle() { for (std::thread& t : global_shuffle_threads) { t.join(); } - VLOG(3) << "DatasetImpl::GlobalShuffle() end"; + std::vector().swap(memory_data_); + timeline.Pause(); + VLOG(3) << "DatasetImpl::GlobalShuffle() end, cost time=" + << timeline.ElapsedSec() << " seconds"; } template void DatasetImpl::CreateReaders() { VLOG(3) << "Calling CreateReaders()"; CHECK(thread_num_ > 0) << "thread_num should > 0"; + int file_cnt = filelist_.size(); + int memory_data_size = memory_data_.size(); + if (memory_data_size != 0 && thread_num_ > memory_data_size) { + VLOG(3) << "Dataset thread num = " << thread_num_ + << ", memory data size = " << memory_data_size + << ". Changing Dataset thread num = " << memory_data_size; + thread_num_ = memory_data_size; + } else if (file_cnt != 0 && thread_num_ > file_cnt) { + VLOG(3) << "Dataset thread num = " << thread_num_ + << ", file num = " << file_cnt + << ". Changing Dataset thread num = " << file_cnt; + thread_num_ = file_cnt; + } VLOG(3) << "thread_num in Readers: " << thread_num_; VLOG(3) << "readers size: " << readers_.size(); + VLOG(3) << "Filelist size in readers: " << filelist_.size(); if (readers_.size() != 0) { return; } @@ -154,9 +202,10 @@ void DatasetImpl::CreateReaders() { readers_.back()->SetThreadId(i); readers_.back()->SetThreadNum(thread_num_); readers_.back()->SetTrainerNum(trainer_num_); + readers_.back()->SetFileListMutex(&mutex_for_pick_file_); + readers_.back()->SetFileListIndex(&file_idx_); + readers_.back()->SetFileList(filelist_); } - VLOG(3) << "Filelist size in readers: " << filelist_.size(); - readers_[0]->SetFileList(filelist_); } template @@ -184,9 +233,12 @@ void DatasetImpl::DestroyReaders() { template int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) { - // todo random - // int64_t index = paddle::ps::local_random_engine()() % thread_num_; - int64_t index = 0; + VLOG(3) << "ReceiveFromClient msg_type=" << msg_type + << ", client_id=" << client_id << ", msg length=" + << msg.length(); + auto fleet_ptr = FleetWrapper::GetInstance(); + int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_; + VLOG(3) << "ramdom index=" << index; readers_[index]->PutInsToChannel(msg); return 0; } diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 41aa636c6b..02e07c5b5f 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -33,6 +33,8 @@ class Dataset { virtual void SetFileList(const std::vector& filelist) = 0; virtual void SetThreadNum(int thread_num) = 0; virtual void SetTrainerNum(int trainer_num) = 0; + virtual void SetHdfsConfig(const std::string& fs_name, + const std::string& fs_ugi) = 0; virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; virtual const std::vector& GetFileList() = 0; virtual int GetThreadNum() = 0; @@ -60,6 +62,8 @@ class DatasetImpl : public Dataset { virtual void SetFileList(const std::vector& filelist); virtual void SetThreadNum(int thread_num); virtual void SetTrainerNum(int trainer_num); + virtual void SetHdfsConfig(const std::string& fs_name, + const std::string& fs_ugi); virtual void SetDataFeedDesc(const std::string& data_feed_desc_str); virtual const std::vector& GetFileList() { return filelist_; } @@ -85,8 +89,10 @@ class DatasetImpl : public Dataset { std::mutex mutex_for_update_memory_data_; int thread_num_; paddle::framework::DataFeedDesc data_feed_desc_; - std::vector filelist_; int trainer_num_; + std::vector filelist_; + size_t file_idx_; + std::mutex mutex_for_pick_file_; }; class MultiSlotDataset : public DatasetImpl> { diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 0c42f5bf69..636e0a7354 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -26,12 +26,14 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); SetDataset(dataset); - workers_.resize(thread_num_); dataset->CreateReaders(); const std::vector> readers = dataset->GetReaders(); + thread_num_ = readers.size(); + workers_.resize(thread_num_); + for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 73db3cae55..1497628e64 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include #include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/scope.h" namespace paddle { namespace framework { @@ -203,6 +204,60 @@ void FleetWrapper::PullDenseVarsSync( #endif } +void FleetWrapper::PushDenseParamSync( + const ProgramDesc& program, const uint64_t table_id, + const std::vector& var_names) { +#ifdef PADDLE_WITH_PSLIB + paddle::framework::Scope scope; + auto& block = program.Block(0); + for (auto& var : block.AllVars()) { + if (var->Persistable()) { + auto* ptr = scope.Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + } else { + auto* ptr = scope.Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + } + } + auto place = platform::CPUPlace(); + std::vector regions; + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor* tensor = var->GetMutable(); + std::vector dim; + for (auto& var : block.AllVars()) { + if (var->Name() == t) { + dim = var->GetShape(); + break; + } + } + int cnt = 1; + for (auto& i: dim) { + cnt *= i; + } + DDim d(std::vector{cnt}.data(), 1); + float* g = tensor->mutable_data(d, place); + CHECK(g != nullptr) << "var[" << t << "] value not initialized"; + float init_range = 0.2; + int rown = tensor->dims()[0]; + init_range /= sqrt(rown); + std::normal_distribution ndistr(0.0, 1.0); + for (auto i = 0u; i < tensor->numel(); ++i) { + g[i] = ndistr(LocalRandomEngine()) * init_range; + } + paddle::ps::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + auto push_status = pslib_ptr_->_worker_ptr->push_dense_param( + regions.data(), regions.size(), table_id); + push_status.wait(); + auto status = push_status.get(); + CHECK(status == 0) << "push dense param failed, status[" + << status << "]"; + } +#endif +} + void FleetWrapper::PushDenseVarsSync( Scope* scope, const uint64_t table_id, const std::vector& var_names) {} @@ -269,6 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( continue; } LOG(WARNING) << "going to memcpy"; + CHECK(fea_idx < (*push_values).size()); + CHECK(fea_idx < fea_labels.size()); memcpy((*push_values)[fea_idx].data() + offset, g, sizeof(float) * emb_dim); LOG(WARNING) << "show"; @@ -294,13 +351,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( #endif } -int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, - MsgHandlerFunc handler) { +int FleetWrapper::RegisterClientToClientMsgHandler( + int msg_type, MsgHandlerFunc handler) { #ifdef PADDLE_WITH_PSLIB VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; VLOG(3) << "pslib_ptr_=" << pslib_ptr_; VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr; - pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler); + return pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler); #else VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler" << " does nothing when no pslib"; @@ -308,15 +365,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, return 0; } -int FleetWrapper::SendClientToClientMsg(int msg_type, int to_client_id, - const std::string& msg) { +std::future FleetWrapper::SendClientToClientMsg( + int msg_type, int to_client_id, const std::string& msg) { #ifdef PADDLE_WITH_PSLIB - pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg); + return pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg); #else VLOG(0) << "FleetWrapper::SendClientToClientMsg" << " does nothing when no pslib"; #endif - return 0; + return std::future(); } std::default_random_engine& FleetWrapper::LocalRandomEngine() { @@ -336,10 +393,12 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() { } template -void FleetWrapper::Serialize(const T& t, std::string* str) { +void FleetWrapper::Serialize(const std::vector& t, std::string* str) { #ifdef PADDLE_WITH_PSLIB paddle::ps::BinaryArchive ar; - ar << t; + for (size_t i = 0; i < t.size(); ++i) { + ar << *(t[i]); + } *str = std::string(ar.buffer(), ar.length()); #else VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib"; @@ -347,20 +406,30 @@ void FleetWrapper::Serialize(const T& t, std::string* str) { } template -void FleetWrapper::Deserialize(T* t, const std::string& str) { +void FleetWrapper::Deserialize(std::vector* t, const std::string& str) { #ifdef PADDLE_WITH_PSLIB + if (str.length() == 0) { + return; + } paddle::ps::BinaryArchive ar; ar.set_read_buffer(const_cast(str.c_str()), str.length(), nullptr); - *t = ar.get(); + if (ar.cursor() == ar.finish()) { + return; + } + while (ar.cursor() < ar.finish()) { + t->push_back(ar.get()); + } + CHECK(ar.cursor() == ar.finish()); + VLOG(3) << "Deserialize size " << t->size(); #else VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib"; #endif } template void FleetWrapper::Serialize>( - const std::vector&, std::string*); -template void FleetWrapper::Deserialize(std::vector*, - const std::string&); + const std::vector*>&, std::string*); +template void FleetWrapper::Deserialize>( + std::vector>*, const std::string&); } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index deab3bc1db..ed3217b376 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN +#include "paddle/fluid/framework/program_desc.h" namespace paddle { namespace framework { @@ -71,6 +72,10 @@ class FleetWrapper { const std::vector& var_names, std::vector<::std::future>* pull_dense_status); + void PushDenseParamSync( + const ProgramDesc& program, const uint64_t table_id, + const std::vector& var_names); + // Push dense variables to server in async mode // Param: scope, table_id, var_names, // Param: push_sparse_status @@ -119,16 +124,15 @@ class FleetWrapper { typedef std::function MsgHandlerFunc; int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); - int SendClientToClientMsg(int msg_type, - int to_client_id, - const std::string& msg); + std::future SendClientToClientMsg(int msg_type, + int to_client_id, + const std::string& msg); std::default_random_engine& LocalRandomEngine(); template - void Serialize(const T& t, std::string* str); + void Serialize(const std::vector& t, std::string* str); template - void Deserialize(T* t, const std::string& str); - + void Deserialize(std::vector* t, const std::string& str); static std::shared_ptr GetInstance() { if (NULL == s_instance_) { s_instance_.reset(new paddle::framework::FleetWrapper()); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 30d6311728..7f955e3550 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -26,13 +26,15 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, thread_num_ = trainer_desc.thread_num(); SetDataset(dataset); // get filelist from trainer_desc here - workers_.resize(thread_num_); - VLOG(3) << "worker thread num: " << thread_num_; dataset->CreateReaders(); VLOG(3) << "readers created"; const std::vector> readers = dataset->GetReaders(); VLOG(3) << "readers num: " << readers.size(); + // change thread num to readers num + thread_num_ = readers.size(); + VLOG(3) << "worker thread num: " << thread_num_; + workers_.resize(thread_num_); for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); diff --git a/paddle/fluid/pybind/async_executor_py.cc b/paddle/fluid/pybind/async_executor_py.cc index 6dc865e8ed..3bb6bff236 100644 --- a/paddle/fluid/pybind/async_executor_py.cc +++ b/paddle/fluid/pybind/async_executor_py.cc @@ -49,7 +49,7 @@ void BindAsyncExecutor(py::module* m) { new framework::AsyncExecutor(scope, place)); })) .def("run_from_files", &framework::AsyncExecutor::RunFromFile) - .def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset) + //.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset) .def("init_server", &framework::AsyncExecutor::InitServer) .def("init_worker", &framework::AsyncExecutor::InitWorker) .def("start_server", &framework::AsyncExecutor::StartServer) diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 3ed4c01bed..2138ecab85 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -50,6 +50,7 @@ void BindDataset(py::module* m) { .def("set_filelist", &framework::Dataset::SetFileList) .def("set_thread_num", &framework::Dataset::SetThreadNum) .def("set_trainer_num", &framework::Dataset::SetTrainerNum) + .def("set_hdfs_config", &framework::Dataset::SetHdfsConfig) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("local_shuffle", &framework::Dataset::LocalShuffle) diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index f6a2ed7a27..444a3c7f14 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -47,6 +47,7 @@ void BindFleetWrapper(py::module* m) { .def("init_server", &framework::FleetWrapper::InitServer) .def("run_server", &framework::FleetWrapper::RunServer) .def("init_worker", &framework::FleetWrapper::InitWorker) + .def("init_model", &framework::FleetWrapper::PushDenseParamSync) .def("stop_server", &framework::FleetWrapper::StopServer) .def("gather_servers", &framework::FleetWrapper::GatherServers); } // end FleetWrapper diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 6ae1d3cf15..988272e632 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -86,6 +86,9 @@ class DatasetBase(object): "Currently, fluid.dataset only supports dtype=float32 and dtype=int64" ) + def set_hdfs_config(self, fs_name, fs_ugi): + self.dataset.set_hdfs_config(fs_name, fs_ugi) + def _prepare_to_run(self): self.dataset.set_data_feed_desc(self.desc()) @@ -115,11 +118,15 @@ class InMemoryDataset(DatasetBase): def local_shuffle(self): self.dataset.local_shuffle() - def global_shuffle(self): - from .distributed import ps_instance - instance = ps_instance.PaddlePSInstance(1, 2) - self.dataset.set_trainer_num(instance.get_worker_num()) + def global_shuffle(self, fleet=None): + trainer_num = 1 + if fleet is not None: + fleet.fleet_instance.role_maker_.barrier_worker() + trainer_num = fleet.worker_num() + self.dataset.set_trainer_num(trainer_num) self.dataset.global_shuffle() + if fleet is not None: + fleet.fleet_instance.role_maker_.barrier_worker() class QueueDataset(DatasetBase): @@ -130,5 +137,5 @@ class QueueDataset(DatasetBase): def local_shuffle(self): pass - def global_shuffle(self): + def global_shuffle(self, fleet=None): pass diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index 9f57b9a2e5..baaeb1abef 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -170,7 +170,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): """ if self._check_role_generation(): if self.is_worker(): - return self.get_size() + return self.get_size() / 2; return 0 def server_num(self): @@ -179,7 +179,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): """ if self._check_role_generation(): if self.is_server(): - return self.get_size() + return self.get_size() / 2; return 0 def worker_index(self): diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py index d8efba432f..b0cb6a0041 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py @@ -43,7 +43,7 @@ class Fleet(object): save_pserver_model(): save model parameters in pserver, called from a server node Example: - + .. code-block:: python import paddle.fluid.incubate.fleet.parameter_server as fleet from my_model import bow_net @@ -58,7 +58,7 @@ class Fleet(object): fleet.init_worker() # init worker should be called before training # do other things like training elif fleet.is_server(): - fleet.init_pserver() + fleet.init_pserver() fleet.stop() """ @@ -75,7 +75,7 @@ class Fleet(object): """ init(): which should be called only once in user's python scripts. init() will initialize FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying - current node's role, e.g. worker, server, etc. + current node's role, e.g. worker, server, etc. """ if not self.is_initialized_: self.role_maker_ = MPISymetricRoleMaker() @@ -122,7 +122,7 @@ class Fleet(object): print("You should run DistributedOptimizer.minimize() first") sys.exit(-1) - def init_worker(self): + def init_worker(self, program): """ init_worker(): will be called by user. When a user knows current process is_server(), he/she should call init_worker() to initialize global information about worker and connect @@ -143,6 +143,19 @@ class Fleet(object): self.role_maker_.get_rank()) self.role_maker_.barrier_all() self.role_maker_.barrier_worker() + if self.role_maker_.is_first_worker(): + tables = self._dist_desc.trainer_param.dense_table._values + for i in range(0, len(tables)): + table = tables[i]; + var_name_list = [] + for i in range(0, len(table.dense_variable_name)): + var_name_list.append(table.dense_variable_name[i]) + #print "table id ", table.table_id + #print "var_name_list ", var_name_list + self._fleet_ptr.init_model(program.desc, + int(table.table_id), + var_name_list) + self.role_maker_.barrier_worker() else: print("You should run DistributedOptimizer.minimize() first") sys.exit(-1) -- GitLab