From 5a3c8bf8133698d940c3dc6435cc717b6cfc8b9b Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 9 Jun 2018 23:16:09 +0800 Subject: [PATCH] fix in c++ side --- paddle/fluid/framework/details/CMakeLists.txt | 2 + .../framework/details/graph_builder_factory.h | 6 ++- .../details/multi_devices_graph_builder.cc | 44 ++++++++++++++----- .../details/nccl_all_reduce_op_handle.cc | 23 ++++++++-- .../details/nccl_all_reduce_op_handle.h | 14 ++++-- paddle/fluid/framework/parallel_executor.cc | 40 ++++++++++------- .../test_parallel_executor_fetch_feed.py | 24 ++++++---- .../unittests/test_parallel_executor_mnist.py | 2 + ...test_parallel_executor_test_while_train.py | 18 +++++--- .../test_parallel_executor_transformer.py | 3 +- 10 files changed, 127 insertions(+), 49 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index c43826b64cc..207dd1b93a0 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -19,6 +19,8 @@ if(WITH_GPU) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) else() + cc_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory + variable_visitor) set(multi_devices_graph_builder_deps) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) diff --git a/paddle/fluid/framework/details/graph_builder_factory.h b/paddle/fluid/framework/details/graph_builder_factory.h index 857ab12d684..91a119de83e 100644 --- a/paddle/fluid/framework/details/graph_builder_factory.h +++ b/paddle/fluid/framework/details/graph_builder_factory.h @@ -40,7 +40,11 @@ class SSAGraphBuilderFactory { loss_var_name_(loss_var_name), param_names_(param_names), local_scopes_(local_scopes), - strategy_(strategy) {} + strategy_(strategy) { +#ifdef PADDLE_WITH_CUDA + nccl_ctxs_ = nullptr; +#endif + } #ifdef PADDLE_WITH_CUDA void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0c4d369e889..cd7dda143cf 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -20,16 +20,13 @@ #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" -#endif - namespace paddle { namespace framework { namespace details { @@ -305,7 +302,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, auto *out_var = new VarHandle(vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); -#ifndef ADDLE_WITH_CUDA +#ifdef PADDLE_WITH_CUDA + if (nccl_ctxs_ == nullptr) { + op_handle->SetDeviceContext( + p, platform::DeviceContextPool::Instance().Get(p)); + } +#else op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); #endif @@ -324,7 +326,10 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA result->ops_.emplace_back( - new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); + new NCCLAllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); +#else + result->ops_.emplace_back(new NCCLAllReduceOpHandle(local_scopes_, places_)); +#endif auto *op_handle = result->ops_.back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -334,13 +339,23 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); +#ifdef PADDLE_WITH_CUDA + if (nccl_ctxs_ == nullptr) { + op_handle->SetDeviceContext( + p, platform::DeviceContextPool::Instance().Get(p)); + } +#else + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); +#endif + + VLOG(4) << "NCCL - - - " << p; + op_handle->DeviceContext(p)->Wait(); + VLOG(4) << "NCCL - - - " << p << " " << op_handle->DeviceContext(p); auto var = new VarHandle(vars.size() - 1, i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } -#else - PADDLE_ENFORCE("Not implemented"); -#endif } bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( @@ -379,7 +394,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle #ifdef PADDLE_WITH_CUDA - auto *communication_dev_ctx = nccl_ctxs_->DevCtx(places_[i]); + auto *communication_dev_ctx = + nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i]) + : platform::DeviceContextPool::Instance().Get(places_[i]); #else auto *communication_dev_ctx = platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); @@ -425,8 +442,13 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, for (size_t i = 0; i < places_.size(); ++i) { auto &vars = result->vars_[i][og]; -#ifndef PADDLE_WITH_CUDA auto &p = places_[i]; +#ifdef PADDLE_WITH_CUDA + if (nccl_ctxs_ == nullptr) { + op_handle->SetDeviceContext( + p, platform::DeviceContextPool::Instance().Get(p)); + } +#else op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); #endif diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 5bba089ade8..ab5dc676137 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -21,15 +21,25 @@ namespace paddle { namespace framework { namespace details { + +#ifdef PADDLE_WITH_CUDA NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( const std::vector &local_scopes, const std::vector &places, - const platform::NCCLContextMap &ctxs) + const platform::NCCLContextMap *ctxs) : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { - for (auto &p : places_) { - this->dev_ctxes_[p] = nccl_ctxs_.DevCtx(p); + if (ctxs) { + for (auto &p : places_) { + this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); + } } } +#else +NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( + const std::vector &local_scopes, + const std::vector &places) + : local_scopes_(local_scopes), places_(places) {} +#endif void NCCLAllReduceOpHandle::RunImpl() { if (NoDummyInputSize() == 1) { @@ -58,6 +68,8 @@ void NCCLAllReduceOpHandle::RunImpl() { } if (platform::is_gpu_place(lod_tensors[0]->place())) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE(nccl_ctxs_); int dtype = -1; size_t numel = 0; std::vector> all_reduce_calls; @@ -75,7 +87,7 @@ void NCCLAllReduceOpHandle::RunImpl() { } int dev_id = boost::get(p).device; - auto &nccl_ctx = nccl_ctxs_.at(dev_id); + auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto stream = nccl_ctx.stream(); auto comm = nccl_ctx.comm_; all_reduce_calls.emplace_back([=] { @@ -90,6 +102,9 @@ void NCCLAllReduceOpHandle::RunImpl() { call(); } }); +#else + PADDLE_THROW("Not compiled with CUDA"); +#endif } else { // Special handle CPU only Operator's gradient. Like CRF auto &trg = *this->local_scopes_[0] ->FindVar(kLocalExecScopeName) diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h index 8e98d894b82..e0d206bd937 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -20,17 +20,23 @@ #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" +#endif namespace paddle { namespace framework { namespace details { struct NCCLAllReduceOpHandle : public OpHandleBase { +#ifdef PADDLE_WITH_CUDA NCCLAllReduceOpHandle(const std::vector &local_scopes, const std::vector &places, - const platform::NCCLContextMap &ctxs); - + const platform::NCCLContextMap *ctxs); +#else + NCCLAllReduceOpHandle(const std::vector &local_scopes, + const std::vector &places); +#endif std::string Name() const override; // Delay and buffer nccl_all_reduce together can significantly increase @@ -43,7 +49,9 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { private: std::vector local_scopes_; std::vector places_; - const platform::NCCLContextMap &nccl_ctxs_; +#ifdef PADDLE_WITH_CUDA + const platform::NCCLContextMap *nccl_ctxs_; +#endif }; } // namespace details diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ce56f55e419..4c98c54b32e 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -44,6 +44,7 @@ class ParallelExecutorPrivate { std::unique_ptr nccl_ctxs_; #endif bool own_local_scope; + bool use_cuda; }; std::vector &ParallelExecutor::GetLocalScopes() { @@ -60,6 +61,7 @@ ParallelExecutor::ParallelExecutor( size_t num_trainers, size_t trainer_id) : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; + member_->use_cuda = exec_strategy.use_event_; // Step 1. Bcast the params to devs. // Create local scopes @@ -77,18 +79,22 @@ ParallelExecutor::ParallelExecutor( } } + if (member_->use_cuda) { // Bcast Parameters to all GPUs #ifdef PADDLE_WITH_CUDA - auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); - ncclUniqueId *nccl_id = nullptr; - if (nccl_id_var != nullptr) { - nccl_id = nccl_id_var->GetMutable(); - } - member_->nccl_ctxs_.reset(new platform::NCCLContextMap( - member_->places_, nccl_id, num_trainers, trainer_id)); + auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); + ncclUniqueId *nccl_id = nullptr; + if (nccl_id_var != nullptr) { + nccl_id = nccl_id_var->GetMutable(); + } + member_->nccl_ctxs_.reset(new platform::NCCLContextMap( + member_->places_, nccl_id, num_trainers, trainer_id)); +#else + PADDLE_THROW("Not compiled with CUDA"); #endif - if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 && - local_scopes.empty()) { // Is CUDA + } + + if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { BCastParamsToGPUs(bcast_vars); } // Startup Program has been run. All local scopes has correct parameters. @@ -108,9 +114,13 @@ ParallelExecutor::ParallelExecutor( details::SSAGraphBuilderFactory builder_factory( member_->places_, loss_var_name, params, member_->local_scopes_, build_strategy); + if (member_->use_cuda) { #ifdef PADDLE_WITH_CUDA - builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); + builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); +#else + PADDLE_THROW("Not compiled with CUDA"); #endif + } member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, @@ -123,7 +133,6 @@ ParallelExecutor::ParallelExecutor( void ParallelExecutor::BCastParamsToGPUs( const std::unordered_set &vars) const { -#ifdef PADDLE_WITH_CUDA auto *main_scope = member_->local_scopes_[0]; for (auto &var : vars) { @@ -135,6 +144,7 @@ void ParallelExecutor::BCastParamsToGPUs( auto &main_tensor = main_var->Get(); auto &dims = main_tensor.dims(); if (paddle::platform::is_gpu_place(main_tensor.place())) { +#ifdef PADDLE_WITH_CUDA size_t numel = main_tensor.numel(); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); platform::NCCLGroupGuard guard; @@ -153,6 +163,10 @@ void ParallelExecutor::BCastParamsToGPUs( platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm_, nccl_ctx.stream()); } + member_->nccl_ctxs_->WaitAll(); +#else + PADDLE_THROW("Not compiled with CUDA"); +#endif } else { platform::CPUPlace cpu; for (size_t i = 1; i < member_->places_.size(); ++i) { @@ -163,11 +177,7 @@ void ParallelExecutor::BCastParamsToGPUs( paddle::framework::TensorCopy(main_tensor, cpu, t); } } - member_->nccl_ctxs_->WaitAll(); } -#else - PADDLE_THROW("Not compiled with CUDA"); -#endif } void ParallelExecutor::Run(const std::vector &fetch_tensors, diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py index 24f8d28c030..5c26cd4894c 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py @@ -35,7 +35,7 @@ def Lenet(data, class_dim): class TestFetchOp(unittest.TestCase): - def parallel_exe(self, train_inputs, seed): + def parallel_exe(self, train_inputs, seed, use_cuda): main = fluid.Program() startup = fluid.Program() startup.random_seed = seed @@ -59,13 +59,13 @@ class TestFetchOp(unittest.TestCase): # conv2d_1.b_0@GRAD. Those variables should not be pruned. # fluid.memory_optimize(main) - place = fluid.CUDAPlace(0) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup) feeder = fluid.DataFeeder(place=place, feed_list=[data, label]) pe = fluid.ParallelExecutor( - use_cuda=True, loss_name=loss.name, main_program=main) + use_cuda=use_cuda, loss_name=loss.name, main_program=main) fetch_list = [] all_vars = main.global_block().vars @@ -88,14 +88,15 @@ class TestFetchOp(unittest.TestCase): for i in range(iters): train_inputs.append(tst_reader_iter.next()) - self.parallel_exe(train_inputs, seed=1) + self.parallel_exe(train_inputs, seed=1, use_cuda=True) + self.parallel_exe(train_inputs, seed=1, use_cuda=False) class TestFeedParallel(unittest.TestCase): - def test_main(self): + def parallel_exe(self, use_cuda, seed): main = fluid.Program() startup = fluid.Program() - startup.random_seed = 1 + startup.random_seed = seed with fluid.scope_guard(fluid.core.Scope()): with fluid.program_guard(main, startup): data = fluid.layers.data( @@ -111,15 +112,18 @@ class TestFeedParallel(unittest.TestCase): regularization=fluid.regularizer.L2Decay(1e-4)) opt.minimize(loss) - place = fluid.CUDAPlace(0) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() feeder = fluid.DataFeeder(place=place, feed_list=[data, label]) reader = feeder.decorate_reader( paddle.batch( flowers.train(), batch_size=16), multi_devices=True) + exe = fluid.Executor(place) exe.run(startup) + pe = fluid.ParallelExecutor( - use_cuda=True, loss_name=loss.name, main_program=main) + use_cuda=use_cuda, loss_name=loss.name, main_program=main) for batch_id, data in enumerate(reader()): loss_np = np.array(pe.run(feed=data, fetch_list=[loss.name])[0]) @@ -127,6 +131,10 @@ class TestFeedParallel(unittest.TestCase): if batch_id == 2: break + def test_feed_op(self): + self.parallel_exe(use_cuda=True, seed=1) + self.parallel_exe(use_cuda=False, seed=1) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index 52dfb9620f8..3bc846d125e 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -117,9 +117,11 @@ class TestMNIST(TestParallelExecutorBase): def test_simple_fc(self): self.check_simple_fc_convergence(False, use_cuda=True) + self.check_simple_fc_convergence(False, use_cuda=False) def test_simple_fc_with_new_strategy(self): self.check_simple_fc_convergence(True, use_cuda=True) + self.check_simple_fc_convergence(True, use_cuda=False) def check_simple_fc_parallel_accuracy(self, balance_parameter_opt_between_cards, diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py index 93a5f767867..1f2555d972f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py @@ -35,7 +35,7 @@ def simple_fc_net(): class ParallelExecutorTestingDuringTraining(unittest.TestCase): - def check_network_convergence(self, build_strategy=None): + def check_network_convergence(self, use_cuda, build_strategy=None): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -49,19 +49,19 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): image = np.random.normal(size=(batch_size, 784)).astype('float32') label = np.random.randint(0, 10, (batch_size, 1), dtype="int64") - place = fluid.CUDAPlace(0) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup) feed_dict = {'image': image, 'label': label} train_exe = fluid.ParallelExecutor( - use_cuda=True, + use_cuda=use_cuda, loss_name=loss.name, main_program=main, build_strategy=build_strategy) test_exe = fluid.ParallelExecutor( - use_cuda=True, + use_cuda=use_cuda, main_program=test_program, share_vars_from=train_exe, build_strategy=build_strategy) @@ -81,12 +81,18 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): def test_parallel_testing(self): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce - self.check_network_convergence(build_strategy) + self.check_network_convergence( + use_cuda=True, build_strategy=build_strategy) + self.check_network_convergence( + use_cuda=False, build_strategy=build_strategy) def test_parallel_testing_with_new_strategy(self): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce - self.check_network_convergence(build_strategy) + self.check_network_convergence( + use_cuda=True, build_strategy=build_strategy) + self.check_network_convergence( + use_cuda=False, build_strategy=build_strategy) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py index c81df66d987..3e2c37191fe 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py @@ -167,7 +167,8 @@ class TestTransformer(TestParallelExecutorBase): @unittest.skip("transformer is buggy in multi gpu") def test_main(self): - self.check_network_convergence(transformer) + self.check_network_convergence(transformer, use_cuda=True) + self.check_network_convergence(transformer, use_cuda=False) if __name__ == '__main__': -- GitLab