未验证 提交 34c7e3e3 编写于 作者: F Feiyu Chan 提交者: GitHub

set_value_op: add support for complex types (#46884)

上级 8a9d4003
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/set_value_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h"
......@@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad,
int,
int64_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/set_value_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
......@@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value,
int,
int64_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
CPU,
ALL_LAYOUT,
......@@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
int,
int64_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/set_value_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h"
......@@ -27,4 +28,6 @@ PD_REGISTER_KERNEL(set_value_grad,
int,
int64_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/set_value_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
......@@ -27,7 +28,9 @@ PD_REGISTER_KERNEL(set_value,
int,
int64_t,
bool,
paddle::platform::float16) {}
paddle::platform::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
GPU,
ALL_LAYOUT,
......@@ -37,4 +40,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
int,
int64_t,
bool,
paddle::platform::float16) {}
paddle::platform::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册