diff --git a/paddle/fluid/framework/custom_kernel_test.cc b/paddle/fluid/framework/custom_kernel_test.cc index ea84de700d5f3c28d20c2587ac8ea22f42d28e64..623ad6e5600347e18328d07c17ab19ef36ed6149 100644 --- a/paddle/fluid/framework/custom_kernel_test.cc +++ b/paddle/fluid/framework/custom_kernel_test.cc @@ -35,13 +35,12 @@ limitations under the License. */ // user kernel function namespace custom_kernel { -// Here we use dot for test -// This test will fail when these two kernels are aupported in framework +// Here we use fake_dot for test // input 3: two Tensors and one std::vector // attribute 11: fake_attributes // output 2: one Tensor* and one std::vector -template -void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x, +template +void FakeDot(const Context& dev_ctx, const paddle::Tensor& x, const paddle::Tensor& y, const std::vector& fake_input_vec, bool fake_attr_bool, int fake_attr_int, float fake_attr_float, @@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x, } } // namespace custom_kernel -PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, UINT8, - custom_kernel::FakeDot) { - /* do some args define here - * the only param can be used is OpKernelInfo* kernel */ - kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UINT8); -} +PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float, + double, int, int64_t, int8_t, uint8_t) {} // Upper code will store dot kernels info into OpKernelInfoMap TEST(CustomKernel, custom_kernel_dot) { - std::string op_name = "dot"; + std::string op_name = "fake_dot"; pten::Backend backend = pten::Backend::CPU; - pten::DataLayout layout = pten::DataLayout::ANY; - pten::DataType dtype = pten::DataType::UINT8; + pten::DataLayout layout = pten::DataLayout::ALL_LAYOUT; // 1.custom kernel info parsed and store - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find("dot") != + EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) != paddle::OpKernelInfoMap::Instance().GetMap().end()); // 2.info check EXPECT_EQ( - 1, static_cast(paddle::OpKernelInfoMap::Instance()["dot"].size())); - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetBackend() == + 6, static_cast(paddle::OpKernelInfoMap::Instance()[op_name].size())); + // index 0 + EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() == backend); - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataLayout() == + EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() == layout); - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataType() == - dtype); - - // 3.register - EXPECT_TRUE(pten::KernelFactory::Instance().kernels().end() != - pten::KernelFactory::Instance().kernels().find("dot")); - - pten::KernelKey kernel_key(backend, layout, dtype); - EXPECT_TRUE( - pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) == - pten::KernelFactory::Instance().kernels()["dot"].end()); - + EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() == + pten::DataType::FLOAT32); + // index 5 + EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() == + backend); + EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() == + layout); + EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() == + pten::DataType::UINT8); + + // 3.before register + auto& kernel_factory_instance = pten::KernelFactory::Instance(); + auto& kernels = pten::KernelFactory::Instance().kernels(); + EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name)); + + // mock fake_dot is supported by pten for HasCompatiblePtenKernel check while + // registering + auto& fake_dot_kernels = kernels[op_name]; + + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) == + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) == + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::INT32)) == + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::INT64)) == + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::INT8)) == + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::UINT8)) == + fake_dot_kernels.end()); + + // register paddle::framework::RegisterKernelWithMetaInfoMap( paddle::OpKernelInfoMap::Instance()); - EXPECT_TRUE( - pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) != - pten::KernelFactory::Instance().kernels()["dot"].end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) != + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) != + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::INT32)) != + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::INT64)) != + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::INT8)) != + fake_dot_kernels.end()); + EXPECT_TRUE(fake_dot_kernels.find( + pten::KernelKey(backend, layout, pten::DataType::UINT8)) != + fake_dot_kernels.end()); // 4.kernel select - auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( - op_name, kernel_key); + auto kernel = kernel_factory_instance.SelectKernelOrThrowError( + op_name, pten::KernelKey(backend, layout, pten::DataType::UINT8)); // 5.prepare parameters for kernel const auto alloc = std::make_unique( @@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) { // test OpKernelInfoHelper TEST(OpKernelInfoHelper, op_kernel_info_help_getters) { using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper; - std::string op_name = "dot"; + std::string op_name = "fake_dot"; pten::Backend backend = pten::Backend::CPU; pten::DataLayout layout = pten::DataLayout::ANY; - pten::DataType dtype = pten::DataType::UINT8; + pten::DataType dtype = pten::DataType::FLOAT32; auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0]; @@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) { OpKernelInfoHelper::GetKernelKey(op_kernel_info)); paddle::CustomKernelFunc kernel_fn = - PD_PT_KERNEL(custom_kernel::FakeDot); + PD_PT_KERNEL(custom_kernel::FakeDot); EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info)); - void* variadic_func = PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot); + void* variadic_func = + PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot); EXPECT_EQ(variadic_func, OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info)); diff --git a/paddle/pten/api/ext/op_kernel_info.h b/paddle/pten/api/ext/op_kernel_info.h index bcfff61bc6fd26dfc752a59d71f9de7d9c01780d..ebecfaf924a07e78b5fd3726b0f728d4c2604fa9 100644 --- a/paddle/pten/api/ext/op_kernel_info.h +++ b/paddle/pten/api/ext/op_kernel_info.h @@ -30,6 +30,8 @@ limitations under the License. */ #include "paddle/utils/any.h" #include "paddle/utils/small_vector.h" +#include "paddle/pten/common/data_type.h" + /** * Custom Kernel Info Define. * @@ -635,29 +637,624 @@ void RegisterAllCustomKernel(); // register custom kernels void LoadCustomKernelLib(const std::string& dso_name); -//////////////// Custom kernel register macro ///////////////// +//////////////// Custom kernel register macro ///////////////////// +// Refer to paddle/pten/core/kernel_registry.h, we can not use +// PT_REGISTER_KERNEL directly, common macros and functions are +// not ready for custom kernel now. +// Difference: custom_kernel stores all kernels' info into global +// g_custom_kernel_info_map before loading and registering into +// pten kernel management. Only providing PD_REGISTER_KERNEL which +// supports 2 template arguments. + #define PD_BACKEND(arg__) pten::Backend::arg__ #define PD_DATALAYOUT(arg__) pten::DataLayout::arg__ #define PD_DATATYPE(arg__) pten::DataType::arg__ -#define PD_REGISTER_KERNEL(name, backend, layout, dtype, func) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_kernel__##name##_##backend##_##layout##_##dtype, \ - "PD_REGISTER_KERNEL must be called in global namespace."); \ - void __PD_USER_args_def_##name##_##backend##_##layout_##dtype( \ - ::paddle::OpKernelInfo* op_kernel_info); \ - static ::paddle::OpKernelInfoBuilder \ - __op_kernel_info_##name##_##backend##_##layout##_##dtype = \ - ::paddle::OpKernelInfoBuilder(#name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - PD_DATATYPE(dtype)) \ - .SetKernelFn(PD_PT_KERNEL(func)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL(func)) \ - .ArgsParse(PD_PT_ARGS_PARSE(func)) \ - .ArgsDef( \ - &__PD_USER_args_def_##name##_##backend##_##layout_##dtype); \ - void __PD_USER_args_def_##name##_##backend##_##layout_##dtype( \ +#define PD_NARGS(...) _PD_NARGS((__VA_ARGS__, _PD_RESQ_N())) +#define _PD_NARGS(...) _PD_ARG_N(__VA_ARGS__) +#define _PD_ARG_N_EXPAND( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \ + N +#define _PD_ARG_N(args) _PD_ARG_N_EXPAND args +#define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + +#define PD_CONCATENATE(arg1, arg2) PD_CONCATENATE1(arg1, arg2) +#define PD_CONCATENATE1(arg1, arg2) PD_CONCATENATE2(arg1, arg2) +#define PD_CONCATENATE2(arg1, arg2) arg1##arg2 + +#define PD_EXPAND(x) x + +#ifdef __COUNTER__ +#define PD_ID __COUNTER__ +#else +#define PD_ID __LINE__ +#endif + +#define PD_REGISTER_KERNEL(kernel_name, backend, layout, func, cpp_dtype, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + _reg_custom_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PD_REGISTER_KERNEL must be called in global namespace."); \ + _PD_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, func, cpp_dtype, ##__VA_ARGS__) + +// WIN32 is not supported +#define _PD_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__); \ + static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::paddle::OpKernelInfo* kernel); \ + PD_KERNEL_REGISTRAR_INIT( \ + kernel_name, \ + backend, \ + layout, \ + &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + ##__VA_ARGS__); \ + void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::paddle::OpKernelInfo* kernel) +#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \ + _PD_KERNEL_INSTANTIATION(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \ + meta_kernel_fn, \ + backend, \ + cpp_dtype, \ + ##__VA_ARGS__) + +#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \ + PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \ + (meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__) + +#define _PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn +#define _PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, ##__VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, ##__VA_ARGS__)) + +#define PD_KERNEL_REGISTRAR_INIT( \ + kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ + _PD_KERNEL_REGISTRAR_INIT(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ##__VA_ARGS__) + +// clang-format off + +/* The =pre-commit always treats this macro into the wrong format, + and multi-line macros cannot be skipped with NOLINT.*/ +#define _PD_KERNEL_REGISTRAR_INIT(N, \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \ + kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ##__VA_ARGS__) + +// clang-format on + +#define _PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); + +#define _PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) + +#define _PD_KERNEL_REGISTRAR_INIT_15(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ + custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ + registrar_id) = \ + ::paddle::OpKernelInfoBuilder( \ + #kernel_name, \ + PD_BACKEND(backend), \ + PD_DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type()) \ + .SetKernelFn(PD_PT_KERNEL( \ + meta_kernel_fn)) \ + .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ + meta_kernel_fn)) \ + .ArgsParse(PD_PT_ARGS_PARSE( \ + meta_kernel_fn)) \ + .ArgsDef(args_def_fn); \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \ + backend, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + ##__VA_ARGS__)) } // namespace paddle diff --git a/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc b/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc index e61b7314ef61cfd3b8587f4fb53796815da22180..3ae30c2f30577a683544014e1fbd5c93039351ef 100644 --- a/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc +++ b/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc @@ -20,8 +20,8 @@ namespace custom_kernel { // Here we use dot for test // This test will fail when this kernel is supported in framework -template -void Dot(const paddle::CPUContext& dev_ctx, +template +void Dot(const Context& dev_ctx, const paddle::Tensor& x, const paddle::Tensor& y, paddle::Tensor* out) { @@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx, } // namespace custom_kernel } // namespace paddle -PD_REGISTER_KERNEL( - dot, CPU, ALL_LAYOUT, INT8, paddle::custom_kernel::Dot) { - /* do some args define here - * the only param can be used is OpKernelInfo* kernel */ +PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, paddle::custom_kernel::Dot, int8_t) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8); }