Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5fb9cf60
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
5fb9cf60
编写于
2月 17, 2022
作者:
C
Chen Weihang
提交者:
GitHub
2月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support set fp32 input for fp16 kernel (#39625)
上级
d63ece1f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
49 addition
and
10 deletion
+49
-10
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+9
-9
paddle/pten/core/type_defs.h
paddle/pten/core/type_defs.h
+1
-1
paddle/pten/tests/core/test_kernel_factory.cc
paddle/pten/tests/core/test_kernel_factory.cc
+39
-0
未找到文件。
paddle/pten/core/kernel_registry.h
浏览文件 @
5fb9cf60
...
...
@@ -184,7 +184,7 @@ struct KernelRegistrar {
KernelKey
kernel_key
(
backend
,
layout
,
dtype
);
Kernel
kernel
(
kernel_fn
,
variadic_kernel_fn
);
args_parse_fn
(
kernel_key
,
kernel
.
mutable_args_def
());
args_def_fn
(
&
kernel
);
args_def_fn
(
kernel_key
,
&
kernel
);
KernelFactory
::
Instance
().
kernels
()[
kernel_name
][
kernel_key
]
=
kernel
;
}
};
...
...
@@ -231,7 +231,7 @@ struct KernelRegistrar {
kernel_name, backend, layout, meta_kernel_fn, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*);
\
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel);
\
PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
...
...
@@ -240,7 +240,7 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
const ::pten::KernelKey& kernel_key,
::pten::Kernel* kernel)
#else
/**
* `template decltype(fn) fn` can work on gcc and clang,
...
...
@@ -257,7 +257,7 @@ struct KernelRegistrar {
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*);
\
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel);
\
PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
...
...
@@ -266,7 +266,7 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__)); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
const ::pten::KernelKey& kernel_key,
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \
...
...
@@ -786,7 +786,7 @@ struct KernelRegistrar {
kernel_name, backend, layout, kernel_fn, dtype) \
template decltype(kernel_fn) kernel_fn; \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*);
\
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel);
\
static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
...
...
@@ -800,12 +800,12 @@ struct KernelRegistrar {
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
const ::pten::KernelKey& kernel_key,
::pten::Kernel* kernel)
#else
#define _PT_REGISTER_GENERAL_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*);
\
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel);
\
static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
...
...
@@ -819,7 +819,7 @@ struct KernelRegistrar {
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
const ::pten::KernelKey& kernel_key,
::pten::Kernel* kernel)
#endif
/** PT_DECLARE_KERNEL
...
...
paddle/pten/core/type_defs.h
浏览文件 @
5fb9cf60
...
...
@@ -27,7 +27,7 @@ class ArgumentMappingContext;
class
InferMetaContext
;
using
KernelFn
=
std
::
function
<
void
(
KernelContext
*
ctx
)
>
;
using
KernelArgsDefFn
=
void
(
*
)(
Kernel
*
kernel
);
using
KernelArgsDefFn
=
void
(
*
)(
const
KernelKey
&
kernel_key
,
Kernel
*
kernel
);
using
KernelArgsParseFn
=
void
(
*
)(
const
KernelKey
&
default_key
,
KernelArgsDef
*
args_def
);
...
...
paddle/pten/tests/core/test_kernel_factory.cc
浏览文件 @
5fb9cf60
...
...
@@ -15,6 +15,8 @@ limitations under the License. */
#include <iostream>
#include <sstream>
#include "paddle/pten/common/float16.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_registry.h"
...
...
@@ -47,5 +49,42 @@ TEST(KernelFactory, SelectedKernelMap) {
}
}
template
<
typename
T
,
typename
Context
>
void
TestKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
param
,
DenseTensor
*
out
)
{}
TEST
(
KernelRegistry
,
SetFP32Input
)
{
pten
::
KernelKey
kernel_key
(
pten
::
Backend
::
CPU
,
pten
::
DataLayout
::
ALL_LAYOUT
,
pten
::
DataType
::
FLOAT16
);
auto
test_kernel
=
pten
::
KernelFactory
::
Instance
().
SelectKernel
(
"test"
,
kernel_key
);
EXPECT_TRUE
(
test_kernel
.
IsValid
());
auto
&
arg_defs
=
test_kernel
.
args_def
();
auto
&
input_defs
=
arg_defs
.
input_defs
();
auto
&
attr_defs
=
arg_defs
.
attribute_defs
();
auto
&
output_defs
=
arg_defs
.
output_defs
();
EXPECT_EQ
(
input_defs
.
size
(),
2UL
);
EXPECT_EQ
(
attr_defs
.
size
(),
0UL
);
EXPECT_EQ
(
output_defs
.
size
(),
1UL
);
EXPECT_EQ
(
input_defs
.
at
(
0
).
dtype
,
pten
::
DataType
::
FLOAT16
);
EXPECT_EQ
(
input_defs
.
at
(
1
).
dtype
,
pten
::
DataType
::
FLOAT32
);
EXPECT_EQ
(
output_defs
.
at
(
0
).
dtype
,
pten
::
DataType
::
FLOAT16
);
}
}
// namespace tests
}
// namespace pten
PT_REGISTER_KERNEL
(
test
,
CPU
,
ALL_LAYOUT
,
pten
::
tests
::
TestKernel
,
float
,
double
,
pten
::
dtype
::
float16
)
{
if
(
kernel_key
.
dtype
()
==
pten
::
DataType
::
FLOAT16
)
{
kernel
->
InputAt
(
1
).
SetDataType
(
pten
::
DataType
::
FLOAT32
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录