diff --git a/paddle/phi/kernels/cpu/set_value_grad_kernel.cc b/paddle/phi/kernels/cpu/set_value_grad_kernel.cc index 882648e8c346a62f0f2b5ad3f63da29944c03cd8..dad7628dcf30a732ec6b3fddbd9890f82afb22a2 100644 --- a/paddle/phi/kernels/cpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/set_value_grad_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/set_value_grad_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h" @@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/set_value_kernel.cc b/paddle/phi/kernels/cpu/set_value_kernel.cc index be5affb4ccfbfcd53f3b0e0e85181e9131a341a8..4b0c0415e483491cce7476bc3a1211de7bf8b516 100644 --- a/paddle/phi/kernels/cpu/set_value_kernel.cc +++ b/paddle/phi/kernels/cpu/set_value_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/set_value_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/set_value_kernel_impl.h" @@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(set_value_with_tensor, CPU, ALL_LAYOUT, @@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/set_value_grad_kernel.cu b/paddle/phi/kernels/gpu/set_value_grad_kernel.cu index 49a57b944187215a113ed256b48c6793425cdf9a..77e140cab14ce6deb4fa720db55a48097f34df19 100644 --- a/paddle/phi/kernels/gpu/set_value_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/set_value_grad_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/set_value_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h" @@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/set_value_kernel.cu b/paddle/phi/kernels/gpu/set_value_kernel.cu index 0e6c5734852b787153ff583c961f05e275ec9839..1a268c2f6b089b2dc77d1480acc0bf0ca56e0d76 100644 --- a/paddle/phi/kernels/gpu/set_value_kernel.cu +++ b/paddle/phi/kernels/gpu/set_value_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/set_value_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/set_value_kernel_impl.h" @@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value, int, int64_t, bool, - paddle::platform::float16) {} + paddle::platform::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(set_value_with_tensor, GPU, ALL_LAYOUT, @@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor, int, int64_t, bool, - paddle::platform::float16) {} + paddle::platform::float16, + phi::dtype::complex, + phi::dtype::complex) {}