未验证 提交 b76ab792 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

[PHI] Replace paddle::experimental::DataType as phi::DataType (#51470)

* Replace paddle::experimental::DataType as phi::DataType

* restore custom_device.cc
上级 383a3f8c
......@@ -38,17 +38,17 @@ enum CCLDataType {
};
inline CCLDataType ToCCLDataType(paddle::experimental::DataType type) {
if (type == paddle::experimental::DataType::FLOAT64) {
if (type == phi::DataType::FLOAT64) {
return CCL_DATA_TYPE_FP64;
} else if (type == paddle::experimental::DataType::FLOAT32) {
} else if (type == phi::DataType::FLOAT32) {
return CCL_DATA_TYPE_FP32;
} else if (type == paddle::experimental::DataType::FLOAT16) {
} else if (type == phi::DataType::FLOAT16) {
return CCL_DATA_TYPE_FP16;
} else if (type == paddle::experimental::DataType::INT64) {
} else if (type == phi::DataType::INT64) {
return CCL_DATA_TYPE_INT64;
} else if (type == paddle::experimental::DataType::INT32) {
} else if (type == phi::DataType::INT32) {
return CCL_DATA_TYPE_INT32;
} else if (type == paddle::experimental::DataType::INT8) {
} else if (type == phi::DataType::INT8) {
return CCL_DATA_TYPE_INT8;
} else {
PADDLE_THROW(
......
......@@ -24,7 +24,7 @@ namespace phi {
/* From phi::DenseTensor */
/* --------------------------- */
DenseTensor::DenseTensor() {
meta_.dtype = paddle::experimental::DataType::FLOAT32;
meta_.dtype = phi::DataType::FLOAT32;
meta_.offset = 0;
}
......
......@@ -146,11 +146,11 @@ void KernelFactory::AddToLowPrecisionKernelList(
auto count = OpCount();
low_precision_kernels_[op_name] = count;
}
if (kernel_key_type == paddle::experimental::DataType::FLOAT16) {
if (kernel_key_type == phi::DataType::FLOAT16) {
low_precision_kernels_[op_name].fp16_called_ += 1;
} else if (kernel_key_type == paddle::experimental::DataType::BFLOAT16) {
} else if (kernel_key_type == phi::DataType::BFLOAT16) {
low_precision_kernels_[op_name].bf16_called_ += 1;
} else if (kernel_key_type == paddle::experimental::DataType::FLOAT32) {
} else if (kernel_key_type == phi::DataType::FLOAT32) {
low_precision_kernels_[op_name].fp32_called_ += 1;
} else {
low_precision_kernels_[op_name].other_called_ += 1;
......
......@@ -112,7 +112,7 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel,
float,
double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(1).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(update_loss_scaling,
......
......@@ -24,5 +24,5 @@ using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
as_real, CPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -47,5 +47,5 @@ PD_REGISTER_KERNEL(cast_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -48,5 +48,5 @@ PD_REGISTER_KERNEL(cast,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -80,7 +80,7 @@ PD_REGISTER_KERNEL(equal_all,
int64_t,
float,
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
......@@ -95,7 +95,7 @@ PD_REGISTER_KERNEL(equal_all,
float, \
double, \
phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
CPU, \
......@@ -108,7 +108,7 @@ PD_REGISTER_KERNEL(equal_all,
float, \
double, \
phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
......
......@@ -210,10 +210,10 @@ PD_REGISTER_KERNEL(dropout,
float,
double,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
PD_REGISTER_KERNEL(
dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
......@@ -258,5 +258,5 @@ PD_REGISTER_KERNEL(eigvals,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -95,5 +95,5 @@ void NMSKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -93,5 +93,5 @@ PD_REGISTER_KERNEL(nonzero,
bool,
float,
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -84,5 +84,5 @@ void OneHotRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -54,5 +54,5 @@ PD_REGISTER_KERNEL(sum_raw,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -242,5 +242,5 @@ void TopkKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
topk, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -319,5 +319,5 @@ void ViterbiDecodeKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
viterbi_decode, CPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -357,7 +357,7 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(1).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(update_loss_scaling,
......
......@@ -24,5 +24,5 @@ using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
as_real, GPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -48,8 +48,7 @@ void CastGradKernel(const Context& dev_ctx,
phi::dtype::complex<float>, \
phi::dtype::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); \
}
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast_grad, phi::dtype::bfloat16)
......@@ -49,8 +49,7 @@ void CastKernel(const Context& dev_ctx,
phi::dtype::complex<float>, \
phi::dtype::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); \
}
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, phi::dtype::bfloat16)
......@@ -90,7 +90,7 @@ PD_REGISTER_KERNEL(dropout,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
PD_REGISTER_KERNEL(dropout_nd,
......@@ -102,5 +102,5 @@ PD_REGISTER_KERNEL(dropout_nd,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
......@@ -114,5 +114,5 @@ void NMSKernel(const Context& dev_ctx,
}
} // namespace phi
PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -84,5 +84,5 @@ PD_REGISTER_KERNEL(nonzero,
bool,
float,
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -92,5 +92,5 @@ void OneHotRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -349,5 +349,5 @@ PD_REGISTER_KERNEL(topk,
int,
int64_t,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -398,5 +398,5 @@ void ViterbiDecodeKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
viterbi_decode, GPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -96,47 +96,47 @@ inline void CompareAllKernelImpl(const Context& ctx,
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(
greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(
less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(
less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(
greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(
greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(
not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#else
......@@ -150,7 +150,7 @@ PD_REGISTER_KERNEL(equal_all,
int64_t,
float,
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
......@@ -166,7 +166,7 @@ PD_REGISTER_KERNEL(equal_all,
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
KPS, \
......@@ -180,7 +180,7 @@ PD_REGISTER_KERNEL(equal_all,
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
......
......@@ -147,7 +147,7 @@ void SumRawKernel(const Context& dev_ctx,
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(sum_raw, KPS, ALL_LAYOUT, phi::SumRawKernel, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#else
using float16 = phi::dtype::float16;
......@@ -169,6 +169,6 @@ PD_REGISTER_KERNEL(sum_raw,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#endif
......@@ -31,17 +31,17 @@ void OneHotKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
}
#endif
......@@ -49,7 +49,7 @@ PD_REGISTER_KERNEL(sum,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -67,13 +67,13 @@ PD_REGISTER_KERNEL(sum,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(sum, KPS, ALL_LAYOUT, phi::SumKernel, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#endif
......@@ -92,6 +92,6 @@ PD_REGISTER_KERNEL(sum,
int8_t,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#endif
......@@ -48,7 +48,7 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -68,7 +68,7 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif
......@@ -85,6 +85,6 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif
......@@ -286,5 +286,5 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel,
float,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(1).SetDataType(phi::DataType::BOOL);
}
......@@ -106,5 +106,5 @@ PD_REGISTER_KERNEL(cast,
bool,
uint8_t,
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -91,7 +91,7 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>)
PD_REGISTER_KERNEL(
less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(less_than_raw,
......@@ -101,13 +101,13 @@ PD_REGISTER_KERNEL(less_than_raw,
int,
int64_t,
float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
XPU, \
......@@ -116,7 +116,7 @@ PD_REGISTER_KERNEL(less_than_raw,
int, \
int64_t, \
float) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
......
......@@ -70,5 +70,5 @@ void NonZeroKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -63,5 +63,5 @@ void OneHotRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -188,5 +188,5 @@ void TopkKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
topk, XPU, ALL_LAYOUT, phi::TopkKernel, float, phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册