diff --git a/paddle/fluid/framework/blocking_queue.h b/paddle/fluid/framework/blocking_queue.h index e1b49986a50c672403d6ffbb49e4836cd7a11302..cc5b4e8c4b8e114668f472ea2af9de96835720d0 100644 --- a/paddle/fluid/framework/blocking_queue.h +++ b/paddle/fluid/framework/blocking_queue.h @@ -83,10 +83,10 @@ class BlockingQueue { return rc; } - void Pop(T &t) { + void Pop(T *t) { std::unique_lock lock(mutex_); cv_.wait(lock, [=] { return !q_.empty(); }); - t = std::move(q_.front()); + *t = std::move(q_.front()); q_.pop_front(); } diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 62e391a3d278a85a1ab30f611f35f38b6a11c7d0..4f8fa005d7b75440b6964bde7cd7b4d9af66fae7 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -48,7 +48,7 @@ bool DataFeed::SetFileList(const std::vector& files) { return false; } */ - //PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); + // PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); filelist_.assign(files.begin(), files.end()); finish_set_filelist_ = true; @@ -190,7 +190,8 @@ int InMemoryDataFeed::Next() { if (in_channel->Size() == 0) { break; } - in_channel->Pop(instance); + in_channel->Pop(&instance); + AddInstanceToInsVec(&ins_vec, instance, index++); out_channel->Push(std::move(instance)); } @@ -268,17 +269,19 @@ void InMemoryDataFeed::FillChannelToMemoryData() { } CHECK(channel != nullptr); CHECK(pre_channel != nullptr); - CHECK(pre_channel->Size() == 0); + CHECK_EQ(pre_channel->Size(), 0); local_vec.resize(channel->Size()); for (int64_t i = 0; i < local_vec.size(); ++i) { - channel->Pop(local_vec[i]); + channel->Pop(&local_vec[i]); } - VLOG(3) << "local_vec size=" << local_vec.size() <<", thread_id=" << thread_id_; + VLOG(3) << "local_vec size=" << local_vec.size() + <<", thread_id=" << thread_id_; { std::lock_guard g(*mutex_for_update_memory_data_); VLOG(3) << "before insert, memory_data_ size=" << memory_data_->size() << ", thread_id=" << thread_id_; - memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end()); + memory_data_->insert(memory_data_->end(), local_vec.begin(), + local_vec.end()); VLOG(3) << "after insert memory_data_ size=" << memory_data_->size() << ", thread_id=" << thread_id_; } @@ -574,7 +577,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( const char* str = reader.get(); std::string line = std::string(str); - //VLOG(3) << line; + // VLOG(3) << line; char* endptr = const_cast(str); int pos = 0; for (size_t i = 0; i < use_slots_index_.size(); ++i) { @@ -750,7 +753,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe( const char* str = reader.get(); std::string line = std::string(str); - //VLOG(3) << line; + // VLOG(3) << line; char* endptr = const_cast(str); int pos = 0; for (size_t i = 0; i < use_slots_index_.size(); ++i) { diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index cab0b431b5dd247502b2982627bda7e45419eb16..1c6c44242dbed82b99f8560673699aaaddc08b81 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -21,7 +21,8 @@ limitations under the License. */ #include // NOLINT #include #include -#include +#include // NOLINT +#include #include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index c2e8bff348ade126b48cc8f54ba6c355360e665c..62001c24df8af0058a8ba4db38d97ed6d0464e85 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -82,6 +82,18 @@ DatasetImpl::GetReaders() { return readers_; } +// if sent message between workers, should first call this function +template +void DatasetImpl::RegisterClientToClientMsgHandler() { + auto fleet_ptr = FleetWrapper::GetInstance(); + VLOG(3) << "RegisterClientToClientMsgHandler"; + fleet_ptr->RegisterClientToClientMsgHandler( + 0, [this](int msg_type, int client_id, const std::string& msg) -> int { + return this->ReceiveFromClient(msg_type, client_id, msg); + }); + VLOG(3) << "RegisterClientToClientMsgHandler done"; +} + // load data into memory, Dataset hold this memory, // which will later be fed into readers' channel template @@ -106,6 +118,14 @@ void DatasetImpl::LoadIntoMemory() { << ", cost time=" << timeline.ElapsedSec() << " seconds"; } +// release memory data +template +void DatasetImpl::ReleaseMemory() { + VLOG(3) << "DatasetImpl::ReleaseMemory() begin"; + std::vector().swap(memory_data_); + VLOG(3) << "DatasetImpl::ReleaseMemory() end"; +} + // do local shuffle template void DatasetImpl::LocalShuffle() { @@ -137,12 +157,6 @@ void DatasetImpl::GlobalShuffle() { VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; platform::Timer timeline; timeline.Start(); - auto fleet_ptr = FleetWrapper::GetInstance(); - VLOG(3) << "RegisterClientToClientMsgHandler"; - fleet_ptr->RegisterClientToClientMsgHandler( - 0, [this](int msg_type, int client_id, const std::string& msg) -> int { - return this->ReceiveFromClient(msg_type, client_id, msg); - }); if (readers_.size() == 0) { CreateReaders(); } diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index a13d0f869d4fa9cf7cb97a8ecf3e406495a5c567..4bbcc6d06a999eef619136febf129bbf7b66c0ad 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -40,22 +40,43 @@ class Dataset { public: Dataset() {} virtual ~Dataset() {} + // set file list virtual void SetFileList(const std::vector& filelist) = 0; + // set readers' num virtual void SetThreadNum(int thread_num) = 0; + // set workers' num virtual void SetTrainerNum(int trainer_num) = 0; + // set fs name and ugi virtual void SetHdfsConfig(const std::string& fs_name, const std::string& fs_ugi) = 0; + // set data fedd desc, which contains: + // data feed name, batch size, slots virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; + // get file list virtual const std::vector& GetFileList() = 0; + // get thread num virtual int GetThreadNum() = 0; + // get worker num virtual int GetTrainerNum() = 0; + // get data fedd desc virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0; + // get readers, the reader num depend both on thread num + // and filelist size virtual std::vector>& GetReaders() = 0; + // register message handler between workers + virtual void RegisterClientToClientMsgHandler() = 0; + // load all data into memory virtual void LoadIntoMemory() = 0; + // release all memory data + virtual void ReleaseMemory() = 0; + // local shuffle data virtual void LocalShuffle() = 0; + // global shuffle data virtual void GlobalShuffle() = 0; + // create readers virtual void CreateReaders() = 0; + // destroy readers virtual void DestroyReaders() = 0; protected: @@ -84,10 +105,12 @@ class DatasetImpl : public Dataset { virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() { return data_feed_desc_; } - virtual std::vector>& GetReaders(); + + virtual void RegisterClientToClientMsgHandler(); virtual void LoadIntoMemory(); + virtual void ReleaseMemory(); virtual void LocalShuffle(); virtual void GlobalShuffle(); virtual void CreateReaders(); diff --git a/paddle/fluid/pybind/async_executor_py.cc b/paddle/fluid/pybind/async_executor_py.cc index 3bb6bff23638620ef282988d27ef059b87a6ae38..b0951f0ccd16394a8baf3c901440f566e9664ab0 100644 --- a/paddle/fluid/pybind/async_executor_py.cc +++ b/paddle/fluid/pybind/async_executor_py.cc @@ -23,6 +23,7 @@ limitations under the License. */ #endif #include #include +#include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/text_format.h" @@ -49,7 +50,6 @@ void BindAsyncExecutor(py::module* m) { new framework::AsyncExecutor(scope, place)); })) .def("run_from_files", &framework::AsyncExecutor::RunFromFile) - //.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset) .def("init_server", &framework::AsyncExecutor::InitServer) .def("init_worker", &framework::AsyncExecutor::InitWorker) .def("start_server", &framework::AsyncExecutor::StartServer) diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 2138ecab8526fc87912acc81d3d61f54bc12f93b..30d1d185cf1c4942c75ad48d5d40040bd06b1d1d 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -52,7 +52,10 @@ void BindDataset(py::module* m) { .def("set_trainer_num", &framework::Dataset::SetTrainerNum) .def("set_hdfs_config", &framework::Dataset::SetHdfsConfig) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) + .def("register_client2client_msg_handler", + &framework::Dataset::RegisterClientToClientMsgHandler) .def("load_into_memory", &framework::Dataset::LoadIntoMemory) + .def("release_memory", &framework::Dataset::ReleaseMemory) .def("local_shuffle", &framework::Dataset::LocalShuffle) .def("global_shuffle", &framework::Dataset::GlobalShuffle); } diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 34a3e5d8ec1a51958381d52205eb6a2f56ca80f7..cf487fdfe201f61c6c6abb80355d838b71cd71a6 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -237,7 +237,10 @@ class InMemoryDataset(DatasetBase): if fleet is not None: fleet.fleet_instance.role_maker_.barrier_worker() trainer_num = fleet.worker_num() + self.dataset.register_client2client_msg_handler() self.dataset.set_trainer_num(trainer_num) + if fleet is not None: + fleet.fleet_instance.role_maker_.barrier_worker() self.dataset.global_shuffle() if fleet is not None: fleet.fleet_instance.role_maker_.barrier_worker()