未验证 提交 85545bbc 编写于 作者: L liuyuhui 提交者: GitHub

fix xpu pe sync, test=notest (#30095) (#30114)

上级 157ff094
...@@ -215,6 +215,13 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) { ...@@ -215,6 +215,13 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA.")); platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
dev_ctxes_.at(place)->Wait();
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with XPU."));
#endif #endif
} }
// There are nothing to do when the place is CPUPlace. // There are nothing to do when the place is CPUPlace.
...@@ -264,6 +271,19 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { ...@@ -264,6 +271,19 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA.")); platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif
} else if (platform::is_xpu_place(in_var_handle->place())) {
#ifdef PADDLE_WITH_XPU
PADDLE_ENFORCE_EQ(
platform::is_same_place(place, in_var_handle->place()), true,
platform::errors::InvalidArgument(
"The place of output(%s) is not consistent with the "
"place of current op(%s).",
in_var_handle->Name(), Name()));
dev_ctxes_.at(place)->Wait();
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with XPU."));
#endif #endif
} }
// There are nothing to do when the place is CPUPlace. // There are nothing to do when the place is CPUPlace.
......
...@@ -23,7 +23,9 @@ namespace ir { ...@@ -23,7 +23,9 @@ namespace ir {
static bool IsLockAndRecordEventFreeComputationOpHandle( static bool IsLockAndRecordEventFreeComputationOpHandle(
details::ComputationOpHandle *op, const OpGraphView &graph_view) { details::ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false; if (!platform::is_gpu_place(op->GetPlace()) &&
!platform::is_xpu_place(op->GetPlace()))
return false;
for (auto &pending_op : graph_view.PendingOps(op)) { for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op); auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册