diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 13e89092ca7f8443df65b040096d36f61c1258e8..50224028abc2fcff43918c695be2cbab8997a997 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -63,6 +63,43 @@ inline void VSUB(int n, const T *x, const T *y, T *z) { } } +void Communicator::SetEnvFlagsDefault() { + env_flags_dict.clear(); + env_flags_dict.insert(std::pair( + "independent_recv_thread", FLAGS_communicator_independent_recv_thread)); + env_flags_dict.insert(std::pair( + "send_queue_size", FLAGS_communicator_send_queue_size)); + env_flags_dict.insert(std::pair( + "min_send_grad_num_before_recv", + FLAGS_communicator_min_send_grad_num_before_recv)); + env_flags_dict.insert(std::pair( + "thread_pool_size", FLAGS_communicator_thread_pool_size)); + env_flags_dict.insert(std::pair( + "send_wait_times", FLAGS_communicator_send_wait_times)); + env_flags_dict.insert(std::pair( + "max_merge_var_num", FLAGS_communicator_max_merge_var_num)); + env_flags_dict.insert( + std::pair("fake_rpc", FLAGS_communicator_fake_rpc)); + env_flags_dict.insert(std::pair( + "merge_sparse_grad", FLAGS_communicator_merge_sparse_grad)); + env_flags_dict.insert(std::pair( + "is_sgd_optimizer", FLAGS_communicator_is_sgd_optimizer)); + + return; +} + +Communicator::Communicator() { SetEnvFlagsDefault(); } + +Communicator::Communicator(const std::map &env_flags) { + SetEnvFlagsDefault(); + for (auto &iter : env_flags) { + std::string flag_name = iter.first; + int val_ = iter.second; + env_flags_dict.at(flag_name) = val_; + } + return; +} + std::once_flag Communicator::init_flag_; std::shared_ptr Communicator::communicator_(nullptr); @@ -73,25 +110,6 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); recv_scope_ = std::move(recv_scope); - // get all send information from graph, build vars_to_send - VLOG(0) << "communicator_independent_recv_thread: " - << FLAGS_communicator_independent_recv_thread; - VLOG(0) << "communicator_send_queue_size: " - << FLAGS_communicator_send_queue_size; - VLOG(0) << "communicator_min_send_grad_num_before_recv: " - << FLAGS_communicator_min_send_grad_num_before_recv; - VLOG(0) << "communicator_thread_pool_size: " - << FLAGS_communicator_thread_pool_size; - VLOG(0) << "communicator_send_wait_times: " - << FLAGS_communicator_send_wait_times; - VLOG(0) << "communicator_max_merge_var_num: " - << FLAGS_communicator_max_merge_var_num; - VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; - VLOG(0) << "communicator_merge_sparse_grad: " - << FLAGS_communicator_merge_sparse_grad; - VLOG(0) << "communicator_is_sgd_optimizer: " - << FLAGS_communicator_is_sgd_optimizer; - if (send_varname_to_ctx.size() == 0) { VLOG(0) << "nothing need to be send, will not start send_thread"; } else { @@ -99,17 +117,17 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, for (auto &iter : send_varname_to_ctx_) { send_varname_to_queue_[iter.first] = std::make_shared>>( - FLAGS_communicator_send_queue_size); + env_flags_dict["send_queue_size"]); } send_threadpool_.reset( - new ::ThreadPool(FLAGS_communicator_thread_pool_size)); + new ::ThreadPool(env_flags_dict["thread_pool_size"])); } if (recv_varname_to_ctx.size() == 0) { VLOG(0) << "nothing need to be received, will not start recv_thread"; } else { recv_threadpool_.reset( - new ::ThreadPool(FLAGS_communicator_thread_pool_size)); + new ::ThreadPool(env_flags_dict["thread_pool_size"])); } } @@ -132,7 +150,7 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program, auto trainer_id = boost::get(op->GetNullableAttr("trainer_id")); auto merge_add = boost::get(op->GetNullableAttr("merge_add")); if (!merge_add) { - merge_add = FLAGS_communicator_is_sgd_optimizer; + merge_add = static_cast(env_flags_dict["is_sgd_optimizer"]); } auto use_send_handler = boost::get(op->GetNullableAttr("use_send_handler")); @@ -194,10 +212,10 @@ void AsyncCommunicator::SendThread() { std::vector> vars; int merged_var_num = 0; int wait_times = 0; - while (merged_var_num < FLAGS_communicator_max_merge_var_num) { + while (merged_var_num < env_flags_dict["max_merge_var_num"]) { if (var_queue->Size() == 0) { VLOG(4) << "wait_times -> " << wait_times; - if (wait_times >= FLAGS_communicator_send_wait_times) { + if (wait_times >= env_flags_dict["send_wait_times"]) { break; } std::this_thread::sleep_for(std::chrono::milliseconds(10)); @@ -226,7 +244,7 @@ void AsyncCommunicator::SendThread() { VLOG(4) << "merge " << merged_var_num << " " << var_name << " use time " << after_merge - before_merge; auto send_functor = distributed::ParameterSend(); - if (!FLAGS_communicator_fake_rpc) { + if (!env_flags_dict["fake_rpc"]) { send_functor(ctx, *send_scope_, true, 1); } auto after_send = GetCurrentUS(); @@ -255,7 +273,7 @@ void AsyncCommunicator::RecvThread() { VLOG(3) << "RecvThread start!"; while (running_) { int grad_num = grad_num_.load(); - if (grad_num > FLAGS_communicator_min_send_grad_num_before_recv) { + if (grad_num > env_flags_dict["min_send_grad_num_before_recv"]) { VLOG(1) << "current grad num " << grad_num; RecvAll(); grad_num_.store(0); @@ -273,10 +291,10 @@ void AsyncCommunicator::Send(const std::string &var_name, auto *grad_var = scope.FindVar(var_name); PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); if (grad_var->IsType() && - !FLAGS_communicator_merge_sparse_grad) { + !env_flags_dict["merge_sparse_grad"]) { auto send_functor = distributed::ParameterSend(); auto &ctx = send_varname_to_ctx_.at(var_name); - if (!FLAGS_communicator_fake_rpc) { + if (!env_flags_dict["fake_rpc"]) { send_functor(ctx, scope, true, 1); } } else { @@ -289,7 +307,7 @@ void AsyncCommunicator::Send(const std::string &var_name, } void AsyncCommunicator::Recv() { - if (FLAGS_communicator_independent_recv_thread) { + if (env_flags_dict["independent_recv_thread"]) { return; } @@ -313,7 +331,7 @@ void AsyncCommunicator::RecvAll() { auto &var_name = iter.first; VLOG(4) << "recv var " << var_name; auto recv_functor = distributed::ParameterRecv(); - if (!FLAGS_communicator_fake_rpc) { + if (!env_flags_dict["fake_rpc"]) { recv_functor(iter.second, *recv_scope_); } }; @@ -336,7 +354,7 @@ void AsyncCommunicator::Start() { // start send and recv thread send_thread_.reset( new std::thread(std::bind(&AsyncCommunicator::SendThread, this))); - if (FLAGS_communicator_independent_recv_thread) { + if (env_flags_dict["independent_recv_thread"]) { recv_thread_.reset( new std::thread(std::bind(&AsyncCommunicator::RecvThread, this))); } @@ -396,25 +414,8 @@ void GeoSgdCommunicator::InitImpl( geo_need_push_nums_ = std::move(geo_need_push_nums); // get all send information from graph, build vars_to_send - VLOG(0) << "communicator_independent_recv_thread: " - << FLAGS_communicator_independent_recv_thread; - VLOG(0) << "communicator_send_queue_size: " - << FLAGS_communicator_send_queue_size; - VLOG(0) << "communicator_min_send_grad_num_before_recv: " - << FLAGS_communicator_min_send_grad_num_before_recv; - VLOG(0) << "communicator_thread_pool_size: " - << FLAGS_communicator_thread_pool_size; - VLOG(0) << "communicator_send_wait_times: " - << FLAGS_communicator_send_wait_times; - VLOG(0) << "communicator_max_merge_var_num: " - << FLAGS_communicator_max_merge_var_num; - VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; - VLOG(0) << "communicator_merge_sparse_grad: " - << FLAGS_communicator_merge_sparse_grad; VLOG(0) << "Trainer nums: " << trainer_nums_; VLOG(0) << "geo_sgd_push_before_local_train_nums: " << geo_need_push_nums_; - VLOG(0) << "communicator_merge_sparse_bucket " - << FLAGS_communicator_merge_sparse_bucket; // process var info from transpiler for (auto &iter : vars_info) { @@ -461,7 +462,7 @@ void GeoSgdCommunicator::InitImpl( LOG(WARNING) << "no var need to send and recv!!"; } - send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size)); + send_threadpool_.reset(new ::ThreadPool(env_flags_dict["thread_pool_size"])); need_push_queue_ = std::make_shared>>( geo_need_push_nums); @@ -570,7 +571,7 @@ void GeoSgdCommunicator::SendThread() { VLOG(4) << "ids_send_vec_ pushed"; } else if (need_push_queue_->Size() == 0) { VLOG(4) << "wait_times -> " << wait_times; - if (wait_times >= FLAGS_communicator_send_wait_times) { + if (wait_times >= env_flags_dict["send_wait_times"]) { break; } std::this_thread::sleep_for(std::chrono::milliseconds(10)); diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index c22ab9a798c6703d860f96865b34323a8cbc7444..b37163640a6f0cff67f42b3e8adcdcac851f5adf 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -174,9 +174,12 @@ using RpcCtxMap = std::unordered_map; class Communicator { public: - Communicator() {} + Communicator(); + explicit Communicator(const std::map& env_flags); virtual ~Communicator() {} + virtual void SetEnvFlagsDefault(); + virtual void Start() = 0; virtual void Stop() = 0; virtual bool IsRunning() { return running_; } @@ -221,9 +224,10 @@ class Communicator { template static Communicator* InitInstance( - const paddle::framework::ProgramDesc& program, Scope* recv_scope) { + const paddle::framework::ProgramDesc& program, Scope* recv_scope, + const std::map& env_flags) { std::call_once(init_flag_, &Communicator::InitWithProgram, program, - recv_scope); + recv_scope, std::ref(env_flags)); return communicator_.get(); } @@ -232,10 +236,12 @@ class Communicator { const paddle::framework::ProgramDesc& program, Scope* training_scope, std::map>>& vars_info, - const int& trainers, const int& geo_need_push_nums) { + const int& trainers, const int& geo_need_push_nums, + const std::map& env_flags) { std::call_once(init_flag_, &Communicator::InitWithTranspilerInfo, program, training_scope, std::ref(vars_info), - std::ref(trainers), std::ref(geo_need_push_nums)); + std::ref(trainers), std::ref(geo_need_push_nums), + std::ref(env_flags)); return communicator_.get(); } @@ -253,9 +259,10 @@ class Communicator { template static void InitWithProgram(const paddle::framework::ProgramDesc& program, - Scope* recv_scope) { + Scope* recv_scope, + const std::map& env_flags) { if (communicator_.get() == nullptr) { - communicator_.reset(new T()); + communicator_.reset(new T(std::ref(env_flags))); communicator_->InitImpl(program, recv_scope); } } @@ -265,9 +272,10 @@ class Communicator { const paddle::framework::ProgramDesc& program, Scope* training_scope, std::map>>& vars_info, - const int& trainers, const int& geo_need_push_nums) { + const int& trainers, const int& geo_need_push_nums, + const std::map& env_flags) { if (communicator_.get() == nullptr) { - communicator_.reset(new T()); + communicator_.reset(new T(std::ref(env_flags))); communicator_->InitImpl(program, training_scope, std::ref(vars_info), std::ref(trainers), std::ref(geo_need_push_nums)); } @@ -277,6 +285,7 @@ class Communicator { bool running_ = false; static std::shared_ptr communicator_; static std::once_flag init_flag_; + std::unordered_map env_flags_dict; }; using SparseIdsMap = @@ -284,7 +293,9 @@ using SparseIdsMap = class AsyncCommunicator : public Communicator { public: - AsyncCommunicator() {} + AsyncCommunicator() : Communicator() {} + explicit AsyncCommunicator(const std::map& env_flags) + : Communicator(env_flags) {} ~AsyncCommunicator(); void Start() override; void Stop() override; @@ -331,7 +342,9 @@ class AsyncCommunicator : public Communicator { class GeoSgdCommunicator : public Communicator { public: - GeoSgdCommunicator() {} + GeoSgdCommunicator() : Communicator() {} + explicit GeoSgdCommunicator(const std::map& env_flags) + : Communicator(env_flags) {} ~GeoSgdCommunicator(); void InitImpl( const paddle::framework::ProgramDesc& program, Scope* training_scope, diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index 05b95dbd4e6e995a3d2bb5df595294d2d1236c89..05df65135cfca10d11b210a366f74d0999bb3227 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -356,6 +356,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); + auto rpc_get_thread_num = Attr("rpc_get_thread_num"); + auto rpc_send_thread_num = Attr("rpc_send_thread_num"); + auto rpc_prefetch_thread_num = Attr("rpc_prefetch_thread_num"); + request_send_handler_.reset( new distributed::RequestSendHandler(sync_mode, dc_sgd)); request_get_handler_.reset( @@ -370,21 +374,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, new distributed::RequestNotifyHandler(sync_mode, lr_decay_block_id)); rpc_service_->RegisterRPC(distributed::kRequestSend, - request_send_handler_.get(), - FLAGS_rpc_send_thread_num); + request_send_handler_.get(), rpc_send_thread_num); rpc_service_->RegisterRPC(distributed::kRequestGet, - request_get_handler_.get(), - FLAGS_rpc_get_thread_num); + request_get_handler_.get(), rpc_get_thread_num); rpc_service_->RegisterRPC(distributed::kRequestPrefetch, request_prefetch_handler_.get(), - FLAGS_rpc_prefetch_thread_num); + rpc_prefetch_thread_num); rpc_service_->RegisterRPC(distributed::kRequestCheckpoint, request_checkpoint_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, request_get_no_barrier_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestNotify, - request_notify_handler_.get(), - FLAGS_rpc_send_thread_num); + request_notify_handler_.get(), rpc_send_thread_num); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -549,6 +550,11 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(-1); AddAttr(kLRDecayBlockId, "BolckID to run lr decay on pserer.") .SetDefault(-1); + AddAttr("rpc_get_thread_num", "pserver get thread num.").SetDefault(1); + AddAttr("rpc_send_thread_num", "pserver send thread num.") + .SetDefault(1); + AddAttr("rpc_prefetch_thread_num", "pserver prefetch thread num.") + .SetDefault(1); } }; diff --git a/paddle/fluid/pybind/communicator_py.cc b/paddle/fluid/pybind/communicator_py.cc index 45e8b8ad93ef8c9b478fd5a9df5e15b93be82568..ad507ec111f203fc538cb5de4604647bde029db4 100644 --- a/paddle/fluid/pybind/communicator_py.cc +++ b/paddle/fluid/pybind/communicator_py.cc @@ -39,19 +39,23 @@ void BindCommunicator(py::module* m) { // Communicator is already used by nccl, change to DistCommunicator py::class_>(*m, "DistCommunicator") - .def(py::init([](const ProgramDesc& program, Scope* param_scope) { + .def(py::init([](const ProgramDesc& program, Scope* param_scope, + std::map& env_flags) { VLOG(0) << "using communicator"; - Communicator::InitInstance(program, param_scope); + Communicator::InitInstance(program, param_scope, + env_flags); return Communicator::GetInstantcePtr(); })) .def(py::init([]( const ProgramDesc& program, Scope* training_scope, std::map>>& vars_info, - int& trainers, int& geo_need_push_nums) { + int& trainers, int& geo_need_push_nums, + std::map& env_flags) { VLOG(0) << "using geo sgd communicator"; Communicator::InitInstance( - program, training_scope, vars_info, trainers, geo_need_push_nums); + program, training_scope, vars_info, trainers, geo_need_push_nums, + env_flags); return Communicator::GetInstantcePtr(); })) .def("stop", &Communicator::Stop) diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py index e6caeb006caacc106a21fee40fd28125e436ae6c..8b27b774954f1a7dc06fb7b9393098c63d9f6219 100644 --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -28,7 +28,8 @@ class Communicator(object): program, vars_info=None, trainers=None, - geo_sgd_need_push_nums=None): + geo_sgd_need_push_nums=None, + env_flags=None): """ Communicator is used for async distribute training in distribute_transpiler mode. It's a wrapper of a cpp class Communicator and should be used inside fleet API. @@ -56,14 +57,19 @@ class Communicator(object): if op.type == "recv": op._set_attr('do_not_run', True) # Todo: Add check + if env_flags is None: + env_flags = {} + if vars_info and trainers and geo_sgd_need_push_nums: # for geo sgd self.communicator_ = core.DistCommunicator( program.desc, - global_scope(), vars_info, trainers, geo_sgd_need_push_nums) + global_scope(), vars_info, trainers, geo_sgd_need_push_nums, + env_flags) else: self.communicator_ = core.DistCommunicator(program.desc, - global_scope()) + global_scope(), + env_flags) def start(self): """ diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 81e6ecc11993084be7dd37e642cf9b2a3efff0c7..e334c48ed3962c3593c70a39d2bd789334befc88 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import warnings """ Convert the fluid program to distributed data-parallelism programs. """ +from .distributed_strategy import * import paddle.fluid.io as io from paddle.fluid.communicator import Communicator from paddle.fluid.framework import default_main_program @@ -26,8 +28,7 @@ from paddle.fluid.executor import Executor from paddle.fluid.parallel_executor import ParallelExecutor from paddle.fluid.optimizer import Optimizer from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspiler as OriginTranspiler -from paddle.fluid.transpiler.geo_sgd_transpiler import GeoSgdTranspiler -from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig +from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer from paddle.fluid.incubate.fleet.base.fleet_base import Fleet @@ -66,15 +67,24 @@ class DistributedTranspiler(Fleet): from paddle.fluid.transpiler.details.checkport import wait_server_ready wait_server_ready(fleet.server_endpoints(to_string=False)) - if not self._transpile_config.sync_mode: - if self._transpile_config.geo_sgd_mode: - self._communicator = Communicator( - self.main_program, self.vars_info, - fleet.worker_num(), - self._transpile_config.geo_sgd_need_push_nums) - else: - self._communicator = Communicator(self.main_program) - + program_config = self._transpile_config.get_program_config() + trainer_communicator_config = self._transpile_config.get_trainer_runtime_config( + ) + print(trainer_communicator_config) + + need_communicator_flag = False + if isinstance(self._transpile_config, GeoStrategy): + need_communicator_flag = True + self._communicator = Communicator( + self.main_program, self.vars_info, + fleet.worker_num(), program_config.geo_sgd_need_push_nums, + trainer_communicator_config.get_communicator_flags()) + elif isinstance(self._transpile_config, AsyncStrategy): + need_communicator_flag = True + self._communicator = Communicator( + self.main_program, + env_flags=trainer_communicator_config.get_communicator_flags()) + if need_communicator_flag: if not self._communicator.is_running(): self._communicator.start() else: @@ -129,7 +139,8 @@ class DistributedTranspiler(Fleet): Returns: None """ - if not self._transpile_config.sync_mode: + if isinstance(self._transpile_config, GeoStrategy) or isinstance( + self._transpile_config, AsyncStrategy): self._communicator.stop() self._executor.close() if isinstance(self._role_maker, MPISymetricRoleMaker): @@ -239,36 +250,44 @@ class DistributedTranspiler(Fleet): io.save_persistables(executor, dirname, main_program, None) def _transpile(self, config): - if not isinstance(config, DistributeTranspilerConfig): + if isinstance(config, DistributeTranspilerConfig): + self._transpile_config = DistributedStrategy() + self._transpile_config.set_program_config(config) + elif isinstance(config, DistributedStrategy): + self._transpile_config = config + else: raise TypeError( - "config must be an instance of DistributeTranspilerConfig") + "config must be an instance of DistributeTranspilerConfig or DistributedStrategy" + ) - if not config.sync_mode: - config.runtime_split_send_recv = True + program_config = self._transpile_config.get_program_config() # _origin_program is a deep copy for default_main_program, for inference self._origin_program = default_main_program().clone(for_test=False) - self._transpile_config = config - if config.geo_sgd_mode: - self._transpiler = GeoSgdTranspiler(config) + if program_config.geo_sgd_mode: + from paddle.fluid.transpiler.geo_sgd_transpiler import GeoSgdTranspiler + self._transpiler = GeoSgdTranspiler(program_config) else: - self._transpiler = OriginTranspiler(config) + self._transpiler = OriginTranspiler(program_config) + self._transpiler._set_server_config( + self._transpile_config.get_server_runtime_config()) if self.is_worker(): self._transpiler.transpile( trainer_id=fleet.worker_index(), pservers=fleet.server_endpoints(to_string=True), trainers=fleet.worker_num(), - sync_mode=config.sync_mode) + sync_mode=program_config.sync_mode) if isinstance(self._role_maker, MPISymetricRoleMaker): - config.wait_port = False + program_config.wait_port = False + self._transpile_config.set_program_config(program_config) self.main_program = self._transpiler.get_trainer_program( - wait_port=config.wait_port) + wait_port=program_config.wait_port) self.startup_program = default_startup_program() - if self._transpile_config.geo_sgd_mode: + if program_config.geo_sgd_mode: self.vars_info = self._transpiler._get_vars_info() self.startup_program = self._transpiler.trainer_startup_program else: @@ -276,7 +295,7 @@ class DistributedTranspiler(Fleet): trainer_id=fleet.worker_index(), pservers=fleet.server_endpoints(to_string=True), trainers=fleet.worker_num(), - sync_mode=config.sync_mode, + sync_mode=program_config.sync_mode, current_endpoint=self.server_endpoints()[self.server_index()]) self.main_program, self.startup_program = \ self._transpiler.get_pserver_programs( @@ -308,14 +327,17 @@ class TranspilerOptimizer(DistributedOptimizer): super(TranspilerOptimizer, self).__init__(optimizer, strategy) if strategy: - if not isinstance(strategy, DistributeTranspilerConfig): + if isinstance(strategy, DistributedStrategy): + self._strategy = strategy + elif isinstance(strategy, DistributeTranspilerConfig): + self._strategy = DistributedStrategy() + self._strategy.set_program_config(strategy) + else: raise TypeError( - "In {} mode, strategy must be an instance of DistributeTranspilerConfig". + "In {} mode, strategy must be an instance of DistributeTranspilerConfig or DistributedStrategy". format(fleet._mode)) - else: - self._strategy = strategy else: - self._strategy = DistributeTranspilerConfig() + self._strategy = DistributedStrategy() def backward(self, loss, diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..0b7277d82dd3f34760f50e3a7fca6f2659166335 --- /dev/null +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py @@ -0,0 +1,228 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "TrainerRuntimeConfig", "DistributedStrategy", "SyncStrategy", + "AsyncStrategy", "HalfAsyncStrategy", "GeoStrategy", "StrategyFactory" +] + +import os +import paddle.fluid as fluid +from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig + + +class TrainerRuntimeConfig(object): + def __init__(self): + self.max_merge_var_num = int( + os.getenv("FLAGS_communicator_max_merge_var_num", "20")) + self.send_queue_size = int( + os.getenv("FLAGS_communicator_send_queue_size", "20")) + self.independent_recv_thread = int( + os.getenv("FLAGS_communicator_independent_recv_thread", "1")) + self.min_send_grad_num_before_recv = int( + os.getenv("FLAGS_communicator_min_send_grad_num_before_recv", "20")) + self.thread_pool_size = int( + os.getenv("FLAGS_communicator_thread_pool_size", "5")) + self.send_wait_times = int( + os.getenv("FLAGS_communicator_send_wait_times", "5")) + self.fake_rpc = int(os.getenv("FLAGS_communicator_fake_rpc", "0")) + self.merge_sparse_grad = int( + os.getenv("FLAGS_communicator_merge_sparse_grad", "1")) + self.is_sgd_optimizer = int( + os.getenv("FLAGS_communicator_is_sgd_optimizer", "1")) + + # not used + self._rpc_deadline = int(os.getenv("FLAGS_rpc_deadline", "180000")) + self._rpc_retry_times = int(os.getenv("FLAGS_rpc_retry_times", "3")) + + def get_communicator_flags(self): + _communicator_flags = dict() + _communicator_flags["max_merge_var_num"] = self.max_merge_var_num + _communicator_flags["send_queue_size"] = self.send_queue_size + _communicator_flags[ + "independent_recv_thread"] = self.independent_recv_thread + _communicator_flags[ + "min_send_grad_num_before_recv"] = self.min_send_grad_num_before_recv + _communicator_flags["thread_pool_size"] = self.thread_pool_size + _communicator_flags["send_wait_times"] = self.send_wait_times + _communicator_flags["fake_rpc"] = self.fake_rpc + _communicator_flags["merge_sparse_grad"] = self.merge_sparse_grad + _communicator_flags["is_sgd_optimizer"] = self.is_sgd_optimizer + return _communicator_flags + + def __repr__(self): + _str = "please check that TrainerRuntimeConfig is as expected:\n" + _communicator_flags = self.get_communicator_flags() + for key in _communicator_flags: + _str += "communicator_{}: {}\n".format(key, + _communicator_flags[key]) + return _str + + +class DistributedStrategy(object): + def __init__(self): + self._program_config = DistributeTranspilerConfig() + self._trainer_runtime_config = TrainerRuntimeConfig() + self._server_runtime_config = ServerRuntimeConfig() + self._execute_strategy = fluid.ExecutionStrategy() + self._build_strategy = fluid.BuildStrategy() + num_threads = int(os.getenv("CPU_NUM", "1")) + self._execute_strategy.num_threads = num_threads + if num_threads > 1: + self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce + + def get_program_config(self): + return self._program_config + + def set_program_config(self, config): + if isinstance(config, DistributeTranspilerConfig): + self._program_config = config + elif isinstance(config, dict): + for key in config: + if hasattr(self._program_config, key): + setattr(self._program_config, key, config[key]) + else: + raise ValueError( + "DistributeTranspilerConfig doesn't have key: {}". + format(key)) + else: + raise TypeError( + "program_config only accept input type: dict or DistributeTranspilerConfig" + ) + + def get_trainer_runtime_config(self): + return self._trainer_runtime_config + + def set_trainer_runtime_config(self, config): + if isinstance(config, TrainerRuntimeConfig): + self._trainer_runtime_config = config + elif isinstance(config, dict): + for key in config: + if hasattr(self._trainer_runtime_config, key): + setattr(self._trainer_runtime_config, key, config[key]) + else: + raise ValueError( + "TrainerRuntimeConfig doesn't have key: {}".format(key)) + else: + raise TypeError( + "trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig" + ) + + def get_server_runtime_config(self): + return self._server_runtime_config + + def set_server_runtime_config(self, config): + if isinstance(config, ServerRuntimeConfig): + self._server_runtime_config = config + elif isinstance(config, dict): + for key in config: + if hasattr(self._server_runtime_config, key): + setattr(self._server_runtime_config, key, config[key]) + else: + raise ValueError( + "ServerRuntimeConfig doesn't have key: {}".format(key)) + else: + raise TypeError( + "server_runtime_config only accept input type: dict or ServerRuntimeConfig" + ) + + def get_execute_strategy(self): + return self._execute_strategy + + def set_execute_strategy(self, config): + if isinstance(config, fluid.ExecutionStrategy): + self._execute_strategy = config + elif isinstance(config, dict): + for key in config: + if hasattr(self._execute_strategy, key): + setattr(self._execute_strategy, key, config[key]) + else: + raise ValueError( + "ExecutionStrategy doesn't have key: {}".format(key)) + else: + raise TypeError( + "execute_strategy only accept input type: dict or ExecutionStrategy" + ) + + def get_build_strategy(self): + return self._build_strategy + + def set_build_strategy(self, config): + if isinstance(config, fluid.BuildStrategy): + self._build_strategy = config + elif isinstance(config, dict): + for key in config: + if hasattr(self._build_strategy, key): + setattr(self._build_strategy, key, config[key]) + else: + raise ValueError( + "BuildStrategy doesn't have key: {}".format(key)) + else: + raise TypeError( + "build_strategy only accept input type: dict or BuildStrategy") + + +class SyncStrategy(DistributedStrategy): + def __init__(self): + super(SyncStrategy, self).__init__() + self._program_config.sync_mode = True + self._program_config.runtime_split_send_recv = False + self._build_strategy.async_mode = False + + +class AsyncStrategy(DistributedStrategy): + def __init__(self): + super(AsyncStrategy, self).__init__() + self._program_config.sync_mode = False + self._program_config.runtime_split_send_recv = True + self._build_strategy.async_mode = True + + +class HalfAsyncStrategy(DistributedStrategy): + def __init__(self): + super(HalfAsyncStrategy, self).__init__() + self._program_config.sync_mode = False + self._program_config.runtime_split_send_recv = False + self._build_strategy.async_mode = False + + +class GeoStrategy(DistributedStrategy): + def __init__(self, update_frequency=100): + super(GeoStrategy, self).__init__() + self._program_config.sync_mode = False + self._program_config.runtime_split_send_recv = True + self._program_config.geo_sgd_mode = True + self._program_config.geo_sgd_need_push_nums = update_frequency + self._build_strategy.async_mode = True + + +class StrategyFactory(object): + def __init_(self): + pass + + @staticmethod + def create_sync_strategy(): + return SyncStrategy() + + @staticmethod + def create_half_async_strategy(): + return HalfAsyncStrategy() + + @staticmethod + def create_async_strategy(): + return AsyncStrategy() + + @staticmethod + def create_geo_strategy(update_frequency=100): + return GeoStrategy(update_frequency) diff --git a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py index 1c4d9703fafed0d0f251be814a8664ded658e9e0..d5a29b925649be0752ac8d2b14d6119ec88618be 100644 --- a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py +++ b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py @@ -61,6 +61,24 @@ def load_lr_input_record(sent): return res +class CtrReader(object): + def __init__(self): + pass + + def _reader_creator(self, filelist): + def reader(): + for file in filelist: + with open(file, 'r') as f: + for line in f: + fs = line.strip().split('\t') + dnn_input = load_dnn_input_record(fs[0]) + lr_input = load_lr_input_record(fs[1]) + click = [int(fs[2])] + yield [dnn_input] + [lr_input] + [click] + + return reader + + class DatasetCtrReader(data_generator.MultiSlotDataGenerator): def generate_sample(self, line): def get_rand(low=0.0, high=1.0): diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index 41a689f849e9c7033f6599fca7ed3ec90bec42c6..3988c6900189553948cf2dbed4316df669bbb451 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -21,8 +21,10 @@ import shutil import tempfile import time +import paddle import paddle.fluid as fluid import os +import numpy as np import ctr_dataset_reader from test_dist_fleet_base import runtime_main, FleetDistRunnerBase @@ -131,7 +133,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): with open(os.path.join(dirname, "__model__.proto"), "w") as wn: wn.write(str(program)) - def do_training(self, fleet): + def do_pyreader_training(self, fleet): """ do training using dataset, using fetch handler to catch variable Args: @@ -146,13 +148,63 @@ class TestDistCTR2x2(FleetDistRunnerBase): exe.run(fleet.startup_program) thread_num = 2 + batch_size = 128 + filelist = [] + for _ in range(thread_num): + filelist.append(train_file_path) + + train_reader = paddle.batch( + paddle.reader.shuffle( + ctr_dataset_reader.CtrReader()._reader_creator(filelist), + buf_size=batch_size * 100), + batch_size=batch_size) + self.reader.decorate_sample_list_generator(train_reader) + + compiled_prog = fluid.compiler.CompiledProgram( + fleet.main_program).with_data_parallel( + loss_name=self.avg_cost.name, + build_strategy=self.strategy.get_build_strategy(), + exec_strategy=self.strategy.get_execute_strategy()) + + for epoch_id in range(1): + self.reader.start() + try: + pass_start = time.time() + while True: + loss_val = exe.run(program=compiled_prog, + fetch_list=[self.avg_cost.name]) + loss_val = np.mean(loss_val) + print("TRAIN ---> pass: {} loss: {}\n".format(epoch_id, + loss_val)) + pass_time = time.time() - pass_start + except fluid.core.EOFException: + self.reader.reset() + + model_dir = tempfile.mkdtemp() + fleet.save_inference_model( + exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost) + self.check_model_right(model_dir) + shutil.rmtree(model_dir) + fleet.stop_worker() + + def do_dataset_training(self, fleet): + dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data( + ) + + exe = fluid.Executor(fluid.CPUPlace()) + + fleet.init_worker() + exe.run(fleet.startup_program) + + thread_num = 2 + batch_size = 128 filelist = [] for _ in range(thread_num): filelist.append(train_file_path) # config dataset dataset = fluid.DatasetFactory().create_dataset() - dataset.set_batch_size(128) + dataset.set_batch_size(batch_size) dataset.set_use_var(self.feeds) pipe_command = 'python ctr_dataset_reader.py' dataset.set_pipe_command(pipe_command) @@ -172,11 +224,14 @@ class TestDistCTR2x2(FleetDistRunnerBase): debug=False) pass_time = time.time() - pass_start + res_dict = dict() + res_dict['loss'] = self.avg_cost + class FH(fluid.executor.FetchHandler): - def handler(self, fetch_target_vars): - for i in range(len(fetch_target_vars)): - print("{}: \n {}\n".format(self.fetch_target_names[0], - fetch_target_vars[0])) + def handle(self, res_dict): + for key in res_dict: + v = res_dict[key] + print("{}: \n {}\n".format(key, v)) for epoch_id in range(1): pass_start = time.time() @@ -184,7 +239,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): exe.train_from_dataset( program=fleet.main_program, dataset=dataset, - fetch_handler=FH([self.avg_cost.name], period_secs=2), + fetch_handler=FH(var_dict=res_dict, period_secs=2), debug=False) pass_time = time.time() - pass_start diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py index 3733d4cfad0dffa8bd38602774286661dd022709..9227eb651faabaf068b64745877a8fa6073f80d2 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py @@ -37,6 +37,7 @@ import paddle.fluid as fluid import paddle.fluid.incubate.fleet.base.role_maker as role_maker from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory RUN_STEP = 5 LEARNING_RATE = 0.01 @@ -50,6 +51,19 @@ class FleetDistRunnerBase(object): do training : exe run program """ + def generate_strategy(self, args): + self.strategy = None + if args.mode == "async": + self.strategy = StrategyFactory.create_async_strategy() + elif args.mode == "sync": + self.strategy = StrategyFactory.create_sync_strategy() + elif args.mode == "half_async": + self.strategy = StrategyFactory.create_half_async_strategy() + elif args.mode == "geo": + self.strategy = StrategyFactory.create_geo_strategy( + args.geo_sgd_need_push_nums) + return self.strategy + def run_pserver(self, args): if args.role.upper() != "PSERVER": raise ValueError("args role must be PSERVER") @@ -62,10 +76,7 @@ class FleetDistRunnerBase(object): fleet.init(role) - strategy = DistributeTranspilerConfig() - strategy.sync_mode = args.sync_mode - strategy.geo_sgd_mode = args.geo_sgd_mode - strategy.geo_sgd_need_push_nums = args.geo_sgd_need_push_nums + strategy = self.generate_strategy(args) avg_cost = self.net() @@ -76,7 +87,28 @@ class FleetDistRunnerBase(object): fleet.init_server() fleet.run_server() - def run_trainer(self, args): + def run_dataset_trainer(self, args): + if args.role.upper() != "TRAINER": + raise ValueError("args role must be TRAINER") + + role = role_maker.UserDefinedRoleMaker( + current_id=args.current_id, + role=role_maker.Role.WORKER, + worker_num=args.trainers, + server_endpoints=args.endpoints.split(",")) + + fleet.init(role) + + strategy = self.generate_strategy(args) + + avg_cost = self.net() + optimizer = fluid.optimizer.SGD(LEARNING_RATE) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + optimizer.minimize(avg_cost) + + out = self.do_dataset_training(fleet) + + def run_pyreader_trainer(self, args): if args.role.upper() != "TRAINER": raise ValueError("args role must be TRAINER") @@ -88,26 +120,33 @@ class FleetDistRunnerBase(object): fleet.init(role) - strategy = DistributeTranspilerConfig() - strategy.sync_mode = args.sync_mode - strategy.geo_sgd_mode = args.geo_sgd_mode - strategy.geo_sgd_need_push_nums = args.geo_sgd_need_push_nums + strategy = self.generate_strategy(args) avg_cost = self.net() + self.reader = fluid.io.PyReader( + feed_list=self.feeds, + capacity=64, + iterable=False, + use_double_buffer=False) + optimizer = fluid.optimizer.SGD(LEARNING_RATE) optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(avg_cost) - out = self.do_training(fleet) + out = self.do_pyreader_training(fleet) def net(self, batch_size=4, lr=0.01): raise NotImplementedError( "get_model should be implemented by child classes.") - def do_training(self, fleet): + def do_dataset_training(self, fleet): raise NotImplementedError( - "do_training should be implemented by child classes.") + "do_dataset_training should be implemented by child classes.") + + def do_pyreader_training(self, fleet): + raise NotImplementedError( + "do_pyreader_training should be implemented by child classes.") class TestFleetBase(unittest.TestCase): @@ -120,7 +159,8 @@ class TestFleetBase(unittest.TestCase): raise NotImplementedError("tests should have _setup_config implemented") def setUp(self): - self._sync_mode = True + self._mode = "sync" + self._reader = "pyreader" self._trainers = 2 self._pservers = 2 self._port_set = set() @@ -139,7 +179,6 @@ class TestFleetBase(unittest.TestCase): self._find_free_port(), self._find_free_port()) self._python_interp = sys.executable - self._geo_sgd = False self._geo_sgd_need_push_nums = 5 self._setup_config() @@ -203,21 +242,13 @@ class TestFleetBase(unittest.TestCase): envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') python_path += " -m coverage run --branch -p" - tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3}".format( - python_path, model, self._ps_endpoints, self._trainers) - - ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3}".format( - python_path, model, self._ps_endpoints, self._trainers) + tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format( + python_path, model, self._ps_endpoints, self._trainers, self._mode, + self._geo_sgd_need_push_nums, self._reader) - if self._sync_mode: - tr_cmd += " --sync_mode" - ps_cmd += " --sync_mode" - - if self._geo_sgd: - tr_cmd += " --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}".format( - self._geo_sgd, self._geo_sgd_need_push_nums) - ps_cmd += " --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}".format( - self._geo_sgd, self._geo_sgd_need_push_nums) + ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format( + python_path, model, self._ps_endpoints, self._trainers, self._mode, + self._geo_sgd_need_push_nums, self._reader) # Run dist train to compare with local results ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env) @@ -301,15 +332,17 @@ def runtime_main(test_class): parser.add_argument('--endpoints', type=str, required=False, default="") parser.add_argument('--current_id', type=int, required=False, default=0) parser.add_argument('--trainers', type=int, required=False, default=1) - parser.add_argument('--sync_mode', action='store_true') - parser.add_argument( - '--geo_sgd_mode', type=bool, required=False, default=False) + parser.add_argument('--mode', type=str, required=False, default='geo') parser.add_argument( '--geo_sgd_need_push_nums', type=int, required=False, default=2) + parser.add_argument('--reader', type=str, required=False, default='dataset') args = parser.parse_args() model = test_class() if args.role == "pserver": model.run_pserver(args) else: - model.run_trainer(args) + if args.reader == "dataset": + model.run_dataset_trainer(args) + else: + model.run_pyreader_trainer(args) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py index 5d3c0fbdd0c9aebf7b229f77aadafea5fb8a23c6..548e053105d60e6439cbec7b0f1f589813f03d14 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py @@ -19,9 +19,103 @@ import unittest from test_dist_fleet_base import TestFleetBase -class TestDistMnist2x2(TestFleetBase): +class TestDistMnistSync2x2(TestFleetBase): def _setup_config(self): - self._sync_mode = False + self._mode = "sync" + self._reader = "pyreader" + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_rpc_deadline": "5000", # 5sec to fail fast + "http_proxy": "" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) + + def test_dist_train(self): + self.check_with_place( + "dist_fleet_ctr.py", delta=1e-5, check_error_log=True) + + +class TestDistMnistHalfAsync2x2(TestFleetBase): + def _setup_config(self): + self._mode = "half_async" + self._reader = "pyreader" + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_rpc_deadline": "5000", # 5sec to fail fast + "http_proxy": "" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) + + def test_dist_train(self): + self.check_with_place( + "dist_fleet_ctr.py", delta=1e-5, check_error_log=True) + + +class TestDistMnistAsync2x2(TestFleetBase): + def _setup_config(self): + self._mode = "async" + self._reader = "pyreader" + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_rpc_deadline": "5000", # 5sec to fail fast + "http_proxy": "" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) + + def test_dist_train(self): + self.check_with_place( + "dist_fleet_ctr.py", delta=1e-5, check_error_log=True) + + +class TestDistMnistAsyncDataset2x2(TestFleetBase): + def _setup_config(self): + self._mode = "async" + self._reader = "dataset" def check_with_place(self, model_file, diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py index 402ae15e9c43423da3619a3debaa7edb7d2d8bdb..ee0600d31054630d01d0b352297051b7ae78ada4 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py @@ -19,15 +19,16 @@ import unittest import paddle.fluid as fluid import paddle.fluid.incubate.fleet.base.role_maker as role_maker from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet -from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig +from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig +from paddle.fluid.transpiler.geo_sgd_transpiler import GeoSgdTranspiler from test_dist_fleet_base import TestFleetBase from dist_simnet_bow import train_network class TestDistGeoCtr_2x2(TestFleetBase): def _setup_config(self): - self._sync_mode = False - self._geo_sgd = True + self._mode = "geo" + self._reader = "dataset" self._geo_sgd_need_push_nums = 5 def check_with_place(self, diff --git a/python/paddle/fluid/tests/unittests/test_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_distributed_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..e2355b1e8ce8b65f3f7197b1de535f248ade4837 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_distributed_strategy.py @@ -0,0 +1,169 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, StrategyFactory +import os + + +class TestStrategyFactor(unittest.TestCase): + def test_sync_strategy(self): + os.environ['CPU_NUM'] = "2" + strategy = StrategyFactory.create_sync_strategy() + self.assertEqual(strategy._program_config.sync_mode, True) + self.assertEqual(strategy._program_config.runtime_split_send_recv, + False) + self.assertEqual(strategy._build_strategy.async_mode, False) + self.assertEqual(strategy._execute_strategy.num_threads, 2) + + # test set_program_config using DistributeTranspilerConfig() + program_config_class = DistributeTranspilerConfig() + program_config_class.min_block_size = 81920 + strategy.set_program_config(program_config_class) + program_config = strategy.get_program_config() + self.assertEqual(program_config.min_block_size, 81920) + + # test set_program_config using dict + program_config_dict = dict() + program_config_dict['min_block_size'] = 8192 + strategy.set_program_config(program_config_dict) + program_config = strategy.get_program_config() + self.assertEqual(program_config.min_block_size, 8192) + + # test set_program_config exception + program_config_dict['unknown'] = None + self.assertRaises(Exception, strategy.set_program_config, + program_config_dict) + program_config_illegal = None + self.assertRaises(Exception, strategy.set_program_config, + program_config_illegal) + + def test_geo_strategy(self): + strategy = StrategyFactory.create_geo_strategy(5) + self.assertEqual(strategy._program_config.sync_mode, False) + self.assertEqual(strategy._program_config.runtime_split_send_recv, True) + self.assertEqual(strategy._program_config.geo_sgd_mode, True) + self.assertEqual(strategy._program_config.geo_sgd_need_push_nums, 5) + self.assertEqual(strategy._build_strategy.async_mode, True) + + # test set_build_strategy using fluid.BuildStrategy + build_strategy_class = fluid.BuildStrategy() + build_strategy_class.memory_optimize = False + strategy.set_build_strategy(build_strategy_class) + build_strategy = strategy.get_build_strategy() + self.assertEqual(build_strategy.memory_optimize, False) + + # test set_build_strategy using dict + build_strategy_dict = dict() + build_strategy_dict['memory_optimize'] = True + strategy.set_build_strategy(build_strategy_dict) + build_strategy = strategy.get_build_strategy() + self.assertEqual(build_strategy.memory_optimize, True) + + # test set_build_strategy exception + build_strategy_dict['unknown'] = None + self.assertRaises(Exception, strategy.set_build_strategy, + build_strategy_dict) + build_strategy_illegal = None + self.assertRaises(Exception, strategy.set_build_strategy, + build_strategy_illegal) + + def test_async_strategy(self): + strategy = StrategyFactory.create_async_strategy() + self.assertEqual(strategy._program_config.sync_mode, False) + self.assertEqual(strategy._program_config.runtime_split_send_recv, True) + self.assertEqual(strategy._build_strategy.async_mode, True) + + # test set_trainer_runtime_config using TrainerRuntimeConfig + trainer_runtime_config_class = TrainerRuntimeConfig() + trainer_runtime_config_class.send_queue_size = 50 + print(trainer_runtime_config_class) + strategy.set_trainer_runtime_config(trainer_runtime_config_class) + trainer_runtime_config = strategy.get_trainer_runtime_config() + self.assertEqual(trainer_runtime_config.send_queue_size, 50) + + # test set_trainer_runtime_config using dict + trainer_runtime_config_dict = dict() + trainer_runtime_config_dict['send_queue_size'] = 100 + strategy.set_trainer_runtime_config(trainer_runtime_config_dict) + trainer_runtime_config = strategy.get_trainer_runtime_config() + trainer_communicator_flags = trainer_runtime_config.get_communicator_flags( + ) + self.assertIn('send_queue_size', trainer_communicator_flags) + self.assertEqual(trainer_communicator_flags['send_queue_size'], 100) + + # test set_trainer_runtime_config exception + trainer_runtime_config_dict['unknown'] = None + self.assertRaises(Exception, strategy.set_trainer_runtime_config, + trainer_runtime_config_dict) + trainer_runtime_config_illegal = None + self.assertRaises(Exception, strategy.set_trainer_runtime_config, + trainer_runtime_config_illegal) + + # test set_execute_strategy using fluid.ExecutionStrategy + exec_strategy_class = fluid.ExecutionStrategy() + exec_strategy_class.num_threads = 4 + strategy.set_execute_strategy(exec_strategy_class) + exec_strategy = strategy.get_execute_strategy() + self.assertEqual(exec_strategy.num_threads, 4) + + # test set_execute_strategy using dict + exec_strategy_dict = dict() + exec_strategy_dict['num_threads'] = 8 + strategy.set_execute_strategy(exec_strategy_dict) + exec_strategy = strategy.get_execute_strategy() + self.assertEqual(exec_strategy.num_threads, 8) + + # test set_execute_strategy exception + exec_strategy_dict['unknown'] = None + self.assertRaises(Exception, strategy.set_execute_strategy, + exec_strategy_dict) + exec_strategy_illegal = None + self.assertRaises(Exception, strategy.set_execute_strategy, + exec_strategy_illegal) + + def test_half_async_strategy(self): + strategy = StrategyFactory.create_half_async_strategy() + self.assertEqual(strategy._program_config.sync_mode, False) + self.assertEqual(strategy._program_config.runtime_split_send_recv, + False) + self.assertEqual(strategy._build_strategy.async_mode, False) + + # test set_server_runtime_config using ServerRuntimeConfig + server_runtime_config_class = ServerRuntimeConfig() + server_runtime_config_class._rpc_send_thread_num = 24 + strategy.set_server_runtime_config(server_runtime_config_class) + server_runtime_config = strategy.get_server_runtime_config() + self.assertEqual(server_runtime_config._rpc_send_thread_num, 24) + + # test set_server_runtime_config using dict + server_runtime_config_dict = dict() + server_runtime_config_dict['_rpc_send_thread_num'] = 20 + strategy.set_server_runtime_config(server_runtime_config_dict) + server_runtime_config = strategy.get_server_runtime_config() + self.assertEqual(server_runtime_config._rpc_send_thread_num, 20) + + # test set_server_runtime_config exception + server_runtime_config_dict['unknown'] = None + self.assertRaises(Exception, strategy.set_server_runtime_config, + server_runtime_config_dict) + server_runtime_config_illegal = None + self.assertRaises(Exception, strategy.set_server_runtime_config, + server_runtime_config_illegal) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 3579536c96f2a98453876a5bb978241458630cb9..dbc63d5cd675cbbd4b0d840f815f6ae0f3cb52c5 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -30,6 +30,7 @@ Steps to transpile pserver: 5. add listen_and_serv op """ +import os import sys import math from functools import reduce @@ -177,8 +178,8 @@ class DistributeTranspilerConfig(object): print_log = False wait_port = True # split the send recv var in runtime - _runtime_split_send_recv = False - _sync_mode = True + __runtime_split_send_recv = False + __sync_mode = True # Geo-sgd algorithm geo_sgd_mode = False @@ -200,31 +201,41 @@ class DistributeTranspilerConfig(object): @property def runtime_split_send_recv(self): - return self._runtime_split_send_recv + return self.__runtime_split_send_recv @runtime_split_send_recv.setter def runtime_split_send_recv(self, value): if value is None: raise ValueError("runtime_split_send_recv can't be None") - if value and self._sync_mode: + if value and self.__sync_mode: raise ValueError( "if you want to set runtime_split_send_recv to be true, make ensure config.sync_mode is false at first" ) - self._runtime_split_send_recv = value + self.__runtime_split_send_recv = value @property def sync_mode(self): - return self._sync_mode + return self.__sync_mode @sync_mode.setter def sync_mode(self, value): if value is None: raise ValueError("sync_mode can't be None") - if value and self._runtime_split_send_recv: + if value and self.__runtime_split_send_recv: raise ValueError( "if you want to set sync_mode to be true, make ensure config.runtime_split_send_recv is false at first" ) - self._sync_mode = value + self.__sync_mode = value + + +class ServerRuntimeConfig(object): + def __init__(self): + self._rpc_send_thread_num = int( + os.getenv("FLAGS_rpc_send_thread_num", "12")) + self._rpc_get_thread_num = int( + os.getenv("FLAGS_rpc_get_thread_num", "12")) + self._rpc_prefetch_thread_num = int( + os.getenv("FLAGS_rpc_prefetch_thread_num", "12")) class DistributeTranspiler(object): @@ -295,6 +306,7 @@ class DistributeTranspiler(object): self.config = config else: self.config = DistributeTranspilerConfig() + self._set_server_config() if self.config.split_method is None: self.config.split_method = RoundRobin @@ -306,6 +318,16 @@ class DistributeTranspiler(object): assert (self.config.split_method.__bases__[0] == PSDispatcher) self.counter_var = None + def _set_server_config(self, server_config=None): + if server_config is None: + self.server_config = ServerRuntimeConfig() + elif isinstance(server_config, ServerRuntimeConfig): + self.server_config = server_config + else: + raise TypeError( + "In DistributeTranspiler, server_config must be an instance of ServerRuntimeConfig" + ) + def _transpile_nccl2(self, trainer_id, trainers, @@ -1313,6 +1335,10 @@ class DistributeTranspiler(object): "grad_to_block_id": grad_to_block_id, "sparse_grad_to_param": sparse_grad_to_param, "lr_decay_block_id": lr_decay_block_id, + "rpc_get_thread_num": self.server_config._rpc_get_thread_num, + "rpc_send_thread_num": self.server_config._rpc_send_thread_num, + "rpc_prefetch_thread_num": + self.server_config._rpc_prefetch_thread_num } if self.has_distributed_lookup_table: diff --git a/python/paddle/fluid/transpiler/geo_sgd_transpiler.py b/python/paddle/fluid/transpiler/geo_sgd_transpiler.py index bd791e50b61c0af4ec4f90a09dde99902494e04c..4bd1aa2d15dfa624eb5b8bf153af49e0d7dc2473 100644 --- a/python/paddle/fluid/transpiler/geo_sgd_transpiler.py +++ b/python/paddle/fluid/transpiler/geo_sgd_transpiler.py @@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \ from .details import wait_server_ready, VarsDistributed from .details import delete_ops from ..distribute_lookup_table import find_distributed_lookup_table -from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var +from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var, ServerRuntimeConfig RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( ) @@ -51,6 +51,7 @@ class GeoSgdTranspiler(DistributeTranspiler): self.config = config else: self.config = DistributeTranspilerConfig() + self._set_server_config() if self.config.split_method is None: self.config.split_method = RoundRobin @@ -241,7 +242,11 @@ class GeoSgdTranspiler(DistributeTranspiler): "Fanin": self.trainer_num, "sync_mode": self.sync_mode, "grad_to_block_id": param_to_block_id, - "sparse_grad_to_param": sparse_grad_to_param + "sparse_grad_to_param": sparse_grad_to_param, + "rpc_get_thread_num": self.server_config._rpc_get_thread_num, + "rpc_send_thread_num": self.server_config._rpc_send_thread_num, + "rpc_prefetch_thread_num": + self.server_config._rpc_prefetch_thread_num } # step5 append the listen_and_serv op