Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
63d2333e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
63d2333e
编写于
2月 10, 2022
作者:
A
Aganlengzi
提交者:
GitHub
2月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PluggableDevice] custom kernel supports multi cpp_dtype registering (#39385)
上级
2a5d858c
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
695 addition
and
63 deletion
+695
-63
paddle/fluid/framework/custom_kernel_test.cc
paddle/fluid/framework/custom_kernel_test.cc
+76
-38
paddle/pten/api/ext/op_kernel_info.h
paddle/pten/api/ext/op_kernel_info.h
+616
-19
python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc
python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc
+3
-6
未找到文件。
paddle/fluid/framework/custom_kernel_test.cc
浏览文件 @
63d2333e
...
...
@@ -35,13 +35,12 @@ limitations under the License. */
// user kernel function
namespace
custom_kernel
{
// Here we use dot <CPU, ANY, UINT8> for test
// This test will fail when these two kernels are aupported in framework
// Here we use fake_dot for test
// input 3: two Tensors and one std::vector<Tensor>
// attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*>
template
<
typename
T
>
void
FakeDot
(
const
paddle
::
CPU
Context
&
dev_ctx
,
const
paddle
::
Tensor
&
x
,
template
<
typename
T
,
typename
Context
>
void
FakeDot
(
const
Context
&
dev_ctx
,
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
y
,
const
std
::
vector
<
paddle
::
Tensor
>&
fake_input_vec
,
bool
fake_attr_bool
,
int
fake_attr_int
,
float
fake_attr_float
,
...
...
@@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
}
}
// namespace custom_kernel
PD_REGISTER_KERNEL
(
dot
,
CPU
,
ALL_LAYOUT
,
UINT8
,
custom_kernel
::
FakeDot
<
uint8_t
>
)
{
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UINT8
);
}
PD_REGISTER_KERNEL
(
fake_dot
,
CPU
,
ALL_LAYOUT
,
custom_kernel
::
FakeDot
,
float
,
double
,
int
,
int64_t
,
int8_t
,
uint8_t
)
{}
// Upper code will store dot kernels info into OpKernelInfoMap
TEST
(
CustomKernel
,
custom_kernel_dot
)
{
std
::
string
op_name
=
"dot"
;
std
::
string
op_name
=
"
fake_
dot"
;
pten
::
Backend
backend
=
pten
::
Backend
::
CPU
;
pten
::
DataLayout
layout
=
pten
::
DataLayout
::
ANY
;
pten
::
DataType
dtype
=
pten
::
DataType
::
UINT8
;
pten
::
DataLayout
layout
=
pten
::
DataLayout
::
ALL_LAYOUT
;
// 1.custom kernel info parsed and store
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
().
GetMap
().
find
(
"dot"
)
!=
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
().
GetMap
().
find
(
op_name
)
!=
paddle
::
OpKernelInfoMap
::
Instance
().
GetMap
().
end
());
// 2.info check
EXPECT_EQ
(
1
,
static_cast
<
int
>
(
paddle
::
OpKernelInfoMap
::
Instance
()[
"dot"
].
size
()));
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
"dot"
][
0
].
GetBackend
()
==
6
,
static_cast
<
int
>
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
].
size
()));
// index 0
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
].
GetBackend
()
==
backend
);
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
"dot"
][
0
].
GetDataLayout
()
==
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
].
GetDataLayout
()
==
layout
);
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
"dot"
][
0
].
GetDataType
()
==
dtype
);
// 3.register
EXPECT_TRUE
(
pten
::
KernelFactory
::
Instance
().
kernels
().
end
()
!=
pten
::
KernelFactory
::
Instance
().
kernels
().
find
(
"dot"
));
pten
::
KernelKey
kernel_key
(
backend
,
layout
,
dtype
);
EXPECT_TRUE
(
pten
::
KernelFactory
::
Instance
().
kernels
()[
"dot"
].
find
(
kernel_key
)
==
pten
::
KernelFactory
::
Instance
().
kernels
()[
"dot"
].
end
());
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
].
GetDataType
()
==
pten
::
DataType
::
FLOAT32
);
// index 5
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
5
].
GetBackend
()
==
backend
);
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
5
].
GetDataLayout
()
==
layout
);
EXPECT_TRUE
(
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
5
].
GetDataType
()
==
pten
::
DataType
::
UINT8
);
// 3.before register
auto
&
kernel_factory_instance
=
pten
::
KernelFactory
::
Instance
();
auto
&
kernels
=
pten
::
KernelFactory
::
Instance
().
kernels
();
EXPECT_TRUE
(
!
kernel_factory_instance
.
HasCompatiblePtenKernel
(
op_name
));
// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while
// registering
auto
&
fake_dot_kernels
=
kernels
[
op_name
];
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
FLOAT32
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
FLOAT64
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT32
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT64
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT8
))
==
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
UINT8
))
==
fake_dot_kernels
.
end
());
// register
paddle
::
framework
::
RegisterKernelWithMetaInfoMap
(
paddle
::
OpKernelInfoMap
::
Instance
());
EXPECT_TRUE
(
pten
::
KernelFactory
::
Instance
().
kernels
()[
"dot"
].
find
(
kernel_key
)
!=
pten
::
KernelFactory
::
Instance
().
kernels
()[
"dot"
].
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
FLOAT32
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
FLOAT64
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT32
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT64
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
INT8
))
!=
fake_dot_kernels
.
end
());
EXPECT_TRUE
(
fake_dot_kernels
.
find
(
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
UINT8
))
!=
fake_dot_kernels
.
end
());
// 4.kernel select
auto
kernel
=
pten
::
KernelFactory
::
Instance
()
.
SelectKernelOrThrowError
(
op_name
,
kernel_key
);
auto
kernel
=
kernel_factory_instance
.
SelectKernelOrThrowError
(
op_name
,
pten
::
KernelKey
(
backend
,
layout
,
pten
::
DataType
::
UINT8
)
);
// 5.prepare parameters for kernel
const
auto
alloc
=
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
...
...
@@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
// test OpKernelInfoHelper
TEST
(
OpKernelInfoHelper
,
op_kernel_info_help_getters
)
{
using
OpKernelInfoHelper
=
paddle
::
framework
::
OpKernelInfoHelper
;
std
::
string
op_name
=
"dot"
;
std
::
string
op_name
=
"
fake_
dot"
;
pten
::
Backend
backend
=
pten
::
Backend
::
CPU
;
pten
::
DataLayout
layout
=
pten
::
DataLayout
::
ANY
;
pten
::
DataType
dtype
=
pten
::
DataType
::
UINT8
;
pten
::
DataType
dtype
=
pten
::
DataType
::
FLOAT32
;
auto
op_kernel_info
=
paddle
::
OpKernelInfoMap
::
Instance
()[
op_name
][
0
];
...
...
@@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
OpKernelInfoHelper
::
GetKernelKey
(
op_kernel_info
));
paddle
::
CustomKernelFunc
kernel_fn
=
PD_PT_KERNEL
(
custom_kernel
::
FakeDot
<
uint8_
t
>
);
PD_PT_KERNEL
(
custom_kernel
::
FakeDot
<
float
,
paddle
::
CPUContex
t
>
);
EXPECT_EQ
(
kernel_fn
,
OpKernelInfoHelper
::
GetKernelFn
(
op_kernel_info
));
void
*
variadic_func
=
PD_PT_VARIADIC_KERNEL
(
custom_kernel
::
FakeDot
<
uint8_t
>
);
void
*
variadic_func
=
PD_PT_VARIADIC_KERNEL
(
custom_kernel
::
FakeDot
<
float
,
paddle
::
CPUContext
>
);
EXPECT_EQ
(
variadic_func
,
OpKernelInfoHelper
::
GetVariadicKernelFn
(
op_kernel_info
));
...
...
paddle/pten/api/ext/op_kernel_info.h
浏览文件 @
63d2333e
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc
浏览文件 @
63d2333e
...
...
@@ -20,8 +20,8 @@ namespace custom_kernel {
// Here we use dot <CPU, ANY, INT8> for test
// This test will fail when this kernel is supported in framework
template
<
typename
T
>
void
Dot
(
const
paddle
::
CPU
Context
&
dev_ctx
,
template
<
typename
T
,
typename
Context
>
void
Dot
(
const
Context
&
dev_ctx
,
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
y
,
paddle
::
Tensor
*
out
)
{
...
...
@@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx,
}
// namespace custom_kernel
}
// namespace paddle
PD_REGISTER_KERNEL
(
dot
,
CPU
,
ALL_LAYOUT
,
INT8
,
paddle
::
custom_kernel
::
Dot
<
int8_t
>
)
{
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
PD_REGISTER_KERNEL
(
dot
,
CPU
,
ALL_LAYOUT
,
paddle
::
custom_kernel
::
Dot
,
int8_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
INT8
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录