未验证 提交 254ad619 编写于 作者: L liuyuhui 提交者: GitHub

fix xpu pe sync, test=notest (#30095)

上级 0b8e1fad
......@@ -215,6 +215,13 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
dev_ctxes_.at(place)->Wait();
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with XPU."));
#endif
}
// There are nothing to do when the place is CPUPlace.
......@@ -264,6 +271,19 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif
} else if (platform::is_xpu_place(in_var_handle->place())) {
#ifdef PADDLE_WITH_XPU
PADDLE_ENFORCE_EQ(
platform::is_same_place(place, in_var_handle->place()), true,
platform::errors::InvalidArgument(
"The place of output(%s) is not consistent with the "
"place of current op(%s).",
in_var_handle->Name(), Name()));
dev_ctxes_.at(place)->Wait();
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with XPU."));
#endif
}
// There are nothing to do when the place is CPUPlace.
......
......@@ -23,7 +23,9 @@ namespace ir {
static bool IsLockAndRecordEventFreeComputationOpHandle(
details::ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
if (!platform::is_gpu_place(op->GetPlace()) &&
!platform::is_xpu_place(op->GetPlace()))
return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册