From e36bbcc87172743f0e6ec69bc50e697af7fe649d Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Thu, 7 Mar 2019 10:35:27 +0800 Subject: [PATCH] fix some typo and CMakefile.txt --- paddle/fluid/framework/data_set.cc | 36 +++++++++++++++--------------- paddle/fluid/framework/data_set.h | 28 +++++++++++------------ paddle/fluid/framework/executor.cc | 4 ++++ paddle/fluid/framework/executor.h | 8 ++----- paddle/fluid/pybind/CMakeLists.txt | 6 +---- paddle/fluid/pybind/data_set_py.cc | 29 +++++++++++------------- 6 files changed, 51 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index ae342148778..047b172df45 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -18,15 +18,14 @@ namespace paddle { namespace framework { -Dataset::Dataset() { - thread_num_ = 1; -} +Dataset::Dataset() { thread_num_ = 1; } void Dataset::SetFileList(const std::vector& filelist) { filelist_ = filelist; int file_cnt = filelist_.size(); if (thread_num_ > file_cnt) { - VLOG(1) << "DataSet thread num = " << thread_num_ << ", file num = " << file_cnt + VLOG(1) << "DataSet thread num = " << thread_num_ + << ", file num = " << file_cnt << ". Changing DataSet thread num = " << file_cnt; thread_num_ = file_cnt; } @@ -35,22 +34,23 @@ void Dataset::SetFileList(const std::vector& filelist) { void Dataset::SetThreadNum(int thread_num) { int file_cnt = filelist_.size(); if (file_cnt != 0 && thread_num > file_cnt) { - VLOG(1) << "DataSet thread num = " << thread_num << ", file num = " << file_cnt + VLOG(1) << "DataSet thread num = " << thread_num + << ", file num = " << file_cnt << ". Changing DataSet thread num = " << file_cnt; thread_num = file_cnt; } thread_num_ = thread_num; } -void Dataset::SetTrainerNum(int trainer_num) { - trainer_num_ = trainer_num; -} +void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } -void Dataset::SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc) { +void Dataset::SetDataFeedDesc( + const paddle::framework::DataFeedDesc& data_feed_desc) { data_feed_desc_ = data_feed_desc; } -std::vector> Dataset::GetReaders() { +std::vector> +Dataset::GetReaders() { return readers_; } @@ -60,8 +60,8 @@ void Dataset::LoadIntoMemory() { } std::vector load_threads; for (int64_t i = 0; i < thread_num_; ++i) { - load_threads.push_back(std::thread(&paddle::framework::DataFeed::LoadIntoMemory, - readers_[i].get())); + load_threads.push_back(std::thread( + &paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get())); } for (std::thread& t : load_threads) { t.join(); @@ -74,8 +74,8 @@ void Dataset::LocalShuffle() { } std::vector local_shuffle_threads; for (int64_t i = 0; i < thread_num_; ++i) { - local_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::LocalShuffle, - readers_[i].get())); + local_shuffle_threads.push_back(std::thread( + &paddle::framework::DataFeed::LocalShuffle, readers_[i].get())); } for (std::thread& t : local_shuffle_threads) { t.join(); @@ -115,14 +115,14 @@ void Dataset::CreateReaders() { readers_[0]->SetFileList(filelist_); } -int Dataset::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) { +int Dataset::ReceiveFromClient(int msg_type, int client_id, + const std::string& msg) { // can also use hash // int64_t index = paddle::ps::local_random_engine()() % thread_num_; - // todo int64_t index = 0; readers_[index]->PutInsToChannel(msg); return 0; } -} -} +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index f6f53f1b204..91998e98ad7 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -34,29 +34,27 @@ class Dataset { virtual void SetFileList(const std::vector& filelist); virtual void SetThreadNum(int thread_num); virtual void SetTrainerNum(int trainer_num); - virtual void SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc); + virtual void SetDataFeedDesc( + const paddle::framework::DataFeedDesc& data_feed_desc); - virtual const std::vector& GetFileList() { - return filelist_; - } - virtual int GetThreadNum() { - return thread_num_; - } - virtual int GetTrainerNum() { - return trainer_num_; - } + virtual const std::vector& GetFileList() { return filelist_; } + virtual int GetThreadNum() { return thread_num_; } + virtual int GetTrainerNum() { return trainer_num_; } virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() { return data_feed_desc_; } - virtual std::vector> GetReaders(); + virtual std::vector> + GetReaders(); virtual void LoadIntoMemory(); virtual void LocalShuffle(); // todo global shuffle - virtual void GlobalShuffle(); + virtual void GlobalShuffle(); virtual void CreateReaders(); + protected: - virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg); + virtual int ReceiveFromClient(int msg_type, int client_id, + const std::string& msg); std::vector> readers_; int thread_num_; std::string fs_name_; @@ -66,5 +64,5 @@ class Dataset { int trainer_num_; }; -} -} +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0d4334f193d..97fd6ee15d4 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -115,6 +115,10 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, } } +void Executor::RunFromDataset(const ProgramDesc& pdesc, const Dataset& dataset, + const std::string& trainer_desc_str, + const bool debug) {} + void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars, const std::vector& skip_ref_cnt_vars, diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 1b25b993844..8685ad8028a 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -19,13 +19,13 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/framework/data_set.h" namespace paddle { namespace framework { @@ -112,11 +112,7 @@ class Executor { void EnableMKLDNN(const ProgramDesc& program); - void RunFromTrainerDesc(const ProgramDesc& main_program, - const std::string& trainer_desc_str, - const bool debug); - - void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset, + void RunFromDataset(const ProgramDesc& main_program, const Dataset& dataset, const std::string& trainer_desc_str, const bool debug); public: diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 8207f2b72cf..8b82f3aad4c 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -5,11 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wr if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) endif() -<<<<<<< HEAD -set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc) -======= -set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc fleet_wrapper_py.cc imperative.cc ir.cc inference_api.cc) ->>>>>>> add pybind for fleet +set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc fleet_wrapper_py.cc data_set_py.cc imperative.cc ir.cc inference_api.cc) if(WITH_PYTHON) if(WITH_AMD_GPU) diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 029cabbc701..45b90ee6c20 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include - -// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3) #ifdef _POSIX_C_SOURCE #undef _POSIX_C_SOURCE #endif @@ -29,12 +27,12 @@ limitations under the License. */ #include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/variant.h" -#include "paddle/fluid/pybind/async_executor_py.h" -#include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/pybind/data_set_py.h" namespace py = pybind11; namespace pd = paddle::framework; @@ -43,18 +41,17 @@ namespace paddle { namespace pybind { void BindDataset(py::module* m) { - py::class_(*m, "Dataset") - .def(py::init([]() { - return std::unique_ptr( - new framework::Dataset()); - })) - .def("set_filelist", &framework::Dataset::SetFileList) - .def("set_thread_num", &framework::Dataset::SetThreadNum) - .def("set_trainer_num", &framework::Dataset::SetTrainerNum) - .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) - .def("load_into_memory", &framework::Dataset::LoadIntoMemory) - .def("local_shuffle", &framework::Dataset::LocalShuffle) - .def("global_shuffle", &framework::Dataset::GLobalShuffle) + py::class_(*m, "Dataset") + .def(py::init([]() { + return std::unique_ptr(new framework::Dataset()); + })) + .def("set_filelist", &framework::Dataset::SetFileList) + .def("set_thread_num", &framework::Dataset::SetThreadNum) + .def("set_trainer_num", &framework::Dataset::SetTrainerNum) + .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) + .def("load_into_memory", &framework::Dataset::LoadIntoMemory) + .def("local_shuffle", &framework::Dataset::LocalShuffle) + .def("global_shuffle", &framework::Dataset::GlobalShuffle); } } // end namespace pybind -- GitLab