diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 2afa47c81bead6fb104f49886713bf75dc1b4dc0..f176e4e1599a05108e25b3e1085ac3c4339d3793 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() { out_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - // Wait input done, this Wait is asynchronous operation platform::Place - // &in_place; - WaitInputVarGenerated(*in_var_handle); + WaitInputVarGenerated(); std::vector var_scopes; for (auto *s : local_scopes_) { @@ -147,14 +145,6 @@ void BroadcastOpHandle::RunImpl() { } } -void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) { - if (in_var.generated_op_) { - for (auto &pair : dev_ctxes_) { - in_var.generated_op_->Wait(pair.second); - } - } -} - std::string BroadcastOpHandle::Name() const { return "broadcast"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 984a95008c0393eff01c2d419cc98949aed14980..48e356af4bace846563fd729f40cdd0f8b7a415e 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; - void WaitInputVarGenerated(const VarHandle &in_var); private: const std::vector &local_scopes_; diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 7ff0efe09387b7e5d7cfe0dfe5e129ca9914d90b..ffbe2094a454d981f67431cfbdb53505d8d731a1 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, place_(place) {} void ComputationOpHandle::RunImpl() { - auto *cur_ctx = dev_ctxes_[place_]; - for (auto *in : inputs_) { - bool need_wait = in->generated_op_ && - in->generated_op_->DeviceContext(place_) != cur_ctx; - if (need_wait) { - in->generated_op_->Wait(cur_ctx); - } - } + WaitInputVarGenerated(place_); this->RunAndRecordEvent([this] { op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); }); } +bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { + bool need_wait = + dynamic_cast(in_var) && in_var->generated_op_ && + in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; + return need_wait; +} + std::string ComputationOpHandle::Name() const { return op_->Type(); } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index c363b973d9abbae6bea76c2458fbe82a37a342ca..36e6f1bf59a7646e1dff6c4844f2a36a5caf363a 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase { protected: void RunImpl() override; + virtual bool NeedWait(VarHandleBase *in_var); + private: std::unique_ptr op_; Scope *scope_; diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 1e8ca20b51d43554cf1898b41b31c27b90e6c642..b1c9dd0d15223f7d1bf6ea44144589f1de927e3e 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() { } } -void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) { +void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); } @@ -45,12 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const { } void FetchOpHandle::RunImpl() { - auto cpu_ctx = - platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); - for (auto *input : inputs_) { - auto *var = static_cast(input); - var->generated_op_->Wait(cpu_ctx); - } + WaitInputVarGenerated(platform::CPUPlace()); + tensors_.resize(inputs_.size()); auto *var_handle = static_cast(inputs_[0]); auto &var_name = var_handle->name_; @@ -77,6 +73,15 @@ void FetchOpHandle::RunImpl() { this->WaitAndMergeCPUTensors(); } +void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) { + auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place); + for (auto *input : inputs_) { + if (input->generated_op_) { + input->generated_op_->RecordWaitEventOnCtx(cpu_ctx); + } + } +} + std::string FetchOpHandle::Name() const { return "Fetch"; } } // namespace details diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index b49f3df338dc11310a4a0c27c8aaae3602373fcc..e696a7a9ce562e7f1b7fe6633623cb940810fbe1 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase { ~FetchOpHandle(); - void Wait(platform::DeviceContext *waited_dev) override; + void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override; void WaitAndMergeCPUTensors() const; @@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase { protected: void RunImpl() override; + virtual void WaitInputVarGenerated(const platform::Place &place); + private: FeedFetchList *data_; size_t offset_; diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 3dfc972a44c62bd2adfc1331f29ffb1cca537652..2be02304566cf5dbe348fa01fc4171990eafd158 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() { "Currently, gather_op only can gather SelectedRows."); // Wait input done, this Wait is asynchronous operation - WaitInputVarGenerated(in_var_handles); + WaitInputVarGenerated(); auto &pre_in_value = pre_in_var->Get(); std::vector out_rows; @@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() { }); } -void GatherOpHandle::WaitInputVarGenerated( - const std::vector &in_var_handles) { - for (auto *in : in_var_handles) { - if (in->generated_op_) { - for (auto pair : dev_ctxes_) { - in->generated_op_->Wait(pair.second); - } - } - } -} - std::string GatherOpHandle::Name() const { return "gather"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index c394dd7a14b07cb956aa1aedfc0df4fa25744dd7..d11ef8556aa8840949ca8dc7aa176413f70b9f22 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase { protected: void RunImpl() override; - void WaitInputVarGenerated(const std::vector &in_var_handles); private: const std::vector &local_scopes_; diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index b055bb48f608c9fd9cc671d175cb463d25dc489b..95aa599cd3e403e9cc66b2b5ad35d0d214d1ab5b 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -34,10 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() { return; // No need to all reduce when GPU count = 1; } else { // Wait input done - for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; - in->generated_op_->Wait(dev_ctxes_[p]); - } + WaitInputVarGenerated(); auto &var_name = static_cast(this->inputs_[0])->name_; int dtype = -1; diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 534d77860f87be08c8834efd373d90eb199ed6a2..b05b9d95e73517da11ef236b00b4f5a749632e54 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) { RunImpl(); } -void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { +void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { #ifdef PADDLE_WITH_CUDA - if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) { + if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { for (auto &dev_ctx : dev_ctxes_) { dev_ctx.second->Wait(); } } else { auto stream = - static_cast(waited_dev)->stream(); + static_cast(waited_ctx)->stream(); for (auto &ev : events_) { PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0)); } @@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) { out->generated_op_ = this; } +void OpHandleBase::WaitInputVarGenerated() { + for (auto in_var : inputs_) { + if (NeedWait(in_var)) { + for (auto &pair : dev_ctxes_) { + in_var->generated_op_->RecordWaitEventOnCtx(pair.second); + } + } + } +} + +void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { + for (auto *in : inputs_) { + if (NeedWait(in)) { + in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]); + } + } +} + +bool OpHandleBase::NeedWait(VarHandleBase *in_var) { + return dynamic_cast(in_var) && in_var->generated_op_; +} + void OpHandleBase::RunAndRecordEvent(const std::function &callback) { #ifdef PADDLE_WITH_CUDA if (!events_.empty()) { // Use event diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 00f213f3ed294adcce7c540e3ff346de8e2be7fb..3d4a09312542b3ba9a4cf3d7697beb9c790868e9 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -38,12 +38,20 @@ class OpHandleBase { void Run(bool use_event); - virtual void Wait(platform::DeviceContext *waited_dev); + virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx); void AddInput(VarHandleBase *in); void AddOutput(VarHandleBase *out); + // Wait inputs are generated, this Wait is asynchronous operation. + virtual void WaitInputVarGenerated(); + + // Wait inputs are generated, this Wait is asynchronous operation. + virtual void WaitInputVarGenerated(const platform::Place &place); + + virtual bool NeedWait(VarHandleBase *in_var); + // If the Op involves data transfer of multiple devices that // will likely block other computations. virtual bool IsMultiDeviceTransfer() { return false; } diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 1bb04c1dfca107f4b7ce4c599e9aa132de3e5985..f653064ade58d7f52d771d432e53d7edac6df24b 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() { PADDLE_ENFORCE_NOT_NULL(pre_in_var); // Wait input done, this Wait is asynchronous operation - WaitInputVarGenerated(in_var_handles); + WaitInputVarGenerated(); // NOTE: The Places of all input tensor must be all on CPU or all on GPU. std::vector in_places; // used to get dev_ctx @@ -157,17 +157,6 @@ std::vector ReduceOpHandle::GetInputValues( return in_selected_rows; } -void ReduceOpHandle::WaitInputVarGenerated( - const std::vector &in_var_handles) { - for (auto *in : in_var_handles) { - if (in->generated_op_) { - for (auto pair : dev_ctxes_) { - in->generated_op_->Wait(pair.second); - } - } - } -} - std::string ReduceOpHandle::Name() const { return "reduce"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 59731d348d17755fbd8bf3b6fa29b32bdefaf71e..c652a2f4eb0f9b73cb19ebbd9d0809210b280ad3 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase { protected: void RunImpl() override; - void WaitInputVarGenerated(const std::vector &in_var_handles); - template std::vector GetInputValues( const std::vector &in_var_handles, diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index 1cd3113030086104e7fc5c4ba3364a5ff027632b..d9c387e79dc71288e7330597fed57171d447f31b 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} void ScaleLossGradOpHandle::RunImpl() { + // Doesn't wait any event std::string var_name = static_cast(this->outputs_[0])->name_; auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get(); diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index 0763f92171e7813ec0ee8ca4f3aa42b76205130a..ee4beb5f9b91da234c013a91919caeee5cda7378 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -27,13 +27,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, void SendOpHandle::RunImpl() { // Wait input done - for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; - if (in->DebugString() == "dummy") { // HACK - continue; - } - in->generated_op_->Wait(dev_ctxes_[p]); - } + WaitInputVarGenerated(); auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead // lock.