未验证 提交 edf46919 编写于 作者: D duanyanhui 提交者: GitHub

add support of controlflow op for custom device (#48259)

上级 b994c89d
......@@ -89,6 +89,13 @@ class ConditionalOp : public framework::OperatorBase {
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_custom_place(ips[0]->place())) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else {
res = ips[0]->data<bool>()[0];
......
......@@ -228,8 +228,9 @@ bool GetCondData(const phi::DenseTensor &cond) {
// platform::is_npu_place(cond.place()) or
// platform::is_xpu_place(cond.place()) is true
std::unique_ptr<phi::DenseTensor> cpu_cond{new phi::DenseTensor()};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU) || \
defined(PADDLE_WITH_CUSTOM_DEVICE)
framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -39,7 +39,8 @@ inline int GetBranchNumber(const phi::DenseTensor &mask) {
}
// when platform::is_gpu_place(mask.place()) is true
std::unique_ptr<phi::DenseTensor> cpu_mask{new phi::DenseTensor()};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_CUSTOM_DEVICE)
framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册