未验证 提交 b60f48ce 编写于 作者: R ronnywang 提交者: GitHub

[PHI CAPI] support complex dtype kernel (#52414)

* [PHI CAPI] support complex dtype kernel

* update
上级 5df1296d
...@@ -92,6 +92,11 @@ inline T GetValue(const phi::DenseTensor* x) { ...@@ -92,6 +92,11 @@ inline T GetValue(const phi::DenseTensor* x) {
if (!platform::is_cpu_place(x->place())) { if (!platform::is_cpu_place(x->place())) {
phi::DenseTensor cpu_x; phi::DenseTensor cpu_x;
framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x); framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x);
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx = pool.Get(x->place());
dev_ctx->Wait();
#endif
value = cpu_x.data<T>()[0]; value = cpu_x.data<T>()[0];
} else { } else {
value = x->data<T>()[0]; value = x->data<T>()[0];
......
...@@ -40,6 +40,8 @@ typedef enum { ...@@ -40,6 +40,8 @@ typedef enum {
FLOAT32, FLOAT32,
FLOAT64, FLOAT64,
BFLOAT16, BFLOAT16,
COMPLEX64,
COMPLEX128,
} C_DataType; } C_DataType;
typedef enum { typedef enum {
......
...@@ -22,20 +22,22 @@ namespace phi { ...@@ -22,20 +22,22 @@ namespace phi {
namespace capi { namespace capi {
#define CPP_TYPE_TO_PD_DTYPE_REGISTER(_) \ #define CPP_TYPE_TO_PD_DTYPE_REGISTER(_) \
_(bool, PD_DataType::BOOL) \ _(bool, PD_DataType::BOOL) \
_(phi::dtype::bfloat16, PD_DataType::BFLOAT16) \ _(phi::dtype::bfloat16, PD_DataType::BFLOAT16) \
_(phi::dtype::float16, PD_DataType::FLOAT16) \ _(phi::dtype::float16, PD_DataType::FLOAT16) \
_(float, PD_DataType::FLOAT32) \ _(float, PD_DataType::FLOAT32) \
_(double, PD_DataType::FLOAT64) \ _(double, PD_DataType::FLOAT64) \
_(uint8_t, PD_DataType::UINT8) \ _(uint8_t, PD_DataType::UINT8) \
_(uint16_t, PD_DataType::UINT16) \ _(uint16_t, PD_DataType::UINT16) \
_(uint32_t, PD_DataType::UINT32) \ _(uint32_t, PD_DataType::UINT32) \
_(uint64_t, PD_DataType::UINT64) \ _(uint64_t, PD_DataType::UINT64) \
_(int8_t, PD_DataType::INT8) \ _(int8_t, PD_DataType::INT8) \
_(int16_t, PD_DataType::INT16) \ _(int16_t, PD_DataType::INT16) \
_(int32_t, PD_DataType::INT32) \ _(int32_t, PD_DataType::INT32) \
_(int64_t, PD_DataType::INT64) _(int64_t, PD_DataType::INT64) \
_(phi::dtype::complex<float>, PD_DataType::COMPLEX64) \
_(phi::dtype::complex<double>, PD_DataType::COMPLEX128)
template <typename T> template <typename T>
struct CppTypeToPDType; struct CppTypeToPDType;
......
...@@ -42,6 +42,8 @@ inline PD_DataType ToPDDataType(::phi::DataType dtype) { ...@@ -42,6 +42,8 @@ inline PD_DataType ToPDDataType(::phi::DataType dtype) {
return_result(UINT16, UINT16); return_result(UINT16, UINT16);
return_result(UINT8, UINT8); return_result(UINT8, UINT8);
return_result(BOOL, BOOL); return_result(BOOL, BOOL);
return_result(COMPLEX64, COMPLEX64);
return_result(COMPLEX128, COMPLEX128);
default: { default: {
PADDLE_THROW( PADDLE_THROW(
::phi::errors::Unavailable("DataType %d is not supported.", dtype)); ::phi::errors::Unavailable("DataType %d is not supported.", dtype));
...@@ -69,6 +71,8 @@ inline ::phi::DataType ToPhiDataType(PD_DataType dtype) { ...@@ -69,6 +71,8 @@ inline ::phi::DataType ToPhiDataType(PD_DataType dtype) {
return_result(UINT16, UINT16); return_result(UINT16, UINT16);
return_result(UINT8, UINT8); return_result(UINT8, UINT8);
return_result(BOOL, BOOL); return_result(BOOL, BOOL);
return_result(COMPLEX64, COMPLEX64);
return_result(COMPLEX128, COMPLEX128);
default: { default: {
PADDLE_THROW( PADDLE_THROW(
::phi::errors::Unavailable("DataType %d is not supported.", dtype)); ::phi::errors::Unavailable("DataType %d is not supported.", dtype));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册