From 254ad6195999c2ba0c064d79a232fc7422ef37ef Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Tue, 5 Jan 2021 15:09:19 +0800 Subject: [PATCH] fix xpu pe sync, test=notest (#30095) --- .../fluid/framework/details/op_handle_base.cc | 20 +++++++++++++++++++ .../modify_op_lock_and_record_event_pass.cc | 4 +++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index eeff0f3d46..e2f4f453cc 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -215,6 +215,13 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) { #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with CUDA.")); +#endif + } else if (platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU + dev_ctxes_.at(place)->Wait(); +#else + PADDLE_THROW( + platform::errors::PreconditionNotMet("Not compiled with XPU.")); #endif } // There are nothing to do when the place is CPUPlace. @@ -264,6 +271,19 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with CUDA.")); +#endif + } else if (platform::is_xpu_place(in_var_handle->place())) { +#ifdef PADDLE_WITH_XPU + PADDLE_ENFORCE_EQ( + platform::is_same_place(place, in_var_handle->place()), true, + platform::errors::InvalidArgument( + "The place of output(%s) is not consistent with the " + "place of current op(%s).", + in_var_handle->Name(), Name())); + dev_ctxes_.at(place)->Wait(); +#else + PADDLE_THROW( + platform::errors::PreconditionNotMet("Not compiled with XPU.")); #endif } // There are nothing to do when the place is CPUPlace. diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc index e9b35aefc9..70b95c9154 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc @@ -23,7 +23,9 @@ namespace ir { static bool IsLockAndRecordEventFreeComputationOpHandle( details::ComputationOpHandle *op, const OpGraphView &graph_view) { - if (!platform::is_gpu_place(op->GetPlace())) return false; + if (!platform::is_gpu_place(op->GetPlace()) && + !platform::is_xpu_place(op->GetPlace())) + return false; for (auto &pending_op : graph_view.PendingOps(op)) { auto *tmp = dynamic_cast(pending_op); if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { -- GitLab