未验证 提交 9ad06e06 编写于 作者: W wanghuancoder 提交者: GitHub

refine fill with tensor (#56568)

上级 a771e343
...@@ -147,4 +147,7 @@ PD_REGISTER_KERNEL(full_with_tensor, ...@@ -147,4 +147,7 @@ PD_REGISTER_KERNEL(full_with_tensor,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
...@@ -162,4 +162,7 @@ PD_REGISTER_KERNEL(full_with_tensor, ...@@ -162,4 +162,7 @@ PD_REGISTER_KERNEL(full_with_tensor,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
...@@ -109,7 +109,10 @@ PD_REGISTER_KERNEL(full_with_tensor_sr, ...@@ -109,7 +109,10 @@ PD_REGISTER_KERNEL(full_with_tensor_sr,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(full_with_tensor_sr, PD_REGISTER_KERNEL(full_with_tensor_sr,
...@@ -125,7 +128,10 @@ PD_REGISTER_KERNEL(full_with_tensor_sr, ...@@ -125,7 +128,10 @@ PD_REGISTER_KERNEL(full_with_tensor_sr,
bool, bool,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
#endif #endif
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
...@@ -139,5 +145,8 @@ PD_REGISTER_KERNEL(full_with_tensor_sr, ...@@ -139,5 +145,8 @@ PD_REGISTER_KERNEL(full_with_tensor_sr,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
#endif #endif
...@@ -165,4 +165,7 @@ PD_REGISTER_KERNEL(full_with_tensor, ...@@ -165,4 +165,7 @@ PD_REGISTER_KERNEL(full_with_tensor,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::CPU);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册