未验证 提交 63d2333e 编写于 作者: A Aganlengzi 提交者: GitHub

[PluggableDevice] custom kernel supports multi cpp_dtype registering (#39385)

上级 2a5d858c
......@@ -35,13 +35,12 @@ limitations under the License. */
// user kernel function
namespace custom_kernel {
// Here we use dot <CPU, ANY, UINT8> for test
// This test will fail when these two kernels are aupported in framework
// Here we use fake_dot for test
// input 3: two Tensors and one std::vector<Tensor>
// attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*>
template <typename T>
void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
template <typename T, typename Context>
void FakeDot(const Context& dev_ctx, const paddle::Tensor& x,
const paddle::Tensor& y,
const std::vector<paddle::Tensor>& fake_input_vec,
bool fake_attr_bool, int fake_attr_int, float fake_attr_float,
......@@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
}
} // namespace custom_kernel
PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, UINT8,
custom_kernel::FakeDot<uint8_t>) {
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UINT8);
}
PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float,
double, int, int64_t, int8_t, uint8_t) {}
// Upper code will store dot kernels info into OpKernelInfoMap
TEST(CustomKernel, custom_kernel_dot) {
std::string op_name = "dot";
std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8;
pten::DataLayout layout = pten::DataLayout::ALL_LAYOUT;
// 1.custom kernel info parsed and store
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find("dot") !=
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) !=
paddle::OpKernelInfoMap::Instance().GetMap().end());
// 2.info check
EXPECT_EQ(
1, static_cast<int>(paddle::OpKernelInfoMap::Instance()["dot"].size()));
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetBackend() ==
6, static_cast<int>(paddle::OpKernelInfoMap::Instance()[op_name].size()));
// index 0
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() ==
backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataLayout() ==
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() ==
layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataType() ==
dtype);
// 3.register
EXPECT_TRUE(pten::KernelFactory::Instance().kernels().end() !=
pten::KernelFactory::Instance().kernels().find("dot"));
pten::KernelKey kernel_key(backend, layout, dtype);
EXPECT_TRUE(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) ==
pten::KernelFactory::Instance().kernels()["dot"].end());
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() ==
pten::DataType::FLOAT32);
// index 5
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() ==
backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() ==
layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() ==
pten::DataType::UINT8);
// 3.before register
auto& kernel_factory_instance = pten::KernelFactory::Instance();
auto& kernels = pten::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name));
// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while
// registering
auto& fake_dot_kernels = kernels[op_name];
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) ==
fake_dot_kernels.end());
// register
paddle::framework::RegisterKernelWithMetaInfoMap(
paddle::OpKernelInfoMap::Instance());
EXPECT_TRUE(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) !=
pten::KernelFactory::Instance().kernels()["dot"].end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) !=
fake_dot_kernels.end());
// 4.kernel select
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
op_name, kernel_key);
auto kernel = kernel_factory_instance.SelectKernelOrThrowError(
op_name, pten::KernelKey(backend, layout, pten::DataType::UINT8));
// 5.prepare parameters for kernel
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
......@@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
// test OpKernelInfoHelper
TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper;
std::string op_name = "dot";
std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8;
pten::DataType dtype = pten::DataType::FLOAT32;
auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0];
......@@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
OpKernelInfoHelper::GetKernelKey(op_kernel_info));
paddle::CustomKernelFunc kernel_fn =
PD_PT_KERNEL(custom_kernel::FakeDot<uint8_t>);
PD_PT_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info));
void* variadic_func = PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<uint8_t>);
void* variadic_func =
PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(variadic_func,
OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info));
......
......@@ -30,6 +30,8 @@ limitations under the License. */
#include "paddle/utils/any.h"
#include "paddle/utils/small_vector.h"
#include "paddle/pten/common/data_type.h"
/**
* Custom Kernel Info Define.
*
......@@ -635,29 +637,624 @@ void RegisterAllCustomKernel();
// register custom kernels
void LoadCustomKernelLib(const std::string& dso_name);
//////////////// Custom kernel register macro /////////////////
//////////////// Custom kernel register macro /////////////////////
// Refer to paddle/pten/core/kernel_registry.h, we can not use
// PT_REGISTER_KERNEL directly, common macros and functions are
// not ready for custom kernel now.
// Difference: custom_kernel stores all kernels' info into global
// g_custom_kernel_info_map before loading and registering into
// pten kernel management. Only providing PD_REGISTER_KERNEL which
// supports 2 template arguments.
#define PD_BACKEND(arg__) pten::Backend::arg__
#define PD_DATALAYOUT(arg__) pten::DataLayout::arg__
#define PD_DATATYPE(arg__) pten::DataType::arg__
#define PD_REGISTER_KERNEL(name, backend, layout, dtype, func) \
#define PD_NARGS(...) _PD_NARGS((__VA_ARGS__, _PD_RESQ_N()))
#define _PD_NARGS(...) _PD_ARG_N(__VA_ARGS__)
#define _PD_ARG_N_EXPAND( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
N
#define _PD_ARG_N(args) _PD_ARG_N_EXPAND args
#define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
#define PD_CONCATENATE(arg1, arg2) PD_CONCATENATE1(arg1, arg2)
#define PD_CONCATENATE1(arg1, arg2) PD_CONCATENATE2(arg1, arg2)
#define PD_CONCATENATE2(arg1, arg2) arg1##arg2
#define PD_EXPAND(x) x
#ifdef __COUNTER__
#define PD_ID __COUNTER__
#else
#define PD_ID __LINE__
#endif
#define PD_REGISTER_KERNEL(kernel_name, backend, layout, func, cpp_dtype, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_kernel__##name##_##backend##_##layout##_##dtype, \
_reg_custom_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PD_REGISTER_KERNEL must be called in global namespace."); \
void __PD_USER_args_def_##name##_##backend##_##layout_##dtype( \
::paddle::OpKernelInfo* op_kernel_info); \
static ::paddle::OpKernelInfoBuilder \
__op_kernel_info_##name##_##backend##_##layout##_##dtype = \
::paddle::OpKernelInfoBuilder(#name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
PD_DATATYPE(dtype)) \
.SetKernelFn(PD_PT_KERNEL(func)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL(func)) \
.ArgsParse(PD_PT_ARGS_PARSE(func)) \
.ArgsDef( \
&__PD_USER_args_def_##name##_##backend##_##layout_##dtype); \
void __PD_USER_args_def_##name##_##backend##_##layout_##dtype( \
_PD_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, func, cpp_dtype, ##__VA_ARGS__)
// WIN32 is not supported
#define _PD_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__); \
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::paddle::OpKernelInfo* kernel); \
PD_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
layout, \
&__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__); \
void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::paddle::OpKernelInfo* kernel)
#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \
_PD_KERNEL_INSTANTIATION(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
##__VA_ARGS__)
#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__)
#define _PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>
#define _PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, ##__VA_ARGS__))
#define PD_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__)
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PD_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__)
// clang-format on
#define _PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn);
#define _PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_15(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
} // namespace paddle
......@@ -20,8 +20,8 @@ namespace custom_kernel {
// Here we use dot <CPU, ANY, INT8> for test
// This test will fail when this kernel is supported in framework
template <typename T>
void Dot(const paddle::CPUContext& dev_ctx,
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
const paddle::Tensor& x,
const paddle::Tensor& y,
paddle::Tensor* out) {
......@@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx,
} // namespace custom_kernel
} // namespace paddle
PD_REGISTER_KERNEL(
dot, CPU, ALL_LAYOUT, INT8, paddle::custom_kernel::Dot<int8_t>) {
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, paddle::custom_kernel::Dot, int8_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册