diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index d4a9ca5fbfeb59d97dc396e1857efcda2971c0ef..7a546b7b0ce274635f42cccf43187cdee0b68f81 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -190,6 +190,30 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph build_strategy fast_threaded_ssa_graph_executor variable_helper) +<<<<<<< HEAD +======= +if(WITH_PSLIB) + cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc + executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc + trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc + downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc + data_set.cc + DEPS op_registry device_context scope framework_proto + trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer + feed_fetch_method graph_to_program_pass async_executor_proto + variable_helper pslib_brpc pslib timer) +else() + cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc + executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc + trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc + downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc + data_set.cc + DEPS op_registry device_context scope framework_proto + trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer + feed_fetch_method graph_to_program_pass async_executor_proto + variable_helper timer) +endif(WITH_PSLIB) +>>>>>>> 870b88bbd7... add DataSet and InMemoryDataFeed, support load data into memory and shuffle data cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index e93683cb7f577789078dd051d263079b5f373f17..7f1993dbc3e993681b2ab7449c79a52dcd8ddb37 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -242,6 +242,109 @@ void InMemoryDataFeed::GlobalShuffle(int trainer_num) { template class InMemoryDataFeed>; +template +InMemoryDataFeed::InMemoryDataFeed() { + cur_channel_ = 0; + shuffled_ins_ = nullptr; + shuffled_ins_out_ = nullptr; +} + +template +bool InMemoryDataFeed::Start() { + DataFeed::CheckSetFileList(); + if (memory_data_.size() != 0) { + CHECK_EQ(cur_channel_, 0); + shuffled_ins_->Extend(std::move(memory_data_)); + std::vector().swap(memory_data_); + } + DataFeed::finish_start_ = true; + return true; +} + +template +int InMemoryDataFeed::Next() { + DataFeed::CheckStart(); + std::shared_ptr> in_channel = nullptr; + std::shared_ptr> out_channel = nullptr; + if (cur_channel_ == 0) { + in_channel = shuffled_ins_; + out_channel = shuffled_ins_out_; + } else { + in_channel = shuffled_ins_out_; + out_channel = shuffled_ins_; + } + CHECK(in_channel != nullptr); + CHECK(out_channel != nullptr); + int index = 0; + T instance; + T ins_vec; + while (index < DataFeed::default_batch_size_) { + if (in_channel->Size() == 0) { + break; + } + in_channel->Pop(instance); + AddInstanceToInsVec(&ins_vec, instance, index++); + out_channel->Push(std::move(instance)); + } + DataFeed::batch_size_ = index; + if (DataFeed::batch_size_ != 0) { + PutToFeedVec(ins_vec); + } else { + cur_channel_ = 1 - cur_channel_; + } + return DataFeed::batch_size_; +} + +template +void InMemoryDataFeed::PutInsToChannel(const std::string& ins_str) { + T ins; + DeserializeIns(ins, ins_str); + shuffled_ins_->Push(std::move(ins)); +} + +template +void InMemoryDataFeed::LoadIntoMemory() { + std::vector local_vec; + std::string filename; + while (DataFeed::PickOneFile(&filename)) { + int err_no = 0; + PrivateQueueDataFeed::fp_ = + fs_open_read(filename, &err_no, PrivateQueueDataFeed::pipe_command_); + __fsetlocking(&*PrivateQueueDataFeed::fp_, FSETLOCKING_BYCALLER); + T instance; + while (ParseOneInstanceFromPipe(&instance)) { + local_vec.push_back(instance); + } + memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end()); + std::vector().swap(local_vec); + } +} + +template +void InMemoryDataFeed::LocalShuffle() { + std::random_shuffle(memory_data_.begin(), memory_data_.end()); +} + +// todo global shuffle +/* +template +void InMemoryDataFeed::GlobalShuffle(int trainer_num) { + std::random_shuffle(memory_data_.begin(), memory_data_.end()); + for (int64_t i = 0; i < memory_data_.size(); ++i) { + // todo get ins id + //std::string ins_id = memory_data_[i].ins_id; + // todo hash + int64_t hash_id = paddle::ps::local_random_engine()(); + //int64_t hash_id = hash(ins_id); + int64_t node_id = hash_id % trainer_num_; + std::string str; + SerializeIns(memory_data_[i], str); + auto fleet_ptr = FleetWrapper::GetInstance(); + auto ret = fleet_ptr->send_client2client_msg(0, node_id, str); + } +} +*/ + void MultiSlotDataFeed::Init( const paddle::framework::DataFeedDesc& data_feed_desc) { finish_init_ = false; diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 45b90ee6c2030ffc3f5bb66cc806a6fa4976956c..8a0af0654249a5485f3d9aa1e8bb046e33570758 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include + +// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3) #ifdef _POSIX_C_SOURCE #undef _POSIX_C_SOURCE #endif @@ -41,7 +43,7 @@ namespace paddle { namespace pybind { void BindDataset(py::module* m) { - py::class_(*m, "Dataset") + py::class_(*m, "Dataset") .def(py::init([]() { return std::unique_ptr(new framework::Dataset()); })) @@ -51,7 +53,7 @@ void BindDataset(py::module* m) { .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("local_shuffle", &framework::Dataset::LocalShuffle) - .def("global_shuffle", &framework::Dataset::GlobalShuffle); + .def("global_shuffle", &framework::Dataset::GLobalShuffle) } } // end namespace pybind