diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index eeff0f3d46d633c8f834dba96e0ada2e09dd86a0..e2f4f453ccfe35e6cd6fc98e2dc23c87cc6b8e1a 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -215,6 +215,13 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) { #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with CUDA.")); +#endif + } else if (platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU + dev_ctxes_.at(place)->Wait(); +#else + PADDLE_THROW( + platform::errors::PreconditionNotMet("Not compiled with XPU.")); #endif } // There are nothing to do when the place is CPUPlace. @@ -264,6 +271,19 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with CUDA.")); +#endif + } else if (platform::is_xpu_place(in_var_handle->place())) { +#ifdef PADDLE_WITH_XPU + PADDLE_ENFORCE_EQ( + platform::is_same_place(place, in_var_handle->place()), true, + platform::errors::InvalidArgument( + "The place of output(%s) is not consistent with the " + "place of current op(%s).", + in_var_handle->Name(), Name())); + dev_ctxes_.at(place)->Wait(); +#else + PADDLE_THROW( + platform::errors::PreconditionNotMet("Not compiled with XPU.")); #endif } // There are nothing to do when the place is CPUPlace. diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc index e9b35aefc94e8544455e9559746990cdb4362ebb..70b95c9154fd300c358ce6d05c56e565c5ed9689 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc @@ -23,7 +23,9 @@ namespace ir { static bool IsLockAndRecordEventFreeComputationOpHandle( details::ComputationOpHandle *op, const OpGraphView &graph_view) { - if (!platform::is_gpu_place(op->GetPlace())) return false; + if (!platform::is_gpu_place(op->GetPlace()) && + !platform::is_xpu_place(op->GetPlace())) + return false; for (auto &pending_op : graph_view.PendingOps(op)) { auto *tmp = dynamic_cast(pending_op); if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {