From ca8c77d966c963c4afafe4750391de63014dea0f Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Fri, 28 Dec 2018 17:08:29 +0800 Subject: [PATCH] selecte execution according to strategy test=develop --- .../fluid/framework/details/build_strategy.cc | 7 +- .../fluid/framework/details/build_strategy.h | 11 ++- .../details/multi_devices_graph_pass.cc | 12 +-- paddle/fluid/framework/parallel_executor.cc | 77 ++++++++++++------- paddle/fluid/framework/parallel_executor.h | 3 + paddle/fluid/pybind/pybind.cc | 8 -- python/paddle/fluid/__init__.py | 3 +- .../unittests/parallel_executor_test_base.py | 2 - .../unittests/test_parallel_executor_crf.py | 8 +- .../unittests/test_parallel_executor_mnist.py | 39 +++------- .../test_parallel_executor_seresnext.py | 15 +--- .../test_parallel_executor_transformer.py | 2 - 12 files changed, 86 insertions(+), 101 deletions(-) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 504265260..9a092104e 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -134,7 +134,7 @@ std::shared_ptr BuildStrategy::CreatePassesFromStrategy( std::unique_ptr BuildStrategy::Apply( const ProgramDesc &main_program, const std::vector &places, const std::string &loss_var_name, const std::vector &local_scopes, - const size_t &num_parallel_devices, + const size_t &nranks, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { #else @@ -153,9 +153,8 @@ std::unique_ptr BuildStrategy::Apply( pass->Erase("local_scopes"); pass->SetNotOwned>("local_scopes", &local_scopes); - pass->Erase("num_parallel_devices"); - pass->Set("num_parallel_devices", - new size_t(num_parallel_devices)); + pass->Erase("nranks"); + pass->Set("nranks", new size_t(nranks)); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index b31e60ad8..b75c01c48 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -84,8 +84,6 @@ struct BuildStrategy { bool fuse_broadcast_op_{false}; - bool enable_parallel_graph_{false}; - int num_trainers_{1}; int trainer_id_{0}; std::vector trainers_endpoints_; @@ -112,7 +110,7 @@ struct BuildStrategy { const std::vector &places, const std::string &loss_var_name, const std::vector &local_scopes, - const size_t &num_parallel_devices_, + const size_t &nranks, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const; @@ -120,6 +118,13 @@ struct BuildStrategy { const bool use_cuda) const; #endif + // If set true, ParallelExecutor would build the main_program into multiple + // graphs, + // each of the graphs would run with one device. This approach can achieve + // better performance + // on some scenarios. + mutable bool enable_parallel_graph_ = false; + private: mutable bool is_finalized_ = false; mutable std::shared_ptr pass_builder_; diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 211668b87..761c9ab90 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -138,7 +138,7 @@ static const char kLossVarName[] = "loss_var_name"; static const char kPlaces[] = "places"; static const char kLocalScopes[] = "local_scopes"; static const char kStrategy[] = "strategy"; -static const char kNumParallelDevices[] = "num_parallel_devices"; +static const char kNRanks[] = "nranks"; void MultiDevSSAGraphBuilder::Init() const { all_vars_.clear(); @@ -174,7 +174,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( auto nodes = graph->ReleaseNodes(); ir::Graph &result = *graph; - size_t num_parallel_devices = Get(kNumParallelDevices); + size_t nranks = Get(kNRanks); for (auto &node : nodes) { if (node->IsVar() && node->Var()) { @@ -251,7 +251,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( CreateComputationalOps(&result, node, places_.size()); } - if (!is_forwarding && num_parallel_devices > 1UL) { + if (!is_forwarding && nranks > 1UL) { bool is_bk_op = static_cast(boost::get(node->Op()->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & @@ -649,13 +649,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID( void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( ir::Graph *result, const std::string &loss_grad_name, ir::Node *out_var_node, proto::VarType::Type dtype) const { - size_t num_parallel_devices = Get("num_parallel_devices"); + size_t nranks = Get("nranks"); for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *op_handle = new ScaleLossGradOpHandle( result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), - num_parallel_devices, local_scopes_[i], places_[i], dev_ctx, dtype); + nranks, local_scopes_[i], places_[i], dev_ctx, dtype); result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -888,4 +888,4 @@ REGISTER_PASS(multi_devices_pass, .RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kLocalScopes) .RequirePassAttr(paddle::framework::details::kStrategy) - .RequirePassAttr(paddle::framework::details::kNumParallelDevices); + .RequirePassAttr(paddle::framework::details::kNRanks); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index fd566be44..934cf34cb 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -107,7 +107,7 @@ class ParallelExecutorPrivate { bool own_local_scope_; bool use_cuda_; bool use_all_reduce_; - size_t num_parallel_devices_; + size_t nranks_; // global_ref_cnts_ is only initialized when ParallelExecutor constructs, and // then keeps unchanged @@ -203,7 +203,7 @@ ParallelExecutor::ParallelExecutor( member_->build_strategy_ = build_strategy; member_->use_all_reduce_ = build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; - member_->num_parallel_devices_ = num_trainers * places.size(); + member_->nranks_ = num_trainers * places.size(); if (!member_->use_all_reduce_) { PADDLE_ENFORCE(places.size() > 1, @@ -211,16 +211,14 @@ ParallelExecutor::ParallelExecutor( "the number of places must be greater than 1."); } - if (build_strategy.enable_parallel_graph_) { - PADDLE_ENFORCE( - member_->use_all_reduce_, - "build_strategy.reduce should be `AllReduce` if you want to enable" - "ParallelGraph."); - PADDLE_ENFORCE( - member_->use_cuda_, - "execution_strategy.use_cuda should be True if you want to enable " - "ParallelGraph."); - } + // FIXME(Yancey1989): parallel graph mode get better performance + // in GPU allreduce distributed training. Need an elegant way to + // choice the execution strategy. + build_strategy.enable_parallel_graph_ = + EnableParallelGraphExecution(main_program, exec_strategy, build_strategy); + + VLOG(1) << "Enable ParallelGraph Execution: " + << build_strategy.enable_parallel_graph_; // Step 1. Bcast the bcast_vars to devs. // Create local scopes @@ -242,20 +240,20 @@ ParallelExecutor::ParallelExecutor( // Bcast Parameters to all GPUs #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); - ncclUniqueId *nccl_id = nullptr; - // nccl collective would broadcast nccl id by gen_nccl_id operator. + std::unique_ptr nccl_id; + // nccl collective would broadcast ncclUniqueId by gen_nccl_id operator. if (nccl_id_var != nullptr) { - nccl_id = nccl_id_var->GetMutable(); + nccl_id.reset(nccl_id_var->GetMutable()); } - if (build_strategy.enable_parallel_graph_ && places.size() > 1) { - if (nccl_id == nullptr) { - nccl_id = new ncclUniqueId(); - PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id)); + if (build_strategy.enable_parallel_graph_ && member_->nranks_ > 1UL) { + if (nccl_id.get() == nullptr) { + nccl_id.reset(new ncclUniqueId()); + platform::dynload::ncclGetUniqueId(nccl_id.get()); } } member_->nccl_ctxs_.reset(new platform::NCCLContextMap( - member_->places_, nccl_id, num_trainers, trainer_id)); + member_->places_, nccl_id.get(), num_trainers, trainer_id)); #else PADDLE_THROW("Not compiled with CUDA"); #endif @@ -268,27 +266,25 @@ ParallelExecutor::ParallelExecutor( // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp std::vector> graphs; - member_->num_parallel_devices_ = member_->places_.size() * num_trainers; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) if (build_strategy.enable_parallel_graph_) { for (size_t i = 0; i < member_->places_.size(); ++i) { std::unique_ptr graph = build_strategy.Apply( main_program, {member_->places_[i]}, loss_var_name, - {member_->local_scopes_[i]}, member_->num_parallel_devices_, - member_->use_cuda_, member_->nccl_ctxs_.get()); + {member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_, + member_->nccl_ctxs_.get()); graphs.push_back(std::move(graph)); } } else { std::unique_ptr graph = build_strategy.Apply( main_program, member_->places_, loss_var_name, member_->local_scopes_, - member_->num_parallel_devices_, member_->use_cuda_, - member_->nccl_ctxs_.get()); + member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get()); graphs.push_back(std::move(graph)); } #else std::unique_ptr graph = build_strategy.Apply( main_program, member_->places_, loss_var_name, member_->local_scopes_, - member_->num_parallel_devices_, member_->use_cuda_); + member_->nranks_, member_->use_cuda_); graphs.push_back(std::move(graph)); #endif auto max_memory_size = GetEagerDeletionThreshold(); @@ -470,6 +466,35 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( } } +bool ParallelExecutor::EnableParallelGraphExecution( + const ProgramDesc &main_program, const ExecutionStrategy &exec_strategy, + const BuildStrategy &build_strategy) const { + bool enable_parallel_graph = true; + + // TODO(Yancey1989): support sparse update in ParallelGraph mode. + for (auto &var_desc : main_program.Block(0).AllVars()) { + if (var_desc->GetType() == proto::VarType::SELECTED_ROWS) { + enable_parallel_graph = false; + } + } + + // TODO(Yancey1989): support pserver mode + for (auto &op_desc : main_program.Block(0).AllOps()) { + if (op_desc->Type() == "send" || op_desc->Type() == "recv") { + enable_parallel_graph = false; + break; + } + } + + if (!member_->use_all_reduce_ || !member_->use_cuda_) + enable_parallel_graph = false; + + if (build_strategy.enable_sequential_execution_ || + exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) + enable_parallel_graph = false; + return enable_parallel_graph; +} + ParallelExecutor::~ParallelExecutor() { for (auto &p : member_->places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 5f6c2159a..dc70894db 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -68,6 +68,9 @@ class ParallelExecutor { private: void BCastParamsToDevices(const std::unordered_set &vars) const; + bool EnableParallelGraphExecution(const ProgramDesc &main_program, + const ExecutionStrategy &exec_strategy, + const BuildStrategy &build_strategy) const; ParallelExecutorPrivate *member_; }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3bb08cbeb..81d63aace 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -980,14 +980,6 @@ All parameter, weight, gradient are variables in Paddle. R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether to fuse elementwise_add_op and activation_op, it may make the execution faster. Default False)DOC") - .def_property( - "enable_parallel_graph", - [](const BuildStrategy &self) { return self.enable_parallel_graph_; }, - [](BuildStrategy &self, bool b) { self.enable_parallel_graph_ = b; }, - R"DOC(The type is BOOL, if set True, ParallelExecutor would build the main_program into multiple graphs, - each of the graphs would run with one device. This approach can achieve better performance in - some scenarios. Please note, this approach only supports all-reduce mode - on GPU device)DOC") .def_property( "memory_optimize", [](const BuildStrategy &self) { return self.memory_optimize_; }, diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index e0078e531..cdc631860 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -156,7 +156,8 @@ def __bootstrap__(): read_env_flags += [ 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', - 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus' + 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', + 'sync_nccl_allreduce' ] core.init_gflags([sys.argv[0]] + diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 36b13d455..2b0ab0cc3 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -39,7 +39,6 @@ class TestParallelExecutorBase(unittest.TestCase): seed=None, use_parallel_executor=True, use_reduce=False, - use_parallel_graph=False, use_ir_memory_optimize=False, fuse_elewise_add_act_ops=False, optimizer=fluid.optimizer.Adam, @@ -80,7 +79,6 @@ class TestParallelExecutorBase(unittest.TestCase): if use_fast_executor: exec_strategy.use_experimental_executor = True build_strategy = fluid.BuildStrategy() - build_strategy.enable_parallel_graph = use_parallel_graph build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py index 41286ba08..1c6cfce0c 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py @@ -175,14 +175,13 @@ class TestCRFModel(unittest.TestCase): print(pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name])[0]) - def _new_build_strategy(self, use_reduce=False, use_parallel_graph=False): + def _new_build_strategy(self, use_reduce=False): build_strategy = fluid.BuildStrategy() if use_reduce: build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce else: build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce - build_strategy.enable_parallel_graph = use_parallel_graph return build_strategy @@ -204,11 +203,6 @@ class TestCRFModel(unittest.TestCase): is_sparse=False, build_strategy=self._new_build_strategy(), use_cuda=True) - self.check_network_convergence( - is_sparse=False, - build_strategy=self._new_build_strategy( - use_parallel_graph=True), - use_cuda=True) self.check_network_convergence( is_sparse=False, diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index 7d2349fad..0ff7b7312 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -100,10 +100,7 @@ class TestMNIST(TestParallelExecutorBase): self.assertAlmostEqual(loss[0], loss[1], delta=1e-4) # simple_fc - def check_simple_fc_convergence(self, - use_cuda, - use_reduce=False, - use_parallel_graph=False): + def check_simple_fc_convergence(self, use_cuda, use_reduce=False): if use_cuda and not core.is_compiled_with_cuda(): return @@ -114,15 +111,13 @@ class TestMNIST(TestParallelExecutorBase): feed_dict={"image": img, "label": label}, use_cuda=use_cuda, - use_reduce=use_reduce, - use_parallel_graph=use_parallel_graph) + use_reduce=use_reduce) def test_simple_fc(self): # use_cuda if core.is_compiled_with_cuda(): self.check_simple_fc_convergence(True) - self.check_simple_fc_convergence( - True, use_reduce=False, use_parallel_graph=True) + self.check_simple_fc_convergence(True, use_reduce=False) self.check_simple_fc_convergence(False) def test_simple_fc_with_new_strategy(self): @@ -130,9 +125,7 @@ class TestMNIST(TestParallelExecutorBase): self._compare_reduce_and_allreduce(simple_fc_net, True) self._compare_reduce_and_allreduce(simple_fc_net, False) - def check_simple_fc_parallel_accuracy(self, - use_cuda, - use_parallel_graph=False): + def check_simple_fc_parallel_accuracy(self, use_cuda): if use_cuda and not core.is_compiled_with_cuda(): return @@ -144,16 +137,7 @@ class TestMNIST(TestParallelExecutorBase): feed_dict={"image": img, "label": label}, use_cuda=use_cuda, - use_parallel_executor=False, - use_parallel_graph=use_parallel_graph) - parallel_first_loss, parallel_last_loss = self.check_network_convergence( - method=simple_fc_net, - seed=1, - feed_dict={"image": img, - "label": label}, - use_cuda=use_cuda, - use_parallel_executor=True, - use_parallel_graph=use_parallel_graph) + use_parallel_executor=False) self.assertAlmostEquals( np.mean(parallel_first_loss), @@ -165,15 +149,11 @@ class TestMNIST(TestParallelExecutorBase): def test_simple_fc_parallel_accuracy(self): if core.is_compiled_with_cuda(): self.check_simple_fc_parallel_accuracy(True) - self.check_simple_fc_parallel_accuracy( - True, use_parallel_graph=True) + self.check_simple_fc_parallel_accuracy(True) # FIXME(Yancey1989): ParallelGraph executor type support CPU mode self.check_simple_fc_parallel_accuracy(False) - def check_batchnorm_fc_convergence(self, - use_cuda, - use_fast_executor, - use_parallel_graph=False): + def check_batchnorm_fc_convergence(self, use_cuda, use_fast_executor): if use_cuda and not core.is_compiled_with_cuda(): return @@ -184,8 +164,7 @@ class TestMNIST(TestParallelExecutorBase): feed_dict={"image": img, "label": label}, use_cuda=use_cuda, - use_fast_executor=use_fast_executor, - use_parallel_graph=use_parallel_graph) + use_fast_executor=use_fast_executor) def test_batchnorm_fc(self): for use_cuda in (False, True): @@ -193,7 +172,7 @@ class TestMNIST(TestParallelExecutorBase): self.check_batchnorm_fc_convergence(use_cuda, use_fast_executor) self.check_batchnorm_fc_convergence( - use_cuda=True, use_fast_executor=False, use_parallel_graph=True) + use_cuda=True, use_fast_executor=False) def test_batchnorm_fc_with_new_strategy(self): # FIXME(zcd): close this test temporally. diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py index 9bdaab162..4f1d902f5 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py @@ -277,9 +277,7 @@ class TestResnet(TestParallelExecutorBase): use_cuda=True, use_reduce=False, iter=20, - delta2=1e-6, - use_parallel_graph=False, - lr_scale=1.0): + delta2=1e-6): if use_cuda and not core.is_compiled_with_cuda(): return @@ -298,8 +296,7 @@ class TestResnet(TestParallelExecutorBase): use_cuda=use_cuda, use_reduce=use_reduce, optimizer=optimizer, - use_parallel_executor=False, - use_parallel_graph=use_parallel_graph) + use_parallel_executor=False) parallel_first_loss, parallel_last_loss = self.check_network_convergence( model, feed_dict={"image": img, @@ -308,8 +305,7 @@ class TestResnet(TestParallelExecutorBase): batch_size=batch_size, use_cuda=use_cuda, use_reduce=use_reduce, - optimizer=optimizer, - use_parallel_graph=use_parallel_graph) + optimizer=optimizer) self.assertAlmostEquals( np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6) @@ -320,11 +316,6 @@ class TestResnet(TestParallelExecutorBase): if core.is_compiled_with_cuda(): self._check_resnet_convergence( model=SE_ResNeXt50Small, use_cuda=True) - self._check_resnet_convergence( - model=SE_ResNeXt50Small, - use_cuda=True, - use_parallel_graph=True, - lr_scale=core.get_cuda_device_count()) self._check_resnet_convergence( model=SE_ResNeXt50Small, use_cuda=False, iter=2, delta2=1e-3) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py index c3ac9d92b..382774390 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py @@ -175,8 +175,6 @@ class TestTransformer(TestParallelExecutorBase): self.check_network_convergence(transformer, use_cuda=True) self.check_network_convergence( transformer, use_cuda=True, enable_sequential_execution=True) - self.check_network_convergence( - transformer, use_cuda=True, use_parallel_graph=True) self.check_network_convergence(transformer, use_cuda=False, iter=5) -- GitLab