diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index ffbe2094a454d981f67431cfbdb53505d8d731a1..df05bb06333d6b964f2f5434c3d43214e5d2cb7a 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -35,7 +35,7 @@ void ComputationOpHandle::RunImpl() { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool need_wait = - dynamic_cast(in_var) && in_var->generated_op_ && + in_var && in_var->generated_op_ && in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; return need_wait; } diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index b05b9d95e73517da11ef236b00b4f5a749632e54..6b064650b4f09737836bda4a43fa421720077929 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -105,7 +105,7 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { } bool OpHandleBase::NeedWait(VarHandleBase *in_var) { - return dynamic_cast(in_var) && in_var->generated_op_; + return in_var && in_var->generated_op_; } void OpHandleBase::RunAndRecordEvent(const std::function &callback) { diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index 01f3a9df76532cbdd68aa94059a19f7e0ec1d461..7109659dd7001f91e7674ac7bebbe3a59794cfc0 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -26,8 +26,17 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, place_(place) {} void SendOpHandle::RunImpl() { + // TODO(wuyi): need further analysis whether wait VarDummyHandle. // Wait input done - WaitInputVarGenerated(place_); + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + if (in->DebugString() == "dummy") { // HACK + continue; + } + if (in->generated_op_) { + in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); + } + } auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead // lock.