提交 2e9a836c 编写于 作者: X xjqbest 提交者: dongdaxiang

add DataSet and InMemoryDataFeed, support load data into memory and shuffle data

上级 8de4d31a
......@@ -190,6 +190,30 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper)
<<<<<<< HEAD
=======
if(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer)
else()
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer)
endif(WITH_PSLIB)
>>>>>>> 870b88bbd7... add DataSet and InMemoryDataFeed, support load data into memory and shuffle data
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
......
......@@ -242,6 +242,109 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0;
shuffled_ins_ = nullptr;
shuffled_ins_out_ = nullptr;
}
template <typename T>
bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList();
if (memory_data_.size() != 0) {
CHECK_EQ(cur_channel_, 0);
shuffled_ins_->Extend(std::move(memory_data_));
std::vector<T>().swap(memory_data_);
}
DataFeed::finish_start_ = true;
return true;
}
template <typename T>
int InMemoryDataFeed<T>::Next() {
DataFeed::CheckStart();
std::shared_ptr<paddle::framework::BlockingQueue<T>> in_channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> out_channel = nullptr;
if (cur_channel_ == 0) {
in_channel = shuffled_ins_;
out_channel = shuffled_ins_out_;
} else {
in_channel = shuffled_ins_out_;
out_channel = shuffled_ins_;
}
CHECK(in_channel != nullptr);
CHECK(out_channel != nullptr);
int index = 0;
T instance;
T ins_vec;
while (index < DataFeed::default_batch_size_) {
if (in_channel->Size() == 0) {
break;
}
in_channel->Pop(instance);
AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance));
}
DataFeed::batch_size_ = index;
if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
cur_channel_ = 1 - cur_channel_;
}
return DataFeed::batch_size_;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins;
DeserializeIns(ins, ins_str);
shuffled_ins_->Push(std::move(ins));
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
std::vector<T> local_vec;
std::string filename;
while (DataFeed::PickOneFile(&filename)) {
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
__fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER);
T instance;
while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end());
std::vector<T>().swap(local_vec);
}
}
template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
}
// todo global shuffle
/*
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
for (int64_t i = 0; i < memory_data_.size(); ++i) {
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id);
int64_t node_id = hash_id % trainer_num_;
std::string str;
SerializeIns(memory_data_[i], str);
auto fleet_ptr = FleetWrapper::GetInstance();
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
}
}
*/
void MultiSlotDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
......
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
......@@ -41,7 +43,7 @@ namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::Dataset>(*m, "Dataset")
py::class_<framework::DataSet>(*m, "Dataset")
.def(py::init([]() {
return std::unique_ptr<framework::Dataset>(new framework::Dataset());
}))
......@@ -51,7 +53,7 @@ void BindDataset(py::module* m) {
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle);
.def("global_shuffle", &framework::Dataset::GLobalShuffle)
}
} // end namespace pybind
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册