From b415ec27e8791f40f2d07fed7c65e44f2804efce Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Sat, 9 Mar 2019 22:08:33 +0800 Subject: [PATCH] make Dataset* as an argument --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/data_set.cc | 2 +- paddle/fluid/framework/data_set.h | 2 +- paddle/fluid/framework/dist_multi_trainer.cc | 17 ++---- paddle/fluid/framework/executor.cc | 11 ++-- paddle/fluid/framework/executor.h | 9 +-- paddle/fluid/framework/multi_trainer.cc | 25 ++------ python/paddle/fluid/distributed/fleet.py | 63 ++++++++++++++++++++ python/paddle/fluid/executor.py | 20 +++++++ python/paddle/fluid/trainer.py | 16 ----- python/paddle/fluid/trainer_factory.py | 32 ++++++++++ 11 files changed, 134 insertions(+), 65 deletions(-) create mode 100644 python/paddle/fluid/distributed/fleet.py delete mode 100644 python/paddle/fluid/trainer.py create mode 100644 python/paddle/fluid/trainer_factory.py diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index e6e4a2ce4..24c181e8c 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -30,7 +30,7 @@ add_subdirectory(io) proto_library(framework_proto SRCS framework.proto) proto_library(data_feed_proto SRCS data_feed.proto) proto_library(async_executor_proto SRCS data_feed.proto) -proto_library(trainer_desc_proto SRCS trainer_desc.proto) +proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 457ae9360..baa971cde 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -52,7 +52,7 @@ void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) { data_feed_desc_str, &data_feed_desc_); } -std::vector> +const std::vector>& Dataset::GetReaders() { return readers_; } diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 06f47da32..f99dc1470 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -43,7 +43,7 @@ class Dataset { return data_feed_desc_; } - virtual std::vector> + virtual const std::vector>& GetReaders(); virtual void LoadIntoMemory(); virtual void LocalShuffle(); diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index cbfd29501..9997da019 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" @@ -25,26 +26,18 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) { thread_num_ = trainer_desc.thread_num(); workers_.resize(thread_num_); - readers_.resize(thread_num_); + + const std::vector> readers = + data_set->GetReaders(); for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); - readers_[i] = - DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name()); workers_[i]->SetDeviceIndex(i); - readers_[i]->Init(trainer_desc.data_desc()); - workers_[i]->SetDataFeed(readers_[i]); + workers_[i]->SetDataFeed(readers[i]); workers_[i]->Initialize(trainer_desc); } - std::vector filelist_vec; - for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) { - filelist_vec.push_back(trainer_desc.filelist(i)); - } - - readers_[0]->SetFileList(filelist_vec); - fleet_ptr_ = FleetWrapper::GetInstance(); pull_dense_worker_ = PullDenseWorker::GetInstance(); pull_dense_worker_->Initialize(trainer_desc); diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 9eccea7ac..9ba50ff9e 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -116,10 +116,9 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, } } -void Executor::RunFromDataset(const ProgramDesc& main_program, +void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope, Dataset* dataset, - const std::string& trainer_desc_str, - const bool debug) { + const std::string& trainer_desc_str) { VLOG(3) << "Start to RunFromDataset in executor"; TrainerDesc trainer_desc; google::protobuf::TextFormat::ParseFromString(trainer_desc_str, @@ -132,9 +131,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, VLOG(3) << "Going to initialize trainer"; trainer->Initialize(trainer_desc, dataset); VLOG(3) << "Set root scope here"; - trainer->SetScope(root_scope_); - VLOG(3) << "Going to set debug"; - trainer->SetDebug(debug); + trainer->SetScope(scope); // prepare training environment and helper environment VLOG(3) << "Try to init train environment"; trainer->InitTrainerEnv(main_program, place_); @@ -146,7 +143,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, VLOG(3) << "Trainer going to finalize"; trainer->Finalize(); VLOG(3) << "Drop current scope kids"; - root_scope_->DropKids(); + scope->DropKids(); return; } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 6368d9b38..1a0ae48b8 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -114,16 +114,11 @@ class Executor { void EnableMKLDNN(const ProgramDesc& program); - void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset, - const std::string& trainer_desc_str, const bool debug); - - public: - std::shared_ptr fleet_ptr_; - Scope* root_scope_; + void RunFromDataset(const ProgramDesc& main_program, Scope* scope, + Dataset* dataset, const std::string& trainer_desc_str); private: const platform::Place place_; - int actual_thread_num_; }; } // namespace framework diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 7d9b6839e..0da4fa863 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -26,31 +26,16 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, thread_num_ = trainer_desc.thread_num(); // get filelist from trainer_desc here workers_.resize(thread_num_); - - /* - if (NULL == dataset) { - readers_.resize(thread_num_); - for (int i = 0; i < thread_num_; ++i) { - readers_[i] = - DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name()); - readers_[i]->Init(trainer_desc.data_desc()); - } - std::vector filelist_vec; - for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) { - filelist_vec.push_back(trainer_desc.filelist(i)); - } - readers_[0]->SetFileList(filelist_vec); - } else { - // readers_ = dataset.get_readers(); ? - } - */ - + const std::vector> readers = + dataset->GetReaders(); for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); workers_[i]->SetDeviceIndex(i); - workers_[i]->SetDataFeed(readers_[i]); + workers_[i]->SetDataFeed(readers[i]); } + + // set debug here } // call only after all resources are set in current trainer diff --git a/python/paddle/fluid/distributed/fleet.py b/python/paddle/fluid/distributed/fleet.py new file mode 100644 index 000000000..386ced0ee --- /dev/null +++ b/python/paddle/fluid/distributed/fleet.py @@ -0,0 +1,63 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +from .. import core + +__all__ = ['Fleet'] + + +class Fleet(object): + """ + + """ + + def __init__(self): + self.instance_ = ps_instance.PaddlePSInstance() + self.fleet_ = core.FleetWrapper() + + def stop(self): + self.instance_.barrier_worker() + if self.instance.is_first_worker(): + self.fleet_.stop_server() + self.instance_.barrier_worker() + self.instance_.barrier_all() + self.instance.finalize() + + def init_pserver(self, dist_desc): + self.dist_desc_str_ = text_format.MessageToString(dist_desc) + self.dist_desc = dist_desc + self.fleet_.init_server(self.dist_desc_str_) + ip = self.fleet_.start_server() + self.instance_.set_ip(ip) + self.instance.barrier_all() + ips = self.instance.gather_ips() + self.fleet.gather_servers(ips, self.instance_.get_node_cnt()) + self.instance_.barrier_all() + + def init_worker(self, dist_desc): + self.dist_desc_str_ = text_format.MessageToString(dist_desc) + self.dist_desc_ = dist_desc + + self.instance_.barrier_all() + ips = self.instance.gather_ips() + self.fleet_.init_worker(self.dist_desc_str_, ips, + self.instance_.get_node_cnt(), + self.instance._rankid) + self.instance.barrier_worker() + + def init_pserver_model(self): + if self.instance_.is_first_worker(): + self.fleet_.init_model() + self.instance_.barrier_worker() + + def save_pserver_model(self, save_path): + self.fleet_.save_model(save_path) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 018e38cbb..98a16e201 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -610,3 +610,23 @@ class Executor(object): def _run_inference(self, exe, feed): return exe.run(feed) + + def run_from_dataset(self, + program=None, + dataset=None, + fetch_list=None, + scope=None, + opt_info=None): + if scope is None: + scope = global_scope() + if fetch_list is None: + fetch_list = [] + compiled = isinstance(program, compiler.CompiledProgram) + if not compiled: + trainer = TrainerFactory().create_trainer(opt_info) + self._default_executor.run_from_dataset(program_desc, + trainer._desc()) + else: + # For compiled program, more runtime should be implemented + print("run_from_dataset current does not support compiled program" + ", we will support this later", sys.stderr) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py deleted file mode 100644 index b495b6699..000000000 --- a/python/paddle/fluid/trainer.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# NOTE: Trainer is moved into fluid.contrib.trainer. -__all__ = [] diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py new file mode 100644 index 000000000..1b413b05d --- /dev/null +++ b/python/paddle/fluid/trainer_factory.py @@ -0,0 +1,32 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["TrainerFactory"] + + +class TrainerFactory(object): + def __init__(self): + pass + + def create_trainer(self, opt_info=None): + if opt_info == None: + return MultiTrainer() + else: + if opt_info["optimizer"] == "DownpourSGD": + trainer = DistMultiTrainer() + trainer.gen_trainer_desc( + fleet_desc=opt_info["fleet"], worker="downpour") + return trainer + else: + print("Currently only support DownpourSGD") -- GitLab