未验证 提交 d08a1a0d 编写于 作者: D duanyanhui 提交者: GitHub

add xpu controlflow (#51488)

上级 4a484973
...@@ -89,6 +89,13 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -89,6 +89,13 @@ class ConditionalOp : public framework::OperatorBase {
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor); framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait(); platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
res = cpu_tensor.data<bool>()[0]; res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_xpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_XPU
phi::DenseTensor cpu_tensor;
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif #endif
} else if (platform::is_custom_place(ips[0]->place())) { } else if (platform::is_custom_place(ips[0]->place())) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE) #if defined(PADDLE_WITH_CUSTOM_DEVICE)
......
...@@ -40,7 +40,7 @@ inline int GetBranchNumber(const phi::DenseTensor &mask) { ...@@ -40,7 +40,7 @@ inline int GetBranchNumber(const phi::DenseTensor &mask) {
// when platform::is_gpu_place(mask.place()) is true // when platform::is_gpu_place(mask.place()) is true
std::unique_ptr<phi::DenseTensor> cpu_mask{new phi::DenseTensor()}; std::unique_ptr<phi::DenseTensor> cpu_mask{new phi::DenseTensor()};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_CUSTOM_DEVICE) defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU)
framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get()); framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get());
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册