From 24863897935860f9b2c7e9a1c0c3c4e68be111cc Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Fri, 8 Mar 2019 15:36:26 +0800 Subject: [PATCH] add RunFromDataset in executor --- paddle/fluid/framework/CMakeLists.txt | 39 ++++-------- paddle/fluid/framework/async_executor.cc | 11 +--- paddle/fluid/framework/data_feed.cc | 63 ++++++++++---------- paddle/fluid/framework/dist_multi_trainer.cc | 3 +- paddle/fluid/framework/executor.cc | 41 +++++++++++-- paddle/fluid/framework/multi_trainer.cc | 4 +- paddle/fluid/framework/trainer.h | 11 ++-- 7 files changed, 94 insertions(+), 78 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 040e36b79..d4a9ca5fb 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -28,7 +28,7 @@ add_subdirectory(common) add_subdirectory(io) #ddim lib proto_library(framework_proto SRCS framework.proto) -proto_library(async_executor_proto SRCS data_feed.proto) +proto_library(data_feed_proto SRCS data_feed.proto) proto_library(trainer_desc_proto SRCS trainer_desc.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) @@ -175,15 +175,11 @@ cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc o if(WITH_DISTRIBUTE) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog - lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS}) + lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} trainer_library) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - if(WITH_NGRAPH) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper) - else(WITH_NGRAPH) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) - endif(WITH_NGRAPH) + cc_library(executor SRCS executor.cc multi_trainer.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() @@ -194,28 +190,15 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph build_strategy fast_threaded_ssa_graph_executor variable_helper) -if(WITH_PSLIB) - cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc - executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc - trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc - downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc - data_set.cc - DEPS op_registry device_context scope framework_proto - trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer - feed_fetch_method graph_to_program_pass async_executor_proto - variable_helper pslib_brpc pslib timer) -else() - cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc - executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc - trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc - downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc - data_set.cc - DEPS op_registry device_context scope framework_proto - trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer - feed_fetch_method graph_to_program_pass async_executor_proto - variable_helper timer) -endif(WITH_PSLIB) +cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc + executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc + trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc + downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc + data_set.cc DEPS op_registry device_context scope framework_proto + trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer + feed_fetch_method graph_to_program_pass data_feed_proto + variable_helper timer) cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 902f44291..d1a086f71 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -154,14 +154,5 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, return; } -// todo RunFromDataset -void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program, - Dataset* data_set, - const std::string& trainer_desc_str, - const bool debug) { - -} - - -} // einit_modelnd namespace framework +} // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 4a7793ec8..e93683cb7 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -135,9 +135,7 @@ int PrivateQueueDataFeed::Next() { return batch_size_; } -#ifdef _WIN32 template class PrivateQueueDataFeed>; -#endif template InMemoryDataFeed::InMemoryDataFeed() { @@ -150,7 +148,7 @@ template bool InMemoryDataFeed::Start() { DataFeed::CheckSetFileList(); if (memory_data_.size() != 0) { - CHECK(cur_channel_ == 0); + CHECK_EQ(cur_channel_, 0); shuffled_ins_->Extend(std::move(memory_data_)); std::vector().swap(memory_data_); } @@ -173,30 +171,30 @@ int InMemoryDataFeed::Next() { CHECK(in_channel != nullptr); CHECK(out_channel != nullptr); int index = 0; - T instance; - T ins_vec; - while (index < DataFeed::default_batch_size_) { - if (in_channel->Size() == 0) { - break; - } - in_channel->Pop(instance); - AddInstanceToInsVec(&ins_vec, instance, index++); - out_channel->Push(std::move(instance)); - } - DataFeed::batch_size_ = index; - if (DataFeed::batch_size_ != 0) { - PutToFeedVec(ins_vec); - } else { - cur_channel_ = 1 - cur_channel_; + T instance; + T ins_vec; + while (index < DataFeed::default_batch_size_) { + if (in_channel->Size() == 0) { + break; } - return DataFeed::batch_size_; + in_channel->Pop(instance); + AddInstanceToInsVec(&ins_vec, instance, index++); + out_channel->Push(std::move(instance)); + } + DataFeed::batch_size_ = index; + if (DataFeed::batch_size_ != 0) { + PutToFeedVec(ins_vec); + } else { + cur_channel_ = 1 - cur_channel_; + } + return DataFeed::batch_size_; } template void InMemoryDataFeed::PutInsToChannel(const std::string& ins_str) { - T ins; - DeserializeIns(ins, ins_str); - shuffled_ins_->Push(std::move(ins)); + T ins; + DeserializeIns(ins, ins_str); + shuffled_ins_->Push(std::move(ins)); } template @@ -205,11 +203,11 @@ void InMemoryDataFeed::LoadIntoMemory() { std::string filename; while (DataFeed::PickOneFile(&filename)) { int err_no = 0; - PrivateQueueDataFeed::fp_ = fs_open_read(filename, &err_no, - PrivateQueueDataFeed::pipe_command_); + PrivateQueueDataFeed::fp_ = + fs_open_read(filename, &err_no, PrivateQueueDataFeed::pipe_command_); __fsetlocking(&*PrivateQueueDataFeed::fp_, FSETLOCKING_BYCALLER); T instance; - while(ParseOneInstanceFromPipe(&instance)) { + while (ParseOneInstanceFromPipe(&instance)) { local_vec.push_back(instance); } memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end()); @@ -242,6 +240,8 @@ void InMemoryDataFeed::GlobalShuffle(int trainer_num) { } */ +template class InMemoryDataFeed>; + void MultiSlotDataFeed::Init( const paddle::framework::DataFeedDesc& data_feed_desc) { finish_init_ = false; @@ -633,7 +633,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe( } } -bool MultiSlotInMemoryDataFeed::ParseOneInstance(std::vector* instance) { +bool MultiSlotInMemoryDataFeed::ParseOneInstance( + std::vector* instance) { std::string line; if (getline(file_, line)) { int use_slots_num = use_slots_.size(); @@ -725,12 +726,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( } // todo serialize ins in global shuffle -void MultiSlotInMemoryDataFeed::SerializeIns(const std::vector& ins, std::string& str) { - +void MultiSlotInMemoryDataFeed::SerializeIns( + const std::vector& ins, std::string& str) { + return; } // todo deserialize ins in global shuffle -void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector& ins, const std::string& str) { - +void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector& ins, + const std::string& str) { + return; } } // namespace framework diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 8b15a3d7a..44509486c 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -21,7 +21,8 @@ limitations under the License. */ namespace paddle { namespace framework { -void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) { +void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, + const Dataset& data_set) { thread_num_ = trainer_desc.thread_num(); workers_.resize(thread_num_); readers_.resize(thread_num_); diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 97fd6ee15..ef84d3876 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -19,13 +19,16 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/executor_gc_helper.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" -#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/framework/trainer_desc.pb.h" +#include "paddle/fluid/framework/trainer_factory.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" @@ -115,9 +118,39 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, } } -void Executor::RunFromDataset(const ProgramDesc& pdesc, const Dataset& dataset, +void Executor::RunFromDataset(const ProgramDesc& main_program, + const Dataset& dataset, const std::string& trainer_desc_str, - const bool debug) {} + const bool debug) { + VLOG(3) << "Start to RunFromDataset in executor"; + TrainerDesc trainer_desc; + google::protobuf::TextFormat::ParseFromString(trainer_desc_str, + &trainer_desc); + VLOG(3) << "Going to create trainer, trainer class is " + << trainer_desc.class_name(); + std::shared_ptr trainer; + trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name()); + // initialize trainer + VLOG(3) << "Going to initialize trainer"; + trainer->Initialize(trainer_desc, dataset); + VLOG(3) << "Set root scope here"; + trainer->SetScope(root_scope_); + VLOG(3) << "Going to set debug"; + trainer->SetDebug(debug); + // prepare training environment and helper environment + VLOG(3) << "Try to init train environment"; + trainer->InitTrainerEnv(main_program, place_); + VLOG(3) << "Try to init other environment"; + trainer->InitOtherEnv(main_program); + // training and finalize training + VLOG(3) << "Trainer starts to run"; + trainer->Run(); + VLOG(3) << "Trainer going to finalize"; + trainer->Finalize(); + VLOG(3) << "Drop current scope kids"; + root_scope_->DropKids(); + return; +} void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars, diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index d1ade19f5..dd52d3608 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -22,11 +22,12 @@ namespace paddle { namespace framework { void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, - Dataset* dataset) { + const Dataset& dataset) { thread_num_ = trainer_desc.thread_num(); // get filelist from trainer_desc here workers_.resize(thread_num_); + /* if (NULL == dataset) { readers_.resize(thread_num_); for (int i = 0; i < thread_num_; ++i) { @@ -42,6 +43,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, } else { // readers_ = dataset.get_readers(); ? } + */ for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 654254592..2de4d93cb 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -22,6 +22,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/program_desc.h" @@ -29,7 +30,6 @@ limitations under the License. */ #include "paddle/fluid/framework/trainer_desc.pb.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/reader/blocking_queue.h" -#include "paddle/fluid/framework/data_set.h" namespace paddle { namespace framework { @@ -41,7 +41,8 @@ class TrainerBase { // model memory are hosted in root_scope void SetScope(Scope* root_scope); void SetDebug(const bool debug) { debug_ = debug; } - virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) = 0; + virtual void Initialize(const TrainerDesc& trainer_desc, + const Dataset& data_set) = 0; virtual void InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place) = 0; virtual void InitOtherEnv(const ProgramDesc& main_program) = 0; @@ -60,7 +61,8 @@ class MultiTrainer : public TrainerBase { public: MultiTrainer() {} virtual ~MultiTrainer() {} - virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); + virtual void Initialize(const TrainerDesc& trainer_desc, + const Dataset& data_set); virtual void InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place); virtual void InitOtherEnv(const ProgramDesc& main_program) {} @@ -78,7 +80,8 @@ class DistMultiTrainer : public MultiTrainer { public: DistMultiTrainer() {} virtual ~DistMultiTrainer() {} - virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); + virtual void Initialize(const TrainerDesc& trainer_desc, + const Dataset& data_set); virtual void InitOtherEnv(const ProgramDesc& main_program); virtual void Finalize(); -- GitLab