未验证 提交 9ad06e06 编写于 作者: W wanghuancoder 提交者: GitHub

refine fill with tensor (#56568)

上级 a771e343
......@@ -147,4 +147,7 @@ PD_REGISTER_KERNEL(full_with_tensor,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
......@@ -162,4 +162,7 @@ PD_REGISTER_KERNEL(full_with_tensor,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
......@@ -109,7 +109,10 @@ PD_REGISTER_KERNEL(full_with_tensor_sr,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(full_with_tensor_sr,
......@@ -125,7 +128,10 @@ PD_REGISTER_KERNEL(full_with_tensor_sr,
bool,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
#endif
#if defined(PADDLE_WITH_XPU)
......@@ -139,5 +145,8 @@ PD_REGISTER_KERNEL(full_with_tensor_sr,
int,
int64_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
#endif
......@@ -165,4 +165,7 @@ PD_REGISTER_KERNEL(full_with_tensor,
int,
int64_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册