diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 5d4d0ad4b775b8d2fde71d380b9f24019d08c524..040e36b796f8ddd69dd42c53b5d8d10b5afbfde4 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -199,6 +199,7 @@ if(WITH_PSLIB) executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc + data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass async_executor_proto @@ -208,6 +209,7 @@ else() executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc + data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass async_executor_proto diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 27c06f5aa14a9f54a2d81c8f4aeae25e054863ce..902f44291890b4a45a8e91a68b8fa3d0c8aecab0 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -154,5 +154,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, return; } +// todo RunFromDataset +void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program, + Dataset* data_set, + const std::string& trainer_desc_str, + const bool debug) { + +} + + } // einit_modelnd namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h index 17f5a6fc0af88791b5606c03338df87f84a7e461..e54a17333d345cf34a62aea3e369d3c68023e559 100644 --- a/paddle/fluid/framework/async_executor.h +++ b/paddle/fluid/framework/async_executor.h @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/data_set.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/blocking_queue.h b/paddle/fluid/framework/blocking_queue.h index a19558c0ae59005bee575e8c469c7f95d8780ab1..e1b49986a50c672403d6ffbb49e4836cd7a11302 100644 --- a/paddle/fluid/framework/blocking_queue.h +++ b/paddle/fluid/framework/blocking_queue.h @@ -33,6 +33,14 @@ class BlockingQueue { cv_.notify_one(); } + void Push(T &&item) { + { + std::lock_guard g(mutex_); + q_.emplace_back(std::move(item)); + } + cv_.notify_one(); + } + template void Extend(const U &items) { { @@ -44,6 +52,17 @@ class BlockingQueue { cv_.notify_all(); } + template + void Extend(U &&items) { + { + std::lock_guard g(mutex_); + for (auto &item : items) { + q_.emplace_back(std::move(item)); + } + } + cv_.notify_all(); + } + std::deque PopAll(size_t ms, bool *timeout) { auto time = std::chrono::system_clock::now() + std::chrono::milliseconds(ms); @@ -64,6 +83,18 @@ class BlockingQueue { return rc; } + void Pop(T &t) { + std::unique_lock lock(mutex_); + cv_.wait(lock, [=] { return !q_.empty(); }); + t = std::move(q_.front()); + q_.pop_front(); + } + + size_t Size() { + std::lock_guard lock(mutex_); + return q_.size(); + } + private: std::mutex mutex_; std::condition_variable cv_; diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 4cfd2b434b33510d24073d984e790c79c538160c..4a7793ec8113a27c9639d4283f7ef393b943265b 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -139,6 +139,109 @@ int PrivateQueueDataFeed::Next() { template class PrivateQueueDataFeed>; #endif +template +InMemoryDataFeed::InMemoryDataFeed() { + cur_channel_ = 0; + shuffled_ins_ = nullptr; + shuffled_ins_out_ = nullptr; +} + +template +bool InMemoryDataFeed::Start() { + DataFeed::CheckSetFileList(); + if (memory_data_.size() != 0) { + CHECK(cur_channel_ == 0); + shuffled_ins_->Extend(std::move(memory_data_)); + std::vector().swap(memory_data_); + } + DataFeed::finish_start_ = true; + return true; +} + +template +int InMemoryDataFeed::Next() { + DataFeed::CheckStart(); + std::shared_ptr> in_channel = nullptr; + std::shared_ptr> out_channel = nullptr; + if (cur_channel_ == 0) { + in_channel = shuffled_ins_; + out_channel = shuffled_ins_out_; + } else { + in_channel = shuffled_ins_out_; + out_channel = shuffled_ins_; + } + CHECK(in_channel != nullptr); + CHECK(out_channel != nullptr); + int index = 0; + T instance; + T ins_vec; + while (index < DataFeed::default_batch_size_) { + if (in_channel->Size() == 0) { + break; + } + in_channel->Pop(instance); + AddInstanceToInsVec(&ins_vec, instance, index++); + out_channel->Push(std::move(instance)); + } + DataFeed::batch_size_ = index; + if (DataFeed::batch_size_ != 0) { + PutToFeedVec(ins_vec); + } else { + cur_channel_ = 1 - cur_channel_; + } + return DataFeed::batch_size_; +} + +template +void InMemoryDataFeed::PutInsToChannel(const std::string& ins_str) { + T ins; + DeserializeIns(ins, ins_str); + shuffled_ins_->Push(std::move(ins)); +} + +template +void InMemoryDataFeed::LoadIntoMemory() { + std::vector local_vec; + std::string filename; + while (DataFeed::PickOneFile(&filename)) { + int err_no = 0; + PrivateQueueDataFeed::fp_ = fs_open_read(filename, &err_no, + PrivateQueueDataFeed::pipe_command_); + __fsetlocking(&*PrivateQueueDataFeed::fp_, FSETLOCKING_BYCALLER); + T instance; + while(ParseOneInstanceFromPipe(&instance)) { + local_vec.push_back(instance); + } + memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end()); + std::vector().swap(local_vec); + } +} + +template +void InMemoryDataFeed::LocalShuffle() { + std::random_shuffle(memory_data_.begin(), memory_data_.end()); +} + +// todo global shuffle +/* +template +void InMemoryDataFeed::GlobalShuffle(int trainer_num) { + std::random_shuffle(memory_data_.begin(), memory_data_.end()); + for (int64_t i = 0; i < memory_data_.size(); ++i) { + // todo get ins id + //std::string ins_id = memory_data_[i].ins_id; + // todo hash + int64_t hash_id = paddle::ps::local_random_engine()(); + //int64_t hash_id = hash(ins_id); + int64_t node_id = hash_id % trainer_num_; + std::string str; + SerializeIns(memory_data_[i], str); + auto fleet_ptr = FleetWrapper::GetInstance(); + auto ret = fleet_ptr->send_client2client_msg(0, node_id, str); + } +} +*/ + void MultiSlotDataFeed::Init( const paddle::framework::DataFeedDesc& data_feed_desc) { finish_init_ = false; @@ -445,5 +548,190 @@ void MultiSlotDataFeed::PutToFeedVec( } } +void MultiSlotInMemoryDataFeed::Init( + const paddle::framework::DataFeedDesc& data_feed_desc) { + finish_init_ = false; + finish_set_filelist_ = false; + finish_start_ = false; + + PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), + "Multi_slot_desc has not been set."); + paddle::framework::MultiSlotDesc multi_slot_desc = + data_feed_desc.multi_slot_desc(); + SetBatchSize(data_feed_desc.batch_size()); + SetQueueSize(data_feed_desc.batch_size()); + size_t all_slot_num = multi_slot_desc.slots_size(); + all_slots_.resize(all_slot_num); + all_slots_type_.resize(all_slot_num); + use_slots_index_.resize(all_slot_num); + use_slots_.clear(); + use_slots_is_dense_.clear(); + for (size_t i = 0; i < all_slot_num; ++i) { + const auto& slot = multi_slot_desc.slots(i); + all_slots_[i] = slot.name(); + all_slots_type_[i] = slot.type(); + use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1; + if (slot.is_used()) { + use_slots_.push_back(all_slots_[i]); + use_slots_is_dense_.push_back(slot.is_dense()); + } + } + feed_vec_.resize(use_slots_.size()); + pipe_command_ = data_feed_desc.pipe_command(); + finish_init_ = true; +} + +bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe( + std::vector* instance) { + thread_local string::LineFileReader reader; + + if (!reader.getline(&*(fp_.get()))) { + return false; + } else { + int use_slots_num = use_slots_.size(); + instance->resize(use_slots_num); + + const char* str = reader.get(); + std::string line = std::string(str); + VLOG(3) << line; + char* endptr = const_cast(str); + int pos = 0; + for (size_t i = 0; i < use_slots_index_.size(); ++i) { + int idx = use_slots_index_[i]; + int num = strtol(&str[pos], &endptr, 10); + PADDLE_ENFORCE( + num, + "The number of ids can not be zero, you need padding " + "it in data generator; or if there is something wrong with " + "the data, please check if the data contains unresolvable " + "characters.\nplease check this error line: %s", + str); + if (idx != -1) { + (*instance)[idx].Init(all_slots_type_[i]); + if ((*instance)[idx].GetType()[0] == 'f') { // float + for (int j = 0; j < num; ++j) { + float feasign = strtof(endptr, &endptr); + (*instance)[idx].AddValue(feasign); + } + } else if ((*instance)[idx].GetType()[0] == 'u') { // uint64 + for (int j = 0; j < num; ++j) { + uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10); + (*instance)[idx].AddValue(feasign); + } + } + pos = endptr - str; + } else { + for (int j = 0; j <= num; ++j) { + // pos = line.find_first_of(' ', pos + 1); + while (line[pos + 1] != ' ') { + pos++; + } + } + } + } + return true; + } +} + +bool MultiSlotInMemoryDataFeed::ParseOneInstance(std::vector* instance) { + std::string line; + if (getline(file_, line)) { + int use_slots_num = use_slots_.size(); + instance->resize(use_slots_num); + // parse line + const char* str = line.c_str(); + char* endptr = const_cast(str); + int pos = 0; + for (size_t i = 0; i < use_slots_index_.size(); ++i) { + int idx = use_slots_index_[i]; + int num = strtol(&str[pos], &endptr, 10); + PADDLE_ENFORCE( + num, + "The number of ids can not be zero, you need padding " + "it in data generator; or if there is something wrong with " + "the data, please check if the data contains unresolvable " + "characters.\nplease check this error line: %s", + str); + + if (idx != -1) { + (*instance)[idx].Init(all_slots_type_[i]); + if ((*instance)[idx].GetType()[0] == 'f') { // float + for (int j = 0; j < num; ++j) { + float feasign = strtof(endptr, &endptr); + (*instance)[idx].AddValue(feasign); + } + } else if ((*instance)[idx].GetType()[0] == 'u') { // uint64 + for (int j = 0; j < num; ++j) { + uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10); + (*instance)[idx].AddValue(feasign); + } + } + pos = endptr - str; + } else { + for (int j = 0; j <= num; ++j) { + pos = line.find_first_of(' ', pos + 1); + } + } + } + } else { + return false; + } + return true; +} + +void MultiSlotInMemoryDataFeed::AddInstanceToInsVec( + std::vector* ins_vec, + const std::vector& instance, int index) { + if (index == 0) { + ins_vec->resize(instance.size()); + for (size_t i = 0; i < instance.size(); ++i) { + (*ins_vec)[i].Init(instance[i].GetType()); + (*ins_vec)[i].InitOffset(); + } + } + + for (size_t i = 0; i < instance.size(); ++i) { + (*ins_vec)[i].AddIns(instance[i]); + } +} + +void MultiSlotInMemoryDataFeed::PutToFeedVec( + const std::vector& ins_vec) { + for (size_t i = 0; i < use_slots_.size(); ++i) { + const auto& type = ins_vec[i].GetType(); + const auto& offset = ins_vec[i].GetOffset(); + int total_instance = static_cast(offset.back()); + + if (type[0] == 'f') { // float + const auto& feasign = ins_vec[i].GetFloatData(); + float* tensor_ptr = feed_vec_[i]->mutable_data( + {total_instance, 1}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); + } else if (type[0] == 'u') { // uint64 + // no uint64_t type in paddlepaddle + const auto& feasign = ins_vec[i].GetUint64Data(); + int64_t* tensor_ptr = feed_vec_[i]->mutable_data( + {total_instance, 1}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t)); + } + + LoD data_lod{offset}; + feed_vec_[i]->set_lod(data_lod); + if (use_slots_is_dense_[i]) { + int dim = total_instance / batch_size_; + feed_vec_[i]->Resize({batch_size_, dim}); + } + } +} + +// todo serialize ins in global shuffle +void MultiSlotInMemoryDataFeed::SerializeIns(const std::vector& ins, std::string& str) { + +} +// todo deserialize ins in global shuffle +void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector& ins, const std::string& str) { + +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 91793ab3997ae68bd2dad7ee924548a7403680fa..0e1ac79664fc71012728e6305e09a2762a39fbb7 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -27,6 +27,8 @@ limitations under the License. */ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/framework/blocking_queue.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" namespace paddle { namespace framework { @@ -76,6 +78,19 @@ class DataFeed { // This function is used for binding feed_vec memory virtual void AddFeedVar(Variable* var, const std::string& name); + virtual void LoadIntoMemory() { + PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); + } + virtual void LocalShuffle() { + PADDLE_THROW("This function(LocalShuffle) is not implemented."); + } + virtual void GlobalShuffle(int trainer_num) { + PADDLE_THROW("This function(GlobalShuffle) is not implemented."); + } + virtual void PutInsToChannel(const std::string& ins_str) { + PADDLE_THROW("This function(PutToChannel) is not implemented."); + } + protected: // The following three functions are used to check if it is executed in this // order: @@ -161,6 +176,35 @@ class PrivateQueueDataFeed : public DataFeed { std::unique_ptr> queue_; }; +template +class InMemoryDataFeed : public PrivateQueueDataFeed { + public: + InMemoryDataFeed(); + virtual ~InMemoryDataFeed() {} + virtual bool Start(); + virtual int Next(); + virtual void PutInsToChannel(const std::string& ins_str); + virtual void LoadIntoMemory(); + virtual void LocalShuffle(); + // todo global shuffle + //virtual void GlobalShuffle(int trainer_num); + protected: + virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0; + virtual bool ParseOneInstance(T* instance) = 0; + virtual bool ParseOneInstanceFromPipe(T* instance) = 0; + virtual void PutToFeedVec(const T& ins_vec) = 0; + virtual void SerializeIns(const T& ins, std::string& str) = 0; + virtual void DeserializeIns(T& ins, const std::string& str) = 0; + + std::vector memory_data_; + // when read ins, we put ins from one channel to the other, + // and when finish reading, we set cur_channel = 1 - cur_channel, + // so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_ + int cur_channel_; + std::shared_ptr> shuffled_ins_; + std::shared_ptr> shuffled_ins_out_; +}; + // This class define the data type of instance(ins_vec) in MultiSlotDataFeed class MultiSlotType { public: @@ -245,5 +289,23 @@ class MultiSlotDataFeed virtual bool ParseOneInstanceFromPipe(std::vector* instance); virtual void PutToFeedVec(const std::vector& ins_vec); }; + +class MultiSlotInMemoryDataFeed + : public InMemoryDataFeed> { + public: + MultiSlotInMemoryDataFeed() {} + virtual ~MultiSlotInMemoryDataFeed() {} + virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc); + protected: + virtual void AddInstanceToInsVec(std::vector* vec_ins, + const std::vector& instance, + int index); + virtual bool ParseOneInstance(std::vector* instance); + virtual bool ParseOneInstanceFromPipe(std::vector* instance); + virtual void PutToFeedVec(const std::vector& ins_vec); + virtual void SerializeIns(const std::vector& ins, std::string& str); + virtual void DeserializeIns(std::vector& ins, const std::string& str); +}; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_feed_factory.cc b/paddle/fluid/framework/data_feed_factory.cc index 72148b9f7d343e19d60bb2be44d8270ad78d1412..2938655af57c302f1a90ea4c2f533230b1346c66 100644 --- a/paddle/fluid/framework/data_feed_factory.cc +++ b/paddle/fluid/framework/data_feed_factory.cc @@ -60,5 +60,6 @@ std::shared_ptr DataFeedFactory::CreateDataFeed( } REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); +REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae34214877831910cb81b47caabfb38267156264 --- /dev/null +++ b/paddle/fluid/framework/data_set.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/framework/data_feed_factory.h" + +namespace paddle { +namespace framework { + +Dataset::Dataset() { + thread_num_ = 1; +} + +void Dataset::SetFileList(const std::vector& filelist) { + filelist_ = filelist; + int file_cnt = filelist_.size(); + if (thread_num_ > file_cnt) { + VLOG(1) << "DataSet thread num = " << thread_num_ << ", file num = " << file_cnt + << ". Changing DataSet thread num = " << file_cnt; + thread_num_ = file_cnt; + } +} + +void Dataset::SetThreadNum(int thread_num) { + int file_cnt = filelist_.size(); + if (file_cnt != 0 && thread_num > file_cnt) { + VLOG(1) << "DataSet thread num = " << thread_num << ", file num = " << file_cnt + << ". Changing DataSet thread num = " << file_cnt; + thread_num = file_cnt; + } + thread_num_ = thread_num; +} + +void Dataset::SetTrainerNum(int trainer_num) { + trainer_num_ = trainer_num; +} + +void Dataset::SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc) { + data_feed_desc_ = data_feed_desc; +} + +std::vector> Dataset::GetReaders() { + return readers_; +} + +void Dataset::LoadIntoMemory() { + if (readers_.size() == 0) { + CreateReaders(); + } + std::vector load_threads; + for (int64_t i = 0; i < thread_num_; ++i) { + load_threads.push_back(std::thread(&paddle::framework::DataFeed::LoadIntoMemory, + readers_[i].get())); + } + for (std::thread& t : load_threads) { + t.join(); + } +} + +void Dataset::LocalShuffle() { + if (readers_.size() == 0) { + CreateReaders(); + } + std::vector local_shuffle_threads; + for (int64_t i = 0; i < thread_num_; ++i) { + local_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::LocalShuffle, + readers_[i].get())); + } + for (std::thread& t : local_shuffle_threads) { + t.join(); + } +} + +// todo global shuffle +void Dataset::GlobalShuffle() { + /* + auto fleet_ptr = FleetWrapper::GetInstance(); + fleet_ptr->registe_client2client_msg_handler(0, + [this](int msg_type, int client_id, const std::string& msg) -> int { + return this->ReceiveFromClient(msg_type, client_id, msg); + }); + if (readers_.size() == 0) { + CreateReaders(); + } + std::vector global_shuffle_threads; + for (int64_t i = 0; i < thread_num_; ++i) { + global_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::GlobalShuffle, + readers_[i].get(), trainer_num_)); + } + for (std::thread& t : global_shuffle_threads) { + t.join(); + }*/ +} + +void Dataset::CreateReaders() { + CHECK(thread_num_ > 0) << "thread_num should > 0"; + if (readers_.size() != 0) { + return; + } + for (int64_t i = 0; i < thread_num_; ++i) { + readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); + readers_.back()->Init(data_feed_desc_); + } + readers_[0]->SetFileList(filelist_); +} + +int Dataset::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) { + // can also use hash + // int64_t index = paddle::ps::local_random_engine()() % thread_num_; + // todo + int64_t index = 0; + readers_[index]->PutInsToChannel(msg); + return 0; +} + +} +} diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h new file mode 100644 index 0000000000000000000000000000000000000000..f6f53f1b204fc9fba73f7fa6a7452c6f8267aee7 --- /dev/null +++ b/paddle/fluid/framework/data_set.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_feed.h" + +namespace paddle { +namespace framework { + +class Dataset { + public: + Dataset(); + virtual ~Dataset() {} + + virtual void SetFileList(const std::vector& filelist); + virtual void SetThreadNum(int thread_num); + virtual void SetTrainerNum(int trainer_num); + virtual void SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc); + + virtual const std::vector& GetFileList() { + return filelist_; + } + virtual int GetThreadNum() { + return thread_num_; + } + virtual int GetTrainerNum() { + return trainer_num_; + } + virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() { + return data_feed_desc_; + } + + virtual std::vector> GetReaders(); + virtual void LoadIntoMemory(); + virtual void LocalShuffle(); + // todo global shuffle + virtual void GlobalShuffle(); + virtual void CreateReaders(); + protected: + virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg); + std::vector> readers_; + int thread_num_; + std::string fs_name_; + std::string fs_ugi_; + paddle::framework::DataFeedDesc data_feed_desc_; + std::vector filelist_; + int trainer_num_; +}; + +} +} diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 45eb4ae0ea69cb8b5278b8283ec420098587c0d6..8b15a3d7a259d380139d796dd851aac2e9553ac9 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace framework { -void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) { +void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) { thread_num_ = trainer_desc.thread_num(); workers_.resize(thread_num_); readers_.resize(thread_num_); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 48aeb151d57aa27ac88419b7a83b4aafe1163c22..1b25b9938441ea4a3a0547e7682975c8e55443eb 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/framework/data_set.h" namespace paddle { namespace framework { @@ -115,7 +116,7 @@ class Executor { const std::string& trainer_desc_str, const bool debug); - void RunFromDataset(const ProgramDesc& main_program, const Dataset* dataset, + void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset, const std::string& trainer_desc_str, const bool debug); public: diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index e1602f6c8ce461b2f15c36b96c491c7e9b4e2d0d..65425459204678b3288fd174c33e83560cc3a57a 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/fluid/framework/trainer_desc.pb.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/reader/blocking_queue.h" +#include "paddle/fluid/framework/data_set.h" namespace paddle { namespace framework { @@ -40,7 +41,7 @@ class TrainerBase { // model memory are hosted in root_scope void SetScope(Scope* root_scope); void SetDebug(const bool debug) { debug_ = debug; } - virtual void Initialize(const TrainerDesc& trainer_desc) = 0; + virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) = 0; virtual void InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place) = 0; virtual void InitOtherEnv(const ProgramDesc& main_program) = 0; @@ -59,7 +60,7 @@ class MultiTrainer : public TrainerBase { public: MultiTrainer() {} virtual ~MultiTrainer() {} - virtual void Initialize(const TrainerDesc& trainer_desc); + virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); virtual void InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place); virtual void InitOtherEnv(const ProgramDesc& main_program) {} @@ -77,7 +78,7 @@ class DistMultiTrainer : public MultiTrainer { public: DistMultiTrainer() {} virtual ~DistMultiTrainer() {} - virtual void Initialize(const TrainerDesc& trainer_desc); + virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); virtual void InitOtherEnv(const ProgramDesc& main_program); virtual void Finalize(); diff --git a/paddle/fluid/pybind/async_executor_py.cc b/paddle/fluid/pybind/async_executor_py.cc index 222c128c66f37a259eb17527fe2586860f701275..6dc865e8ed0db2b9c50b7733e61aaad953e8d400 100644 --- a/paddle/fluid/pybind/async_executor_py.cc +++ b/paddle/fluid/pybind/async_executor_py.cc @@ -49,6 +49,7 @@ void BindAsyncExecutor(py::module* m) { new framework::AsyncExecutor(scope, place)); })) .def("run_from_files", &framework::AsyncExecutor::RunFromFile) + .def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset) .def("init_server", &framework::AsyncExecutor::InitServer) .def("init_worker", &framework::AsyncExecutor::InitWorker) .def("start_server", &framework::AsyncExecutor::StartServer) diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..029cabbc701fb98ab4ed9af3d2ac39ca736e1c39 --- /dev/null +++ b/paddle/fluid/pybind/data_set_py.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3) +#ifdef _POSIX_C_SOURCE +#undef _POSIX_C_SOURCE +#endif + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif +#include +#include + +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" +#include "paddle/fluid/framework/async_executor.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/variant.h" +#include "paddle/fluid/pybind/async_executor_py.h" +#include "paddle/fluid/framework/data_set.h" + +namespace py = pybind11; +namespace pd = paddle::framework; + +namespace paddle { +namespace pybind { + +void BindDataset(py::module* m) { + py::class_(*m, "Dataset") + .def(py::init([]() { + return std::unique_ptr( + new framework::Dataset()); + })) + .def("set_filelist", &framework::Dataset::SetFileList) + .def("set_thread_num", &framework::Dataset::SetThreadNum) + .def("set_trainer_num", &framework::Dataset::SetTrainerNum) + .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) + .def("load_into_memory", &framework::Dataset::LoadIntoMemory) + .def("local_shuffle", &framework::Dataset::LocalShuffle) + .def("global_shuffle", &framework::Dataset::GLobalShuffle) +} + +} // end namespace pybind +} // end namespace paddle diff --git a/paddle/fluid/pybind/data_set_py.h b/paddle/fluid/pybind/data_set_py.h new file mode 100644 index 0000000000000000000000000000000000000000..f60e862ce673119c7b8e8ae5981fc54e8c9bdb2e --- /dev/null +++ b/paddle/fluid/pybind/data_set_py.h @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindDataset(py::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index e1ef00681c741e84cad1624478285ddd0477d42d..46a8ad4d88797880b6f33cd3c6035507f5d7fc63 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -61,6 +61,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/string/to_string.h" +#include "paddle/fluid/pybind/data_set_py.h" #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 @@ -1359,6 +1360,7 @@ All parameter, weight, gradient are variables in Paddle. BindGraph(&m); BindNode(&m); BindInferenceApi(&m); + BindDataset(&m); } } // namespace pybind } // namespace paddle