未验证 提交 b60f48ce 编写于 作者: R ronnywang 提交者: GitHub

[PHI CAPI] support complex dtype kernel (#52414)

* [PHI CAPI] support complex dtype kernel

* update
上级 5df1296d
......@@ -92,6 +92,11 @@ inline T GetValue(const phi::DenseTensor* x) {
if (!platform::is_cpu_place(x->place())) {
phi::DenseTensor cpu_x;
framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x);
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx = pool.Get(x->place());
dev_ctx->Wait();
#endif
value = cpu_x.data<T>()[0];
} else {
value = x->data<T>()[0];
......
......@@ -40,6 +40,8 @@ typedef enum {
FLOAT32,
FLOAT64,
BFLOAT16,
COMPLEX64,
COMPLEX128,
} C_DataType;
typedef enum {
......
......@@ -22,20 +22,22 @@ namespace phi {
namespace capi {
#define CPP_TYPE_TO_PD_DTYPE_REGISTER(_) \
_(bool, PD_DataType::BOOL) \
_(phi::dtype::bfloat16, PD_DataType::BFLOAT16) \
_(phi::dtype::float16, PD_DataType::FLOAT16) \
_(float, PD_DataType::FLOAT32) \
_(double, PD_DataType::FLOAT64) \
_(uint8_t, PD_DataType::UINT8) \
_(uint16_t, PD_DataType::UINT16) \
_(uint32_t, PD_DataType::UINT32) \
_(uint64_t, PD_DataType::UINT64) \
_(int8_t, PD_DataType::INT8) \
_(int16_t, PD_DataType::INT16) \
_(int32_t, PD_DataType::INT32) \
_(int64_t, PD_DataType::INT64)
#define CPP_TYPE_TO_PD_DTYPE_REGISTER(_) \
_(bool, PD_DataType::BOOL) \
_(phi::dtype::bfloat16, PD_DataType::BFLOAT16) \
_(phi::dtype::float16, PD_DataType::FLOAT16) \
_(float, PD_DataType::FLOAT32) \
_(double, PD_DataType::FLOAT64) \
_(uint8_t, PD_DataType::UINT8) \
_(uint16_t, PD_DataType::UINT16) \
_(uint32_t, PD_DataType::UINT32) \
_(uint64_t, PD_DataType::UINT64) \
_(int8_t, PD_DataType::INT8) \
_(int16_t, PD_DataType::INT16) \
_(int32_t, PD_DataType::INT32) \
_(int64_t, PD_DataType::INT64) \
_(phi::dtype::complex<float>, PD_DataType::COMPLEX64) \
_(phi::dtype::complex<double>, PD_DataType::COMPLEX128)
template <typename T>
struct CppTypeToPDType;
......
......@@ -42,6 +42,8 @@ inline PD_DataType ToPDDataType(::phi::DataType dtype) {
return_result(UINT16, UINT16);
return_result(UINT8, UINT8);
return_result(BOOL, BOOL);
return_result(COMPLEX64, COMPLEX64);
return_result(COMPLEX128, COMPLEX128);
default: {
PADDLE_THROW(
::phi::errors::Unavailable("DataType %d is not supported.", dtype));
......@@ -69,6 +71,8 @@ inline ::phi::DataType ToPhiDataType(PD_DataType dtype) {
return_result(UINT16, UINT16);
return_result(UINT8, UINT8);
return_result(BOOL, BOOL);
return_result(COMPLEX64, COMPLEX64);
return_result(COMPLEX128, COMPLEX128);
default: {
PADDLE_THROW(
::phi::errors::Unavailable("DataType %d is not supported.", dtype));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册