未验证 提交 b60f48ce 编写于 作者: R ronnywang 提交者: GitHub

[PHI CAPI] support complex dtype kernel (#52414)

* [PHI CAPI] support complex dtype kernel

* update
上级 5df1296d
......@@ -92,6 +92,11 @@ inline T GetValue(const phi::DenseTensor* x) {
if (!platform::is_cpu_place(x->place())) {
phi::DenseTensor cpu_x;
framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x);
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx = pool.Get(x->place());
dev_ctx->Wait();
#endif
value = cpu_x.data<T>()[0];
} else {
value = x->data<T>()[0];
......
......@@ -40,6 +40,8 @@ typedef enum {
FLOAT32,
FLOAT64,
BFLOAT16,
COMPLEX64,
COMPLEX128,
} C_DataType;
typedef enum {
......
......@@ -35,7 +35,9 @@ namespace capi {
_(int8_t, PD_DataType::INT8) \
_(int16_t, PD_DataType::INT16) \
_(int32_t, PD_DataType::INT32) \
_(int64_t, PD_DataType::INT64)
_(int64_t, PD_DataType::INT64) \
_(phi::dtype::complex<float>, PD_DataType::COMPLEX64) \
_(phi::dtype::complex<double>, PD_DataType::COMPLEX128)
template <typename T>
struct CppTypeToPDType;
......
......@@ -42,6 +42,8 @@ inline PD_DataType ToPDDataType(::phi::DataType dtype) {
return_result(UINT16, UINT16);
return_result(UINT8, UINT8);
return_result(BOOL, BOOL);
return_result(COMPLEX64, COMPLEX64);
return_result(COMPLEX128, COMPLEX128);
default: {
PADDLE_THROW(
::phi::errors::Unavailable("DataType %d is not supported.", dtype));
......@@ -69,6 +71,8 @@ inline ::phi::DataType ToPhiDataType(PD_DataType dtype) {
return_result(UINT16, UINT16);
return_result(UINT8, UINT8);
return_result(BOOL, BOOL);
return_result(COMPLEX64, COMPLEX64);
return_result(COMPLEX128, COMPLEX128);
default: {
PADDLE_THROW(
::phi::errors::Unavailable("DataType %d is not supported.", dtype));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册