未验证 提交 b76ab792 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

[PHI] Replace paddle::experimental::DataType as phi::DataType (#51470)

* Replace paddle::experimental::DataType as phi::DataType

* restore custom_device.cc
上级 383a3f8c
...@@ -38,17 +38,17 @@ enum CCLDataType { ...@@ -38,17 +38,17 @@ enum CCLDataType {
}; };
inline CCLDataType ToCCLDataType(paddle::experimental::DataType type) { inline CCLDataType ToCCLDataType(paddle::experimental::DataType type) {
if (type == paddle::experimental::DataType::FLOAT64) { if (type == phi::DataType::FLOAT64) {
return CCL_DATA_TYPE_FP64; return CCL_DATA_TYPE_FP64;
} else if (type == paddle::experimental::DataType::FLOAT32) { } else if (type == phi::DataType::FLOAT32) {
return CCL_DATA_TYPE_FP32; return CCL_DATA_TYPE_FP32;
} else if (type == paddle::experimental::DataType::FLOAT16) { } else if (type == phi::DataType::FLOAT16) {
return CCL_DATA_TYPE_FP16; return CCL_DATA_TYPE_FP16;
} else if (type == paddle::experimental::DataType::INT64) { } else if (type == phi::DataType::INT64) {
return CCL_DATA_TYPE_INT64; return CCL_DATA_TYPE_INT64;
} else if (type == paddle::experimental::DataType::INT32) { } else if (type == phi::DataType::INT32) {
return CCL_DATA_TYPE_INT32; return CCL_DATA_TYPE_INT32;
} else if (type == paddle::experimental::DataType::INT8) { } else if (type == phi::DataType::INT8) {
return CCL_DATA_TYPE_INT8; return CCL_DATA_TYPE_INT8;
} else { } else {
PADDLE_THROW( PADDLE_THROW(
......
...@@ -24,7 +24,7 @@ namespace phi { ...@@ -24,7 +24,7 @@ namespace phi {
/* From phi::DenseTensor */ /* From phi::DenseTensor */
/* --------------------------- */ /* --------------------------- */
DenseTensor::DenseTensor() { DenseTensor::DenseTensor() {
meta_.dtype = paddle::experimental::DataType::FLOAT32; meta_.dtype = phi::DataType::FLOAT32;
meta_.offset = 0; meta_.offset = 0;
} }
......
...@@ -146,11 +146,11 @@ void KernelFactory::AddToLowPrecisionKernelList( ...@@ -146,11 +146,11 @@ void KernelFactory::AddToLowPrecisionKernelList(
auto count = OpCount(); auto count = OpCount();
low_precision_kernels_[op_name] = count; low_precision_kernels_[op_name] = count;
} }
if (kernel_key_type == paddle::experimental::DataType::FLOAT16) { if (kernel_key_type == phi::DataType::FLOAT16) {
low_precision_kernels_[op_name].fp16_called_ += 1; low_precision_kernels_[op_name].fp16_called_ += 1;
} else if (kernel_key_type == paddle::experimental::DataType::BFLOAT16) { } else if (kernel_key_type == phi::DataType::BFLOAT16) {
low_precision_kernels_[op_name].bf16_called_ += 1; low_precision_kernels_[op_name].bf16_called_ += 1;
} else if (kernel_key_type == paddle::experimental::DataType::FLOAT32) { } else if (kernel_key_type == phi::DataType::FLOAT32) {
low_precision_kernels_[op_name].fp32_called_ += 1; low_precision_kernels_[op_name].fp32_called_ += 1;
} else { } else {
low_precision_kernels_[op_name].other_called_ += 1; low_precision_kernels_[op_name].other_called_ += 1;
......
...@@ -112,7 +112,7 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, ...@@ -112,7 +112,7 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel, phi::CheckFiniteAndUnscaleKernel,
float, float,
double) { double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(1).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL(update_loss_scaling, PD_REGISTER_KERNEL(update_loss_scaling,
......
...@@ -24,5 +24,5 @@ using complex128 = ::phi::dtype::complex<double>; ...@@ -24,5 +24,5 @@ using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
as_real, CPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) { as_real, CPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -47,5 +47,5 @@ PD_REGISTER_KERNEL(cast_grad, ...@@ -47,5 +47,5 @@ PD_REGISTER_KERNEL(cast_grad,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -48,5 +48,5 @@ PD_REGISTER_KERNEL(cast, ...@@ -48,5 +48,5 @@ PD_REGISTER_KERNEL(cast,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -80,35 +80,35 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -80,35 +80,35 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, int64_t,
float, float,
double) { double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
#define PD_REGISTER_COMPARE_KERNEL(name, func) \ #define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \ PD_REGISTER_KERNEL(name, \
CPU, \ CPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::func##Kernel, \ phi::func##Kernel, \
bool, \ bool, \
int16_t, \ int16_t, \
int, \ int, \
int64_t, \ int64_t, \
float, \ float, \
double, \ double, \
phi::dtype::float16) { \ phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \ } \
PD_REGISTER_KERNEL(name##_raw, \ PD_REGISTER_KERNEL(name##_raw, \
CPU, \ CPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::func##RawKernel, \ phi::func##RawKernel, \
bool, \ bool, \
int16_t, \ int16_t, \
int, \ int, \
int64_t, \ int64_t, \
float, \ float, \
double, \ double, \
phi::dtype::float16) { \ phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} }
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
......
...@@ -210,10 +210,10 @@ PD_REGISTER_KERNEL(dropout, ...@@ -210,10 +210,10 @@ PD_REGISTER_KERNEL(dropout,
float, float,
double, double,
phi::dtype::bfloat16) { phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8); kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) { dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8); kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
} }
...@@ -258,5 +258,5 @@ PD_REGISTER_KERNEL(eigvals, ...@@ -258,5 +258,5 @@ PD_REGISTER_KERNEL(eigvals,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -95,5 +95,5 @@ void NMSKernel(const Context& dev_ctx, ...@@ -95,5 +95,5 @@ void NMSKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) { PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
...@@ -93,5 +93,5 @@ PD_REGISTER_KERNEL(nonzero, ...@@ -93,5 +93,5 @@ PD_REGISTER_KERNEL(nonzero,
bool, bool,
float, float,
double) { double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
...@@ -84,5 +84,5 @@ void OneHotRawKernel(const Context& dev_ctx, ...@@ -84,5 +84,5 @@ void OneHotRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -54,5 +54,5 @@ PD_REGISTER_KERNEL(sum_raw, ...@@ -54,5 +54,5 @@ PD_REGISTER_KERNEL(sum_raw,
int64_t, int64_t,
complex64, complex64,
complex128) { complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -242,5 +242,5 @@ void TopkKernel(const Context& dev_ctx, ...@@ -242,5 +242,5 @@ void TopkKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
topk, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) { topk, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
...@@ -319,5 +319,5 @@ void ViterbiDecodeKernel(const Context& dev_ctx, ...@@ -319,5 +319,5 @@ void ViterbiDecodeKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
viterbi_decode, CPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) { viterbi_decode, CPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
...@@ -357,7 +357,7 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, ...@@ -357,7 +357,7 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) { phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(1).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL(update_loss_scaling, PD_REGISTER_KERNEL(update_loss_scaling,
......
...@@ -24,5 +24,5 @@ using complex128 = ::phi::dtype::complex<double>; ...@@ -24,5 +24,5 @@ using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
as_real, GPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) { as_real, GPU, ALL_LAYOUT, phi::AsRealKernel, complex64, complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -32,24 +32,23 @@ void CastGradKernel(const Context& dev_ctx, ...@@ -32,24 +32,23 @@ void CastGradKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PD_REGISTER_KERNEL(cast_grad, \ PD_REGISTER_KERNEL(cast_grad, \
GPU, \ GPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::CastGradKernel, \ phi::CastGradKernel, \
float, \ float, \
double, \ double, \
int, \ int, \
int64_t, \ int64_t, \
int16_t, \ int16_t, \
bool, \ bool, \
uint8_t, \ uint8_t, \
phi::dtype::float16, \ phi::dtype::float16, \
phi::dtype::complex<float>, \ phi::dtype::complex<float>, \
phi::dtype::complex<double>, \ phi::dtype::complex<double>, \
##__VA_ARGS__) { \ ##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \ kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); \
paddle::experimental::DataType::UNDEFINED); \
} }
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast_grad, phi::dtype::bfloat16) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast_grad, phi::dtype::bfloat16)
...@@ -32,25 +32,24 @@ void CastKernel(const Context& dev_ctx, ...@@ -32,25 +32,24 @@ void CastKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PD_REGISTER_KERNEL(cast, \ PD_REGISTER_KERNEL(cast, \
GPU, \ GPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::CastKernel, \ phi::CastKernel, \
float, \ float, \
double, \ double, \
int, \ int, \
int64_t, \ int64_t, \
int16_t, \ int16_t, \
bool, \ bool, \
int8_t, \ int8_t, \
uint8_t, \ uint8_t, \
phi::dtype::float16, \ phi::dtype::float16, \
phi::dtype::complex<float>, \ phi::dtype::complex<float>, \
phi::dtype::complex<double>, \ phi::dtype::complex<double>, \
##__VA_ARGS__) { \ ##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \ kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); \
paddle::experimental::DataType::UNDEFINED); \
} }
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, phi::dtype::bfloat16) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, phi::dtype::bfloat16)
...@@ -90,7 +90,7 @@ PD_REGISTER_KERNEL(dropout, ...@@ -90,7 +90,7 @@ PD_REGISTER_KERNEL(dropout,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8); kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
} }
PD_REGISTER_KERNEL(dropout_nd, PD_REGISTER_KERNEL(dropout_nd,
...@@ -102,5 +102,5 @@ PD_REGISTER_KERNEL(dropout_nd, ...@@ -102,5 +102,5 @@ PD_REGISTER_KERNEL(dropout_nd,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8); kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
} }
...@@ -114,5 +114,5 @@ void NMSKernel(const Context& dev_ctx, ...@@ -114,5 +114,5 @@ void NMSKernel(const Context& dev_ctx,
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) { PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
...@@ -84,5 +84,5 @@ PD_REGISTER_KERNEL(nonzero, ...@@ -84,5 +84,5 @@ PD_REGISTER_KERNEL(nonzero,
bool, bool,
float, float,
double) { double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
...@@ -92,5 +92,5 @@ void OneHotRawKernel(const Context& dev_ctx, ...@@ -92,5 +92,5 @@ void OneHotRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -349,5 +349,5 @@ PD_REGISTER_KERNEL(topk, ...@@ -349,5 +349,5 @@ PD_REGISTER_KERNEL(topk,
int, int,
int64_t, int64_t,
phi::dtype::float16) { phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
...@@ -398,5 +398,5 @@ void ViterbiDecodeKernel(const Context& dev_ctx, ...@@ -398,5 +398,5 @@ void ViterbiDecodeKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
viterbi_decode, GPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) { viterbi_decode, GPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
...@@ -96,47 +96,47 @@ inline void CompareAllKernelImpl(const Context& ctx, ...@@ -96,47 +96,47 @@ inline void CompareAllKernelImpl(const Context& ctx,
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) { PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) { PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) { PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) { greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) { PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) { PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) { less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) { less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) { greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) { greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) { PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) { not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
#else #else
...@@ -150,37 +150,37 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -150,37 +150,37 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, int64_t,
float, float,
double) { double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
#define PD_REGISTER_COMPARE_KERNEL(name, func) \ #define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \ PD_REGISTER_KERNEL(name, \
KPS, \ KPS, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::func##Kernel, \ phi::func##Kernel, \
bool, \ bool, \
int16_t, \ int16_t, \
int, \ int, \
int64_t, \ int64_t, \
float, \ float, \
double, \ double, \
phi::dtype::float16, \ phi::dtype::float16, \
phi::dtype::bfloat16) { \ phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \ } \
PD_REGISTER_KERNEL(name##_raw, \ PD_REGISTER_KERNEL(name##_raw, \
KPS, \ KPS, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::func##RawKernel, \ phi::func##RawKernel, \
bool, \ bool, \
int16_t, \ int16_t, \
int, \ int, \
int64_t, \ int64_t, \
float, \ float, \
double, \ double, \
phi::dtype::float16, \ phi::dtype::float16, \
phi::dtype::bfloat16) { \ phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} }
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
......
...@@ -147,7 +147,7 @@ void SumRawKernel(const Context& dev_ctx, ...@@ -147,7 +147,7 @@ void SumRawKernel(const Context& dev_ctx,
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(sum_raw, KPS, ALL_LAYOUT, phi::SumRawKernel, float) { PD_REGISTER_KERNEL(sum_raw, KPS, ALL_LAYOUT, phi::SumRawKernel, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
#else #else
using float16 = phi::dtype::float16; using float16 = phi::dtype::float16;
...@@ -169,6 +169,6 @@ PD_REGISTER_KERNEL(sum_raw, ...@@ -169,6 +169,6 @@ PD_REGISTER_KERNEL(sum_raw,
int64_t, int64_t,
complex64, complex64,
complex128) { complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
#endif #endif
...@@ -31,17 +31,17 @@ void OneHotKernel(const Context& dev_ctx, ...@@ -31,17 +31,17 @@ void OneHotKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32); kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32); kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) { PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32); kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
} }
#endif #endif
...@@ -49,7 +49,7 @@ PD_REGISTER_KERNEL(sum, ...@@ -49,7 +49,7 @@ PD_REGISTER_KERNEL(sum,
int64_t, int64_t,
complex64, complex64,
complex128) { complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -67,13 +67,13 @@ PD_REGISTER_KERNEL(sum, ...@@ -67,13 +67,13 @@ PD_REGISTER_KERNEL(sum,
int64_t, int64_t,
complex64, complex64,
complex128) { complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
#endif #endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(sum, KPS, ALL_LAYOUT, phi::SumKernel, float) { PD_REGISTER_KERNEL(sum, KPS, ALL_LAYOUT, phi::SumKernel, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
#endif #endif
...@@ -92,6 +92,6 @@ PD_REGISTER_KERNEL(sum, ...@@ -92,6 +92,6 @@ PD_REGISTER_KERNEL(sum,
int8_t, int8_t,
int, int,
int64_t) { int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
#endif #endif
...@@ -48,7 +48,7 @@ PD_REGISTER_KERNEL(shape, ...@@ -48,7 +48,7 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU); kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32); kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -68,7 +68,7 @@ PD_REGISTER_KERNEL(shape, ...@@ -68,7 +68,7 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU); kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32); kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
} }
#endif #endif
...@@ -85,6 +85,6 @@ PD_REGISTER_KERNEL(shape, ...@@ -85,6 +85,6 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU); kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32); kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
} }
#endif #endif
...@@ -286,5 +286,5 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, ...@@ -286,5 +286,5 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel, phi::CheckFiniteAndUnscaleKernel,
float, float,
phi::dtype::float16) { phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(1).SetDataType(phi::DataType::BOOL);
} }
...@@ -106,5 +106,5 @@ PD_REGISTER_KERNEL(cast, ...@@ -106,5 +106,5 @@ PD_REGISTER_KERNEL(cast,
bool, bool,
uint8_t, uint8_t,
double) { double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -91,7 +91,7 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>) ...@@ -91,7 +91,7 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>)
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) { less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
PD_REGISTER_KERNEL(less_than_raw, PD_REGISTER_KERNEL(less_than_raw,
...@@ -101,22 +101,22 @@ PD_REGISTER_KERNEL(less_than_raw, ...@@ -101,22 +101,22 @@ PD_REGISTER_KERNEL(less_than_raw,
int, int,
int64_t, int64_t,
float) { float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
#define PD_REGISTER_COMPARE_KERNEL(name, func) \ #define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL( \ PD_REGISTER_KERNEL( \
name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) { \ name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \ } \
PD_REGISTER_KERNEL(name##_raw, \ PD_REGISTER_KERNEL(name##_raw, \
XPU, \ XPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::func##RawKernel, \ phi::func##RawKernel, \
int, \ int, \
int64_t, \ int64_t, \
float) { \ float) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} }
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
......
...@@ -70,5 +70,5 @@ void NonZeroKernel(const Context& dev_ctx, ...@@ -70,5 +70,5 @@ void NonZeroKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) { nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
...@@ -63,5 +63,5 @@ void OneHotRawKernel(const Context& dev_ctx, ...@@ -63,5 +63,5 @@ void OneHotRawKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) { one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
...@@ -188,5 +188,5 @@ void TopkKernel(const Context& dev_ctx, ...@@ -188,5 +188,5 @@ void TopkKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
topk, XPU, ALL_LAYOUT, phi::TopkKernel, float, phi::dtype::float16) { topk, XPU, ALL_LAYOUT, phi::TopkKernel, float, phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册