From a89cd46700524b2fd2bec7407464e344030b8158 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 10 May 2018 20:34:11 +0800 Subject: [PATCH] Wait VarDummyHandle generated --- .../fluid/framework/details/computation_op_handle.cc | 2 +- paddle/fluid/framework/details/op_handle_base.cc | 2 +- paddle/fluid/framework/details/send_op_handle.cc | 11 ++++++++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index ffbe2094a..df05bb063 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -35,7 +35,7 @@ void ComputationOpHandle::RunImpl() { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool need_wait = - dynamic_cast(in_var) && in_var->generated_op_ && + in_var && in_var->generated_op_ && in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; return need_wait; } diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index b05b9d95e..6b064650b 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -105,7 +105,7 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { } bool OpHandleBase::NeedWait(VarHandleBase *in_var) { - return dynamic_cast(in_var) && in_var->generated_op_; + return in_var && in_var->generated_op_; } void OpHandleBase::RunAndRecordEvent(const std::function &callback) { diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index 01f3a9df7..7109659dd 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -26,8 +26,17 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, place_(place) {} void SendOpHandle::RunImpl() { + // TODO(wuyi): need further analysis whether wait VarDummyHandle. // Wait input done - WaitInputVarGenerated(place_); + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + if (in->DebugString() == "dummy") { // HACK + continue; + } + if (in->generated_op_) { + in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); + } + } auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead // lock. -- GitLab