From 3cea00bd52a05a788195ba9588515761c9194221 Mon Sep 17 00:00:00 2001 From: xujiaqi01 Date: Tue, 12 Mar 2019 14:26:44 +0800 Subject: [PATCH] store memory data in Dataset && fix bug --- paddle/fluid/framework/data_feed.cc | 131 +++++++++++++++--- paddle/fluid/framework/data_feed.h | 43 +++++- paddle/fluid/framework/data_set.cc | 102 ++++++++++---- paddle/fluid/framework/data_set.h | 48 ++++++- paddle/fluid/framework/fleet/fleet_wrapper.cc | 62 +++++++++ paddle/fluid/framework/fleet/fleet_wrapper.h | 14 ++ paddle/fluid/pybind/data_set_py.cc | 18 +-- python/paddle/fluid/__init__.py | 4 +- python/paddle/fluid/dataset.py | 4 +- 9 files changed, 356 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index fcba99d5f3f..8ee625b5c6d 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) { bool DataFeed::PickOneFile(std::string* filename) { std::unique_lock lock(mutex_for_pick_file_); if (file_idx_ == filelist_.size()) { + VLOG(3) << "DataFeed::PickOneFile no more file to pick"; return false; } + VLOG(3) << "file_idx_=" << file_idx_; *filename = filelist_[file_idx_++]; // LOG(ERROR) << "pick file:" << *filename; return true; @@ -146,17 +148,18 @@ template class PrivateQueueDataFeed>; template InMemoryDataFeed::InMemoryDataFeed() { cur_channel_ = 0; - shuffled_ins_ = nullptr; - shuffled_ins_out_ = nullptr; + shuffled_ins_ = std::make_shared>(); + shuffled_ins_out_ = std::make_shared>(); + fleet_send_batch_size_ = 10000; } template bool InMemoryDataFeed::Start() { DataFeed::CheckSetFileList(); - if (memory_data_.size() != 0) { - CHECK_EQ(cur_channel_, 0); - shuffled_ins_->Extend(std::move(memory_data_)); - std::vector().swap(memory_data_); + if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) { + FillMemoryDataToChannel(); + //std::unique_lock lock(*mutex_for_update_memory_data_); + //std::vector().swap(memory_data_); } DataFeed::finish_start_ = true; return true; @@ -196,6 +199,31 @@ int InMemoryDataFeed::Next() { return DataFeed::batch_size_; } +template +void InMemoryDataFeed::SetMemoryData(void* memory_data) { + memory_data_ = static_cast*>(memory_data); +} + +template +void InMemoryDataFeed::SetMemoryDataMutex(std::mutex* mutex) { + mutex_for_update_memory_data_ = mutex; +} + +template +void InMemoryDataFeed::SetThreadId(int thread_id) { + thread_id_ = thread_id; +} + +template +void InMemoryDataFeed::SetThreadNum(int thread_num) { + thread_num_ = thread_num; +} + +template +void InMemoryDataFeed::SetTrainerNum(int trainer_num) { + trainer_num_ = trainer_num; +} + template void InMemoryDataFeed::PutInsToChannel(const std::string& ins_str) { T ins; @@ -203,11 +231,54 @@ void InMemoryDataFeed::PutInsToChannel(const std::string& ins_str) { shuffled_ins_->Push(std::move(ins)); } +template +void InMemoryDataFeed::FillMemoryDataToChannel() { + VLOG(3) << "InMemoryDataFeed::FillMemoryDataToChannel, thread_id=" << thread_id_; + int64_t start = 0; + int64_t end = 0; + int64_t size = memory_data_->size(); + VLOG(3) << "memory_data size=" << size; + for (int64_t i = 0; i <= static_cast(thread_id_); ++i) { + int64_t len = size / static_cast(thread_num_) + + (i < (size % static_cast(thread_num_))); + start = end; + end += len; + } + for (int64_t i = start; i < end; ++i) { + T& t = (*memory_data_)[i]; + shuffled_ins_->Push(std::move(t)); + } +} + +template +void InMemoryDataFeed::FillChannelToMemoryData() { + VLOG(3) << "InMemoryDataFeed::FillChannelToMemoryData, thread_id=" << thread_id_; + std::vector local_vec; + std::shared_ptr> channel = nullptr; + if (cur_channel_ == 0) { + channel = shuffled_ins_; + } else { + channel = shuffled_ins_out_; + } + CHECK(channel != nullptr); + local_vec.reserve(channel->Size()); + for (int64_t i = 0; i < channel->Size(); ++i) { + channel->Pop(local_vec[i]); + } + std::unique_lock lock(*mutex_for_update_memory_data_); + lock.lock(); + memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end()); + lock.unlock(); + std::vector().swap(local_vec); +} + template void InMemoryDataFeed::LoadIntoMemory() { + VLOG(3) << "InMemoryDataFeed::LoadIntoMemory() begin, thread_id=" << thread_id_; std::vector local_vec; std::string filename; while (DataFeed::PickOneFile(&filename)) { + VLOG(3) << "PickOneFile, filename=" << filename << ", thread_id=" << thread_id_; int err_no = 0; PrivateQueueDataFeed::fp_ = fs_open_read(filename, &err_no, PrivateQueueDataFeed::pipe_command_); @@ -216,35 +287,50 @@ void InMemoryDataFeed::LoadIntoMemory() { while (ParseOneInstanceFromPipe(&instance)) { local_vec.push_back(instance); } - memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end()); + VLOG(3) << "InMemoryDataFeed::LoadIntoMemory() read all lines, thread_id=" << thread_id_; + { + std::lock_guard lock(*mutex_for_update_memory_data_); + memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end()); + } std::vector().swap(local_vec); } + VLOG(3) << "InMemoryDataFeed::LoadIntoMemory() end, thread_id=" << thread_id_; } template void InMemoryDataFeed::LocalShuffle() { - std::random_shuffle(memory_data_.begin(), memory_data_.end()); + VLOG(3) << "InMemoryDataFeed::LocalShuffle() begin, thread_id=" << thread_id_; + FillMemoryDataToChannel(); + VLOG(3) << "InMemoryDataFeed::LocalShuffle() end, thread_id=" << thread_id_; } -// todo global shuffle -/* template -void InMemoryDataFeed::GlobalShuffle(int trainer_num) { - std::random_shuffle(memory_data_.begin(), memory_data_.end()); - for (int64_t i = 0; i < memory_data_.size(); ++i) { +void InMemoryDataFeed::GlobalShuffle() { + auto fleet_ptr = FleetWrapper::GetInstance(); + std::vector send_str_vec(trainer_num_); + for (int64_t i = 0; i < memory_data_->size(); ++i) { // todo get ins id //std::string ins_id = memory_data_[i].ins_id; // todo hash - int64_t hash_id = paddle::ps::local_random_engine()(); - //int64_t hash_id = hash(ins_id); + //int64_t hash_id = paddle::ps::local_random_engine()(); + int64_t hash_id = 0; int64_t node_id = hash_id % trainer_num_; std::string str; - SerializeIns(memory_data_[i], str); - auto fleet_ptr = FleetWrapper::GetInstance(); - auto ret = fleet_ptr->send_client2client_msg(0, node_id, str); + SerializeIns((*memory_data_)[i], str); + send_str_vec[node_id] += str; + if (i % fleet_send_batch_size_ == 0 && i != 0) { + for (int j = 0; j < send_str_vec.size(); ++j) { + fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]); + send_str_vec[j] = ""; + } + } + } + for (int j = 0; j < send_str_vec.size(); ++j) { + if (send_str_vec[j].length() != 0) { + fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]); + } } } -*/ // explicit instantiation template class InMemoryDataFeed>; @@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance( if (getline(file_, line)) { int use_slots_num = use_slots_.size(); instance->resize(use_slots_num); + VLOG(3) << line; // parse line const char* str = line.c_str(); char* endptr = const_cast(str); @@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( // todo serialize ins in global shuffle void MultiSlotInMemoryDataFeed::SerializeIns( const std::vector& ins, std::string& str) { - return; + auto fleet_ptr = FleetWrapper::GetInstance(); + fleet_ptr->Serialize(ins, str); } // todo deserialize ins in global shuffle void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector& ins, const std::string& str) { - return; + auto fleet_ptr = FleetWrapper::GetInstance(); + fleet_ptr->Deserialize(ins, str); } } // namespace framework diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 0e1ac79664f..98aeb4b1f93 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include // NOLINT #include +#include #include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -78,17 +79,33 @@ class DataFeed { // This function is used for binding feed_vec memory virtual void AddFeedVar(Variable* var, const std::string& name); + // This function will do nothing at default + virtual void SetMemoryData(void* memory_data) { } + // This function will do nothing at default + virtual void SetMemoryDataMutex(std::mutex* mutex) { } + // This function will do nothing at default + virtual void SetThreadId(int thread_id) { } + // This function will do nothing at default + virtual void SetThreadNum(int thread_num) { } + // This function will do nothing at default + virtual void SetTrainerNum(int trainer_num) { } virtual void LoadIntoMemory() { PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); } virtual void LocalShuffle() { PADDLE_THROW("This function(LocalShuffle) is not implemented."); } - virtual void GlobalShuffle(int trainer_num) { + virtual void GlobalShuffle() { PADDLE_THROW("This function(GlobalShuffle) is not implemented."); } + virtual void FillMemoryDataToChannel() { + PADDLE_THROW("This function(FillMemoryDataToChannel) is not implemented."); + } + virtual void FillChannelToMemoryData() { + PADDLE_THROW("This function(FillChannelToMemoryData) is not implemented."); + } virtual void PutInsToChannel(const std::string& ins_str) { - PADDLE_THROW("This function(PutToChannel) is not implemented."); + PADDLE_THROW("This function(PutInsToChannel) is not implemented."); } protected: @@ -181,13 +198,20 @@ class InMemoryDataFeed : public PrivateQueueDataFeed { public: InMemoryDataFeed(); virtual ~InMemoryDataFeed() {} + virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual bool Start(); virtual int Next(); + virtual void SetMemoryData(void* memory_data); + virtual void SetMemoryDataMutex(std::mutex* mutex); + virtual void SetThreadId(int thread_id); + virtual void SetThreadNum(int thread_num); + virtual void SetTrainerNum(int trainer_num); virtual void PutInsToChannel(const std::string& ins_str); + virtual void FillMemoryDataToChannel(); + virtual void FillChannelToMemoryData(); virtual void LoadIntoMemory(); virtual void LocalShuffle(); - // todo global shuffle - //virtual void GlobalShuffle(int trainer_num); + virtual void GlobalShuffle(); protected: virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0; virtual bool ParseOneInstance(T* instance) = 0; @@ -196,13 +220,18 @@ class InMemoryDataFeed : public PrivateQueueDataFeed { virtual void SerializeIns(const T& ins, std::string& str) = 0; virtual void DeserializeIns(T& ins, const std::string& str) = 0; - std::vector memory_data_; + int thread_id_; + int thread_num_; + int trainer_num_; + std::vector* memory_data_; + std::mutex* mutex_for_update_memory_data_; // when read ins, we put ins from one channel to the other, // and when finish reading, we set cur_channel = 1 - cur_channel, // so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_ int cur_channel_; std::shared_ptr> shuffled_ins_; std::shared_ptr> shuffled_ins_out_; + int64_t fleet_send_batch_size_; }; // This class define the data type of instance(ins_vec) in MultiSlotDataFeed @@ -226,6 +255,7 @@ class MultiSlotType { offset_[0] = 0; } const std::vector& GetOffset() const { return offset_; } + std::vector& MutableOffset() { return offset_; } void AddValue(const float v) { CheckFloat(); float_feasign_.push_back(v); @@ -248,8 +278,11 @@ class MultiSlotType { } } const std::vector& GetFloatData() const { return float_feasign_; } + std::vector& MutableFloatData() { return float_feasign_; } const std::vector& GetUint64Data() const { return uint64_feasign_; } + std::vector& MutableUint64Data() { return uint64_feasign_; } const std::string& GetType() const { return type_; } + std::string& MutableType() { return type_; } private: void CheckType(const std::string& type) const { diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index ce59bdff8fa..7497e4c9afb 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -12,6 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "paddle/fluid/framework/data_set.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" @@ -21,23 +22,27 @@ namespace paddle { namespace framework { -Dataset::Dataset() { thread_num_ = 1; } +template +DatasetImpl::DatasetImpl() { thread_num_ = 1; } -void Dataset::SetFileList(const std::vector& filelist) { +template +void DatasetImpl::SetFileList(const std::vector& filelist) { VLOG(3) << "filelist size: " << filelist.size(); filelist_ = filelist; + /* int file_cnt = filelist_.size(); if (thread_num_ > file_cnt) { VLOG(1) << "DataSet thread num = " << thread_num_ << ", file num = " << file_cnt << ". Changing DataSet thread num = " << file_cnt; thread_num_ = file_cnt; - } + }*/ } // buggy here, a user should set filelist first before this function // not user friendly -void Dataset::SetThreadNum(int thread_num) { +template +void DatasetImpl::SetThreadNum(int thread_num) { int file_cnt = filelist_.size(); if (file_cnt != 0 && thread_num > file_cnt) { VLOG(1) << "DataSet thread num = " << thread_num @@ -48,19 +53,24 @@ void Dataset::SetThreadNum(int thread_num) { thread_num_ = thread_num; } -void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } +template +void DatasetImpl::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } -void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) { +template +void DatasetImpl::SetDataFeedDesc(const std::string& data_feed_desc_str) { google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, &data_feed_desc_); } -const std::vector>& -Dataset::GetReaders() { +template +std::vector>& + DatasetImpl::GetReaders() { return readers_; } -void Dataset::LoadIntoMemory() { +template +void DatasetImpl::LoadIntoMemory() { + VLOG(3) << "DatasetImpl::LoadIntoMemory() begin"; if (readers_.size() == 0) { CreateReaders(); } @@ -72,12 +82,18 @@ void Dataset::LoadIntoMemory() { for (std::thread& t : load_threads) { t.join(); } + VLOG(3) << "DatasetImpl::LoadIntoMemory() end"; } -void Dataset::LocalShuffle() { +template +void DatasetImpl::LocalShuffle() { + VLOG(3) << "DatasetImpl::LocalShuffle() begin"; if (readers_.size() == 0) { CreateReaders(); } + // if it is not InMemory, memory_data_ is empty + std::random_shuffle(memory_data_.begin(), memory_data_.end()); + std::vector local_shuffle_threads; for (int64_t i = 0; i < thread_num_; ++i) { local_shuffle_threads.push_back(std::thread( @@ -86,30 +102,37 @@ void Dataset::LocalShuffle() { for (std::thread& t : local_shuffle_threads) { t.join(); } + std::vector().swap(memory_data_); + VLOG(3) << "DatasetImpl::LocalShuffle() end"; } -// todo global shuffle -void Dataset::GlobalShuffle() { - /* +template +void DatasetImpl::GlobalShuffle() { + VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; + if (readers_.size() == 0) { + CreateReaders(); + } + // if it is not InMemory, memory_data_ is empty + std::random_shuffle(memory_data_.begin(), memory_data_.end()); auto fleet_ptr = FleetWrapper::GetInstance(); fleet_ptr->registe_client2client_msg_handler(0, [this](int msg_type, int client_id, const std::string& msg) -> int { return this->ReceiveFromClient(msg_type, client_id, msg); }); - if (readers_.size() == 0) { - CreateReaders(); - } std::vector global_shuffle_threads; - for (int64_t i = 0; i < thread_num_; ++i) { - global_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::GlobalShuffle, - readers_[i].get(), trainer_num_)); + for (int i = 0; i < thread_num_; ++i) { + global_shuffle_threads.push_back( + std::thread(&paddle::framework::DataFeed::GlobalShuffle, + readers_[i].get())); } for (std::thread& t : global_shuffle_threads) { t.join(); - }*/ + } + VLOG(3) << "DatasetImpl::GlobalShuffle() end"; } -void Dataset::CreateReaders() { +template +void DatasetImpl::CreateReaders() { VLOG(3) << "Calling CreateReaders()"; CHECK(thread_num_ > 0) << "thread_num should > 0"; VLOG(3) << "thread_num in Readers: " << thread_num_; @@ -118,22 +141,53 @@ void Dataset::CreateReaders() { return; } VLOG(3) << "data feed class name: " << data_feed_desc_.name(); - for (int64_t i = 0; i < thread_num_; ++i) { + for (int i = 0; i < thread_num_; ++i) { readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); readers_.back()->Init(data_feed_desc_); + readers_.back()->SetMemoryData(&memory_data_); + readers_.back()->SetMemoryDataMutex(&mutex_for_update_memory_data_); + readers_.back()->SetThreadId(i); + readers_.back()->SetThreadNum(thread_num_); + readers_.back()->SetTrainerNum(trainer_num_); } VLOG(3) << "Filelist size in readers: " << filelist_.size(); readers_[0]->SetFileList(filelist_); } -int Dataset::ReceiveFromClient(int msg_type, int client_id, +template +void DatasetImpl::DestroyReaders() { + VLOG(3) << "Calling DestroyReaders()"; + // clear memory_data_ before fill it + // because if LoadIntoMemory but no Shuffle, + // memory_data_ has empty data which has been std::move to channel + if (memory_data_.size() != 0) { + std::vector().swap(memory_data_); + } + std::vector fill_threads; + for (int i = 0; i < thread_num_; ++i) { + fill_threads.push_back(std::thread( + &paddle::framework::DataFeed::FillChannelToMemoryData, + readers_[i].get())); + } + for (std::thread& t : fill_threads) { + t.join(); + } + std::vector().swap(filelist_); + std::vector>().swap(readers_); +} + +template +int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) { - // can also use hash + // todo random // int64_t index = paddle::ps::local_random_engine()() % thread_num_; int64_t index = 0; readers_[index]->PutInsToChannel(msg); return 0; } +// explicit instantiation +template class DatasetImpl>; + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index f99dc1470c5..c103fc49a7d 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -28,8 +28,33 @@ namespace framework { class Dataset { public: - Dataset(); - virtual ~Dataset() {} + Dataset() {}; + virtual ~Dataset() {}; + virtual void SetFileList(const std::vector& filelist) = 0; + virtual void SetThreadNum(int thread_num) = 0; + virtual void SetTrainerNum(int trainer_num) = 0; + virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; + virtual const std::vector& GetFileList() = 0; + virtual int GetThreadNum() = 0; + virtual int GetTrainerNum() = 0; + virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0; + virtual std::vector>& + GetReaders() = 0; + virtual void LoadIntoMemory() = 0; + virtual void LocalShuffle() = 0; + virtual void GlobalShuffle() = 0; + virtual void CreateReaders() = 0; + virtual void DestroyReaders() = 0; + protected: + virtual int ReceiveFromClient(int msg_type, int client_id, + const std::string& msg) = 0; +}; + +template +class DatasetImpl : public Dataset { + public: + DatasetImpl(); + virtual ~DatasetImpl() {} virtual void SetFileList(const std::vector& filelist); virtual void SetThreadNum(int thread_num); @@ -43,25 +68,34 @@ class Dataset { return data_feed_desc_; } - virtual const std::vector>& - GetReaders(); + virtual std::vector>& + GetReaders(); virtual void LoadIntoMemory(); virtual void LocalShuffle(); - // todo global shuffle virtual void GlobalShuffle(); virtual void CreateReaders(); + virtual void DestroyReaders(); protected: virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg); std::vector> readers_; + std::vector memory_data_; + std::mutex mutex_for_update_memory_data_; + std::vector>> shuffled_ins_vec_; + std::vector>> shuffled_ins_out_vec_; int thread_num_; - std::string fs_name_; - std::string fs_ugi_; paddle::framework::DataFeedDesc data_feed_desc_; std::vector filelist_; int trainer_num_; }; +class MultiSlotDataset : public DatasetImpl> { + public: + MultiSlotDataset() {} + virtual ~MultiSlotDataset() {} +}; + + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index f4522fd34d2..a2d60927fc8 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -27,6 +27,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/data_feed.h" namespace paddle { namespace framework { @@ -35,6 +36,30 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; std::shared_ptr FleetWrapper::s_instance_ = NULL; bool FleetWrapper::is_initialized_ = false; +#ifdef PADDLE_WITH_PSLIB +template +paddle::ps::Archive& operator << ( + paddle::ps::Archive& ar, + const MultiSlotType& ins) { + ar << ins.GetType(); + ar << ins.GetOffset(); + ar << ins.GetFloatData(); + ar << ins.GetUint64Data(); +return ar; +} + +template +paddle::ps::Archive& operator >> ( + paddle::ps::Archive& ar, + MultiSlotType& ins) { + ar >> ins.MutableType(); + ar >> ins.MutableOffset(); + ar >> ins.MutableFloatData(); + ar >> ins.MutableUint64Data(); +return ar; +} +#endif + #ifdef PADDLE_WITH_PSLIB std::shared_ptr FleetWrapper::pslib_ptr_ = NULL; #endif @@ -266,5 +291,42 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( #endif } +// todo registe_client2client_msg_handler +int FleetWrapper::registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler) { + return 0; +} + +// todo send_client2client_msg +int FleetWrapper::send_client2client_msg(int msg_type, int to_client_id, const std::string& msg) { + return 0; +} + +template +void FleetWrapper::Serialize(const T& t, std::string& str) { +#ifdef PADDLE_WITH_PSLIB + paddle::ps::BinaryArchive ar; + ar << t; + str = std::string(ar.buffer(), ar.length()); +#else + VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib"; +#endif +} + +template +void FleetWrapper::Deserialize(T& t, const std::string& str) { +#ifdef PADDLE_WITH_PSLIB + paddle::ps::BinaryArchive ar; + ar.set_read_buffer(const_cast(str.c_str()), str.length(), nullptr); + t = ar.get(); +#else + VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib"; +#endif +} + +template void FleetWrapper::Serialize>( + const std::vector&, std::string&); +template void FleetWrapper::Deserialize( + std::vector&, const std::string&); + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index edac3e41414..f98db1fe8fd 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -17,7 +17,11 @@ limitations under the License. */ #include #ifdef PADDLE_WITH_PSLIB #include +#include #endif +#include +#include +#include #include #include #include "paddle/fluid/framework/scope.h" @@ -110,6 +114,16 @@ class FleetWrapper { uint64_t RunServer(); void GatherServers(const std::vector& host_sign_list, int node_num); + typedef std::function MsgHandlerFunc; + int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler); + int send_client2client_msg(int msg_type, int to_client_id, const std::string& msg); + std::default_random_engine& local_random_engine(); + + template + void Serialize(const T& t, std::string& str); + template + void Deserialize(T& t, const std::string& str); + static std::shared_ptr GetInstance() { if (NULL == s_instance_) { s_instance_.reset(new paddle::framework::FleetWrapper()); diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 45b90ee6c20..ca05451292b 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -41,17 +41,17 @@ namespace paddle { namespace pybind { void BindDataset(py::module* m) { - py::class_(*m, "Dataset") + py::class_(*m, "MultiSlotDataset") .def(py::init([]() { - return std::unique_ptr(new framework::Dataset()); + return std::unique_ptr(new framework::MultiSlotDataset()); })) - .def("set_filelist", &framework::Dataset::SetFileList) - .def("set_thread_num", &framework::Dataset::SetThreadNum) - .def("set_trainer_num", &framework::Dataset::SetTrainerNum) - .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) - .def("load_into_memory", &framework::Dataset::LoadIntoMemory) - .def("local_shuffle", &framework::Dataset::LocalShuffle) - .def("global_shuffle", &framework::Dataset::GlobalShuffle); + .def("set_filelist", &framework::MultiSlotDataset::SetFileList) + .def("set_thread_num", &framework::MultiSlotDataset::SetThreadNum) + .def("set_trainer_num", &framework::MultiSlotDataset::SetTrainerNum) + .def("set_data_feed_desc", &framework::MultiSlotDataset::SetDataFeedDesc) + .def("load_into_memory", &framework::MultiSlotDataset::LoadIntoMemory) + .def("local_shuffle", &framework::MultiSlotDataset::LocalShuffle) + .def("global_shuffle", &framework::MultiSlotDataset::GlobalShuffle); } } // end namespace pybind diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index b67651bf310..37320f12245 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -30,7 +30,7 @@ from .dataset import * from . import async_executor from .async_executor import * -from . import trainer +from . import trainer_desc from . import inferencer from . import io @@ -67,7 +67,7 @@ from . import install_check Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + \ - trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \ + trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + \ data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [ 'io', diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 31cb0555875..932fb64290c 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -37,7 +37,7 @@ class DatasetBase(object): # to decide whether we need create in memory instance self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc.pipe_command = "cat" - self.dataset = core.Dataset() + self.dataset = core.MultiSlotDataset() self.thread_num = 0 def set_pipe_command(self, pipe_command): @@ -109,7 +109,7 @@ class InMemoryDataset(DatasetBase): self.proto_desc.name = "MultiSlotInMemoryDataFeed" def load_into_memory(self): - _prepare_to_run() + self._prepare_to_run() self.dataset.load_into_memory() def local_shuffle(self): -- GitLab