From b051455f47eaa2cc33237042c220d9724032b210 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 11 Oct 2022 21:39:20 +0800 Subject: [PATCH] set_value_op: add support for complex types (#46885) --- paddle/phi/kernels/cpu/set_value_grad_kernel.cc | 5 ++++- paddle/phi/kernels/cpu/set_value_kernel.cc | 9 +++++++-- paddle/phi/kernels/gpu/set_value_grad_kernel.cu | 5 ++++- paddle/phi/kernels/gpu/set_value_kernel.cu | 9 +++++++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/cpu/set_value_grad_kernel.cc b/paddle/phi/kernels/cpu/set_value_grad_kernel.cc index 882648e8c34..dad7628dcf3 100644 --- a/paddle/phi/kernels/cpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/set_value_grad_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/set_value_grad_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h" @@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/set_value_kernel.cc b/paddle/phi/kernels/cpu/set_value_kernel.cc index be5affb4ccf..4b0c0415e48 100644 --- a/paddle/phi/kernels/cpu/set_value_kernel.cc +++ b/paddle/phi/kernels/cpu/set_value_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/set_value_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/set_value_kernel_impl.h" @@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(set_value_with_tensor, CPU, ALL_LAYOUT, @@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/set_value_grad_kernel.cu b/paddle/phi/kernels/gpu/set_value_grad_kernel.cu index 49a57b94418..77e140cab14 100644 --- a/paddle/phi/kernels/gpu/set_value_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/set_value_grad_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/set_value_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h" @@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/set_value_kernel.cu b/paddle/phi/kernels/gpu/set_value_kernel.cu index 0e6c5734852..1a268c2f6b0 100644 --- a/paddle/phi/kernels/gpu/set_value_kernel.cu +++ b/paddle/phi/kernels/gpu/set_value_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/set_value_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/set_value_kernel_impl.h" @@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value, int, int64_t, bool, - paddle::platform::float16) {} + paddle::platform::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(set_value_with_tensor, GPU, ALL_LAYOUT, @@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor, int, int64_t, bool, - paddle::platform::float16) {} + paddle::platform::float16, + phi::dtype::complex, + phi::dtype::complex) {} -- GitLab