From 7b45a46e13fe057ca12a001dac7b8d6d24d9f211 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 11 Oct 2021 19:59:16 +0800 Subject: [PATCH] Add FLAGS_allreduce_record_one_event to remove event waiting number (#36263) * add FLAGS_allreduce_record_one_event * add more comments * fix ut * improve coverage * fix ut, improve coverage --- .../details/computation_op_handle.cc | 8 +- .../details/fused_all_reduce_op_handle.cc | 85 +++++++++++++++++++ .../details/fused_all_reduce_op_handle.h | 7 ++ paddle/fluid/platform/flags.cc | 17 ++++ .../unittests/test_dist_mnist_fleetapi.py | 6 +- 5 files changed, 120 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 2256b826ed5..60b8461668f 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -16,6 +16,8 @@ #include +DECLARE_bool(allreduce_record_one_event); + namespace paddle { namespace framework { namespace details { @@ -31,11 +33,13 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, scope_idx_(scope_idx) {} void ComputationOpHandle::RunImpl() { - WaitInputVarGenerated(place_); + if (!FLAGS_allreduce_record_one_event) { + WaitInputVarGenerated(place_); + } auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); }; - if (is_lock_and_record_event_free_) { + if (is_lock_and_record_event_free_ || FLAGS_allreduce_record_one_event) { run_func(); } else { this->RunAndRecordEvent(run_func); diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc index 8f45c364476..94507140a81 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc @@ -19,6 +19,8 @@ #include "paddle/fluid/platform/profiler.h" DEFINE_bool(skip_fused_all_reduce_check, false, ""); +DECLARE_bool(allreduce_record_one_event); + namespace paddle { namespace framework { namespace details { @@ -48,11 +50,80 @@ FusedAllReduceOpHandle::FusedAllReduceOpHandle( num_of_all_reduce_(num_of_all_reduce) {} #endif +FusedAllReduceOpHandle::~FusedAllReduceOpHandle() { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto destroy_event = [](gpuEvent_t event) { + if (event == nullptr) return; +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(event)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(event)); +#endif + }; + destroy_event(start_event_); + destroy_event(end_event_); +#endif +} + void FusedAllReduceOpHandle::RunImpl() { platform::RecordEvent record_event(Name()); VLOG(4) << this->DebugString(); +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + if (FLAGS_allreduce_record_one_event && start_event_ == nullptr) { + VLOG(10) << "FLAGS_allreduce_record_one_event=true"; + PADDLE_ENFORCE_EQ(use_hierarchical_allreduce_, false, + platform::errors::Unimplemented( + "The hierarchical allreduce does not support " + "FLAGS_allreduce_record_one_event=true")); + PADDLE_ENFORCE_EQ(places_.size(), 1, + platform::errors::Unimplemented( + "FLAGS_allreduce_record_one_event=true is only valid " + "when using one GPU device per process.")); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(places_[0]), true, + platform::errors::Unimplemented( + "FLAGS_allreduce_record_one_event=true is only valid " + "when using GPU device.")); + auto create_event = [](gpuEvent_t *event) { + if (*event) return; +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + hipEventCreateWithFlags(event, hipEventDisableTiming)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventCreateWithFlags(event, cudaEventDisableTiming)); +#endif + }; + create_event(&start_event_); + create_event(&end_event_); + } + + gpuStream_t nccl_stream{nullptr}; + gpuStream_t compute_stream{nullptr}; + + if (FLAGS_allreduce_record_one_event) { + auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, places_[0]); + compute_stream = + platform::DeviceContextPool::Instance().GetByPlace(gpu_place)->stream(); + auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_); + auto &nccl_ctx = flat_nccl_ctxs->at(gpu_place.device); + nccl_stream = nccl_ctx.stream(); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(start_event_, compute_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + hipStreamWaitEvent(nccl_stream, start_event_, 0)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(start_event_, compute_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamWaitEvent(nccl_stream, start_event_, 0)); +#endif + } else { + WaitInputVarGenerated(); + } +#else WaitInputVarGenerated(); +#endif + // The input: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)... // The output: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)... auto in_var_handles = DynamicCast(this->Inputs()); @@ -94,6 +165,20 @@ void FusedAllReduceOpHandle::RunImpl() { } else { FusedAllReduceFunc(in_var_handles, out_var_handles); } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + if (FLAGS_allreduce_record_one_event) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(end_event_, nccl_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + hipStreamWaitEvent(compute_stream, end_event_, 0)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(end_event_, nccl_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamWaitEvent(compute_stream, end_event_, 0)); +#endif + } +#endif } void FusedAllReduceOpHandle::FusedAllReduceFunc( diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h index d22dc0a421a..8473700867c 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h @@ -67,12 +67,19 @@ struct FusedAllReduceOpHandle : public AllReduceOpHandle { #endif std::string Name() const override; + ~FusedAllReduceOpHandle(); + protected: void RunImpl() override; private: size_t num_of_all_reduce_; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + gpuEvent_t start_event_{nullptr}; + gpuEvent_t end_event_{nullptr}; +#endif + // Check the dtype of the input void GetDTypeAndNumel( const std::vector> &g_tensor, diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 18636f6f842..dd65d743fad 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -682,6 +682,23 @@ PADDLE_DEFINE_EXPORTED_bool( "It controls whether to apply IR pass to program when using Fleet APIs"); /** + * Distributed related FLAG + * Name: FLAGS_allreduce_record_one_event + * Since Version: 2.2.0 + * Value Range: bool, default=false + * Example: FLAGS_allreduce_record_one_event=true makes the allreduce + * operations would only wait one event instead of multiple events. + * Note: Make the allreduce operations would only wait one event instead of + * multiple events. Currently, only fuse allreduce supports this. + * Otherwise, the precision may be wrong. + */ +PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false, + "It controls whether the allreduce operations " + "would only wait one event instead of multiple " + "events. Currently, only fuse allreduce supports " + "this. Otherwise, the precision may be wrong."); + +/* * CINN related FLAG * Name: FLAGS_use_cinn * Since Version: 2.3 diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py index 34abc5b4553..3b15b06b5ef 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py @@ -32,7 +32,11 @@ class TestDistMnistNCCL2FleetApi(TestDistBase): def test_dist_train(self): import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): - self.check_with_place("dist_mnist.py", delta=1e-5) + self.check_with_place( + "dist_mnist.py", + delta=1e-5, + check_error_log=True, + need_envs={'FLAGS_allreduce_record_one_event': '1'}) class FleetCollectiveTest(unittest.TestCase): -- GitLab