From be74de2c61e96d3df6d93ace8a5a096553de5cd3 Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Sun, 24 Mar 2019 01:43:54 +0800 Subject: [PATCH] fix code style & fix register bug & add release_memory test=develop --- paddle/fluid/framework/blocking_queue.h | 4 ++-- paddle/fluid/framework/data_feed.cc | 19 +++++++++-------- paddle/fluid/framework/data_feed.h | 3 ++- paddle/fluid/framework/data_set.cc | 26 ++++++++++++++++++------ paddle/fluid/framework/data_set.h | 25 ++++++++++++++++++++++- paddle/fluid/pybind/async_executor_py.cc | 2 +- paddle/fluid/pybind/data_set_py.cc | 3 +++ python/paddle/fluid/dataset.py | 3 +++ 8 files changed, 66 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/blocking_queue.h b/paddle/fluid/framework/blocking_queue.h index e1b49986a50..cc5b4e8c4b8 100644 --- a/paddle/fluid/framework/blocking_queue.h +++ b/paddle/fluid/framework/blocking_queue.h @@ -83,10 +83,10 @@ class BlockingQueue { return rc; } - void Pop(T &t) { + void Pop(T *t) { std::unique_lock lock(mutex_); cv_.wait(lock, [=] { return !q_.empty(); }); - t = std::move(q_.front()); + *t = std::move(q_.front()); q_.pop_front(); } diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 62e391a3d27..4f8fa005d7b 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -48,7 +48,7 @@ bool DataFeed::SetFileList(const std::vector& files) { return false; } */ - //PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); + // PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); filelist_.assign(files.begin(), files.end()); finish_set_filelist_ = true; @@ -190,7 +190,8 @@ int InMemoryDataFeed::Next() { if (in_channel->Size() == 0) { break; } - in_channel->Pop(instance); + in_channel->Pop(&instance); + AddInstanceToInsVec(&ins_vec, instance, index++); out_channel->Push(std::move(instance)); } @@ -268,17 +269,19 @@ void InMemoryDataFeed::FillChannelToMemoryData() { } CHECK(channel != nullptr); CHECK(pre_channel != nullptr); - CHECK(pre_channel->Size() == 0); + CHECK_EQ(pre_channel->Size(), 0); local_vec.resize(channel->Size()); for (int64_t i = 0; i < local_vec.size(); ++i) { - channel->Pop(local_vec[i]); + channel->Pop(&local_vec[i]); } - VLOG(3) << "local_vec size=" << local_vec.size() <<", thread_id=" << thread_id_; + VLOG(3) << "local_vec size=" << local_vec.size() + <<", thread_id=" << thread_id_; { std::lock_guard g(*mutex_for_update_memory_data_); VLOG(3) << "before insert, memory_data_ size=" << memory_data_->size() << ", thread_id=" << thread_id_; - memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end()); + memory_data_->insert(memory_data_->end(), local_vec.begin(), + local_vec.end()); VLOG(3) << "after insert memory_data_ size=" << memory_data_->size() << ", thread_id=" << thread_id_; } @@ -574,7 +577,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( const char* str = reader.get(); std::string line = std::string(str); - //VLOG(3) << line; + // VLOG(3) << line; char* endptr = const_cast(str); int pos = 0; for (size_t i = 0; i < use_slots_index_.size(); ++i) { @@ -750,7 +753,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe( const char* str = reader.get(); std::string line = std::string(str); - //VLOG(3) << line; + // VLOG(3) << line; char* endptr = const_cast(str); int pos = 0; for (size_t i = 0; i < use_slots_index_.size(); ++i) { diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index cab0b431b5d..1c6c44242db 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -21,7 +21,8 @@ limitations under the License. */ #include // NOLINT #include #include -#include +#include // NOLINT +#include #include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index c2e8bff348a..62001c24df8 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -82,6 +82,18 @@ DatasetImpl::GetReaders() { return readers_; } +// if sent message between workers, should first call this function +template +void DatasetImpl::RegisterClientToClientMsgHandler() { + auto fleet_ptr = FleetWrapper::GetInstance(); + VLOG(3) << "RegisterClientToClientMsgHandler"; + fleet_ptr->RegisterClientToClientMsgHandler( + 0, [this](int msg_type, int client_id, const std::string& msg) -> int { + return this->ReceiveFromClient(msg_type, client_id, msg); + }); + VLOG(3) << "RegisterClientToClientMsgHandler done"; +} + // load data into memory, Dataset hold this memory, // which will later be fed into readers' channel template @@ -106,6 +118,14 @@ void DatasetImpl::LoadIntoMemory() { << ", cost time=" << timeline.ElapsedSec() << " seconds"; } +// release memory data +template +void DatasetImpl::ReleaseMemory() { + VLOG(3) << "DatasetImpl::ReleaseMemory() begin"; + std::vector().swap(memory_data_); + VLOG(3) << "DatasetImpl::ReleaseMemory() end"; +} + // do local shuffle template void DatasetImpl::LocalShuffle() { @@ -137,12 +157,6 @@ void DatasetImpl::GlobalShuffle() { VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; platform::Timer timeline; timeline.Start(); - auto fleet_ptr = FleetWrapper::GetInstance(); - VLOG(3) << "RegisterClientToClientMsgHandler"; - fleet_ptr->RegisterClientToClientMsgHandler( - 0, [this](int msg_type, int client_id, const std::string& msg) -> int { - return this->ReceiveFromClient(msg_type, client_id, msg); - }); if (readers_.size() == 0) { CreateReaders(); } diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index a13d0f869d4..4bbcc6d06a9 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -40,22 +40,43 @@ class Dataset { public: Dataset() {} virtual ~Dataset() {} + // set file list virtual void SetFileList(const std::vector& filelist) = 0; + // set readers' num virtual void SetThreadNum(int thread_num) = 0; + // set workers' num virtual void SetTrainerNum(int trainer_num) = 0; + // set fs name and ugi virtual void SetHdfsConfig(const std::string& fs_name, const std::string& fs_ugi) = 0; + // set data fedd desc, which contains: + // data feed name, batch size, slots virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; + // get file list virtual const std::vector& GetFileList() = 0; + // get thread num virtual int GetThreadNum() = 0; + // get worker num virtual int GetTrainerNum() = 0; + // get data fedd desc virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0; + // get readers, the reader num depend both on thread num + // and filelist size virtual std::vector>& GetReaders() = 0; + // register message handler between workers + virtual void RegisterClientToClientMsgHandler() = 0; + // load all data into memory virtual void LoadIntoMemory() = 0; + // release all memory data + virtual void ReleaseMemory() = 0; + // local shuffle data virtual void LocalShuffle() = 0; + // global shuffle data virtual void GlobalShuffle() = 0; + // create readers virtual void CreateReaders() = 0; + // destroy readers virtual void DestroyReaders() = 0; protected: @@ -84,10 +105,12 @@ class DatasetImpl : public Dataset { virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() { return data_feed_desc_; } - virtual std::vector>& GetReaders(); + + virtual void RegisterClientToClientMsgHandler(); virtual void LoadIntoMemory(); + virtual void ReleaseMemory(); virtual void LocalShuffle(); virtual void GlobalShuffle(); virtual void CreateReaders(); diff --git a/paddle/fluid/pybind/async_executor_py.cc b/paddle/fluid/pybind/async_executor_py.cc index 3bb6bff2363..b0951f0ccd1 100644 --- a/paddle/fluid/pybind/async_executor_py.cc +++ b/paddle/fluid/pybind/async_executor_py.cc @@ -23,6 +23,7 @@ limitations under the License. */ #endif #include #include +#include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/text_format.h" @@ -49,7 +50,6 @@ void BindAsyncExecutor(py::module* m) { new framework::AsyncExecutor(scope, place)); })) .def("run_from_files", &framework::AsyncExecutor::RunFromFile) - //.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset) .def("init_server", &framework::AsyncExecutor::InitServer) .def("init_worker", &framework::AsyncExecutor::InitWorker) .def("start_server", &framework::AsyncExecutor::StartServer) diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 2138ecab852..30d1d185cf1 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -52,7 +52,10 @@ void BindDataset(py::module* m) { .def("set_trainer_num", &framework::Dataset::SetTrainerNum) .def("set_hdfs_config", &framework::Dataset::SetHdfsConfig) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) + .def("register_client2client_msg_handler", + &framework::Dataset::RegisterClientToClientMsgHandler) .def("load_into_memory", &framework::Dataset::LoadIntoMemory) + .def("release_memory", &framework::Dataset::ReleaseMemory) .def("local_shuffle", &framework::Dataset::LocalShuffle) .def("global_shuffle", &framework::Dataset::GlobalShuffle); } diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 34a3e5d8ec1..cf487fdfe20 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -237,7 +237,10 @@ class InMemoryDataset(DatasetBase): if fleet is not None: fleet.fleet_instance.role_maker_.barrier_worker() trainer_num = fleet.worker_num() + self.dataset.register_client2client_msg_handler() self.dataset.set_trainer_num(trainer_num) + if fleet is not None: + fleet.fleet_instance.role_maker_.barrier_worker() self.dataset.global_shuffle() if fleet is not None: fleet.fleet_instance.role_maker_.barrier_worker() -- GitLab