提交 2e9a836c 编写于 作者: X xjqbest 提交者: dongdaxiang

add DataSet and InMemoryDataFeed, support load data into memory and shuffle data

上级 8de4d31a
...@@ -190,6 +190,30 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS ...@@ -190,6 +190,30 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy graph build_strategy
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
<<<<<<< HEAD
=======
if(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer)
else()
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer)
endif(WITH_PSLIB)
>>>>>>> 870b88bbd7... add DataSet and InMemoryDataFeed, support load data into memory and shuffle data
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
......
...@@ -242,6 +242,109 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) { ...@@ -242,6 +242,109 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
template class InMemoryDataFeed<std::vector<MultiSlotType>>; template class InMemoryDataFeed<std::vector<MultiSlotType>>;
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0;
shuffled_ins_ = nullptr;
shuffled_ins_out_ = nullptr;
}
template <typename T>
bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList();
if (memory_data_.size() != 0) {
CHECK_EQ(cur_channel_, 0);
shuffled_ins_->Extend(std::move(memory_data_));
std::vector<T>().swap(memory_data_);
}
DataFeed::finish_start_ = true;
return true;
}
template <typename T>
int InMemoryDataFeed<T>::Next() {
DataFeed::CheckStart();
std::shared_ptr<paddle::framework::BlockingQueue<T>> in_channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> out_channel = nullptr;
if (cur_channel_ == 0) {
in_channel = shuffled_ins_;
out_channel = shuffled_ins_out_;
} else {
in_channel = shuffled_ins_out_;
out_channel = shuffled_ins_;
}
CHECK(in_channel != nullptr);
CHECK(out_channel != nullptr);
int index = 0;
T instance;
T ins_vec;
while (index < DataFeed::default_batch_size_) {
if (in_channel->Size() == 0) {
break;
}
in_channel->Pop(instance);
AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance));
}
DataFeed::batch_size_ = index;
if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
cur_channel_ = 1 - cur_channel_;
}
return DataFeed::batch_size_;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins;
DeserializeIns(ins, ins_str);
shuffled_ins_->Push(std::move(ins));
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
std::vector<T> local_vec;
std::string filename;
while (DataFeed::PickOneFile(&filename)) {
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
__fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER);
T instance;
while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end());
std::vector<T>().swap(local_vec);
}
}
template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
}
// todo global shuffle
/*
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
for (int64_t i = 0; i < memory_data_.size(); ++i) {
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id);
int64_t node_id = hash_id % trainer_num_;
std::string str;
SerializeIns(memory_data_[i], str);
auto fleet_ptr = FleetWrapper::GetInstance();
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
}
}
*/
void MultiSlotDataFeed::Init( void MultiSlotDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) { const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false; finish_init_ = false;
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fcntl.h> #include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE #ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE #undef _POSIX_C_SOURCE
#endif #endif
...@@ -41,7 +43,7 @@ namespace paddle { ...@@ -41,7 +43,7 @@ namespace paddle {
namespace pybind { namespace pybind {
void BindDataset(py::module* m) { void BindDataset(py::module* m) {
py::class_<framework::Dataset>(*m, "Dataset") py::class_<framework::DataSet>(*m, "Dataset")
.def(py::init([]() { .def(py::init([]() {
return std::unique_ptr<framework::Dataset>(new framework::Dataset()); return std::unique_ptr<framework::Dataset>(new framework::Dataset());
})) }))
...@@ -51,7 +53,7 @@ void BindDataset(py::module* m) { ...@@ -51,7 +53,7 @@ void BindDataset(py::module* m) {
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle) .def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle); .def("global_shuffle", &framework::Dataset::GLobalShuffle)
} }
} // end namespace pybind } // end namespace pybind
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册