From 4a4ccac1d060ccf5758b7ff0d32dfb90ab3c5b7f Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Fri, 14 Dec 2018 15:53:13 +0800 Subject: [PATCH] update by comment test=develop --- .../framework/details/all_reduce_op_handle.cc | 14 ++++++-------- .../framework/details/multi_devices_graph_pass.cc | 4 ++-- paddle/fluid/framework/details/op_handle_base.cc | 1 + .../details/threaded_ssa_graph_executor.cc | 1 + paddle/fluid/framework/parallel_executor.cc | 14 ++++++++++---- paddle/fluid/framework/threadpool.h | 1 + .../reader/create_double_buffer_reader_op.cc | 1 + paddle/fluid/platform/nccl_helper.h | 5 +---- .../unittests/test_parallel_executor_mnist.py | 2 +- 9 files changed, 24 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 6b7bbf9003a..5a4f218077d 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -107,22 +107,20 @@ void AllReduceOpHandle::RunImpl() { PADDLE_ENFORCE(platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, comm, stream)); - if (!nccl_ctxs_->need_group_call_) cudaStreamSynchronize(stream); + // TODO(Yancey1989): synchronize here can get better performance + // if don't use NCCL group call, but need more profileing. + if (local_scopes_.size() == 1UL) cudaStreamSynchronize(stream); }); } this->RunAndRecordEvent([&] { - // TODO(Yancey1989): need allreduce operator to avoid this flag - if (nccl_ctxs_->need_group_call_) { + if (all_reduce_calls.size() == 1UL) { + all_reduce_calls[0](); + } else { platform::NCCLGroupGuard guard; for (auto &call : all_reduce_calls) { call(); } - } else { - // only used in executor_type == ParallalGraph, one thread one GPU - // TODO(Yancey1989): use allreduce operator to avoid this tricky. - PADDLE_ENFORCE(all_reduce_calls.size() == 1UL); - all_reduce_calls[0](); } }); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 6e8cf86fcc9..5b82805ad93 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -386,8 +386,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( CreateComputationalOps(&result, node, places_.size()); } -// insert synchronous ops at the backpropagation; and -// insert synchronous ops if the graph contains mutilple places. +// insert collective ops at the backpropagation; and +// insert collective ops if the graph contains mutilple places. #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) if (!is_forwarding && diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 4914e0a5ad3..4822627ac3b 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -52,6 +52,7 @@ void OpHandleBase::Run(bool use_cuda) { #else PADDLE_ENFORCE(!use_cuda); #endif + RunImpl(); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index cebf63364da..677a2937945 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -216,6 +216,7 @@ void ThreadedSSAGraphExecutor::RunOp( if (LIKELY(!strategy_.dry_run_)) { op->Run(strategy_.use_cuda_); } + VLOG(10) << op << " " << op->Name() << " Done "; running_ops_--; ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << "Signal posted"; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 2604e41045b..63f3ef0eacc 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -231,7 +231,6 @@ ParallelExecutor::ParallelExecutor( #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); ncclUniqueId *nccl_id = nullptr; - bool need_group_call = true; if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) { // parallel graph mode should initialize nccl by ncclCommInitRank since // it call nccl operator per device per thread. @@ -243,17 +242,16 @@ ParallelExecutor::ParallelExecutor( } else { nccl_id = nccl_id_var->GetMutable(); } - need_group_call = false; } else if (nccl_id_var != nullptr) { // the other executor type. // the distributed training with nccl mode would initialize the nccl id in // startup_program. nccl_id = nccl_id_var->GetMutable(); } else { - // initlize NCCL by ncclCommInitAll, do not need nccl_id. + // initlize NCCL by ncclCommInitAll, do not need to intialize the nccl_id. } member_->nccl_ctxs_.reset(new platform::NCCLContextMap( - member_->places_, nccl_id, num_trainers, trainer_id, need_group_call)); + member_->places_, nccl_id, num_trainers, trainer_id)); #else PADDLE_THROW("Not compiled with CUDA"); #endif @@ -288,6 +286,14 @@ ParallelExecutor::ParallelExecutor( graphs.push_back(std::move(graph)); #endif + auto max_memory_size = GetEagerDeletionThreshold(); + // TODO(Yancey1989): fix gc failed on ParallelGraph executor. + if (max_memory_size >= 0 && + exec_strategy.type_ != ExecutionStrategy::kParallelGraph) { + graphs[0] = member_->PrepareGCAndRefCnts( + std::move(graphs[0]), static_cast(max_memory_size)); + } + // Step 3. Create vars in each scope. Passes may also create new vars. // skip control vars and empty vars std::vector var_infos; diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 8fd834be9ac..7a51d18fbbf 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -27,6 +27,7 @@ limitations under the License. */ namespace paddle { namespace framework { + struct ExceptionHandler { mutable std::future> future_; explicit ExceptionHandler( diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 440b16cf915..ed719f91d09 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -46,6 +46,7 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { sin >> num; place = platform::CUDAPlace(static_cast(num)); } + out->Reset(framework::MakeDecoratedReader(underlying_reader, place, 2)); } diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 23a0222239a..8d062dcdb47 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -82,15 +82,12 @@ struct NCCLContext { struct NCCLContextMap { std::unordered_map contexts_; std::vector order_; - bool need_group_call_; explicit NCCLContextMap(const std::vector &places, ncclUniqueId *nccl_id = nullptr, - size_t num_trainers = 1, size_t trainer_id = 0, - bool need_group_call = true) { + size_t num_trainers = 1, size_t trainer_id = 0) { PADDLE_ENFORCE(!places.empty()); order_.reserve(places.size()); - need_group_call_ = need_group_call; for (auto &p : places) { int dev_id = boost::get(p).device; order_.emplace_back(dev_id); diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index 0ff079b4e2c..fffe8bee580 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -123,7 +123,7 @@ class TestMNIST(TestParallelExecutorBase): self.check_simple_fc_convergence(False) def test_simple_fc_with_new_strategy(self): - # use_cuda, use_reducea + # use_cuda, use_reduce self._compare_reduce_and_allreduce(simple_fc_net, True) self._compare_reduce_and_allreduce(simple_fc_net, False) -- GitLab