From 91fc8f355e93e45204011d19e07dfceff561108d Mon Sep 17 00:00:00 2001 From: wangguibao Date: Mon, 19 Nov 2018 16:22:38 +0800 Subject: [PATCH] Interface rework --- paddle/fluid/framework/CMakeLists.txt | 35 +- paddle/fluid/framework/async_executor.cc | 409 +++--------------- paddle/fluid/framework/async_executor.h | 121 +----- paddle/fluid/framework/data_feed.cc | 57 +-- paddle/fluid/framework/data_feed.h | 16 +- paddle/fluid/framework/data_feed.proto | 3 +- paddle/fluid/framework/data_feed_factory.cc | 23 +- paddle/fluid/framework/data_feed_factory.h | 5 +- .../fluid/framework/executor_thread_worker.cc | 34 +- .../fluid/framework/executor_thread_worker.h | 9 +- paddle/fluid/pybind/async_executor_py.cc | 36 +- python/paddle/fluid/async_executor.py | 89 ++-- 12 files changed, 238 insertions(+), 599 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 79529724395..9247c85a522 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -36,7 +36,7 @@ add_subdirectory(details) endif (NOT WIN32) # ddim lib proto_library(framework_proto SRCS framework.proto) -proto_library(async_executor_param SRCS async_executor_param.proto) +proto_library(async_executor_proto SRCS data_feed.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) @@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto) -if(NOT WIN32) cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog shape_inference data_transform lod_tensor profiler) -endif(NOT WIN32) + cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) +if (NOT WIN32) py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) -if (NOT WIN32) - add_custom_command(TARGET framework_py_proto POST_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto - COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/ - COMMENT "Copy generated python proto into directory paddle/fluid/proto." - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) -else(NOT WIN32) - string(REPLACE "/" "\\" proto_dstpath "${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/") - add_custom_command(TARGET framework_py_proto POST_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto - COMMAND copy /Y *.py ${proto_dstpath} - COMMENT "Copy generated python proto into directory paddle/fluid/proto." - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) +add_custom_command(TARGET framework_py_proto POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto + COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/ + COMMENT "Copy generated python proto into directory paddle/fluid/proto." + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) endif(NOT WIN32) cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) @@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - if(NOT WIN32) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator) - else(NOT WIN32) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) - endif(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() @@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS endif() # NOT WIN32 cc_library(async_executor - SRCS async_executor.cc data_feed.cc datafeed_creator.cc + SRCS async_executor.cc data_feed.cc data_feed_factory.cc + executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass - async_executor_param) + async_executor_proto) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 28185d07b9d..f49290f1aa7 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -36,43 +36,13 @@ limitations under the License. */ #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/inference/io.h" +#include "paddle/fluid/framework/executor_thread_worker.h" +#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/pybind/pybind.h" namespace paddle { namespace framework { -bool AsyncExecutor::workers_initialized_ = false; - -void CreateTensor(Variable* var, proto::VarType::Type var_type) { - if (var_type == proto::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == proto::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else if (var_type == proto::VarType::FEED_MINIBATCH) { - var->GetMutable(); - } else if (var_type == proto::VarType::FETCH_LIST) { - var->GetMutable(); - } else if (var_type == proto::VarType::STEP_SCOPES) { - var->GetMutable>(); - } else if (var_type == proto::VarType::LOD_RANK_TABLE) { - var->GetMutable(); - } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { - var->GetMutable(); - } else if (var_type == proto::VarType::PLACE_LIST) { - var->GetMutable(); - } else if (var_type == proto::VarType::READER) { - var->GetMutable(); - } else if (var_type == proto::VarType::RAW) { - // GetMutable will be called in operator - } else { - PADDLE_THROW( - "Variable type %d is not in " - "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " - "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]", - var_type); - } -} - static void ReadBinaryFile(const std::string& filename, std::string* content) { std::string &contents = *content; @@ -139,343 +109,100 @@ static void SaveModel( } } // end SaveModel -void ExecutorThreadWorker::Reset() { - inspect_values_.clear(); -} -void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) { - auto& block = program.Block(0); - op_names_.clear(); - for (auto& op_desc : block.AllOps()) { - std::unique_ptr local_op = OpRegistry::CreateOp(*op_desc); - op_names_.push_back(op_desc->Type()); - OperatorBase* local_op_ptr = local_op.release(); - ops_.push_back(local_op_ptr); - continue; - } -} - -void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { - auto& block = program.Block(0); - thread_scope_ = &root_scope_->NewScope(); - for (auto& var : block.AllVars()) { - if (var->Persistable()) { - auto* ptr = root_scope_->Var(var->Name()); - CreateTensor(ptr, var->GetType()); - } else { - auto* ptr = thread_scope_->Var(var->Name()); - CreateTensor(ptr, var->GetType()); - } - } -} - -void ExecutorThreadWorker::SetDataFeed(DataFeed& datafeed) { - if (typeid(datafeed) == typeid(TextClassDataFeed)) { - local_reader_.reset( - new TextClassDataFeed(dynamic_cast(datafeed))); - local_reader_->SetThreadId(thread_id_); - } -} - -void ExecutorThreadWorker::BindingDataFeedMemory() { - const std::vector& input_feed = local_reader_->GetUseSlotAlias(); - for (auto name : input_feed) { - local_reader_->AddFeedVar(thread_scope_->Var(name), name); - } -} +AsyncExecutor::AsyncExecutor(Scope& scope, const platform::Place& place) + : root_scope_(scope), place_(place) {} -void ExecutorThreadWorker::SetInspectVarNames( - const std::vector& inspect_var_names) { - inspect_var_names_.clear(); - inspect_var_names_.insert(inspect_var_names_.end(), - inspect_var_names.begin(), inspect_var_names.end()); +void AsyncExecutor::CreateThreads( + ExecutorThreadWorker* worker, + const ProgramDesc& main_program, + const std::shared_ptr& reader, + const std::vector& fetch_var_names, + Scope& root_scope, + const int thread_index) { + worker->SetThreadId(thread_index); + worker->SetRootScope(&root_scope); + worker->CreateThreadResource(main_program, place_); + worker->SetDataFeed(reader); + worker->SetFetchVarNames(fetch_var_names); + worker->BindingDataFeedMemory(); } -void ExecutorThreadWorker::SetModelParamNames( - const std::vector& param_names) { - model_param_names_ = param_names; -} - -void ExecutorThreadWorker::SetDevice() { - static unsigned priority[] = { - 0, 1, 2, 3, 4, 5, - 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, - 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47 - }; - - unsigned int i = this->thread_id_; - - if (i < sizeof(priority) / sizeof(unsigned)) { - unsigned proc = priority[i]; - - cpu_set_t mask; - CPU_ZERO(&mask); - CPU_SET(proc, &mask); - - if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) { - LOG(ERROR) << "WARNING: Failed to set thread affinity for thread " << i; - } else { - CPU_ZERO(&mask); - if ((0 == sched_getaffinity(0, sizeof(mask), &mask)) - && CPU_ISSET(proc, &mask)) { - LOG(ERROR) << "TRACE: Thread " << i - << " is running on processor " << proc - << "..."; - } - } - } -} - - -void ExecutorThreadWorker::Train() { - LOG(ERROR) << "begin to train"; - SetDevice(); - - int inspect_var_num = inspect_var_names_.size(); - inspect_values_.clear(); - inspect_values_.resize(inspect_var_num, 0); - - local_reader_->WaitNextEpoch(); - int epoch = local_reader_->GetCurrentEpoch(); - - LOG(ERROR) << "epoch: " << epoch; - - int batch_num = 1; - - while (true) { - const char *file = local_reader_->PickOneFile(); - if (file == NULL) { - break; - } - - if (!local_reader_->SetFile(file)) { - break; - } - - while (true) { - bool flag = local_reader_->ReadBatch(); - if (!flag) { - break; - } - - for (unsigned int i = 0; i < ops_.size(); ++i) { - ops_[i]->Run(*thread_scope_, place_); - } - batch_num++; - - float avg_inspect = 0.0; - for (int i = 0; i < inspect_var_num; ++i) { - avg_inspect = thread_scope_->FindVar(inspect_var_names_[i]) - ->GetMutable() - ->data()[0]; - inspect_values_[i] += avg_inspect; - } - thread_scope_->DropKids(); - } - - local_reader_->UpdateEpochNum(); - LOG(ERROR) << "memory used after epoch " << epoch + 1 - << " called: " << memory::memory_usage(place_); - } - - for (int i = 0; i < inspect_var_num; ++i) { - inspect_values_[i] /= batch_num; - std::string var = inspect_var_names_[i].substr( - 0, - inspect_var_names_[i].find_first_of("_")); - LOG(ERROR) << "mean " << var.c_str() - << " of epoch " << i + 1 << ": " << inspect_values_[i]; - } - - if (thread_id_ == 0) { - char modelfile[1024]; - snprintf(&modelfile[0], sizeof(modelfile), "%s_epoch%d.model", - model_prefix_.c_str(), epoch); - std::string model_filename = std::string(modelfile); - // this save_inference_model can only save imdbtask, should make this - // general - // - // currently comment it - LOG(ERROR) << "Going to save model " << modelfile; - SaveModel(main_program_, - thread_scope_, - model_param_names_, - model_filename, - true); - } -} - -void ExecutorThreadWorker::SetThreadId(int tid) { - thread_id_ = tid; -} - -void ExecutorThreadWorker::SetPlace(const platform::Place& place) { - place_ = place; -} - -void ExecutorThreadWorker::SetMainProgram( - const ProgramDesc& main_program_desc) { - main_program_.reset(new ProgramDesc(main_program_desc)); -} - -void ExecutorThreadWorker::SetRootScope(Scope* g_scope) { - root_scope_ = g_scope; -} - -void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) { - max_epoch_ = max_epoch; -} - -AsyncExecutor::AsyncExecutor(ProgramDesc& main_program, - const std::vector& param_names, - TextClassDataFeed& data_feed, - unsigned int thread_num, - const platform::Place& place) - : thread_num_(thread_num), - place_(place), - main_program_(main_program), - data_feed_(data_feed) { - model_param_names_.clear(); - model_param_names_.insert(model_param_names_.end(), - param_names.begin(), - param_names.end()); -} - -void AsyncExecutor::InitRootScope(Scope* scope) { - root_scope_ = scope; -} - -void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) { - max_epoch_ = max_epoch; +void AsyncExecutor::CheckFiles( + const std::vector& files) { + // function for user to check file formats + // should be exposed to users } void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) { model_prefix_ = model_prefix; } -void AsyncExecutor::RunStartupProgram(const ProgramDesc& program, - Scope* scope) { - auto& block = program.Block(0); - for (auto& var : block.AllVars()) { - if (var->Persistable()) { - auto* ptr = scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); - // LOGERR("Persistable Var Name:%s", var->Name().c_str()); - } - } +std::vector AsyncExecutor::RunFromFile( + const ProgramDesc& main_program, + const DataFeedDesc& data_feed_desc, + const std::vector& filelist, + const int thread_num, + const std::vector& fetch_var_names) { + std::vector threads; - std::map param_dict; - std::vector ops; - for (auto& op_desc : block.AllOps()) { - std::vector param_name_vec = op_desc->OutputArgumentNames(); - bool need_to_run = false; - for (auto& name : param_name_vec) { - if (param_dict.find(name) == param_dict.end()) { - param_dict[name] = 1; - need_to_run = true; - } - } - if (need_to_run) { - std::unique_ptr local_op = OpRegistry::CreateOp(*op_desc); - OperatorBase* local_op_ptr = local_op.release(); - ops.push_back(local_op_ptr); - } + /* + readerDesc: protobuf description for reader initlization + argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index + + reader: + 1) each thread has a reader, reader will read input data and + put it into input queue + 2) each reader has a Next() iterface, that can fetch an instance + from the input queue + */ + // todo: should be factory method for creating datafeed + std::vector > readers; + readers.resize(thread_num); + for (unsigned int i = 0; i < readers.size(); ++i) { + readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name()); } - // LOGERR("There are %d parameters in startup program, %d op needs to run", - // param_dict.size(), ops.size()); - for (auto& op : ops) { - op->Run(*scope, place_); + std::vector > workers; + workers.resize(thread_num); + for (auto& worker : workers) { + worker.reset(new ExecutorThreadWorker); } - // LOGERR("total time for startup program: %fs", timeline.elapsed_sec()); - for (auto& op : ops) { - delete op; - } - // LOGERR("run startup program done."); -} -std::unique_ptr AsyncExecutor::LoadDescFromFile( - const std::string& f) { - std::string program_desc_str; - ReadBinaryFile(f, &program_desc_str); - std::unique_ptr program(new ProgramDesc(program_desc_str)); - return program; -} - -void AsyncExecutor::SetInspectVarNames( - const std::vector& inspect_var_names) { - inspect_var_names_.clear(); - inspect_var_names_.insert(inspect_var_names_.end(), - inspect_var_names.begin(), inspect_var_names.end()); -} - -void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) { - workers_.resize(thread_num_); - for (int i = 0; i < thread_num_; ++i) { - workers_[i].reset(new ExecutorThreadWorker); - workers_[i]->SetThreadId(i); - workers_[i]->CreateThreadOperators(host_program); - workers_[i]->SetRootScope(root_scope_); - workers_[i]->SetPlace(place_); - workers_[i]->SetMaxTrainingEpoch(max_epoch_); - workers_[i]->CreateThreadScope(host_program); - workers_[i]->SetInspectVarNames(inspect_var_names_); - workers_[i]->SetModelParamNames(model_param_names_); - workers_[i]->SetMainProgram(host_program); - workers_[i]->SetModelPrefix(model_prefix_); - // - // new a datafeed here - workers_[i]->SetDataFeed(data_feed_); - workers_[i]->BindingDataFeedMemory(); + // prepare thread resource here + for (int thidx = 0; thidx < thread_num; ++thidx) { + CreateThreads(workers[thidx].get(), main_program, + readers[thidx], fetch_var_names, root_scope_, thidx); } -} -std::vector& AsyncExecutor::Run( - const std::vector& inspect_var_names) { - SetInspectVarNames(inspect_var_names); - threads_.clear(); - - // thread binding here? - if (workers_initialized_ == false) { - PrepareThreads(main_program_); - workers_initialized_ = true; - } - - for (int i = 0; i < thread_num_; ++i) { - workers_[i]->Reset(); - workers_[i]->SetInspectVarNames(inspect_var_names); - threads_.push_back(std::thread(&ExecutorThreadWorker::Train, - workers_[i].get())); + // start executing ops in multiple threads + for (int thidx = 0; thidx < thread_num; ++thidx) { + threads.push_back(std::thread(&ExecutorThreadWorker::TrainFiles, + workers[thidx].get())); } - for (auto& th : threads_) { + for (auto& th : threads) { th.join(); } - inspect_values_.clear(); - inspect_values_.resize(inspect_var_names_.size(), 0); - + std::vector fetch_values; + fetch_values.resize(fetch_var_names.size(), 0); - std::vector*> inspect_value_vectors; - inspect_value_vectors.resize(thread_num_); - for (int i = 0; i < thread_num_; ++i) { - inspect_value_vectors[i] = &workers_[i]->GetInspectValues(); + std::vector*> fetch_value_vectors; + fetch_value_vectors.resize(thread_num); + for (int i = 0; i < thread_num; ++i) { + fetch_value_vectors[i] = &workers[i]->GetFetchValues(); } - for (unsigned int i = 0; i < inspect_var_names_.size(); ++i) { + for (unsigned int i = 0; i < fetch_var_names.size(); ++i) { float value = 0.0; - for (int j = 0; j < thread_num_; ++j) { - value += inspect_value_vectors[j]->at(i); + for (int j = 0; j < thread_num; ++j) { + value += fetch_value_vectors[j]->at(i); } - value /= thread_num_; - inspect_values_[i] = value; + value /= thread_num; + fetch_values[i] = value; } - return inspect_values_; + return fetch_values; } void AsyncExecutor::LoadInitModel() { diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h index ca4cad164ac..89aa8efd9c1 100644 --- a/paddle/fluid/framework/async_executor.h +++ b/paddle/fluid/framework/async_executor.h @@ -23,7 +23,8 @@ limitations under the License. */ #include // NOLINT #include #include -#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/datafeed_creator.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/program_desc.h" @@ -31,93 +32,13 @@ limitations under the License. */ namespace paddle { namespace framework { -void CreateTensor(Variable* var, proto::VarType::Type var_type); - -class ExecutorThreadWorker { - public: - ExecutorThreadWorker() {} - ~ExecutorThreadWorker() {} - void CreateThreadScope(const ProgramDesc& program); - void SetThreadId(int tid); - void CreateThreadOperators(const ProgramDesc& program); - void SetRootScope(Scope* g_scope); - void SetDevice(); - void AddFidSet(); - void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; } - void AddTrainFile(const std::string& filename); - void SetMainProgram(const ProgramDesc& main_program_desc); - void SetPlace(const paddle::platform::Place& place); - void SetMaxTrainingEpoch(const int max_epoch); - void BindingDataFeedMemory(); - - void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; } - - void SetInspectVarNames(const std::vector& inspect_var_names); - void SetModelParamNames(const std::vector& param_names); - void SetDataFeed(DataFeed& datafeed); // NOLINT - void Train(); - const char* PickOneFile(); - void UpdateEpochNum(); - void Reset(); - - void Initialize() {} - std::vector& GetInspectValues() {return inspect_values_;} - - protected: - // thread index - int thread_id_; - - // max epoch for each thread - unsigned int max_epoch_; - - // instances learned currently - int comm_batch_; - std::string model_prefix_; - std::vector op_names_; - - // local ops for forward and backward - std::vector ops_; - - // main program for training - std::unique_ptr main_program_; - - // binary data reader - std::unique_ptr local_reader_; - - std::vector inspect_var_names_; - std::vector model_param_names_; - - // execution place - platform::Place place_; - - // root scope for model parameters - Scope* root_scope_; - - // a thread scope, father scope is global score which is shared - Scope* thread_scope_; - - private: - std::vector inspect_values_; -}; - class AsyncExecutor { public: - explicit AsyncExecutor(ProgramDesc& main_program, // NOLINT - const std::vector& param_names, - TextClassDataFeed& data_feed, // NOLINT - unsigned int thread_num, - const platform::Place& place); + explicit AsyncExecutor(Scope& scope, const platform::Place& place); // NOLINT virtual ~AsyncExecutor() {} static std::unique_ptr LoadDescFromFile( const std::string& filename); - void InitRootScope(Scope* scope); - void SetMaxTrainingEpoch(const int max_epoch); - Scope* GetRootScope() { return root_scope_; } - void SetBatchSize(const int batch_size) { batch_size_ = batch_size; } - - void SetCommBatch(int comm_batch) { - comm_batch_ = comm_batch; - } + Scope* GetRootScope() { return &root_scope_; } void SetModelPath(const std::string& model_path) { model_path_ = model_path; @@ -132,38 +53,32 @@ class AsyncExecutor { } void SetModelPrefix(const std::string& model_prefix); - virtual void PrepareThreads(const ProgramDesc& host_program); void RunStartupProgram(const ProgramDesc& program, Scope* scope); - std::vector& Run(const std::vector& inspect_var_names); + std::vector RunFromFile(const ProgramDesc& main_program, + const DataFeedDesc& data_feed_desc, + const std::vector& filelist, + const int thread_num, + const std::vector& fetch_names); + void CheckFiles(const std::vector& files); void LoadInitModel(); private: - void SetInspectVarNames(const std::vector& inspect_var_names); + void CreateThreads(ExecutorThreadWorker* worker, + const ProgramDesc& main_program, + const std::shared_ptr& reader, + const std::vector& fetch_var_names, + Scope& root_scope, // NOLINT + const int thread_index); + public: - int thread_num_; - int max_epoch_; - int batch_size_; - int comm_batch_; - std::vector > workers_; - std::vector threads_; - std::vector inspect_var_names_; - std::vector model_param_names_; std::string model_prefix_; std::string model_path_; std::string init_prog_file_; std::string init_model_file_; - Scope* root_scope_; + Scope& root_scope_; platform::Place place_; - - private: - ProgramDesc& main_program_; - TextClassDataFeed& data_feed_; - std::vector inspect_values_; - - private: - static bool workers_initialized_; }; } // namespace framework diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 26c7a012f34..0fa2154b7c5 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -38,17 +38,17 @@ DEFINE_bool(is_text_feed, false, "is_text_feed"); namespace paddle { namespace framework { -std::vector TextClassDataFeed::s_filelist_; -std::mutex TextClassDataFeed::s_locker_for_pick_file_; -unsigned int TextClassDataFeed::s_current_file_idx_ = 0; -size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0; -unsigned int TextClassDataFeed::s_current_epoch_ = 0; -int TextClassDataFeed::s_current_save_epoch_ = 0; -std::mutex TextClassDataFeed::s_locker_epoch_start_; -std::condition_variable TextClassDataFeed::s_condition_epoch_start_; -bool TextClassDataFeed::s_epoch_start_flag_ = false; - -void TextClassDataFeed::Init() { +std::vector MultiSlotDataFeed::s_filelist_; +std::mutex MultiSlotDataFeed::s_locker_for_pick_file_; +unsigned int MultiSlotDataFeed::s_current_file_idx_ = 0; +size_t MultiSlotDataFeed::s_current_finished_file_cnt_ = 0; +unsigned int MultiSlotDataFeed::s_current_epoch_ = 0; +int MultiSlotDataFeed::s_current_save_epoch_ = 0; +std::mutex MultiSlotDataFeed::s_locker_epoch_start_; +std::condition_variable MultiSlotDataFeed::s_condition_epoch_start_; +bool MultiSlotDataFeed::s_epoch_start_flag_ = false; + +void MultiSlotDataFeed::Init() { // hard coding for a specific datafeed feed_vec_.resize(2); // feed_vec_[0].reset(new LoDTensor); @@ -73,12 +73,12 @@ void TextClassDataFeed::Init() { field_names_.clear(); } -TextClassDataFeed::TextClassDataFeed() { +MultiSlotDataFeed::MultiSlotDataFeed() { Init(); } // todo: use elegant implemention for this function -bool TextClassDataFeed::ReadBatch() { +bool MultiSlotDataFeed::ReadBatch() { paddle::framework::Vector offset; int tlen = 0; int llen = 0; @@ -142,13 +142,13 @@ bool TextClassDataFeed::ReadBatch() { return true; } -TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) { +MultiSlotDataFeed::MultiSlotDataFeed(const MultiSlotDataFeed& data_feed) { Init(); SetBatchSize(data_feed.batch_size_); SetFieldNames(data_feed.field_names_); } -void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { +void MultiSlotDataFeed::AddFeedVar(Variable* feed, const std::string& name) { for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) { if (name == use_slot_alias_[i]) { feed_vec_[i] = feed->GetMutable(); @@ -156,7 +156,7 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { } } -void TextClassDataFeed::SetFileList(const char* filelist) { +void MultiSlotDataFeed::SetFileList(const char* filelist) { s_filelist_.clear(); std::ifstream fin(filelist); PADDLE_ENFORCE(fin.good(), @@ -170,14 +170,14 @@ void TextClassDataFeed::SetFileList(const char* filelist) { fin.close(); } -void TextClassDataFeed::SetFieldNames( +void MultiSlotDataFeed::SetFieldNames( const std::vector& field_names) { field_names_.clear(); field_names_.insert(field_names_.end(), field_names.begin(), field_names.end()); } -bool TextClassDataFeed::SetFile(const char* filename) { +bool MultiSlotDataFeed::SetFile(const char* filename) { // termnum termid termid ... termid label std::ifstream ifs(filename, std::ios::binary); if (ifs.fail()) { @@ -198,7 +198,7 @@ bool TextClassDataFeed::SetFile(const char* filename) { return true; } -void TextClassDataFeed::UpdateEpochNum() { +void MultiSlotDataFeed::UpdateEpochNum() { s_current_finished_file_cnt_++; if (s_current_finished_file_cnt_ >= s_filelist_.size()) { @@ -214,25 +214,14 @@ void TextClassDataFeed::UpdateEpochNum() { } } -void TextClassDataFeed::StartOneEpoch() { - std::lock_guard lock(s_locker_for_pick_file_); - std::random_shuffle(s_filelist_.begin(), s_filelist_.end()); - s_current_file_idx_ = 0; - LOG(INFO) << "Beginning epoch " << s_current_epoch_; - - { - std::lock_guard lock(s_locker_epoch_start_); - s_epoch_start_flag_ = true; - } - s_condition_epoch_start_.notify_all(); +void MultiSlotDataFeed::Start() { } -void TextClassDataFeed::WaitNextEpoch() { - std::unique_lock lock(s_locker_epoch_start_); - s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;}); +int MultiSlotDataFeed::Next() { + return 0; } -const char* TextClassDataFeed::PickOneFile() { +const char* MultiSlotDataFeed::PickOneFile() { std::string file_to_be_processed; std::lock_guard lock(s_locker_for_pick_file_); diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index f5660357788..b32897e7505 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -81,8 +81,8 @@ class DataFeed { virtual unsigned int GetCurrentEpoch() = 0; virtual const char *PickOneFile() = 0; virtual void UpdateEpochNum() = 0; - virtual void StartOneEpoch() = 0; - virtual void WaitNextEpoch() = 0; + virtual void Start() = 0; + virtual int Next() = 0; std::vector& GetFeedVec() { return feed_vec_; @@ -106,13 +106,13 @@ class DataFeed { int thread_id_; }; -class TextClassDataFeed : public DataFeed { +class MultiSlotDataFeed : public DataFeed { public: - TextClassDataFeed(); - TextClassDataFeed(const TextClassDataFeed& data_feed); + MultiSlotDataFeed(); + MultiSlotDataFeed(const MultiSlotDataFeed& data_feed); public: - virtual ~TextClassDataFeed() {} + virtual ~MultiSlotDataFeed() {} virtual void Init(); virtual bool ReadBatch(); virtual void AddFeedVar(Variable* feed, const std::string& name); @@ -125,8 +125,8 @@ class TextClassDataFeed : public DataFeed { void SetBatchSize(int batch) {batch_size_ = batch;} unsigned int GetCurrentEpoch() {return s_current_epoch_;} void UpdateEpochNum(); - void StartOneEpoch(); - void WaitNextEpoch(); + void Start(); + int Next(); public: void SetFieldNames(const std::vector& field_names); diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index 284627e3525..88e576b1907 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto2"; -package paddle; +package paddle.framework; message DataFeedDesc { optional string name = 1; optional int32 batch = 2 [default = 32]; + repeated string field_names = 3; } diff --git a/paddle/fluid/framework/data_feed_factory.cc b/paddle/fluid/framework/data_feed_factory.cc index b07f770a584..45d6375739e 100644 --- a/paddle/fluid/framework/data_feed_factory.cc +++ b/paddle/fluid/framework/data_feed_factory.cc @@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/data_feed_factory.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include +#include +#include + +#include "paddle/fluid/framework/data_feed.h" namespace paddle { namespace framework { -typedef shared_ptr (*Createdata_feedFunction)(); +typedef std::shared_ptr (*Createdata_feedFunction)(); typedef std::unordered_map data_feedMap; data_feedMap g_data_feed_map; #define REGISTER_DATAFEED_CLASS(data_feed_class) \ namespace { \ - shared_ptr Creator_##data_feed_class() { \ - return shared_ptr(new data_feed_class); \ + std::shared_ptr Creator_##data_feed_class() { \ + return std::shared_ptr(new data_feed_class); \ } \ class __Registerer_##data_feed_class { \ public: \ @@ -35,8 +40,8 @@ data_feedMap g_data_feed_map; } // namespace -string DataFeedFactory::DataFeedTypeList() { - string data_feed_types; +std::string DataFeedFactory::DataFeedTypeList() { + std::string data_feed_types; for (auto iter = g_data_feed_map.begin(); iter != g_data_feed_map.end(); ++iter) { if (iter != g_data_feed_map.begin()) { @@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() { return data_feed_types; } -shared_ptr DataFeedFactory::CreateDataFeed( - const char* data_feed_class) { - if (g_data_feed_map.count(string(data_feed_class)) < 1) { +std::shared_ptr DataFeedFactory::CreateDataFeed( + std::string data_feed_class) { + if (g_data_feed_map.count(data_feed_class) < 1) { exit(-1); } return g_data_feed_map[data_feed_class](); diff --git a/paddle/fluid/framework/data_feed_factory.h b/paddle/fluid/framework/data_feed_factory.h index af203001c54..62e56dc58fd 100644 --- a/paddle/fluid/framework/data_feed_factory.h +++ b/paddle/fluid/framework/data_feed_factory.h @@ -16,14 +16,15 @@ limitations under the License. */ #define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_ #include -#include "paddle/framework/data_feed.h" +#include +#include "paddle/fluid/framework/data_feed.h" namespace paddle { namespace framework { class DataFeedFactory { public: static std::string DataFeedTypeList(); - static shared_ptr CreateDataFeed(const char* data_feed_class); + static std::shared_ptr CreateDataFeed(std::string data_feed_class); }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index 36e951cd1a1..6a84136ac70 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource( void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { auto& block = program.Block(0); + + PADDLE_ENFORCE_NOT_NULL( + root_scope_, + "root_scope should be set before creating thread scope"); + thread_scope_ = &root_scope_->NewScope(); for (auto& var : block.AllVars()) { if (var->Persistable()) { @@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { void ExecutorThreadWorker::SetDataFeed( const std::shared_ptr& datafeed) { - local_reader_ = datafeed; + thread_reader_ = datafeed; } void ExecutorThreadWorker::BindingDataFeedMemory() { const std::vector& input_feed = thread_reader_->GetUseSlotAlias(); for (auto name : input_feed) { - local_reader_->AddFeedVar(thread_scope_->Var(name), name); + thread_reader_->AddFeedVar(thread_scope_->Var(name), name); } } +void ExecutorThreadWorker::SetFetchVarNames( + const std::vector& fetch_var_names) { + fetch_var_names_.clear(); + fetch_var_names_.insert(fetch_var_names_.end(), + fetch_var_names.begin(), fetch_var_names.end()); +} + void ExecutorThreadWorker::SetDevice() { // at most 48 threads binding currently static unsigned priority[] = { @@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() { void ExecutorThreadWorker::TrainFiles() { // todo: configurable SetDevice(); + + int fetch_var_num = fetch_var_names_.size(); + fetch_values_.clear(); + fetch_values_.resize(fetch_var_num, 0); + thread_reader_->Start(); - while (int cur_batch = thread_reader_->Next()) { + + int cur_batch; + while ((cur_batch = thread_reader_->Next()) > 0) { // executor run here for (auto& op : ops_) { op->Run(*thread_scope_, place_); } + + float avg_inspect = 0.0; + for (int i = 0; i < fetch_var_num; ++i) { + avg_inspect = thread_scope_->FindVar(fetch_var_names_[i]) + ->GetMutable() + ->data()[0]; + fetch_values_[i] += avg_inspect; + } + thread_scope_->DropKids(); } } diff --git a/paddle/fluid/framework/executor_thread_worker.h b/paddle/fluid/framework/executor_thread_worker.h index 5b70fa5f5b7..bf4e3b04540 100644 --- a/paddle/fluid/framework/executor_thread_worker.h +++ b/paddle/fluid/framework/executor_thread_worker.h @@ -43,6 +43,9 @@ class ExecutorThreadWorker { void SetDevice(); void BindingDataFeedMemory(); void SetDataFeed(const std::shared_ptr& datafeed); + void TrainFiles(); + void SetFetchVarNames(const std::vector& fetch_var_names); + std::vector& GetFetchValues() {return fetch_values_;} private: void CreateThreadScope(const framework::ProgramDesc& program); @@ -66,9 +69,13 @@ class ExecutorThreadWorker { Scope* root_scope_; // a thread scope, father scope is global score which is shared Scope* thread_scope_; + + private: + std::vector fetch_var_names_; + std::vector fetch_values_; }; } // namespace framework } // namespace paddle -#endif // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_ +#endif // PADDLE_FLUID_FRAMEWORK_EXECUTOR_THREAD_WORKER_H_ /* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ diff --git a/paddle/fluid/pybind/async_executor_py.cc b/paddle/fluid/pybind/async_executor_py.cc index 29e36b7b19e..2f7bae06289 100644 --- a/paddle/fluid/pybind/async_executor_py.cc +++ b/paddle/fluid/pybind/async_executor_py.cc @@ -30,36 +30,32 @@ limitations under the License. */ #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/framework/async_executor_param.pb.h" +#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/data_feed.h" namespace py = pybind11; +namespace pd = paddle::framework; namespace paddle { namespace pybind { +using set_name_func = void (pd::DataFeedDesc::*)(const std::string&); void BindAsyncExecutor(py::module* m) { - py::class_(*m, "DataFeed"); - py::class_(*m, "TextDataFeed") - .def(py::init()) - .def("set_filelist", - [] (framework::TextClassDataFeed &self, const char *data_list_file) { - self.SetFileList(data_list_file); - }) - .def("set_batch_size", &framework::TextClassDataFeed::SetBatchSize) - .def("set_field_names", &framework::TextClassDataFeed::SetFieldNames) - .def("start_one_epoch", &framework::TextClassDataFeed::StartOneEpoch); + py::class_(*m, "DataFeedDesc") + .def(pybind11::init<>()) + .def("set_name", (set_name_func)&pd::DataFeedDesc::set_name) + .def("set_batch", &pd::DataFeedDesc::set_batch) + .def("set_field_names", + [] (pd::DataFeedDesc& self, const std::vector &fields) { + for (auto field : fields) { + self.add_field_names(field); + } + }); py::class_(*m, "AsyncExecutor") - .def(py::init&, - framework::TextClassDataFeed&, - unsigned int, - const platform::Place&>()) - .def("init_root_scope", &framework::AsyncExecutor::InitRootScope) - .def("run_startup_program", &framework::AsyncExecutor::RunStartupProgram) - .def("run", &framework::AsyncExecutor::Run); + .def(py::init()) + .def("run_from_files", &framework::AsyncExecutor::RunFromFile) + .def("check_file", &framework::AsyncExecutor::CheckFiles); } // end BindAsyncExecutor } // end namespace pybind } // end namespace paddle diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index 5cab03fdf06..9af7f97f18b 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -19,30 +19,26 @@ import contextlib import six from .framework import Program, default_main_program, Variable from . import core -from . import Executor +from .executor import global_scope -__all__ = ['TextDataFeed', 'AsyncExecutor'] +__all__ = ['MultiSlotDataFeed', 'AsyncExecutor'] g_scope = core.Scope() -class TextDataFeed(): +class DataFeedDesc(object): def __init__(self): - self.feed = core.TextDataFeed() - - def set_filelist(self, filelist): - self.feed.set_filelist(filelist) - + self.desc = core.DataFeedDesc() def set_batch_size(self, batch_size): - self.feed.set_batch_size(batch_size) - - def set_field_names(self, field_names): - if isinstance(field_names, Variable): + self.desc.set_batch(batch_size) + def set_field_name(self, field_names): + if isinstance(field_names, str): field_names = [field_names] + self.desc.set_field_names(field_names) - self.feed.set_field_names(field_names) - - def start_an_epoch(self): - self.feed.start_one_epoch() +class MultiSlotDataFeed(DataFeedDesc): + def __init__(self): + super(MultiSlotDataFeed, self).__init__() + self.desc.set_name("MultiSlotDataFeed") class AsyncExecutor(object): """ @@ -55,45 +51,19 @@ class AsyncExecutor(object): They has the exactly same arguments, and expected the same results. """ - def __init__(self, - program, - param_names, - data_feed, - thread_num, - place=None, - scope=None): - if program is None: - program = default_main_program() - program_desc = program.desc - - if not isinstance(data_feed, TextDataFeed): - raise ValueError("data_feed for AsyncExecutor.run() type error") - + def __init__(self, place=None): if place is None: place = core.CPUPlace() if not isinstance(place, core.CPUPlace): raise ValueError("AsyncExecutor only supports CPU device") - if isinstance(param_names, Variable): - param_names = [param_names] - p = core.Place() p.set_place(place) - self.executor = core.AsyncExecutor(program_desc, param_names, data_feed.feed, thread_num, p) - - def run_startup_program(self, - program=None, - scope=None): - if program is None: - program = default_startup_program() - program_desc = program._get_desc() - if scope is None: - scope = g_scope + scope = global_scope() + self.executor = core.AsyncExecutor(scope, p) - self.executor.run_startup_program(program_desc, scope) - - def run(self, inspect_vars, scope=None): + def run(self, program, data_feed, filelist, thread_num, fetch): """ Run program by this Executor. Feed data by feed map, fetch result by fetch_list. Python executor takes a program, add feed operators and fetch operators to this program according @@ -136,16 +106,27 @@ class AsyncExecutor(object): >>> feed={'X': x}, >>> fetch_list=[loss.name]) """ - if inspect_vars is not None: - if isinstance(inspect_vars, Variable): - inspect_vars = [inspect_vars] - inspect_var_names = [var.name for var in inspect_vars] + if program is None: + program = default_main_program() + program_desc = program.desc + + if data_feed is None: + raise ValueError('ValueError: data_feed should be provided') + + if filelist is None: + raise ValueError('ValueError: filelist should be provided') + + if isinstance(filelist, str): + filelist = [filelist] - if scope is None: - scope = g_scope + if not isinstance(thread_num, int): + raise TypeError('TypeError: thread_num should be a positive number') - self.executor.init_root_scope(scope) + if fetch is not None: + if isinstance(fetch, Variable): + fetch = [fetch] + fetch_var_names = [var.name for var in fetch] - evaluation = self.executor.run(inspect_var_names) + evaluation = self.executor.run_from_files(program_desc, data_feed.desc, filelist, thread_num, fetch_var_names) return evaluation -- GitLab