未验证 提交 e8530a35 编写于 作者: Z zyfncg 提交者: GitHub

Register custom kernel for some all_bakcend kernel (#51639)

* register some custom kernel

* fix bug
上级 1d5cad23
...@@ -33,3 +33,21 @@ PD_REGISTER_KERNEL(numel, ...@@ -33,3 +33,21 @@ PD_REGISTER_KERNEL(numel,
bool) { bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(numel,
Custom,
ALL_LAYOUT,
phi::NumelKernel,
uint8_t,
int16_t,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif
...@@ -74,3 +74,18 @@ PD_REGISTER_KERNEL(flatten_grad, ...@@ -74,3 +74,18 @@ PD_REGISTER_KERNEL(flatten_grad,
int64_t) {} int64_t) {}
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(flatten_grad,
Custom,
ALL_LAYOUT,
phi::FlattenGradKernel,
float,
phi::dtype::float16,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
...@@ -128,3 +128,31 @@ PD_REGISTER_KERNEL(flatten, ...@@ -128,3 +128,31 @@ PD_REGISTER_KERNEL(flatten,
int, int,
int64_t) {} int64_t) {}
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(flatten_infer,
Custom,
ALL_LAYOUT,
phi::FlattenInferKernel,
float,
phi::dtype::float16,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(flatten,
Custom,
ALL_LAYOUT,
phi::FlattenKernel,
float,
phi::dtype::float16,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
...@@ -97,3 +97,16 @@ PD_REGISTER_GENERAL_KERNEL(reshape_double_grad, ...@@ -97,3 +97,16 @@ PD_REGISTER_GENERAL_KERNEL(reshape_double_grad,
phi::ReshapeDoubleGradKernel<phi::XPUContext>, phi::ReshapeDoubleGradKernel<phi::XPUContext>,
ALL_DTYPE) {} ALL_DTYPE) {}
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_GENERAL_KERNEL(reshape_grad,
Custom,
ALL_LAYOUT,
phi::ReshapeGradKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape_double_grad,
Custom,
ALL_LAYOUT,
phi::ReshapeDoubleGradKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif
...@@ -114,3 +114,16 @@ PD_REGISTER_GENERAL_KERNEL(reshape_infer, ...@@ -114,3 +114,16 @@ PD_REGISTER_GENERAL_KERNEL(reshape_infer,
PD_REGISTER_GENERAL_KERNEL( PD_REGISTER_GENERAL_KERNEL(
reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel<phi::XPUContext>, ALL_DTYPE) {} reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel<phi::XPUContext>, ALL_DTYPE) {}
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
Custom,
ALL_LAYOUT,
phi::ReshapeInferKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape,
Custom,
ALL_LAYOUT,
phi::ReshapeKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif
...@@ -45,7 +45,11 @@ PD_REGISTER_KERNEL(shape_sr, ...@@ -45,7 +45,11 @@ PD_REGISTER_KERNEL(shape_sr,
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(shape_sr, PD_REGISTER_KERNEL(shape_sr,
...@@ -60,5 +64,29 @@ PD_REGISTER_KERNEL(shape_sr, ...@@ -60,5 +64,29 @@ PD_REGISTER_KERNEL(shape_sr,
float, float,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(shape_sr,
Custom,
ALL_LAYOUT,
phi::sr::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif #endif
...@@ -89,3 +89,24 @@ PD_REGISTER_KERNEL(shape, ...@@ -89,3 +89,24 @@ PD_REGISTER_KERNEL(shape,
kernel->OutputAt(0).SetDataType(phi::DataType::INT32); kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(shape,
Custom,
ALL_LAYOUT,
phi::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册