Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
63d2333e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
63d2333e
编写于
2月 10, 2022
作者:
A
Aganlengzi
提交者:
GitHub
2月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PluggableDevice] custom kernel supports multi cpp_dtype registering (#39385)
上级
2a5d858c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
695 addition
and
63 deletion
+695
-63
paddle/fluid/framework/custom_kernel_test.cc
paddle/fluid/framework/custom_kernel_test.cc
+76
-38
paddle/pten/api/ext/op_kernel_info.h
paddle/pten/api/ext/op_kernel_info.h
+616
-19
python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc
python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc
+3
-6
未找到文件。
paddle/fluid/framework/custom_kernel_test.cc
浏览文件 @
63d2333e
...
@@ -35,13 +35,12 @@ limitations under the License. */
...
@@ -35,13 +35,12 @@ limitations under the License. */
// user kernel function
// user kernel function
namespace
custom_kernel
{
namespace
custom_kernel
{
// Here we use dot <CPU, ANY, UINT8> for test
// Here we use fake_dot for test
// This test will fail when these two kernels are aupported in framework
// input 3: two Tensors and one std::vector<Tensor>
// input 3: two Tensors and one std::vector<Tensor>
// attribute 11: fake_attributes
// attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*>
// output 2: one Tensor* and one std::vector<Tensor*>
template
<
typename
T
>
template
<
typename
T
,
typename
Context
>
void
FakeDot
(
const
paddle
::
CPU
Context
&
dev_ctx
,
const
paddle
::
Tensor
&
x
,
void
FakeDot
(
const
Context
&
dev_ctx
,
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
y
,
const
paddle
::
Tensor
&
y
,
const
std
::
vector
<
paddle
::
Tensor
>&
fake_input_vec
,
const
std
::
vector
<
paddle
::
Tensor
>&
fake_input_vec
,
bool
fake_attr_bool
,
int
fake_attr_int
,
float
fake_attr_float
,
bool
fake_attr_bool
,
int
fake_attr_int
,
float
fake_attr_float
,
...
@@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
...
@@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
}
}
}
// namespace custom_kernel
}
// namespace custom_kernel
PD_REGISTER_KERNEL
(
dot
,
CPU
,
ALL_LAYOUT
,
UINT8
,
PD_REGISTER_KERNEL
(
fake_dot
,
CPU
,
ALL_LAYOUT
,
custom_kernel
::
FakeDot
,
float
,
custom_kernel
::
FakeDot
<
uint8_t
>
)
{
double
,
int
,
int64_t
,
int8_t
,
uint8_t
)
{}
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UINT8
);
}
// Upper code will store dot kernels info into OpKernelInfoMap
// Upper code will store dot kernels info into OpKernelInfoMap
TEST
(
CustomKernel
,
custom_kernel_dot
)
{
TEST
(
CustomKernel
,
custom_kernel_dot
)
{
std
::
string
op_name
=
"dot"
;
std
::
string
op_name
=
"
fake_
dot"
;
pten
::
Backend
backend
=
pten
::
Backend
::
CPU
;
pten
::
Backend
backend
=
pten
::
Backend
::
CPU
;
pten
::
DataLayout
layout
=
pten
::
DataLayout
::
ANY
;
pten
::
DataLayout
layout
=
pten
::
DataLayout
::
ALL_LAYOUT
;
pten
::
DataType
dtype
=
pten
::
DataType
::
UINT8
;
// 1.custom kernel info parsed and store
// 1.custom kernel info parsed and store
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
().
GetMap
().
find
(
"dot"
)
!=
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
().
GetMap
().
find
(
op_name
)
!=
paddle
::
OpKernelInfoMap
::
Instance
().
GetMap
().
end
());
paddle
::
OpKernelInfoMap
::
Instance
().
GetMap
().
end
());
// 2.info check
// 2.info check
EXPECT_EQ
(
EXPECT_EQ
(
1
,
static_cast
<
int
>
(
paddle
::
OpKernelInfoMap
::
Instance
()[
"dot"
].
size
()));
6
,
static_cast
<
int
>
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
].
size
()));
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
"dot"
][
0
].
GetBackend
()
==
// index 0
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
].
GetBackend
()
==
backend
);
backend
);
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
"dot"
][
0
].
GetDataLayout
()
==
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
].
GetDataLayout
()
==
layout
);
layout
);
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
"dot"
][
0
].
GetDataType
()
==
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
].
GetDataType
()
==
dtype
);
pten
::
DataType
::
FLOAT32
);
// index 5
// 3.register
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
5
].
GetBackend
()
==
EXPECT_TRUE
(
pten
::
KernelFactory
::
Instance
().
kernels
().
end
()
!=
backend
);
pten
::
KernelFactory
::
Instance
().
kernels
().
find
(
"dot"
));
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
5
].
GetDataLayout
()
==
layout
);
pten
::
KernelKey
kernel_key
(
backend
,
layout
,
dtype
);
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
5
].
GetDataType
()
==
EXPECT_TRUE
(
pten
::
DataType
::
UINT8
);
pten
::
KernelFactory
::
Instance
().
kernels
()[
"dot"
].
find
(
kernel_key
)
==
pten
::
KernelFactory
::
Instance
().
kernels
()[
"dot"
].
end
());
// 3.before register
auto
&
kernel_factory_instance
=
pten
::
KernelFactory
::
Instance
();
auto
&
kernels
=
pten
::
KernelFactory
::
Instance
().
kernels
();
EXPECT_TRUE
(
!
kernel_factory_instance
.
HasCompatiblePtenKernel
(
op_name
));
// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while
// registering
auto
&
fake_dot_kernels
=
kernels
[
op_name
];
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
FLOAT32
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
FLOAT64
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT32
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT64
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT8
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
UINT8
))
==
fake_dot_kernels
.
end
());
// register
paddle
::
framework
::
RegisterKernelWithMetaInfoMap
(
paddle
::
framework
::
RegisterKernelWithMetaInfoMap
(
paddle
::
OpKernelInfoMap
::
Instance
());
paddle
::
OpKernelInfoMap
::
Instance
());
EXPECT_TRUE
(
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelFactory
::
Instance
().
kernels
()[
"dot"
].
find
(
kernel_key
)
!=
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
FLOAT32
))
!=
pten
::
KernelFactory
::
Instance
().
kernels
()[
"dot"
].
end
());
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
FLOAT64
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT32
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT64
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT8
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
UINT8
))
!=
fake_dot_kernels
.
end
());
// 4.kernel select
// 4.kernel select
auto
kernel
=
pten
::
KernelFactory
::
Instance
()
.
SelectKernelOrThrowError
(
auto
kernel
=
kernel_factory_instance
.
SelectKernelOrThrowError
(
op_name
,
kernel_key
);
op_name
,
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
UINT8
)
);
// 5.prepare parameters for kernel
// 5.prepare parameters for kernel
const
auto
alloc
=
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
const
auto
alloc
=
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
...
@@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
...
@@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
// test OpKernelInfoHelper
// test OpKernelInfoHelper
TEST
(
OpKernelInfoHelper
,
op_kernel_info_help_getters
)
{
TEST
(
OpKernelInfoHelper
,
op_kernel_info_help_getters
)
{
using
OpKernelInfoHelper
=
paddle
::
framework
::
OpKernelInfoHelper
;
using
OpKernelInfoHelper
=
paddle
::
framework
::
OpKernelInfoHelper
;
std
::
string
op_name
=
"dot"
;
std
::
string
op_name
=
"
fake_
dot"
;
pten
::
Backend
backend
=
pten
::
Backend
::
CPU
;
pten
::
Backend
backend
=
pten
::
Backend
::
CPU
;
pten
::
DataLayout
layout
=
pten
::
DataLayout
::
ANY
;
pten
::
DataLayout
layout
=
pten
::
DataLayout
::
ANY
;
pten
::
DataType
dtype
=
pten
::
DataType
::
UINT8
;
pten
::
DataType
dtype
=
pten
::
DataType
::
FLOAT32
;
auto
op_kernel_info
=
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
];
auto
op_kernel_info
=
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
];
...
@@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
...
@@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
OpKernelInfoHelper
::
GetKernelKey
(
op_kernel_info
));
OpKernelInfoHelper
::
GetKernelKey
(
op_kernel_info
));
paddle
::
CustomKernelFunc
kernel_fn
=
paddle
::
CustomKernelFunc
kernel_fn
=
PD_PT_KERNEL
(
custom_kernel
::
FakeDot
<
uint8_
t
>
);
PD_PT_KERNEL
(
custom_kernel
::
FakeDot
<
float
,
paddle
::
CPUContex
t
>
);
EXPECT_EQ
(
kernel_fn
,
OpKernelInfoHelper
::
GetKernelFn
(
op_kernel_info
));
EXPECT_EQ
(
kernel_fn
,
OpKernelInfoHelper
::
GetKernelFn
(
op_kernel_info
));
void
*
variadic_func
=
PD_PT_VARIADIC_KERNEL
(
custom_kernel
::
FakeDot
<
uint8_t
>
);
void
*
variadic_func
=
PD_PT_VARIADIC_KERNEL
(
custom_kernel
::
FakeDot
<
float
,
paddle
::
CPUContext
>
);
EXPECT_EQ
(
variadic_func
,
EXPECT_EQ
(
variadic_func
,
OpKernelInfoHelper
::
GetVariadicKernelFn
(
op_kernel_info
));
OpKernelInfoHelper
::
GetVariadicKernelFn
(
op_kernel_info
));
...
...
paddle/pten/api/ext/op_kernel_info.h
浏览文件 @
63d2333e
...
@@ -30,6 +30,8 @@ limitations under the License. */
...
@@ -30,6 +30,8 @@ limitations under the License. */
#include "paddle/utils/any.h"
#include "paddle/utils/any.h"
#include "paddle/utils/small_vector.h"
#include "paddle/utils/small_vector.h"
#include "paddle/pten/common/data_type.h"
/**
/**
* Custom Kernel Info Define.
* Custom Kernel Info Define.
*
*
...
@@ -635,29 +637,624 @@ void RegisterAllCustomKernel();
...
@@ -635,29 +637,624 @@ void RegisterAllCustomKernel();
// register custom kernels
// register custom kernels
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_name
);
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_name
);
//////////////// Custom kernel register macro /////////////////
//////////////// Custom kernel register macro /////////////////////
// Refer to paddle/pten/core/kernel_registry.h, we can not use
// PT_REGISTER_KERNEL directly, common macros and functions are
// not ready for custom kernel now.
// Difference: custom_kernel stores all kernels' info into global
// g_custom_kernel_info_map before loading and registering into
// pten kernel management. Only providing PD_REGISTER_KERNEL which
// supports 2 template arguments.
#define PD_BACKEND(arg__) pten::Backend::arg__
#define PD_BACKEND(arg__) pten::Backend::arg__
#define PD_DATALAYOUT(arg__) pten::DataLayout::arg__
#define PD_DATALAYOUT(arg__) pten::DataLayout::arg__
#define PD_DATATYPE(arg__) pten::DataType::arg__
#define PD_DATATYPE(arg__) pten::DataType::arg__
#define PD_REGISTER_KERNEL(name, backend, layout, dtype, func) \
#define PD_NARGS(...) _PD_NARGS((__VA_ARGS__, _PD_RESQ_N()))
STATIC_ASSERT_GLOBAL_NAMESPACE( \
#define _PD_NARGS(...) _PD_ARG_N(__VA_ARGS__)
__reg_kernel__##name##_##backend##_##layout##_##dtype, \
#define _PD_ARG_N_EXPAND( \
"PD_REGISTER_KERNEL must be called in global namespace."); \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
void __PD_USER_args_def_##name##_##backend##_##layout_##dtype( \
N
::paddle::OpKernelInfo* op_kernel_info); \
#define _PD_ARG_N(args) _PD_ARG_N_EXPAND args
static ::paddle::OpKernelInfoBuilder \
#define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
__op_kernel_info_##name##_##backend##_##layout##_##dtype = \
::paddle::OpKernelInfoBuilder(#name, \
#define PD_CONCATENATE(arg1, arg2) PD_CONCATENATE1(arg1, arg2)
PD_BACKEND(backend), \
#define PD_CONCATENATE1(arg1, arg2) PD_CONCATENATE2(arg1, arg2)
PD_DATALAYOUT(layout), \
#define PD_CONCATENATE2(arg1, arg2) arg1##arg2
PD_DATATYPE(dtype)) \
.SetKernelFn(PD_PT_KERNEL(func)) \
#define PD_EXPAND(x) x
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL(func)) \
.ArgsParse(PD_PT_ARGS_PARSE(func)) \
#ifdef __COUNTER__
.ArgsDef( \
#define PD_ID __COUNTER__
&__PD_USER_args_def_##name##_##backend##_##layout_##dtype); \
#else
void __PD_USER_args_def_##name##_##backend##_##layout_##dtype( \
#define PD_ID __LINE__
#endif
#define PD_REGISTER_KERNEL(kernel_name, backend, layout, func, cpp_dtype, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
_reg_custom_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PD_REGISTER_KERNEL must be called in global namespace."); \
_PD_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, func, cpp_dtype, ##__VA_ARGS__)
// WIN32 is not supported
#define _PD_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__); \
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::paddle::OpKernelInfo* kernel); \
PD_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
layout, \
&__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__); \
void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::paddle::OpKernelInfo* kernel)
::paddle::OpKernelInfo* kernel)
#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \
_PD_KERNEL_INSTANTIATION(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
##__VA_ARGS__)
#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__)
#define _PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>
#define _PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, ##__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>; \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, ##__VA_ARGS__))
#define PD_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__)
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PD_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
##__VA_ARGS__)
// clang-format on
#define _PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn);
#define _PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_15(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \
custom_kernel_info_##kernel_name##_##backend##_##layout##_, \
registrar_id) = \
::paddle::OpKernelInfoBuilder( \
#kernel_name, \
PD_BACKEND(backend), \
PD_DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type()) \
.SetKernelFn(PD_PT_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsParse(PD_PT_ARGS_PARSE( \
meta_kernel_fn<cpp_dtype, ::paddle::backend##Context>)) \
.ArgsDef(args_def_fn); \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
##__VA_ARGS__))
}
// namespace paddle
}
// namespace paddle
python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc
浏览文件 @
63d2333e
...
@@ -20,8 +20,8 @@ namespace custom_kernel {
...
@@ -20,8 +20,8 @@ namespace custom_kernel {
// Here we use dot <CPU, ANY, INT8> for test
// Here we use dot <CPU, ANY, INT8> for test
// This test will fail when this kernel is supported in framework
// This test will fail when this kernel is supported in framework
template
<
typename
T
>
template
<
typename
T
,
typename
Context
>
void
Dot
(
const
paddle
::
CPU
Context
&
dev_ctx
,
void
Dot
(
const
Context
&
dev_ctx
,
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
y
,
const
paddle
::
Tensor
&
y
,
paddle
::
Tensor
*
out
)
{
paddle
::
Tensor
*
out
)
{
...
@@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx,
...
@@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx,
}
// namespace custom_kernel
}
// namespace custom_kernel
}
// namespace paddle
}
// namespace paddle
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
dot
,
CPU
,
ALL_LAYOUT
,
paddle
::
custom_kernel
::
Dot
,
int8_t
)
{
dot
,
CPU
,
ALL_LAYOUT
,
INT8
,
paddle
::
custom_kernel
::
Dot
<
int8_t
>
)
{
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
INT8
);
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
INT8
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录