From 9eec2c75098708e95f2207f89c5e8002c09c557a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 9 May 2018 11:13:10 +0800 Subject: [PATCH] refine pe --- .../framework/details/broadcast_op_handle.cc | 12 +------- .../framework/details/broadcast_op_handle.h | 1 - .../details/computation_op_handle.cc | 16 +++++------ .../framework/details/computation_op_handle.h | 2 ++ .../framework/details/fetch_op_handle.cc | 19 ++++++++----- .../fluid/framework/details/fetch_op_handle.h | 4 ++- .../framework/details/gather_op_handle.cc | 13 +-------- .../framework/details/gather_op_handle.h | 1 - .../details/nccl_all_reduce_op_handle.cc | 5 +--- .../fluid/framework/details/op_handle_base.cc | 28 +++++++++++++++++-- .../fluid/framework/details/op_handle_base.h | 10 ++++++- .../framework/details/reduce_op_handle.cc | 13 +-------- .../framework/details/reduce_op_handle.h | 2 -- .../details/scale_loss_grad_op_handle.cc | 1 + .../fluid/framework/details/send_op_handle.cc | 8 +----- 15 files changed, 65 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 2afa47c81be..f176e4e1599 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() { out_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - // Wait input done, this Wait is asynchronous operation platform::Place - // &in_place; - WaitInputVarGenerated(*in_var_handle); + WaitInputVarGenerated(); std::vector var_scopes; for (auto *s : local_scopes_) { @@ -147,14 +145,6 @@ void BroadcastOpHandle::RunImpl() { } } -void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) { - if (in_var.generated_op_) { - for (auto &pair : dev_ctxes_) { - in_var.generated_op_->Wait(pair.second); - } - } -} - std::string BroadcastOpHandle::Name() const { return "broadcast"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 984a95008c0..48e356af4ba 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; - void WaitInputVarGenerated(const VarHandle &in_var); private: const std::vector &local_scopes_; diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 7ff0efe0938..ffbe2094a45 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, place_(place) {} void ComputationOpHandle::RunImpl() { - auto *cur_ctx = dev_ctxes_[place_]; - for (auto *in : inputs_) { - bool need_wait = in->generated_op_ && - in->generated_op_->DeviceContext(place_) != cur_ctx; - if (need_wait) { - in->generated_op_->Wait(cur_ctx); - } - } + WaitInputVarGenerated(place_); this->RunAndRecordEvent([this] { op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); }); } +bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { + bool need_wait = + dynamic_cast(in_var) && in_var->generated_op_ && + in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; + return need_wait; +} + std::string ComputationOpHandle::Name() const { return op_->Type(); } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index c363b973d9a..36e6f1bf59a 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase { protected: void RunImpl() override; + virtual bool NeedWait(VarHandleBase *in_var); + private: std::unique_ptr op_; Scope *scope_; diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 1e8ca20b51d..b1c9dd0d152 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() { } } -void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) { +void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); } @@ -45,12 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const { } void FetchOpHandle::RunImpl() { - auto cpu_ctx = - platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); - for (auto *input : inputs_) { - auto *var = static_cast(input); - var->generated_op_->Wait(cpu_ctx); - } + WaitInputVarGenerated(platform::CPUPlace()); + tensors_.resize(inputs_.size()); auto *var_handle = static_cast(inputs_[0]); auto &var_name = var_handle->name_; @@ -77,6 +73,15 @@ void FetchOpHandle::RunImpl() { this->WaitAndMergeCPUTensors(); } +void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) { + auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place); + for (auto *input : inputs_) { + if (input->generated_op_) { + input->generated_op_->RecordWaitEventOnCtx(cpu_ctx); + } + } +} + std::string FetchOpHandle::Name() const { return "Fetch"; } } // namespace details diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index b49f3df338d..e696a7a9ce5 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase { ~FetchOpHandle(); - void Wait(platform::DeviceContext *waited_dev) override; + void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override; void WaitAndMergeCPUTensors() const; @@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase { protected: void RunImpl() override; + virtual void WaitInputVarGenerated(const platform::Place &place); + private: FeedFetchList *data_; size_t offset_; diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 3dfc972a44c..2be02304566 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() { "Currently, gather_op only can gather SelectedRows."); // Wait input done, this Wait is asynchronous operation - WaitInputVarGenerated(in_var_handles); + WaitInputVarGenerated(); auto &pre_in_value = pre_in_var->Get(); std::vector out_rows; @@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() { }); } -void GatherOpHandle::WaitInputVarGenerated( - const std::vector &in_var_handles) { - for (auto *in : in_var_handles) { - if (in->generated_op_) { - for (auto pair : dev_ctxes_) { - in->generated_op_->Wait(pair.second); - } - } - } -} - std::string GatherOpHandle::Name() const { return "gather"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index c394dd7a14b..d11ef8556aa 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase { protected: void RunImpl() override; - void WaitInputVarGenerated(const std::vector &in_var_handles); private: const std::vector &local_scopes_; diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index b055bb48f60..95aa599cd3e 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -34,10 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() { return; // No need to all reduce when GPU count = 1; } else { // Wait input done - for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; - in->generated_op_->Wait(dev_ctxes_[p]); - } + WaitInputVarGenerated(); auto &var_name = static_cast(this->inputs_[0])->name_; int dtype = -1; diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 534d77860f8..b05b9d95e73 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) { RunImpl(); } -void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { +void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { #ifdef PADDLE_WITH_CUDA - if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) { + if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { for (auto &dev_ctx : dev_ctxes_) { dev_ctx.second->Wait(); } } else { auto stream = - static_cast(waited_dev)->stream(); + static_cast(waited_ctx)->stream(); for (auto &ev : events_) { PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0)); } @@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) { out->generated_op_ = this; } +void OpHandleBase::WaitInputVarGenerated() { + for (auto in_var : inputs_) { + if (NeedWait(in_var)) { + for (auto &pair : dev_ctxes_) { + in_var->generated_op_->RecordWaitEventOnCtx(pair.second); + } + } + } +} + +void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { + for (auto *in : inputs_) { + if (NeedWait(in)) { + in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]); + } + } +} + +bool OpHandleBase::NeedWait(VarHandleBase *in_var) { + return dynamic_cast(in_var) && in_var->generated_op_; +} + void OpHandleBase::RunAndRecordEvent(const std::function &callback) { #ifdef PADDLE_WITH_CUDA if (!events_.empty()) { // Use event diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 00f213f3ed2..3d4a0931254 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -38,12 +38,20 @@ class OpHandleBase { void Run(bool use_event); - virtual void Wait(platform::DeviceContext *waited_dev); + virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx); void AddInput(VarHandleBase *in); void AddOutput(VarHandleBase *out); + // Wait inputs are generated, this Wait is asynchronous operation. + virtual void WaitInputVarGenerated(); + + // Wait inputs are generated, this Wait is asynchronous operation. + virtual void WaitInputVarGenerated(const platform::Place &place); + + virtual bool NeedWait(VarHandleBase *in_var); + // If the Op involves data transfer of multiple devices that // will likely block other computations. virtual bool IsMultiDeviceTransfer() { return false; } diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 1bb04c1dfca..f653064ade5 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() { PADDLE_ENFORCE_NOT_NULL(pre_in_var); // Wait input done, this Wait is asynchronous operation - WaitInputVarGenerated(in_var_handles); + WaitInputVarGenerated(); // NOTE: The Places of all input tensor must be all on CPU or all on GPU. std::vector in_places; // used to get dev_ctx @@ -157,17 +157,6 @@ std::vector ReduceOpHandle::GetInputValues( return in_selected_rows; } -void ReduceOpHandle::WaitInputVarGenerated( - const std::vector &in_var_handles) { - for (auto *in : in_var_handles) { - if (in->generated_op_) { - for (auto pair : dev_ctxes_) { - in->generated_op_->Wait(pair.second); - } - } - } -} - std::string ReduceOpHandle::Name() const { return "reduce"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 59731d348d1..c652a2f4eb0 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase { protected: void RunImpl() override; - void WaitInputVarGenerated(const std::vector &in_var_handles); - template std::vector GetInputValues( const std::vector &in_var_handles, diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index 1cd31130300..d9c387e79dc 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} void ScaleLossGradOpHandle::RunImpl() { + // Doesn't wait any event std::string var_name = static_cast(this->outputs_[0])->name_; auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get(); diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index 0763f92171e..ee4beb5f9b9 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -27,13 +27,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, void SendOpHandle::RunImpl() { // Wait input done - for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; - if (in->DebugString() == "dummy") { // HACK - continue; - } - in->generated_op_->Wait(dev_ctxes_[p]); - } + WaitInputVarGenerated(); auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead // lock. -- GitLab