From 39014b9f9f7b45ed3e96d3487f299d483eca3e00 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Sat, 2 Feb 2019 16:25:52 +0800 Subject: [PATCH] fix class register problem --- paddle/fluid/framework/async_executor.cc | 22 ++----------------- .../fluid/framework/device_worker_factory.cc | 21 ++++++++++++++++++ .../fluid/framework/device_worker_factory.h | 19 ---------------- paddle/fluid/framework/dist_multi_trainer.cc | 4 ++-- paddle/fluid/framework/downpour_worker.cc | 6 ++++- paddle/fluid/framework/hogwild_worker.cc | 1 - paddle/fluid/framework/multi_trainer.cc | 2 -- paddle/fluid/framework/trainer_factory.cc | 20 +++++++++++++++++ paddle/fluid/framework/trainer_factory.h | 17 -------------- 9 files changed, 50 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 610ab9f302..59d8151f1e 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -64,7 +64,6 @@ void AsyncExecutor::InitModel() {} void AsyncExecutor::SaveModel(const std::string& path) {} void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, -<<<<<<< HEAD const std::string& data_feed_desc_str, const std::vector& filelist, const int thread_num, @@ -153,25 +152,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, _pull_dense_thread->stop(); } #endif -======= - const std::string& trainer_desc_str, - const bool debug) { - TrainerDesc trainer_desc; - google::protobuf::TextFormat::ParseFromString(trainer_desc_str, - &trainer_desc); - std::shared_ptr trainer; - trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name()); - // initialize trainer - trainer->Initialize(trainer_desc); - trainer->SetScope(root_scope_); - trainer->SetDebug(debug); - // prepare training environment and helper environment - trainer->InitTrainerEnv(main_program, place_); - trainer->InitOtherEnv(main_program); - // training and finalize training - trainer->Run(); - trainer->Finalize(); ->>>>>>> add dist_multi_trainer for distributed training, add trainer_factory and device_worker_factory so that we can easily extend new training mode, add pull dense worker which is a singleton for parameter fetching + VLOG(3) << "start to run from files in async_executor"; + VLOG(3) << "Drop current scope kids"; root_scope_->DropKids(); return; diff --git a/paddle/fluid/framework/device_worker_factory.cc b/paddle/fluid/framework/device_worker_factory.cc index 7492ae041c..2a7b368145 100644 --- a/paddle/fluid/framework/device_worker_factory.cc +++ b/paddle/fluid/framework/device_worker_factory.cc @@ -20,6 +20,25 @@ limitations under the License. */ namespace paddle { namespace framework { +typedef std::shared_ptr (*Createdevice_workerFunction)(); +typedef std::unordered_map + device_workerMap; +device_workerMap g_device_worker_map; +#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \ + namespace { \ + std::shared_ptr Creator_##device_worker_class() { \ + return std::shared_ptr(new device_worker_class); \ + } \ + class __Registerer_##device_worker_class { \ + public: \ + __Registerer_##device_worker_class() { \ + g_device_worker_map[#device_worker_class] = \ + &Creator_##device_worker_class; \ + } \ + }; \ + __Registerer_##device_worker_class g_registerer_##device_worker_class; \ + } // namespace + std::string DeviceWorkerFactory::DeviceWorkerTypeList() { std::string device_worker_types; for (auto iter = g_device_worker_map.begin(); @@ -40,5 +59,7 @@ std::shared_ptr DeviceWorkerFactory::CreateDeviceWorker( return g_device_worker_map[device_worker_class](); } +REGISTER_DEVICE_WORKER_CLASS(HogwildWorker); +REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/device_worker_factory.h b/paddle/fluid/framework/device_worker_factory.h index 9b16d61099..9d0613385e 100644 --- a/paddle/fluid/framework/device_worker_factory.h +++ b/paddle/fluid/framework/device_worker_factory.h @@ -21,25 +21,6 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef std::shared_ptr (*Createdevice_workerFunction)(); -typedef std::unordered_map - device_workerMap; -device_workerMap g_device_worker_map; -#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \ - namespace { \ - std::shared_ptr Creator_##device_worker_class() { \ - return std::shared_ptr(new device_worker_class); \ - } \ - class __Registerer_##device_worker_class { \ - public: \ - __Registerer_##device_worker_class() { \ - g_device_worker_map[#device_worker_class] = \ - &Creator_##device_worker_class; \ - } \ - }; \ - __Registerer_##device_worker_class g_registerer_##device_worker_class; \ - } // namespace - class DeviceWorkerFactory { public: static std::string DeviceWorkerTypeList(); diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 646409d521..45eb4ae0ea 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" -#include "paddle/fluid/framework/trainer_factory.h" namespace paddle { namespace framework { @@ -48,11 +47,13 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) { fleet_ptr_ = FleetWrapper::GetInstance(); pull_dense_worker_ = PullDenseWorker::GetInstance(); pull_dense_worker_->Initialize(trainer_desc); + VLOG(3) << "initialize pull dense worker"; } void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { pull_dense_worker_->SetRootScope(root_scope_); pull_dense_worker_->Start(); + VLOG(3) << "init other env done."; } void DistMultiTrainer::Finalize() { @@ -62,6 +63,5 @@ void DistMultiTrainer::Finalize() { pull_dense_worker_->Stop(); } -REGISTER_TRAINER_CLASS(DistMultiTrainer); } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index ff2fc3f89a..62126072c8 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -134,6 +134,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { } void DownpourWorker::TrainFiles() { + VLOG(3) << "Begin to train files"; platform::SetNumThreads(1); device_reader_->Start(); int batch_cnt = 0; @@ -148,6 +149,7 @@ void DownpourWorker::TrainFiles() { CollectLabelInfo(i); FillSparseValue(i); } + VLOG(3) << "fill sparse value for all sparse table done."; // do computation here for (auto& op : ops_) { @@ -179,6 +181,7 @@ void DownpourWorker::TrainFiles() { *thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_); } + VLOG(3) << "push sparse and dense gradient done."; // the following code should be more precise and clean // TODO(guru4elephant) int32_t tmp_push_dense_wait_times = -1; @@ -210,16 +213,17 @@ void DownpourWorker::TrainFiles() { push_sparse_status_.resize(0); } + /* for (size_t i = 0; i < param_.dense_table_size(); ++i) { uint64_t tid = static_cast(param_.dense_table(i).table_id()); pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); } + */ thread_scope_->DropKids(); ++batch_cnt; } } -REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index a9c23fd63c..9b603d9f13 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -129,6 +129,5 @@ void HogwildWorker::TrainFiles() { } } -REGISTER_DEVICE_WORKER_CLASS(HogwildWorker); } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index b8e2f0aff1..969d27c8ef 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" -#include "paddle/fluid/framework/trainer_factory.h" namespace paddle { namespace framework { @@ -66,6 +65,5 @@ void MultiTrainer::Finalize() { } } -REGISTER_TRAINER_CLASS(MultiTrainer); } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/trainer_factory.cc b/paddle/fluid/framework/trainer_factory.cc index 915d0c3555..6b4461c0c4 100644 --- a/paddle/fluid/framework/trainer_factory.cc +++ b/paddle/fluid/framework/trainer_factory.cc @@ -22,8 +22,24 @@ limitations under the License. */ namespace paddle { namespace framework { +typedef std::shared_ptr (*CreatetrainerFunction)(); +typedef std::unordered_map trainerMap; trainerMap g_trainer_map; +#define REGISTER_TRAINER_CLASS(trainer_class) \ + namespace { \ + std::shared_ptr Creator_##trainer_class() { \ + return std::shared_ptr(new trainer_class); \ + } \ + class __Registerer_##trainer_class { \ + public: \ + __Registerer_##trainer_class() { \ + g_trainer_map[#trainer_class] = &Creator_##trainer_class; \ + } \ + }; \ + __Registerer_##trainer_class g_registerer_##trainer_class; \ + } // namespace + std::string TrainerFactory::TrainerTypeList() { std::string trainer_types; for (auto iter = g_trainer_map.begin(); iter != g_trainer_map.end(); ++iter) { @@ -38,10 +54,14 @@ std::string TrainerFactory::TrainerTypeList() { std::shared_ptr TrainerFactory::CreateTrainer( std::string trainer_class) { if (g_trainer_map.count(trainer_class) < 1) { + LOG(WARNING) << "Trainer class: " << trainer_class << " not defined"; + LOG(WARNING) << TrainerTypeList(); exit(-1); } return g_trainer_map[trainer_class](); } +REGISTER_TRAINER_CLASS(MultiTrainer); +REGISTER_TRAINER_CLASS(DistMultiTrainer); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/trainer_factory.h b/paddle/fluid/framework/trainer_factory.h index 89348fd3c7..9c772a4f19 100644 --- a/paddle/fluid/framework/trainer_factory.h +++ b/paddle/fluid/framework/trainer_factory.h @@ -20,23 +20,6 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef std::shared_ptr (*CreatetrainerFunction)(); -typedef std::unordered_map trainerMap; -extern trainerMap g_trainer_map; - -#define REGISTER_TRAINER_CLASS(trainer_class) \ - namespace { \ - std::shared_ptr Creator_##trainer_class() { \ - return std::shared_ptr(new trainer_class); \ - } \ - class __Registerer_##trainer_class { \ - public: \ - __Registerer_##trainer_class() { \ - g_trainer_map[#trainer_class] = &Creator_##trainer_class; \ - } \ - }; \ - __Registerer_##trainer_class g_registerer_##trainer_class; \ - } // namespace class TrainerFactory { public: -- GitLab