From 50d92531ca43b6f95efc7de772187edab433e466 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Sun, 29 Jan 2023 10:38:33 +0800 Subject: [PATCH] [CustomDevice] registering feed_dense_tensor, feed_sparse_coo_tensor, feed_strings kernels for custom device (#50042) * [CustomDevice] registering feed_dense_tensor, feed_sparse_coo_tensor, feed_strings kernels for custom device * update * update * update --- paddle/fluid/operators/controlflow/feed_op.cc | 31 +++++++------------ .../custom_device_common_op_registry.cc | 28 +++++++++++++++++ 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/controlflow/feed_op.cc b/paddle/fluid/operators/controlflow/feed_op.cc index 194dccb0e6e..09684b8d737 100644 --- a/paddle/fluid/operators/controlflow/feed_op.cc +++ b/paddle/fluid/operators/controlflow/feed_op.cc @@ -305,24 +305,15 @@ PD_REGISTER_GENERAL_KERNEL( ALL_LAYOUT, paddle::operators::FeedStringsKernel, ALL_DTYPE) {} - -#elif defined(PADDLE_WITH_CUSTOM_DEVICE) -PD_REGISTER_GENERAL_KERNEL( - feed_dense_tensor, - custom_cpu, - ALL_LAYOUT, - paddle::operators::FeedDenseTensorKernel, - ALL_DTYPE) {} -PD_REGISTER_GENERAL_KERNEL( - feed_sparse_coo_tensor, - custom_cpu, - ALL_LAYOUT, - paddle::operators::FeedSparseCooTensorKernel, - ALL_DTYPE) {} -PD_REGISTER_GENERAL_KERNEL( - feed_strings, - custom_cpu, - ALL_LAYOUT, - paddle::operators::FeedStringsKernel, - ALL_DTYPE) {} +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +namespace paddle { +namespace operators { +template void FeedDenseTensorKernel( + const phi::CustomContext& dev_ctx, + const phi::ExtendedTensor& x, + int col, + phi::DenseTensor* out); +} // namespace operators +} // namespace paddle #endif diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index 69625c03dba..bbb75d41833 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/save_combine_op.h" #include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/core/kernel_registry.h" #define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \ static paddle::framework::OpKernelRegistrar \ @@ -26,10 +27,30 @@ limitations under the License. */ paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \ __op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch(); +#define REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL( \ + kernel_name, dev_type, layout, kernel_fn) \ + static phi::KernelRegistrar \ + __reg_custom_device_phi_kernel_##kernel_name##_##backend##_##layout( \ + phi::RegType::INNER, \ + #kernel_name, \ + dev_type, \ + DATALAYOUT(layout), \ + ::phi::KernelArgsParseFunctor::Parse, \ + [](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \ + PHI_KERNEL(kernel_fn), \ + PHI_VARIADIC_KERNEL(kernel_fn)) + namespace paddle { namespace operators { +template +void FeedDenseTensorKernel(const Context& dev_ctx, + const phi::ExtendedTensor& x, + int col, + phi::DenseTensor* out); + void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE auto device_type = dev_type.c_str(); /* see [Why use single type kernel] */ REGISTER_OP_CUSTOM_DEVICE_KERNEL( @@ -66,9 +87,16 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { LoadCombineOpKernel, paddle::operators:: LoadCombineOpKernel); + REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL( + feed_dense_tensor, + device_type, + ALL_LAYOUT, + paddle::operators::FeedDenseTensorKernel); +#endif } } // namespace operators } // namespace paddle #undef REGISTER_OP_CUSTOM_DEVICE_KERNEL +#undef REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL -- GitLab