未验证 提交 7b45a46e 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add FLAGS_allreduce_record_one_event to remove event waiting number (#36263)

* add FLAGS_allreduce_record_one_event

* add more comments

* fix ut

* improve coverage

* fix ut, improve coverage
上级 85b77232
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <string> #include <string>
DECLARE_bool(allreduce_record_one_event);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -31,11 +33,13 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -31,11 +33,13 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
scope_idx_(scope_idx) {} scope_idx_(scope_idx) {}
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
if (!FLAGS_allreduce_record_one_event) {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
}
auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); }; auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); };
if (is_lock_and_record_event_free_) { if (is_lock_and_record_event_free_ || FLAGS_allreduce_record_one_event) {
run_func(); run_func();
} else { } else {
this->RunAndRecordEvent(run_func); this->RunAndRecordEvent(run_func);
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_bool(skip_fused_all_reduce_check, false, ""); DEFINE_bool(skip_fused_all_reduce_check, false, "");
DECLARE_bool(allreduce_record_one_event);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -48,11 +50,80 @@ FusedAllReduceOpHandle::FusedAllReduceOpHandle( ...@@ -48,11 +50,80 @@ FusedAllReduceOpHandle::FusedAllReduceOpHandle(
num_of_all_reduce_(num_of_all_reduce) {} num_of_all_reduce_(num_of_all_reduce) {}
#endif #endif
FusedAllReduceOpHandle::~FusedAllReduceOpHandle() {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto destroy_event = [](gpuEvent_t event) {
if (event == nullptr) return;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(event));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(event));
#endif
};
destroy_event(start_event_);
destroy_event(end_event_);
#endif
}
void FusedAllReduceOpHandle::RunImpl() { void FusedAllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
VLOG(4) << this->DebugString(); VLOG(4) << this->DebugString();
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (FLAGS_allreduce_record_one_event && start_event_ == nullptr) {
VLOG(10) << "FLAGS_allreduce_record_one_event=true";
PADDLE_ENFORCE_EQ(use_hierarchical_allreduce_, false,
platform::errors::Unimplemented(
"The hierarchical allreduce does not support "
"FLAGS_allreduce_record_one_event=true"));
PADDLE_ENFORCE_EQ(places_.size(), 1,
platform::errors::Unimplemented(
"FLAGS_allreduce_record_one_event=true is only valid "
"when using one GPU device per process."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(places_[0]), true,
platform::errors::Unimplemented(
"FLAGS_allreduce_record_one_event=true is only valid "
"when using GPU device."));
auto create_event = [](gpuEvent_t *event) {
if (*event) return;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
hipEventCreateWithFlags(event, hipEventDisableTiming));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventCreateWithFlags(event, cudaEventDisableTiming));
#endif
};
create_event(&start_event_);
create_event(&end_event_);
}
gpuStream_t nccl_stream{nullptr};
gpuStream_t compute_stream{nullptr};
if (FLAGS_allreduce_record_one_event) {
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, places_[0]);
compute_stream =
platform::DeviceContextPool::Instance().GetByPlace(gpu_place)->stream();
auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
auto &nccl_ctx = flat_nccl_ctxs->at(gpu_place.device);
nccl_stream = nccl_ctx.stream();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(start_event_, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
hipStreamWaitEvent(nccl_stream, start_event_, 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(start_event_, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(nccl_stream, start_event_, 0));
#endif
} else {
WaitInputVarGenerated(); WaitInputVarGenerated();
}
#else
WaitInputVarGenerated();
#endif
// The input: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)... // The input: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)...
// The output: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)... // The output: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)...
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs()); auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
...@@ -94,6 +165,20 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -94,6 +165,20 @@ void FusedAllReduceOpHandle::RunImpl() {
} else { } else {
FusedAllReduceFunc(in_var_handles, out_var_handles); FusedAllReduceFunc(in_var_handles, out_var_handles);
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (FLAGS_allreduce_record_one_event) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(end_event_, nccl_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
hipStreamWaitEvent(compute_stream, end_event_, 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(end_event_, nccl_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(compute_stream, end_event_, 0));
#endif
}
#endif
} }
void FusedAllReduceOpHandle::FusedAllReduceFunc( void FusedAllReduceOpHandle::FusedAllReduceFunc(
......
...@@ -67,12 +67,19 @@ struct FusedAllReduceOpHandle : public AllReduceOpHandle { ...@@ -67,12 +67,19 @@ struct FusedAllReduceOpHandle : public AllReduceOpHandle {
#endif #endif
std::string Name() const override; std::string Name() const override;
~FusedAllReduceOpHandle();
protected: protected:
void RunImpl() override; void RunImpl() override;
private: private:
size_t num_of_all_reduce_; size_t num_of_all_reduce_;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
gpuEvent_t start_event_{nullptr};
gpuEvent_t end_event_{nullptr};
#endif
// Check the dtype of the input // Check the dtype of the input
void GetDTypeAndNumel( void GetDTypeAndNumel(
const std::vector<std::pair<std::string, const LoDTensor *>> &g_tensor, const std::vector<std::pair<std::string, const LoDTensor *>> &g_tensor,
......
...@@ -682,6 +682,23 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -682,6 +682,23 @@ PADDLE_DEFINE_EXPORTED_bool(
"It controls whether to apply IR pass to program when using Fleet APIs"); "It controls whether to apply IR pass to program when using Fleet APIs");
/** /**
* Distributed related FLAG
* Name: FLAGS_allreduce_record_one_event
* Since Version: 2.2.0
* Value Range: bool, default=false
* Example: FLAGS_allreduce_record_one_event=true makes the allreduce
* operations would only wait one event instead of multiple events.
* Note: Make the allreduce operations would only wait one event instead of
* multiple events. Currently, only fuse allreduce supports this.
* Otherwise, the precision may be wrong.
*/
PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false,
"It controls whether the allreduce operations "
"would only wait one event instead of multiple "
"events. Currently, only fuse allreduce supports "
"this. Otherwise, the precision may be wrong.");
/*
* CINN related FLAG * CINN related FLAG
* Name: FLAGS_use_cinn * Name: FLAGS_use_cinn
* Since Version: 2.3 * Since Version: 2.3
......
...@@ -32,7 +32,11 @@ class TestDistMnistNCCL2FleetApi(TestDistBase): ...@@ -32,7 +32,11 @@ class TestDistMnistNCCL2FleetApi(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
import paddle.fluid as fluid import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_mnist.py", delta=1e-5) self.check_with_place(
"dist_mnist.py",
delta=1e-5,
check_error_log=True,
need_envs={'FLAGS_allreduce_record_one_event': '1'})
class FleetCollectiveTest(unittest.TestCase): class FleetCollectiveTest(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册