未验证 提交 63d2333e 编写于 作者: A Aganlengzi 提交者: GitHub

[PluggableDevice] custom kernel supports multi cpp_dtype registering (#39385)

上级 2a5d858c
...@@ -35,13 +35,12 @@ limitations under the License. */ ...@@ -35,13 +35,12 @@ limitations under the License. */
// user kernel function // user kernel function
namespace custom_kernel { namespace custom_kernel {
// Here we use dot <CPU, ANY, UINT8> for test // Here we use fake_dot for test
// This test will fail when these two kernels are aupported in framework
// input 3: two Tensors and one std::vector<Tensor> // input 3: two Tensors and one std::vector<Tensor>
// attribute 11: fake_attributes // attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*> // output 2: one Tensor* and one std::vector<Tensor*>
template <typename T> template <typename T, typename Context>
void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x, void FakeDot(const Context& dev_ctx, const paddle::Tensor& x,
const paddle::Tensor& y, const paddle::Tensor& y,
const std::vector<paddle::Tensor>& fake_input_vec, const std::vector<paddle::Tensor>& fake_input_vec,
bool fake_attr_bool, int fake_attr_int, float fake_attr_float, bool fake_attr_bool, int fake_attr_int, float fake_attr_float,
...@@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x, ...@@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
} }
} // namespace custom_kernel } // namespace custom_kernel
PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, UINT8, PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float,
custom_kernel::FakeDot<uint8_t>) { double, int, int64_t, int8_t, uint8_t) {}
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UINT8);
}
// Upper code will store dot kernels info into OpKernelInfoMap // Upper code will store dot kernels info into OpKernelInfoMap
TEST(CustomKernel, custom_kernel_dot) { TEST(CustomKernel, custom_kernel_dot) {
std::string op_name = "dot"; std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU; pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY; pten::DataLayout layout = pten::DataLayout::ALL_LAYOUT;
pten::DataType dtype = pten::DataType::UINT8;
// 1.custom kernel info parsed and store // 1.custom kernel info parsed and store
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find("dot") != EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) !=
paddle::OpKernelInfoMap::Instance().GetMap().end()); paddle::OpKernelInfoMap::Instance().GetMap().end());
// 2.info check // 2.info check
EXPECT_EQ( EXPECT_EQ(
1, static_cast<int>(paddle::OpKernelInfoMap::Instance()["dot"].size())); 6, static_cast<int>(paddle::OpKernelInfoMap::Instance()[op_name].size()));
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetBackend() == // index 0
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() ==
backend); backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataLayout() == EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() ==
layout); layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataType() == EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() ==
dtype); pten::DataType::FLOAT32);
// index 5
// 3.register EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() ==
EXPECT_TRUE(pten::KernelFactory::Instance().kernels().end() != backend);
pten::KernelFactory::Instance().kernels().find("dot")); EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() ==
layout);
pten::KernelKey kernel_key(backend, layout, dtype); EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() ==
EXPECT_TRUE( pten::DataType::UINT8);
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) ==
pten::KernelFactory::Instance().kernels()["dot"].end()); // 3.before register
auto& kernel_factory_instance = pten::KernelFactory::Instance();
auto& kernels = pten::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name));
// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while
// registering
auto& fake_dot_kernels = kernels[op_name];
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) ==
fake_dot_kernels.end());
// register
paddle::framework::RegisterKernelWithMetaInfoMap( paddle::framework::RegisterKernelWithMetaInfoMap(
paddle::OpKernelInfoMap::Instance()); paddle::OpKernelInfoMap::Instance());
EXPECT_TRUE( EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) != pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) !=
pten::KernelFactory::Instance().kernels()["dot"].end()); fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) !=
fake_dot_kernels.end());
// 4.kernel select // 4.kernel select
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel = kernel_factory_instance.SelectKernelOrThrowError(
op_name, kernel_key); op_name, pten::KernelKey(backend, layout, pten::DataType::UINT8));
// 5.prepare parameters for kernel // 5.prepare parameters for kernel
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>( const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
...@@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
// test OpKernelInfoHelper // test OpKernelInfoHelper
TEST(OpKernelInfoHelper, op_kernel_info_help_getters) { TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper; using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper;
std::string op_name = "dot"; std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU; pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY; pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8; pten::DataType dtype = pten::DataType::FLOAT32;
auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0]; auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0];
...@@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) { ...@@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
OpKernelInfoHelper::GetKernelKey(op_kernel_info)); OpKernelInfoHelper::GetKernelKey(op_kernel_info));
paddle::CustomKernelFunc kernel_fn = paddle::CustomKernelFunc kernel_fn =
PD_PT_KERNEL(custom_kernel::FakeDot<uint8_t>); PD_PT_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info)); EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info));
void* variadic_func = PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<uint8_t>); void* variadic_func =
PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(variadic_func, EXPECT_EQ(variadic_func,
OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info)); OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info));
......
...@@ -30,6 +30,8 @@ limitations under the License. */ ...@@ -30,6 +30,8 @@ limitations under the License. */
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
#include "paddle/pten/common/data_type.h"
/** /**
* Custom Kernel Info Define. * Custom Kernel Info Define.
* *
...@@ -635,29 +637,624 @@ void RegisterAllCustomKernel(); ...@@ -635,29 +637,624 @@ void RegisterAllCustomKernel();
// register custom kernels // register custom kernels
void LoadCustomKernelLib(const std::string& dso_name); void LoadCustomKernelLib(const std::string& dso_name);
//////////////// Custom kernel register macro ///////////////// //////////////// Custom kernel register macro /////////////////////
// Refer to paddle/pten/core/kernel_registry.h, we can not use
// PT_REGISTER_KERNEL directly, common macros and functions are
// not ready for custom kernel now.
// Difference: custom_kernel stores all kernels' info into global
// g_custom_kernel_info_map before loading and registering into
// pten kernel management. Only providing PD_REGISTER_KERNEL which
// supports 2 template arguments.
#define PD_BACKEND(arg__) pten::Backend::arg__ #define PD_BACKEND(arg__) pten::Backend::arg__
#define PD_DATALAYOUT(arg__) pten::DataLayout::arg__ #define PD_DATALAYOUT(arg__) pten::DataLayout::arg__
#define PD_DATATYPE(arg__) pten::DataType::arg__ #define PD_DATATYPE(arg__) pten::DataType::arg__
#define PD_REGISTER_KERNEL(name, backend, layout, dtype, func) \ #define PD_NARGS(...) _PD_NARGS((__VA_ARGS__, _PD_RESQ_N()))
STATIC_ASSERT_GLOBAL_NAMESPACE( \ #define _PD_NARGS(...) _PD_ARG_N(__VA_ARGS__)
__reg_kernel__##name##_##backend##_##layout##_##dtype, \ #define _PD_ARG_N_EXPAND( \
"PD_REGISTER_KERNEL must be called in global namespace."); \ _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
void __PD_USER_args_def_##name##_##backend##_##layout_##dtype( \ N
::paddle::OpKernelInfo* op_kernel_info); \ #define _PD_ARG_N(args) _PD_ARG_N_EXPAND args
static ::paddle::OpKernelInfoBuilder \ #define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
__op_kernel_info_##name##_##backend##_##layout##_##dtype = \
::paddle::OpKernelInfoBuilder(#name, \ #define PD_CONCATENATE(arg1, arg2) PD_CONCATENATE1(arg1, arg2)
PD_BACKEND(backend), \ #define PD_CONCATENATE1(arg1, arg2) PD_CONCATENATE2(arg1, arg2)
PD_DATALAYOUT(layout), \ #define PD_CONCATENATE2(arg1, arg2) arg1##arg2
PD_DATATYPE(dtype)) \
.SetKernelFn(PD_PT_KERNEL(func)) \ #define PD_EXPAND(x) x
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL(func)) \
.ArgsParse(PD_PT_ARGS_PARSE(func)) \ #ifdef __COUNTER__
.ArgsDef( \ #define PD_ID __COUNTER__
&__PD_USER_args_def_##name##_##backend##_##layout_##dtype); \ #else
void __PD_USER_args_def_##name##_##backend##_##layout_##dtype( \ #define PD_ID __LINE__
#endif
#define PD_REGISTER_KERNEL(kernel_name, backend, layout, func, cpp_dtype, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
_reg_custom_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PD_REGISTER_KERNEL must be called in global namespace."); \
_PD_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, func, cpp_dtype, ##__VA_ARGS__)
// WIN32 is not supported
#define _PD_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__); \
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::paddle::OpKernelInfo* kernel); \
PD_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
layout, \
&__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__); \
void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::paddle::OpKernelInfo* kernel) ::paddle::OpKernelInfo* kernel)
#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \
_PD_KERNEL_INSTANTIATION(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
##__VA_ARGS__)
#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__)
#define _PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>
#define _PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, ##__VA_ARGS__))
#define PD_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__)
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PD_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__)
// clang-format on
#define _PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn);
#define _PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_15(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
} // namespace paddle } // namespace paddle
...@@ -20,8 +20,8 @@ namespace custom_kernel { ...@@ -20,8 +20,8 @@ namespace custom_kernel {
// Here we use dot <CPU, ANY, INT8> for test // Here we use dot <CPU, ANY, INT8> for test
// This test will fail when this kernel is supported in framework // This test will fail when this kernel is supported in framework
template <typename T> template <typename T, typename Context>
void Dot(const paddle::CPUContext& dev_ctx, void Dot(const Context& dev_ctx,
const paddle::Tensor& x, const paddle::Tensor& x,
const paddle::Tensor& y, const paddle::Tensor& y,
paddle::Tensor* out) { paddle::Tensor* out) {
...@@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx, ...@@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx,
} // namespace custom_kernel } // namespace custom_kernel
} // namespace paddle } // namespace paddle
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, paddle::custom_kernel::Dot, int8_t) {
dot, CPU, ALL_LAYOUT, INT8, paddle::custom_kernel::Dot<int8_t>) {
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册