未验证 提交 34c7e3e3 编写于 作者: F Feiyu Chan 提交者: GitHub

set_value_op: add support for complex types (#46884)

上级 8a9d4003
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/set_value_grad_kernel.h" #include "paddle/phi/kernels/set_value_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h"
...@@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad, ...@@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/set_value_kernel.h" #include "paddle/phi/kernels/set_value_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h" #include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
...@@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value, ...@@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(set_value_with_tensor, PD_REGISTER_KERNEL(set_value_with_tensor,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor, ...@@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/set_value_grad_kernel.h" #include "paddle/phi/kernels/set_value_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h"
...@@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad, ...@@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/set_value_kernel.h" #include "paddle/phi/kernels/set_value_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h" #include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
...@@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value, ...@@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16) {} paddle::platform::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(set_value_with_tensor, PD_REGISTER_KERNEL(set_value_with_tensor,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor, ...@@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16) {} paddle::platform::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册