未验证 提交 50d92531 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] registering feed_dense_tensor, feed_sparse_coo_tensor,...

[CustomDevice] registering feed_dense_tensor, feed_sparse_coo_tensor, feed_strings kernels for custom device (#50042)

* [CustomDevice] registering feed_dense_tensor, feed_sparse_coo_tensor, feed_strings kernels for custom device

* update

* update

* update
上级 decbb588
...@@ -305,24 +305,15 @@ PD_REGISTER_GENERAL_KERNEL( ...@@ -305,24 +305,15 @@ PD_REGISTER_GENERAL_KERNEL(
ALL_LAYOUT, ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>, paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {} ALL_DTYPE) {}
#endif
#elif defined(PADDLE_WITH_CUSTOM_DEVICE) #ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_GENERAL_KERNEL( namespace paddle {
feed_dense_tensor, namespace operators {
custom_cpu, template void FeedDenseTensorKernel<phi::CustomContext>(
ALL_LAYOUT, const phi::CustomContext& dev_ctx,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>, const phi::ExtendedTensor& x,
ALL_DTYPE) {} int col,
PD_REGISTER_GENERAL_KERNEL( phi::DenseTensor* out);
feed_sparse_coo_tensor, } // namespace operators
custom_cpu, } // namespace paddle
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif #endif
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h" #include "paddle/fluid/operators/save_combine_op.h"
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/kernel_registry.h"
#define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \ #define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \
static paddle::framework::OpKernelRegistrar<phi::CustomPlace, __VA_ARGS__> \ static paddle::framework::OpKernelRegistrar<phi::CustomPlace, __VA_ARGS__> \
...@@ -26,10 +27,30 @@ limitations under the License. */ ...@@ -26,10 +27,30 @@ limitations under the License. */
paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \ paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \
__op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch(); __op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch();
#define REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL( \
kernel_name, dev_type, layout, kernel_fn) \
static phi::KernelRegistrar \
__reg_custom_device_phi_kernel_##kernel_name##_##backend##_##layout( \
phi::RegType::INNER, \
#kernel_name, \
dev_type, \
DATALAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
[](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \
PHI_KERNEL(kernel_fn), \
PHI_VARIADIC_KERNEL(kernel_fn))
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out);
void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_type = dev_type.c_str(); auto device_type = dev_type.c_str();
/* see [Why use single type kernel] */ /* see [Why use single type kernel] */
REGISTER_OP_CUSTOM_DEVICE_KERNEL( REGISTER_OP_CUSTOM_DEVICE_KERNEL(
...@@ -66,9 +87,16 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { ...@@ -66,9 +87,16 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>, LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>,
paddle::operators:: paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>); LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>);
REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL(
feed_dense_tensor,
device_type,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>);
#endif
} }
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#undef REGISTER_OP_CUSTOM_DEVICE_KERNEL #undef REGISTER_OP_CUSTOM_DEVICE_KERNEL
#undef REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册