Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
942ff89f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
942ff89f
编写于
8月 02, 2022
作者:
W
Weilong Wu
提交者:
GitHub
8月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] polish and rename, pt* -> phi* (#44697)
* polish and rename, pt* -> phi* * fix code format
上级
d985b4b1
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
844 addition
and
840 deletion
+844
-840
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+3
-3
paddle/fluid/framework/new_executor/interpretercore_util.cc
paddle/fluid/framework/new_executor/interpretercore_util.cc
+12
-12
paddle/fluid/framework/new_executor/new_executor_defs.cc
paddle/fluid/framework/new_executor/new_executor_defs.cc
+3
-1
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+1
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+101
-100
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+4
-4
paddle/fluid/imperative/op_base.h
paddle/fluid/imperative/op_base.h
+1
-1
paddle/fluid/imperative/prepared_operator.cc
paddle/fluid/imperative/prepared_operator.cc
+27
-26
paddle/phi/capi/include/kernel_utils.h
paddle/phi/capi/include/kernel_utils.h
+249
-249
paddle/phi/core/kernel_registry.h
paddle/phi/core/kernel_registry.h
+443
-443
未找到文件。
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
942ff89f
...
@@ -621,13 +621,13 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
...
@@ -621,13 +621,13 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG
(
4
)
<<
"Run phi kernel: "
<<
op
->
Type
();
VLOG
(
4
)
<<
"Run phi kernel: "
<<
op
->
Type
();
VLOG
(
4
)
<<
instr_node
.
InnerRuntimeContext
().
get
()
<<
" "
VLOG
(
4
)
<<
instr_node
.
InnerRuntimeContext
().
get
()
<<
" "
<<
&
instr_node
.
DeviceContext
();
<<
&
instr_node
.
DeviceContext
();
phi
::
KernelContext
p
t
_kernel_context
;
phi
::
KernelContext
p
hi
_kernel_context
;
op_with_kernel
->
BuildPhiKernelContext
(
op_with_kernel
->
BuildPhiKernelContext
(
*
instr_node
.
InnerRuntimeContext
().
get
(),
*
instr_node
.
InnerRuntimeContext
().
get
(),
const_cast
<
platform
::
DeviceContext
*>
(
&
instr_node
.
DeviceContext
()),
const_cast
<
platform
::
DeviceContext
*>
(
&
instr_node
.
DeviceContext
()),
&
p
t
_kernel_context
);
&
p
hi
_kernel_context
);
(
*
instr_node
.
PhiKernel
())(
&
p
t
_kernel_context
);
(
*
instr_node
.
PhiKernel
())(
&
p
hi
_kernel_context
);
}
else
{
}
else
{
instr_node
.
KernelFunc
()(
*
instr_node
.
InnerExecutionContext
().
get
());
instr_node
.
KernelFunc
()(
*
instr_node
.
InnerExecutionContext
().
get
());
...
...
paddle/fluid/framework/new_executor/interpretercore_util.cc
浏览文件 @
942ff89f
...
@@ -513,25 +513,25 @@ void build_op_func_list(const platform::Place& place,
...
@@ -513,25 +513,25 @@ void build_op_func_list(const platform::Place& place,
auto
run_phi_kernel
=
false
;
auto
run_phi_kernel
=
false
;
if
(
phi
::
KernelFactory
::
Instance
().
HasCompatiblePhiKernel
(
if
(
phi
::
KernelFactory
::
Instance
().
HasCompatiblePhiKernel
(
op_with_kernel
->
Type
()))
{
op_with_kernel
->
Type
()))
{
auto
p
t
_kernel_key
=
op_with_kernel
->
ChoosePhiKernel
(
exec_ctx
);
auto
p
hi
_kernel_key
=
op_with_kernel
->
ChoosePhiKernel
(
exec_ctx
);
auto
p
t
_kernel_name
=
op_with_kernel
->
PhiKernelSignature
()
->
name
;
auto
p
hi
_kernel_name
=
op_with_kernel
->
PhiKernelSignature
()
->
name
;
if
(
op_with_kernel
->
PhiKernel
()
->
IsValid
())
{
if
(
op_with_kernel
->
PhiKernel
()
->
IsValid
())
{
run_phi_kernel
=
true
;
run_phi_kernel
=
true
;
}
else
{
}
else
{
if
(
!
op_with_kernel
->
SupportsKernelType
(
expected_kernel_key
))
{
if
(
!
op_with_kernel
->
SupportsKernelType
(
expected_kernel_key
))
{
auto
p
t
_cpu_kernel_key
=
FallBackToCpu
(
auto
p
hi
_cpu_kernel_key
=
FallBackToCpu
(
expected_kernel_key
,
p
t
_kernel_key
,
*
op_with_kernel
);
expected_kernel_key
,
p
hi
_kernel_key
,
*
op_with_kernel
);
op_with_kernel
->
ResetPhiKernel
(
op_with_kernel
->
ResetPhiKernel
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
p
t_kernel_name
,
pt
_cpu_kernel_key
)));
p
hi_kernel_name
,
phi
_cpu_kernel_key
)));
if
(
op_with_kernel
->
PhiKernel
()
->
IsValid
())
{
if
(
op_with_kernel
->
PhiKernel
()
->
IsValid
())
{
VLOG
(
6
)
<<
"Static mode PrepareImpl - kernel name: "
VLOG
(
6
)
<<
"Static mode PrepareImpl - kernel name: "
<<
p
t
_kernel_name
<<
p
hi
_kernel_name
<<
" | kernel key: "
<<
p
t
_cpu_kernel_key
<<
" | kernel key: "
<<
p
hi
_cpu_kernel_key
<<
" | kernel: "
<<
*
(
op_with_kernel
->
PhiKernel
());
<<
" | kernel: "
<<
*
(
op_with_kernel
->
PhiKernel
());
op_with_kernel
->
ResetKernelType
(
new
OpKernelType
(
op_with_kernel
->
ResetKernelType
(
new
OpKernelType
(
TransPhiKernelKeyToOpKernelType
(
p
t
_cpu_kernel_key
)));
TransPhiKernelKeyToOpKernelType
(
p
hi
_cpu_kernel_key
)));
run_phi_kernel
=
true
;
run_phi_kernel
=
true
;
}
}
}
}
...
@@ -541,7 +541,7 @@ void build_op_func_list(const platform::Place& place,
...
@@ -541,7 +541,7 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel
->
ChooseKernel
(
exec_ctx
);
op_with_kernel
->
ChooseKernel
(
exec_ctx
);
op_func_node
.
kernel_func_
=
*
op_with_kernel
->
kernel_func
();
op_func_node
.
kernel_func_
=
*
op_with_kernel
->
kernel_func
();
}
else
{
}
else
{
op_func_node
.
p
t
_kernel_
=
op_with_kernel
->
PhiKernel
();
op_func_node
.
p
hi
_kernel_
=
op_with_kernel
->
PhiKernel
();
}
}
auto
kernel_type
=
*
(
op_with_kernel
->
kernel_type
());
auto
kernel_type
=
*
(
op_with_kernel
->
kernel_type
());
if
(
kernel_type
.
place_
!=
dev_ctx
->
GetPlace
())
{
if
(
kernel_type
.
place_
!=
dev_ctx
->
GetPlace
())
{
...
@@ -583,10 +583,10 @@ void build_op_func_list(const platform::Place& place,
...
@@ -583,10 +583,10 @@ void build_op_func_list(const platform::Place& place,
// step 5. run kernel
// step 5. run kernel
if
(
run_phi_kernel
)
{
if
(
run_phi_kernel
)
{
phi
::
KernelContext
p
t
_kernel_context
;
phi
::
KernelContext
p
hi
_kernel_context
;
op_with_kernel
->
BuildPhiKernelContext
(
op_with_kernel
->
BuildPhiKernelContext
(
runtime_context
,
dev_ctx
,
&
p
t
_kernel_context
);
runtime_context
,
dev_ctx
,
&
p
hi
_kernel_context
);
(
*
op_func_node
.
p
t_kernel_
)(
&
pt
_kernel_context
);
(
*
op_func_node
.
p
hi_kernel_
)(
&
phi
_kernel_context
);
}
else
{
}
else
{
// the place of exec_ctx maybe has changed.
// the place of exec_ctx maybe has changed.
op_func_node
.
kernel_func_
(
ExecutionContext
(
op_func_node
.
kernel_func_
(
ExecutionContext
(
...
...
paddle/fluid/framework/new_executor/new_executor_defs.cc
浏览文件 @
942ff89f
...
@@ -705,7 +705,9 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
...
@@ -705,7 +705,9 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
return
op_func_node_
.
kernel_func_
;
return
op_func_node_
.
kernel_func_
;
}
}
phi
::
Kernel
*
Instruction
::
PhiKernel
()
const
{
return
op_func_node_
.
pt_kernel_
;
}
phi
::
Kernel
*
Instruction
::
PhiKernel
()
const
{
return
op_func_node_
.
phi_kernel_
;
}
OpFuncType
Instruction
::
KernelType
()
const
{
return
op_func_node_
.
type_
;
}
OpFuncType
Instruction
::
KernelType
()
const
{
return
op_func_node_
.
type_
;
}
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
942ff89f
...
@@ -283,7 +283,7 @@ struct OpFuncNode {
...
@@ -283,7 +283,7 @@ struct OpFuncNode {
platform
::
DeviceContext
*
dev_ctx_
;
// not owned
platform
::
DeviceContext
*
dev_ctx_
;
// not owned
// fit for phi kernel
// fit for phi kernel
phi
::
Kernel
*
p
t
_kernel_
{
nullptr
};
// not owned
phi
::
Kernel
*
p
hi
_kernel_
{
nullptr
};
// not owned
OpFuncType
type_
;
OpFuncType
type_
;
};
};
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
942ff89f
...
@@ -1421,7 +1421,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -1421,7 +1421,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
!
need_prepare_phi_data_
)
{
!
need_prepare_phi_data_
)
{
if
(
!
all_kernels_must_compute_runtime_shape_
)
if
(
!
all_kernels_must_compute_runtime_shape_
)
this
->
Info
().
infer_shape_
(
impl_
->
getRuntimeInferShapeContext
());
this
->
Info
().
infer_shape_
(
impl_
->
getRuntimeInferShapeContext
());
(
*
p
t
_kernel_
)(
impl_
->
getKernelContext
());
(
*
p
hi
_kernel_
)(
impl_
->
getKernelContext
());
}
else
{
}
else
{
if
(
runtime_ctx_
.
get
()
==
nullptr
||
pre_scope_
!=
cur_scope
)
{
if
(
runtime_ctx_
.
get
()
==
nullptr
||
pre_scope_
!=
cur_scope
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
cache_update_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
cache_update_mutex_
);
...
@@ -1467,10 +1467,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -1467,10 +1467,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// TODO(chenweihang): in the first phase of project, we only support CPU, CUDA
// TODO(chenweihang): in the first phase of project, we only support CPU, CUDA
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
// phase
// phase
phi
::
KernelKey
p
t
_kernel_key
;
phi
::
KernelKey
p
hi
_kernel_key
;
std
::
string
p
t
_kernel_name
;
std
::
string
p
hi
_kernel_name
;
if
(
phi
::
KernelFactory
::
Instance
().
HasCompatiblePhiKernel
(
type_
))
{
if
(
phi
::
KernelFactory
::
Instance
().
HasCompatiblePhiKernel
(
type_
))
{
if
(
kernel_signature_
==
nullptr
||
p
t
_kernel_
==
nullptr
)
{
if
(
kernel_signature_
==
nullptr
||
p
hi
_kernel_
==
nullptr
)
{
kernel_signature_
.
reset
(
new
phi
::
KernelSignature
(
kernel_signature_
.
reset
(
new
phi
::
KernelSignature
(
std
::
move
(
GetExpectedPhiKernelArgs
(
exe_ctx
))));
std
::
move
(
GetExpectedPhiKernelArgs
(
exe_ctx
))));
VLOG
(
6
)
<<
*
kernel_signature_
.
get
();
VLOG
(
6
)
<<
*
kernel_signature_
.
get
();
...
@@ -1479,7 +1479,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -1479,7 +1479,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
new
OpKernelType
(
std
::
move
(
InnerGetExpectedKernelType
(
exe_ctx
))));
new
OpKernelType
(
std
::
move
(
InnerGetExpectedKernelType
(
exe_ctx
))));
dev_ctx
=
pool
.
Get
(
kernel_type_
->
place_
);
dev_ctx
=
pool
.
Get
(
kernel_type_
->
place_
);
p
t
_kernel_name
=
kernel_signature_
->
name
;
p
hi
_kernel_name
=
kernel_signature_
->
name
;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
// library_type here, otherwise it can't work.
...
@@ -1502,38 +1502,38 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -1502,38 +1502,38 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
auto
expected_kernel_key_library_type
=
kernel_type_
->
library_type_
;
auto
expected_kernel_key_library_type
=
kernel_type_
->
library_type_
;
kernel_type_
->
library_type_
=
LibraryType
::
kKP
;
kernel_type_
->
library_type_
=
LibraryType
::
kKP
;
VLOG
(
3
)
<<
"modifing XPU KP kernel in static graph: "
VLOG
(
3
)
<<
"modifing XPU KP kernel in static graph: "
<<
p
t
_kernel_name
<<
p
hi
_kernel_name
<<
", using_kernel_key:"
<<
*
kernel_type_
.
get
();
<<
", using_kernel_key:"
<<
*
kernel_type_
.
get
();
auto
try_p
t
_kernel_key
=
auto
try_p
hi
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
if
(
!
phi
::
KernelFactory
::
Instance
().
HasKernel
(
p
t
_kernel_name
,
if
(
!
phi
::
KernelFactory
::
Instance
().
HasKernel
(
p
hi
_kernel_name
,
try_p
t
_kernel_key
))
{
try_p
hi
_kernel_key
))
{
kernel_type_
->
library_type_
=
expected_kernel_key_library_type
;
kernel_type_
->
library_type_
=
expected_kernel_key_library_type
;
VLOG
(
3
)
<<
"modify XPU KP kernel in static graph: "
VLOG
(
3
)
<<
"modify XPU KP kernel in static graph: "
<<
p
t
_kernel_name
<<
" is failed "
<<
*
kernel_type_
.
get
();
<<
p
hi
_kernel_name
<<
" is failed "
<<
*
kernel_type_
.
get
();
}
else
{
}
else
{
use_phi_xpu_kp
=
true
;
use_phi_xpu_kp
=
true
;
VLOG
(
3
)
<<
"modify XPU KP kernel in static graph: "
VLOG
(
3
)
<<
"modify XPU KP kernel in static graph: "
<<
p
t
_kernel_name
<<
" is succeed "
<<
*
kernel_type_
.
get
();
<<
p
hi
_kernel_name
<<
" is succeed "
<<
*
kernel_type_
.
get
();
}
}
}
}
}
}
#endif
#endif
p
t
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
p
hi
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
p
t
_kernel_
.
reset
(
p
hi
_kernel_
.
reset
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
p
t_kernel_name
,
pt
_kernel_key
)));
p
hi_kernel_name
,
phi
_kernel_key
)));
if
(
p
t
_kernel_
->
IsValid
())
{
if
(
p
hi
_kernel_
->
IsValid
())
{
VLOG
(
6
)
<<
"Static mode ChoosePhiKernel - kernel name: "
VLOG
(
6
)
<<
"Static mode ChoosePhiKernel - kernel name: "
<<
p
t_kernel_name
<<
" | kernel key: "
<<
pt
_kernel_key
<<
p
hi_kernel_name
<<
" | kernel key: "
<<
phi
_kernel_key
<<
" | kernel: "
<<
*
p
t
_kernel_
;
<<
" | kernel: "
<<
*
p
hi
_kernel_
;
}
else
{
}
else
{
VLOG
(
6
)
<<
"Static mode ChoosePhiKernel - kernel `"
<<
p
t
_kernel_name
VLOG
(
6
)
<<
"Static mode ChoosePhiKernel - kernel `"
<<
p
hi
_kernel_name
<<
"` not found."
;
<<
"` not found."
;
}
}
}
else
{
}
else
{
p
t
_kernel_name
=
kernel_signature_
->
name
;
p
hi
_kernel_name
=
kernel_signature_
->
name
;
// NOTE(Liu-xiandong):In my ctest, this branch do not be executed,
// NOTE(Liu-xiandong):In my ctest, this branch do not be executed,
// I can't understand it, it's really confusing.
// I can't understand it, it's really confusing.
// But we still need to keep this to avoid errors.
// But we still need to keep this to avoid errors.
...
@@ -1556,24 +1556,24 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -1556,24 +1556,24 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
auto
expected_kernel_key_library_type
=
kernel_type_
->
library_type_
;
auto
expected_kernel_key_library_type
=
kernel_type_
->
library_type_
;
kernel_type_
->
library_type_
=
LibraryType
::
kKP
;
kernel_type_
->
library_type_
=
LibraryType
::
kKP
;
VLOG
(
3
)
<<
"modifing XPU KP kernel in static graph: "
VLOG
(
3
)
<<
"modifing XPU KP kernel in static graph: "
<<
p
t
_kernel_name
<<
p
hi
_kernel_name
<<
", using_kernel_key:"
<<
*
kernel_type_
.
get
();
<<
", using_kernel_key:"
<<
*
kernel_type_
.
get
();
auto
try_p
t
_kernel_key
=
auto
try_p
hi
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
if
(
!
phi
::
KernelFactory
::
Instance
().
HasKernel
(
p
t
_kernel_name
,
if
(
!
phi
::
KernelFactory
::
Instance
().
HasKernel
(
p
hi
_kernel_name
,
try_p
t
_kernel_key
))
{
try_p
hi
_kernel_key
))
{
kernel_type_
->
library_type_
=
expected_kernel_key_library_type
;
kernel_type_
->
library_type_
=
expected_kernel_key_library_type
;
VLOG
(
3
)
<<
"modify XPU KP kernel in static graph: "
VLOG
(
3
)
<<
"modify XPU KP kernel in static graph: "
<<
p
t
_kernel_name
<<
" is failed "
<<
*
kernel_type_
.
get
();
<<
p
hi
_kernel_name
<<
" is failed "
<<
*
kernel_type_
.
get
();
}
else
{
}
else
{
use_phi_xpu_kp
=
true
;
use_phi_xpu_kp
=
true
;
VLOG
(
3
)
<<
"modify XPU KP kernel in static graph: "
VLOG
(
3
)
<<
"modify XPU KP kernel in static graph: "
<<
p
t
_kernel_name
<<
" is succeed "
<<
*
kernel_type_
.
get
();
<<
p
hi
_kernel_name
<<
" is succeed "
<<
*
kernel_type_
.
get
();
}
}
}
}
}
}
#endif
#endif
p
t
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
p
hi
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
}
}
// NOTE(Liu-xiandong): Determine whether the selected kernel is valid
// NOTE(Liu-xiandong): Determine whether the selected kernel is valid
...
@@ -1596,7 +1596,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -1596,7 +1596,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
bool
is_xpu_kp_support
=
(
use_xpu_kp_kernel_rt
||
use_xpu_kp_kernel_debug
);
bool
is_xpu_kp_support
=
(
use_xpu_kp_kernel_rt
||
use_xpu_kp_kernel_debug
);
#endif
#endif
if
(
p
t
_kernel_
->
IsValid
()
if
(
p
hi
_kernel_
->
IsValid
()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
&&
!
is_xpu_unsupport
&&
!
is_xpu_unsupport
#endif
#endif
...
@@ -1628,17 +1628,17 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -1628,17 +1628,17 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
||
(
is_xpu_unsupport
&&
!
is_xpu_kp_support
)
||
(
is_xpu_unsupport
&&
!
is_xpu_kp_support
)
#endif
#endif
)
{
)
{
auto
p
t
_cpu_kernel_key
=
auto
p
hi
_cpu_kernel_key
=
FallBackToCpu
(
*
kernel_type_
.
get
(),
p
t
_kernel_key
,
*
this
);
FallBackToCpu
(
*
kernel_type_
.
get
(),
p
hi
_kernel_key
,
*
this
);
p
t
_kernel_
.
reset
(
p
hi
_kernel_
.
reset
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
p
t_kernel_name
,
pt
_cpu_kernel_key
)));
p
hi_kernel_name
,
phi
_cpu_kernel_key
)));
dev_ctx
=
pool
.
Get
(
platform
::
CPUPlace
());
dev_ctx
=
pool
.
Get
(
platform
::
CPUPlace
());
if
(
p
t
_kernel_
->
IsValid
())
{
if
(
p
hi
_kernel_
->
IsValid
())
{
VLOG
(
6
)
<<
"Static mode PrepareImpl - kernel name: "
<<
pt_kernel_name
VLOG
(
6
)
<<
"Static mode PrepareImpl - kernel name: "
<<
" | kernel key: "
<<
pt
_cpu_kernel_key
<<
phi_kernel_name
<<
" | kernel key: "
<<
phi
_cpu_kernel_key
<<
" | kernel: "
<<
*
p
t
_kernel_
;
<<
" | kernel: "
<<
*
p
hi
_kernel_
;
run_phi_kernel_
=
true
;
run_phi_kernel_
=
true
;
}
}
}
}
...
@@ -1692,20 +1692,20 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -1692,20 +1692,20 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
1
,
1
,
platform
::
EventRole
::
kInnerOp
);
platform
::
EventRole
::
kInnerOp
);
if
(
run_phi_kernel_
)
{
if
(
run_phi_kernel_
)
{
phi
::
KernelContext
p
t
_kernel_context
;
phi
::
KernelContext
p
hi
_kernel_context
;
if
(
enable_cache_runtime_context_
&&
!
need_prepare_phi_data_
&&
if
(
enable_cache_runtime_context_
&&
!
need_prepare_phi_data_
&&
!
need_prepare_data_
)
{
!
need_prepare_data_
)
{
impl_
=
impl_
=
new
CacheImpl
(
new
phi
::
KernelContext
(),
new
CacheImpl
(
new
phi
::
KernelContext
(),
new
RuntimeInferShapeContext
(
*
this
,
*
runtime_ctx
));
new
RuntimeInferShapeContext
(
*
this
,
*
runtime_ctx
));
BuildPhiKernelContext
(
*
runtime_ctx
,
dev_ctx
,
impl_
->
getKernelContext
());
BuildPhiKernelContext
(
*
runtime_ctx
,
dev_ctx
,
impl_
->
getKernelContext
());
(
*
p
t
_kernel_
)(
impl_
->
getKernelContext
());
(
*
p
hi
_kernel_
)(
impl_
->
getKernelContext
());
}
else
{
}
else
{
phi
::
KernelContext
p
t
_kernel_context
;
phi
::
KernelContext
p
hi
_kernel_context
;
// Do data transform before building KernelContext
// Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
// TODO(zhiqiu): support TransferInplaceVarsBack
BuildPhiKernelContext
(
*
runtime_ctx
,
dev_ctx
,
&
p
t
_kernel_context
);
BuildPhiKernelContext
(
*
runtime_ctx
,
dev_ctx
,
&
p
hi
_kernel_context
);
(
*
p
t_kernel_
)(
&
pt
_kernel_context
);
(
*
p
hi_kernel_
)(
&
phi
_kernel_context
);
}
}
}
else
{
}
else
{
(
*
kernel_func_
)(
(
*
kernel_func_
)(
...
@@ -1851,20 +1851,20 @@ phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
...
@@ -1851,20 +1851,20 @@ phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
kernel_type_
.
reset
(
kernel_type_
.
reset
(
new
OpKernelType
(
std
::
move
(
InnerGetExpectedKernelType
(
ctx
))));
new
OpKernelType
(
std
::
move
(
InnerGetExpectedKernelType
(
ctx
))));
auto
p
t
_kernel_name
=
kernel_signature_
->
name
;
auto
p
hi
_kernel_name
=
kernel_signature_
->
name
;
auto
p
t
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
auto
p
hi
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
*
kernel_type_
.
get
());
p
t
_kernel_
.
reset
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
p
hi
_kernel_
.
reset
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
p
t_kernel_name
,
pt
_kernel_key
)));
p
hi_kernel_name
,
phi
_kernel_key
)));
if
(
p
t
_kernel_
->
IsValid
())
{
if
(
p
hi
_kernel_
->
IsValid
())
{
VLOG
(
6
)
<<
"Static mode ChoosePhiKernel - kernel name: "
<<
p
t
_kernel_name
VLOG
(
6
)
<<
"Static mode ChoosePhiKernel - kernel name: "
<<
p
hi
_kernel_name
<<
" | kernel key: "
<<
p
t
_kernel_key
<<
" | kernel key: "
<<
p
hi
_kernel_key
<<
" | kernel: "
<<
*
p
t
_kernel_
;
<<
" | kernel: "
<<
*
p
hi
_kernel_
;
}
else
{
}
else
{
VLOG
(
6
)
<<
"Static mode ChoosePhiKernel - kernel `"
<<
p
t
_kernel_name
VLOG
(
6
)
<<
"Static mode ChoosePhiKernel - kernel `"
<<
p
hi
_kernel_name
<<
"` not found."
;
<<
"` not found."
;
}
}
return
p
t
_kernel_key
;
return
p
hi
_kernel_key
;
}
}
void
OperatorWithKernel
::
ChooseKernel
(
const
ExecutionContext
&
ctx
)
const
{
void
OperatorWithKernel
::
ChooseKernel
(
const
ExecutionContext
&
ctx
)
const
{
...
@@ -2302,7 +2302,7 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -2302,7 +2302,7 @@ Scope* OperatorWithKernel::PrepareData(
if
(
run_phi_kernel_
)
{
if
(
run_phi_kernel_
)
{
const
auto
&
input_names
=
kernel_signature_
->
input_names
;
const
auto
&
input_names
=
kernel_signature_
->
input_names
;
const
auto
&
input_defs
=
p
t
_kernel_
->
args_def
().
input_defs
();
const
auto
&
input_defs
=
p
hi
_kernel_
->
args_def
().
input_defs
();
PADDLE_ENFORCE_EQ
(
input_names
.
size
(),
PADDLE_ENFORCE_EQ
(
input_names
.
size
(),
input_defs
.
size
(),
input_defs
.
size
(),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -2311,7 +2311,7 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -2311,7 +2311,7 @@ Scope* OperatorWithKernel::PrepareData(
input_names
.
size
(),
input_names
.
size
(),
input_defs
.
size
()));
input_defs
.
size
()));
for
(
size_t
i
=
0
;
i
<
input_defs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_defs
.
size
();
++
i
)
{
const
auto
&
input_defs
=
p
t
_kernel_
->
args_def
().
input_defs
();
const
auto
&
input_defs
=
p
hi
_kernel_
->
args_def
().
input_defs
();
auto
&
in_def
=
input_defs
.
at
(
i
);
auto
&
in_def
=
input_defs
.
at
(
i
);
std
::
string
input_name
=
input_names
[
i
];
std
::
string
input_name
=
input_names
[
i
];
auto
iter
=
ctx
->
inputs
.
find
(
input_name
);
auto
iter
=
ctx
->
inputs
.
find
(
input_name
);
...
@@ -2577,16 +2577,16 @@ phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
...
@@ -2577,16 +2577,16 @@ phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
void
OperatorWithKernel
::
BuildPhiKernelContext
(
void
OperatorWithKernel
::
BuildPhiKernelContext
(
const
RuntimeContext
&
ctx
,
const
RuntimeContext
&
ctx
,
platform
::
DeviceContext
*
dev_ctx
,
platform
::
DeviceContext
*
dev_ctx
,
phi
::
KernelContext
*
p
t
_kernel_context
)
const
{
phi
::
KernelContext
*
p
hi
_kernel_context
)
const
{
p
t
_kernel_context
->
SetDeviceContext
(
dev_ctx
);
p
hi
_kernel_context
->
SetDeviceContext
(
dev_ctx
);
auto
&
input_names
=
kernel_signature_
->
input_names
;
auto
&
input_names
=
kernel_signature_
->
input_names
;
auto
&
attr_names
=
kernel_signature_
->
attr_names
;
auto
&
attr_names
=
kernel_signature_
->
attr_names
;
auto
&
output_names
=
kernel_signature_
->
output_names
;
auto
&
output_names
=
kernel_signature_
->
output_names
;
auto
input_defs
=
p
t
_kernel_
->
args_def
().
input_defs
();
auto
input_defs
=
p
hi
_kernel_
->
args_def
().
input_defs
();
auto
attr_defs
=
p
t
_kernel_
->
args_def
().
attribute_defs
();
auto
attr_defs
=
p
hi
_kernel_
->
args_def
().
attribute_defs
();
auto
output_defs
=
p
t
_kernel_
->
args_def
().
output_defs
();
auto
output_defs
=
p
hi
_kernel_
->
args_def
().
output_defs
();
PADDLE_ENFORCE_EQ
(
input_names
.
size
(),
PADDLE_ENFORCE_EQ
(
input_names
.
size
(),
input_defs
.
size
(),
input_defs
.
size
(),
...
@@ -2617,7 +2617,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2617,7 +2617,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
// calcute the start and end index of the input tensors
// calcute the start and end index of the input tensors
size_t
start_idx
=
size_t
start_idx
=
(
i
==
0
?
0
:
p
t
_kernel_context
->
InputRangeAt
(
i
-
1
).
second
);
(
i
==
0
?
0
:
p
hi
_kernel_context
->
InputRangeAt
(
i
-
1
).
second
);
// deal with optional here
// deal with optional here
if
((
it
==
ctx
.
inputs
.
end
()
||
it
->
second
.
size
()
==
0
)
&&
if
((
it
==
ctx
.
inputs
.
end
()
||
it
->
second
.
size
()
==
0
)
&&
(
input_defs
[
i
].
type_index
==
(
input_defs
[
i
].
type_index
==
...
@@ -2627,10 +2627,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2627,10 +2627,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
input_defs
[
i
].
type_index
==
input_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
std
::
type_index
(
typeid
(
paddle
::
optional
<
std
::
vector
<
const
phi
::
DenseTensor
*>>
))))
{
paddle
::
optional
<
std
::
vector
<
const
phi
::
DenseTensor
*>>
))))
{
p
t
_kernel_context
->
EmplaceBackInputWithoutSetRange
(
nullptr
);
p
hi
_kernel_context
->
EmplaceBackInputWithoutSetRange
(
nullptr
);
auto
end_idx
=
start_idx
+
1
;
auto
end_idx
=
start_idx
+
1
;
p
t
_kernel_context
->
AssignInputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
p
hi
_kernel_context
->
AssignInputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
i
);
continue
;
continue
;
}
}
...
@@ -2641,10 +2641,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2641,10 +2641,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
auto
*
var
=
ins_vector
[
offset
];
auto
*
var
=
ins_vector
[
offset
];
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
tensor_in
=
&
(
var
->
Get
<
framework
::
LoDTensor
>
());
tensor_in
=
&
(
var
->
Get
<
framework
::
LoDTensor
>
());
p
t
_kernel_context
->
EmplaceBackInputWithoutSetRange
(
tensor_in
);
p
hi
_kernel_context
->
EmplaceBackInputWithoutSetRange
(
tensor_in
);
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
tensor_in
=
&
(
var
->
Get
<
phi
::
SelectedRows
>
());
tensor_in
=
&
(
var
->
Get
<
phi
::
SelectedRows
>
());
p
t
_kernel_context
->
EmplaceBackInputWithoutSetRange
(
tensor_in
);
p
hi
_kernel_context
->
EmplaceBackInputWithoutSetRange
(
tensor_in
);
}
else
if
(
var
->
IsType
<
framework
::
LoDTensorArray
>
())
{
}
else
if
(
var
->
IsType
<
framework
::
LoDTensorArray
>
())
{
need_prepare_phi_data_
=
true
;
need_prepare_phi_data_
=
true
;
paddle
::
small_vector
<
const
phi
::
TensorBase
*>
tensor_vector
;
paddle
::
small_vector
<
const
phi
::
TensorBase
*>
tensor_vector
;
...
@@ -2652,7 +2652,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2652,7 +2652,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
for
(
auto
&
t
:
tensor_array
)
{
for
(
auto
&
t
:
tensor_array
)
{
tensor_vector
.
emplace_back
(
&
t
);
tensor_vector
.
emplace_back
(
&
t
);
}
}
p
t
_kernel_context
->
EmplaceBackInputsWithoutSetRange
(
tensor_vector
);
p
hi
_kernel_context
->
EmplaceBackInputsWithoutSetRange
(
tensor_vector
);
end_idx
+=
tensor_array
.
size
()
-
1
;
end_idx
+=
tensor_array
.
size
()
-
1
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
...
@@ -2661,24 +2661,24 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2661,24 +2661,24 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
}
}
}
// Note: here cannot deal with vector<LoDTensorArray> input
// Note: here cannot deal with vector<LoDTensorArray> input
p
t
_kernel_context
->
AssignInputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
p
hi
_kernel_context
->
AssignInputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
}
}
VLOG
(
4
)
<<
"Done inputs"
;
VLOG
(
4
)
<<
"Done inputs"
;
for
(
size_t
i
=
0
;
i
<
output_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_names
.
size
();
++
i
)
{
auto
it
=
ctx
.
outputs
.
find
(
output_names
[
i
]);
auto
it
=
ctx
.
outputs
.
find
(
output_names
[
i
]);
size_t
start_idx
=
size_t
start_idx
=
(
i
==
0
?
0
:
p
t
_kernel_context
->
OutputRangeAt
(
i
-
1
).
second
);
(
i
==
0
?
0
:
p
hi
_kernel_context
->
OutputRangeAt
(
i
-
1
).
second
);
if
(
it
==
ctx
.
outputs
.
end
()
||
it
->
second
.
empty
())
{
if
(
it
==
ctx
.
outputs
.
end
()
||
it
->
second
.
empty
())
{
// Deal with the case that some outputs are not found or be NULL when run
// Deal with the case that some outputs are not found or be NULL when run
// the kernel.
// the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
// sometimes dx or dy may be NULL.
p
t
_kernel_context
->
EmplaceBackOutputWithoutSetRange
(
nullptr
);
p
hi
_kernel_context
->
EmplaceBackOutputWithoutSetRange
(
nullptr
);
auto
end_idx
=
start_idx
+
1
;
auto
end_idx
=
start_idx
+
1
;
p
t
_kernel_context
->
AssignOutputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
p
hi
_kernel_context
->
AssignOutputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
i
);
continue
;
continue
;
}
}
auto
&
outs_vector
=
it
->
second
;
auto
&
outs_vector
=
it
->
second
;
...
@@ -2691,10 +2691,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2691,10 +2691,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
if
(
var
)
{
if
(
var
)
{
if
(
var
->
template
IsType
<
framework
::
LoDTensor
>())
{
if
(
var
->
template
IsType
<
framework
::
LoDTensor
>())
{
tensor_out
=
var
->
template
GetMutable
<
framework
::
LoDTensor
>();
tensor_out
=
var
->
template
GetMutable
<
framework
::
LoDTensor
>();
p
t
_kernel_context
->
EmplaceBackOutputWithoutSetRange
(
tensor_out
);
p
hi
_kernel_context
->
EmplaceBackOutputWithoutSetRange
(
tensor_out
);
}
else
if
(
var
->
template
IsType
<
phi
::
SelectedRows
>())
{
}
else
if
(
var
->
template
IsType
<
phi
::
SelectedRows
>())
{
tensor_out
=
var
->
template
GetMutable
<
phi
::
SelectedRows
>();
tensor_out
=
var
->
template
GetMutable
<
phi
::
SelectedRows
>();
p
t
_kernel_context
->
EmplaceBackOutputWithoutSetRange
(
tensor_out
);
p
hi
_kernel_context
->
EmplaceBackOutputWithoutSetRange
(
tensor_out
);
}
else
if
(
var
->
template
IsType
<
framework
::
LoDTensorArray
>())
{
}
else
if
(
var
->
template
IsType
<
framework
::
LoDTensorArray
>())
{
paddle
::
small_vector
<
phi
::
TensorBase
*>
tensor_vector
;
paddle
::
small_vector
<
phi
::
TensorBase
*>
tensor_vector
;
auto
*
tensor_array
=
auto
*
tensor_array
=
...
@@ -2704,7 +2704,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2704,7 +2704,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
for
(
auto
&
t
:
*
tensor_array
)
{
for
(
auto
&
t
:
*
tensor_array
)
{
tensor_vector
.
emplace_back
(
&
t
);
tensor_vector
.
emplace_back
(
&
t
);
}
}
p
t
_kernel_context
->
EmplaceBackOutputsWithoutSetRange
(
tensor_vector
);
p
hi
_kernel_context
->
EmplaceBackOutputsWithoutSetRange
(
tensor_vector
);
end_idx
+=
tensor_array
->
size
()
-
1
;
end_idx
+=
tensor_array
->
size
()
-
1
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
...
@@ -2712,10 +2712,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2712,10 +2712,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
framework
::
ToTypeName
(
var
->
Type
())));
framework
::
ToTypeName
(
var
->
Type
())));
}
}
}
else
{
}
else
{
p
t
_kernel_context
->
EmplaceBackOutputWithoutSetRange
(
tensor_out
);
p
hi
_kernel_context
->
EmplaceBackOutputWithoutSetRange
(
tensor_out
);
}
}
}
}
pt_kernel_context
->
AssignOutputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
phi_kernel_context
->
AssignOutputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
}
}
VLOG
(
4
)
<<
"Done outputs"
;
VLOG
(
4
)
<<
"Done outputs"
;
...
@@ -2729,15 +2730,15 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2729,15 +2730,15 @@ void OperatorWithKernel::BuildPhiKernelContext(
// scalar is in the attribute
// scalar is in the attribute
switch
(
AttrTypeID
(
attr_iter
->
second
))
{
switch
(
AttrTypeID
(
attr_iter
->
second
))
{
case
proto
::
AttrType
::
FLOAT
:
case
proto
::
AttrType
::
FLOAT
:
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
))));
phi
::
Scalar
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
))));
break
;
break
;
case
proto
::
AttrType
::
INT
:
case
proto
::
AttrType
::
INT
:
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
))));
phi
::
Scalar
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
))));
break
;
break
;
case
proto
::
AttrType
::
STRING
:
case
proto
::
AttrType
::
STRING
:
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
Scalar
(
PADDLE_GET_CONST
(
std
::
string
,
attr_iter
->
second
))));
PADDLE_GET_CONST
(
std
::
string
,
attr_iter
->
second
))));
break
;
break
;
default:
default:
...
@@ -2749,7 +2750,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2749,7 +2750,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
else
{
// scalar is in the input
}
else
{
// scalar is in the input
need_prepare_phi_data_
=
true
;
need_prepare_phi_data_
=
true
;
auto
&
ins_vector
=
ctx
.
inputs
.
at
(
attr_names
[
i
]);
auto
&
ins_vector
=
ctx
.
inputs
.
at
(
attr_names
[
i
]);
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
experimental
::
MakePhiScalarFromVar
(
*
ins_vector
.
front
())));
experimental
::
MakePhiScalarFromVar
(
*
ins_vector
.
front
())));
}
}
break
;
break
;
...
@@ -2757,19 +2758,19 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2757,19 +2758,19 @@ void OperatorWithKernel::BuildPhiKernelContext(
if
(
attr_iter
!=
Attrs
().
end
())
{
if
(
attr_iter
!=
Attrs
().
end
())
{
switch
(
AttrTypeID
(
attr_iter
->
second
))
{
switch
(
AttrTypeID
(
attr_iter
->
second
))
{
case
proto
::
AttrType
::
INTS
:
case
proto
::
AttrType
::
INTS
:
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
PADDLE_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr_iter
->
second
))));
PADDLE_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr_iter
->
second
))));
break
;
break
;
case
proto
::
AttrType
::
LONGS
:
case
proto
::
AttrType
::
LONGS
:
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
PADDLE_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr_iter
->
second
))));
PADDLE_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr_iter
->
second
))));
break
;
break
;
case
proto
::
AttrType
::
INT
:
case
proto
::
AttrType
::
INT
:
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
&
PADDLE_GET_CONST
(
int32_t
,
attr_iter
->
second
),
1
)));
&
PADDLE_GET_CONST
(
int32_t
,
attr_iter
->
second
),
1
)));
break
;
break
;
case
proto
::
AttrType
::
LONG
:
case
proto
::
AttrType
::
LONG
:
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
phi
::
IntArray
(
&
PADDLE_GET_CONST
(
int64_t
,
attr_iter
->
second
),
1
)));
&
PADDLE_GET_CONST
(
int64_t
,
attr_iter
->
second
),
1
)));
break
;
break
;
default:
default:
...
@@ -2782,10 +2783,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2782,10 +2783,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
need_prepare_phi_data_
=
true
;
need_prepare_phi_data_
=
true
;
auto
&
ins_vector
=
ctx
.
inputs
.
at
(
attr_names
[
i
]);
auto
&
ins_vector
=
ctx
.
inputs
.
at
(
attr_names
[
i
]);
if
(
ins_vector
.
size
()
==
1
)
{
// ShapeTensor
if
(
ins_vector
.
size
()
==
1
)
{
// ShapeTensor
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
experimental
::
MakePhiIntArrayFromVar
(
*
ins_vector
.
front
())));
experimental
::
MakePhiIntArrayFromVar
(
*
ins_vector
.
front
())));
}
else
{
// ShapeTensorList
}
else
{
// ShapeTensorList
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
experimental
::
MakePhiIntArrayFromVarList
(
ins_vector
)));
experimental
::
MakePhiIntArrayFromVarList
(
ins_vector
)));
}
}
}
}
...
@@ -2806,7 +2807,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2806,7 +2807,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
for
(
const
auto
&
val
:
vec
)
{
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
scalar_list
.
emplace_back
(
val
);
}
}
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
break
;
}
break
;
case
proto
::
AttrType
::
LONGS
:
{
case
proto
::
AttrType
::
LONGS
:
{
const
auto
&
vec
=
const
auto
&
vec
=
...
@@ -2816,7 +2817,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2816,7 +2817,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
for
(
const
auto
&
val
:
vec
)
{
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
scalar_list
.
emplace_back
(
val
);
}
}
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
break
;
}
break
;
case
proto
::
AttrType
::
FLOATS
:
{
case
proto
::
AttrType
::
FLOATS
:
{
const
auto
&
vec
=
const
auto
&
vec
=
...
@@ -2826,7 +2827,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2826,7 +2827,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
for
(
const
auto
&
val
:
vec
)
{
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
scalar_list
.
emplace_back
(
val
);
}
}
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
break
;
}
break
;
case
proto
::
AttrType
::
FLOAT64S
:
{
case
proto
::
AttrType
::
FLOAT64S
:
{
const
auto
&
vec
=
const
auto
&
vec
=
...
@@ -2836,7 +2837,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2836,7 +2837,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
for
(
const
auto
&
val
:
vec
)
{
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
scalar_list
.
emplace_back
(
val
);
}
}
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
break
;
}
break
;
case
proto
::
AttrType
::
BOOLEANS
:
{
case
proto
::
AttrType
::
BOOLEANS
:
{
const
auto
&
vec
=
const
auto
&
vec
=
...
@@ -2846,7 +2847,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2846,7 +2847,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
for
(
const
auto
&
val
:
vec
)
{
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
scalar_list
.
emplace_back
(
val
);
}
}
p
t
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
break
;
}
break
;
default:
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
...
@@ -2864,39 +2865,39 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2864,39 +2865,39 @@ void OperatorWithKernel::BuildPhiKernelContext(
attr_names
[
i
]));
attr_names
[
i
]));
switch
(
attr_defs
[
i
].
type_index
)
{
switch
(
attr_defs
[
i
].
type_index
)
{
case
phi
::
AttributeType
::
FLOAT32
:
case
phi
::
AttributeType
::
FLOAT32
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
float
,
attr_iter
->
second
));
break
;
break
;
case
phi
::
AttributeType
::
INT32
:
case
phi
::
AttributeType
::
INT32
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
));
break
;
break
;
case
phi
::
AttributeType
::
BOOL
:
case
phi
::
AttributeType
::
BOOL
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
bool
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
bool
,
attr_iter
->
second
));
break
;
break
;
case
phi
::
AttributeType
::
INT64
:
case
phi
::
AttributeType
::
INT64
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
int64_t
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
int64_t
,
attr_iter
->
second
));
break
;
break
;
case
phi
::
AttributeType
::
INT32S
:
case
phi
::
AttributeType
::
INT32S
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
attr_iter
->
second
));
break
;
break
;
case
phi
::
AttributeType
::
DATA_TYPE
:
{
case
phi
::
AttributeType
::
DATA_TYPE
:
{
auto
data_type
=
framework
::
TransToPhiDataType
(
auto
data_type
=
framework
::
TransToPhiDataType
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
)));
PADDLE_GET_CONST
(
int
,
attr_iter
->
second
)));
p
t
_kernel_context
->
EmplaceBackAttr
(
data_type
);
p
hi
_kernel_context
->
EmplaceBackAttr
(
data_type
);
}
break
;
}
break
;
case
phi
::
AttributeType
::
STRING
:
case
phi
::
AttributeType
::
STRING
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
PADDLE_GET_CONST
(
std
::
string
,
attr_iter
->
second
)));
std
::
move
(
PADDLE_GET_CONST
(
std
::
string
,
attr_iter
->
second
)));
break
;
break
;
case
phi
::
AttributeType
::
INT64S
:
case
phi
::
AttributeType
::
INT64S
:
switch
(
AttrTypeID
(
attr_iter
->
second
))
{
switch
(
AttrTypeID
(
attr_iter
->
second
))
{
case
proto
::
AttrType
::
LONGS
:
case
proto
::
AttrType
::
LONGS
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr_iter
->
second
));
break
;
break
;
case
proto
::
AttrType
::
INTS
:
{
case
proto
::
AttrType
::
INTS
:
{
...
@@ -2904,7 +2905,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2904,7 +2905,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
attr_iter
->
second
);
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
attr_iter
->
second
);
const
std
::
vector
<
int64_t
>
vector_int64_attr
(
const
std
::
vector
<
int64_t
>
vector_int64_attr
(
vector_int_attr
.
begin
(),
vector_int_attr
.
end
());
vector_int_attr
.
begin
(),
vector_int_attr
.
end
());
p
t
_kernel_context
->
EmplaceBackAttr
(
vector_int64_attr
);
p
hi
_kernel_context
->
EmplaceBackAttr
(
vector_int64_attr
);
}
break
;
}
break
;
default:
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
...
@@ -2915,11 +2916,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2915,11 +2916,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
}
break
;
break
;
case
phi
::
AttributeType
::
FLOAT32S
:
case
phi
::
AttributeType
::
FLOAT32S
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
std
::
vector
<
float
>
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
std
::
vector
<
float
>
,
attr_iter
->
second
));
break
;
break
;
case
phi
::
AttributeType
::
STRINGS
:
case
phi
::
AttributeType
::
STRINGS
:
p
t
_kernel_context
->
EmplaceBackAttr
(
p
hi
_kernel_context
->
EmplaceBackAttr
(
PADDLE_GET_CONST
(
std
::
vector
<
std
::
string
>
,
attr_iter
->
second
));
PADDLE_GET_CONST
(
std
::
vector
<
std
::
string
>
,
attr_iter
->
second
));
break
;
break
;
default:
default:
...
...
paddle/fluid/framework/operator.h
浏览文件 @
942ff89f
...
@@ -652,16 +652,16 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -652,16 +652,16 @@ class OperatorWithKernel : public OperatorBase {
void
BuildPhiKernelContext
(
const
RuntimeContext
&
ctx
,
void
BuildPhiKernelContext
(
const
RuntimeContext
&
ctx
,
platform
::
DeviceContext
*
dev_ctx
,
platform
::
DeviceContext
*
dev_ctx
,
phi
::
KernelContext
*
p
t
_kernel_context
)
const
;
phi
::
KernelContext
*
p
hi
_kernel_context
)
const
;
phi
::
KernelSignature
*
PhiKernelSignature
()
const
{
phi
::
KernelSignature
*
PhiKernelSignature
()
const
{
return
kernel_signature_
.
get
();
return
kernel_signature_
.
get
();
}
}
phi
::
Kernel
*
PhiKernel
()
const
{
return
p
t
_kernel_
.
get
();
}
phi
::
Kernel
*
PhiKernel
()
const
{
return
p
hi
_kernel_
.
get
();
}
void
ResetPhiKernel
(
phi
::
Kernel
*
kernel
)
const
{
void
ResetPhiKernel
(
phi
::
Kernel
*
kernel
)
const
{
return
p
t
_kernel_
.
reset
(
kernel
);
return
p
hi
_kernel_
.
reset
(
kernel
);
}
}
const
OpKernelType
*
kernel_type
()
const
{
return
kernel_type_
.
get
();
}
const
OpKernelType
*
kernel_type
()
const
{
return
kernel_type_
.
get
();
}
...
@@ -730,7 +730,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -730,7 +730,7 @@ class OperatorWithKernel : public OperatorBase {
mutable
bool
run_phi_kernel_
=
false
;
mutable
bool
run_phi_kernel_
=
false
;
mutable
bool
run_kp_kernel
=
false
;
mutable
bool
run_kp_kernel
=
false
;
mutable
std
::
unique_ptr
<
phi
::
KernelSignature
>
kernel_signature_
;
mutable
std
::
unique_ptr
<
phi
::
KernelSignature
>
kernel_signature_
;
mutable
std
::
unique_ptr
<
phi
::
Kernel
>
p
t
_kernel_
;
mutable
std
::
unique_ptr
<
phi
::
Kernel
>
p
hi
_kernel_
;
mutable
std
::
unique_ptr
<
phi
::
ArgumentMappingFn
>
arg_map_fn_
;
mutable
std
::
unique_ptr
<
phi
::
ArgumentMappingFn
>
arg_map_fn_
;
struct
CacheImpl
;
struct
CacheImpl
;
...
...
paddle/fluid/imperative/op_base.h
浏览文件 @
942ff89f
...
@@ -227,7 +227,7 @@ class OpBase {
...
@@ -227,7 +227,7 @@ class OpBase {
size_t
id_
{
-
1UL
};
size_t
id_
{
-
1UL
};
// In order to reduce the compatibility phase
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
// performance overhead, temporarily cache KernelContext
static
phi
::
KernelContext
p
t
_kernel_context_
;
static
phi
::
KernelContext
p
hi
_kernel_context_
;
std
::
vector
<
std
::
shared_ptr
<
std
::
function
<
void
()
>>>
void_function_post_hooks_
;
std
::
vector
<
std
::
shared_ptr
<
std
::
function
<
void
()
>>>
void_function_post_hooks_
;
};
};
...
...
paddle/fluid/imperative/prepared_operator.cc
浏览文件 @
942ff89f
...
@@ -183,8 +183,8 @@ PreparedOp PrepareImpl(
...
@@ -183,8 +183,8 @@ PreparedOp PrepareImpl(
const
phi
::
KernelSignature
*
default_kernel_signature
=
nullptr
;
const
phi
::
KernelSignature
*
default_kernel_signature
=
nullptr
;
phi
::
KernelSignature
kernel_signature
;
phi
::
KernelSignature
kernel_signature
;
phi
::
KernelKey
p
t
_kernel_key
;
phi
::
KernelKey
p
hi
_kernel_key
;
std
::
string
p
t
_kernel_name
;
std
::
string
p
hi
_kernel_name
;
#if defined(PADDLE_WITH_XPU)
#if defined(PADDLE_WITH_XPU)
bool
is_xpu_unsupport
=
bool
is_xpu_unsupport
=
paddle
::
platform
::
is_xpu_place
(
expected_kernel_key
.
place_
)
&&
paddle
::
platform
::
is_xpu_place
(
expected_kernel_key
.
place_
)
&&
...
@@ -213,7 +213,7 @@ PreparedOp PrepareImpl(
...
@@ -213,7 +213,7 @@ PreparedOp PrepareImpl(
if
(
has_phi_kernel
)
{
if
(
has_phi_kernel
)
{
VLOG
(
6
)
<<
kernel_signature
;
VLOG
(
6
)
<<
kernel_signature
;
p
t
_kernel_name
=
kernel_signature
.
name
;
p
hi
_kernel_name
=
kernel_signature
.
name
;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
// library_type here, otherwise it can't work.
...
@@ -236,34 +236,35 @@ PreparedOp PrepareImpl(
...
@@ -236,34 +236,35 @@ PreparedOp PrepareImpl(
auto
expected_kernel_key_library_type
=
auto
expected_kernel_key_library_type
=
expected_kernel_key
.
library_type_
;
expected_kernel_key
.
library_type_
;
expected_kernel_key
.
library_type_
=
paddle
::
framework
::
LibraryType
::
kKP
;
expected_kernel_key
.
library_type_
=
paddle
::
framework
::
LibraryType
::
kKP
;
VLOG
(
3
)
<<
"modifing XPU KP kernel: "
<<
p
t
_kernel_name
VLOG
(
3
)
<<
"modifing XPU KP kernel: "
<<
p
hi
_kernel_name
<<
", using_kernel_key:"
<<
expected_kernel_key
;
<<
", using_kernel_key:"
<<
expected_kernel_key
;
phi
::
KernelKey
try_p
t
_kernel_key
=
phi
::
KernelKey
try_p
hi
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
expected_kernel_key
);
TransOpKernelTypeToPhiKernelKey
(
expected_kernel_key
);
if
(
!
phi_kernel_factory
.
HasKernel
(
pt_kernel_name
,
try_pt_kernel_key
))
{
if
(
!
phi_kernel_factory
.
HasKernel
(
phi_kernel_name
,
try_phi_kernel_key
))
{
expected_kernel_key
.
library_type_
=
expected_kernel_key_library_type
;
expected_kernel_key
.
library_type_
=
expected_kernel_key_library_type
;
VLOG
(
3
)
<<
"modify XPU KP kernel: "
<<
p
t
_kernel_name
VLOG
(
3
)
<<
"modify XPU KP kernel: "
<<
p
hi
_kernel_name
<<
" in dynamic graph is failed "
<<
expected_kernel_key
;
<<
" in dynamic graph is failed "
<<
expected_kernel_key
;
}
else
{
}
else
{
VLOG
(
3
)
<<
"modify XPU KP kernel: "
<<
p
t
_kernel_name
VLOG
(
3
)
<<
"modify XPU KP kernel: "
<<
p
hi
_kernel_name
<<
" in dynamic graph is succeed "
<<
expected_kernel_key
;
<<
" in dynamic graph is succeed "
<<
expected_kernel_key
;
}
}
}
}
}
}
#endif
#endif
p
t
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
expected_kernel_key
);
p
hi
_kernel_key
=
TransOpKernelTypeToPhiKernelKey
(
expected_kernel_key
);
auto
&
phi_kernel
=
auto
&
phi_kernel
=
phi_kernel_factory
.
SelectKernel
(
p
t_kernel_name
,
pt
_kernel_key
);
phi_kernel_factory
.
SelectKernel
(
p
hi_kernel_name
,
phi
_kernel_key
);
if
(
phi_kernel
.
IsValid
()
if
(
phi_kernel
.
IsValid
()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
&&
!
is_xpu_unsupport
&&
!
is_xpu_unsupport
#endif
#endif
)
{
)
{
VLOG
(
6
)
<<
"Dynamic mode PrepareImpl - kernel name: "
<<
p
t
_kernel_name
VLOG
(
6
)
<<
"Dynamic mode PrepareImpl - kernel name: "
<<
p
hi
_kernel_name
<<
" | kernel key: "
<<
p
t
_kernel_key
<<
" | kernel key: "
<<
p
hi
_kernel_key
<<
" | kernel: "
<<
phi_kernel
;
<<
" | kernel: "
<<
phi_kernel
;
if
(
expected_kernel_key
.
place_
!=
place
)
{
if
(
expected_kernel_key
.
place_
!=
place
)
{
...
@@ -279,7 +280,7 @@ PreparedOp PrepareImpl(
...
@@ -279,7 +280,7 @@ PreparedOp PrepareImpl(
phi_kernel
,
phi_kernel
,
dev_ctx
);
dev_ctx
);
}
else
{
}
else
{
VLOG
(
6
)
<<
"Dynamic mode ChoosePhiKernel - kernel `"
<<
p
t
_kernel_name
VLOG
(
6
)
<<
"Dynamic mode ChoosePhiKernel - kernel `"
<<
p
hi
_kernel_name
<<
"` not found."
;
<<
"` not found."
;
}
}
}
}
...
@@ -316,23 +317,23 @@ PreparedOp PrepareImpl(
...
@@ -316,23 +317,23 @@ PreparedOp PrepareImpl(
#endif
#endif
)
{
)
{
if
(
has_phi_kernel
)
{
if
(
has_phi_kernel
)
{
auto
p
t
_cpu_kernel_key
=
auto
p
hi
_cpu_kernel_key
=
FallBackToCpu
(
expected_kernel_key
,
p
t
_kernel_key
,
op
);
FallBackToCpu
(
expected_kernel_key
,
p
hi
_kernel_key
,
op
);
auto
&
p
t
_cpu_kernel
=
auto
&
p
hi
_cpu_kernel
=
phi_kernel_factory
.
SelectKernel
(
p
t_kernel_name
,
pt
_cpu_kernel_key
);
phi_kernel_factory
.
SelectKernel
(
p
hi_kernel_name
,
phi
_cpu_kernel_key
);
if
(
p
t
_cpu_kernel
.
IsValid
())
{
if
(
p
hi
_cpu_kernel
.
IsValid
())
{
VLOG
(
6
)
<<
"Dynamic mode PrepareImpl - kernel name: "
<<
p
t
_kernel_name
VLOG
(
6
)
<<
"Dynamic mode PrepareImpl - kernel name: "
<<
p
hi
_kernel_name
<<
" | kernel key: "
<<
p
t
_cpu_kernel_key
<<
" | kernel key: "
<<
p
hi
_cpu_kernel_key
<<
" | kernel: "
<<
p
t
_cpu_kernel
;
<<
" | kernel: "
<<
p
hi
_cpu_kernel
;
auto
*
cpu_ctx
=
pool
.
Get
(
paddle
::
platform
::
CPUPlace
());
auto
*
cpu_ctx
=
pool
.
Get
(
paddle
::
platform
::
CPUPlace
());
return
PreparedOp
(
return
PreparedOp
(
op
,
op
,
empty_ctx
,
empty_ctx
,
framework
::
TransPhiKernelKeyToOpKernelType
(
p
t
_cpu_kernel_key
),
framework
::
TransPhiKernelKeyToOpKernelType
(
p
hi
_cpu_kernel_key
),
arg_map_fn
,
arg_map_fn
,
default_kernel_signature
,
default_kernel_signature
,
std
::
move
(
kernel_signature
),
std
::
move
(
kernel_signature
),
p
t
_cpu_kernel
,
p
hi
_cpu_kernel
,
cpu_ctx
);
cpu_ctx
);
}
}
}
}
...
@@ -610,7 +611,7 @@ static void PreparedOpRunPtImpl(
...
@@ -610,7 +611,7 @@ static void PreparedOpRunPtImpl(
PreparePhiData
<
VarType
>
(
phi_kernel
,
kernel_signature
,
ins
);
PreparePhiData
<
VarType
>
(
phi_kernel
,
kernel_signature
,
ins
);
phi
::
KernelContext
p
t
_kernel_context
;
phi
::
KernelContext
p
hi
_kernel_context
;
BuildDygraphPhiKernelContext
<
VarType
>
(
kernel_signature
,
BuildDygraphPhiKernelContext
<
VarType
>
(
kernel_signature
,
phi_kernel
,
phi_kernel
,
ins
,
ins
,
...
@@ -618,9 +619,9 @@ static void PreparedOpRunPtImpl(
...
@@ -618,9 +619,9 @@ static void PreparedOpRunPtImpl(
attrs
,
attrs
,
default_attrs
,
default_attrs
,
dev_ctx
,
dev_ctx
,
&
p
t
_kernel_context
);
&
p
hi
_kernel_context
);
phi_kernel
(
&
p
t
_kernel_context
);
phi_kernel
(
&
p
hi
_kernel_context
);
}
}
if
(
FLAGS_check_nan_inf
)
{
if
(
FLAGS_check_nan_inf
)
{
...
...
paddle/phi/capi/include/kernel_utils.h
浏览文件 @
942ff89f
...
@@ -130,283 +130,283 @@ namespace capi {
...
@@ -130,283 +130,283 @@ namespace capi {
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
_PD_BUILD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
_PD_BUILD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_1(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_1(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype) \
cpp_dtype)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
int TouchCustomKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
int TouchCustomKernelSymbolFor_##kernel_name##_##backend##_##layout() {
\
return 0; \
return 0;
\
}
}
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_2(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_2(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_1(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_1(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_3(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_3(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_2(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_2(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_4(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_4(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_3(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_3(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_5(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_5(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_4(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_4(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_6(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_6(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_5(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_5(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_7(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_7(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_6(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_6(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_8(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_8(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_7(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_7(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_9(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_9(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_8(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_8(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_10(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_10(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_9(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_9(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_11(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_11(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_10(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_10(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_12(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_12(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_11(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_11(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_13(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_13(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_12(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_12(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_14(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_14(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_13(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_13(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_15(registrar_class, \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT_15(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE( \
static const registrar_class<cpp_dtype> PD_CUSTOM_PHI_KERNEL_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id); \
PD_CUSTOM_PHI_KERNEL_EXPAND( \
PD_CUSTOM_PHI_KERNEL_EXPAND(
\
_PD_BUILD_KERNEL_REGISTRAR_INIT_14(registrar_class, \
_PD_BUILD_KERNEL_REGISTRAR_INIT_14(registrar_class,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
layout, \
layout,
\
PD_CUSTOM_PHI_KERNEL_ID, \
PD_CUSTOM_PHI_KERNEL_ID,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_BUILD_KERNEL_REGISTRAR_INIT( \
#define _PD_BUILD_KERNEL_REGISTRAR_INIT( \
...
...
paddle/phi/core/kernel_registry.h
浏览文件 @
942ff89f
...
@@ -553,461 +553,461 @@ struct KernelRegistrar {
...
@@ -553,461 +553,461 @@ struct KernelRegistrar {
// clang-format on
// clang-format on
#define _PD_KERNEL_REGISTRAR_INIT_1(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_1(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype) \
cpp_dtype)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PD_KERNEL_REGISTRAR_INIT_2(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_2(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_3(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_3(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_4(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_4(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_5(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_5(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_6(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_6(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_7(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_7(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_8(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_8(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_9(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_9(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_10(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_10(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_11(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_11(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_12(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_12(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_13(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_13(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_14(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_14(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_15(reg_type, \
#define _PD_KERNEL_REGISTRAR_INIT_15(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
registrar_id, \
registrar_id,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
cpp_dtype, \
cpp_dtype,
\
...) \
...)
\
static const ::phi::KernelRegistrar PD_CONCATENATE( \
static const ::phi::KernelRegistrar PD_CONCATENATE(
\
__reg_p
t
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_p
hi
_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \
reg_type,
\
#kernel_name, \
#kernel_name,
\
#backend, \
#backend,
\
DATALAYOUT(layout), \
DATALAYOUT(layout),
\
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(),
\
::phi::KernelArgsParseFunctor< \
::phi::KernelArgsParseFunctor<
\
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse,
\
args_def_fn, \
args_def_fn,
\
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>),
\
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>));
\
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(reg_type, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(reg_type,
\
kernel_name, \
kernel_name,
\
backend, \
backend,
\
context, \
context,
\
layout, \
layout,
\
PD_ID, \
PD_ID,
\
args_def_fn, \
args_def_fn,
\
meta_kernel_fn, \
meta_kernel_fn,
\
__VA_ARGS__))
__VA_ARGS__))
/** PD_REGISTER_GENERAL_KERNEL
/** PD_REGISTER_GENERAL_KERNEL
*
*
...
@@ -1035,7 +1035,7 @@ struct KernelRegistrar {
...
@@ -1035,7 +1035,7 @@ struct KernelRegistrar {
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
static const ::phi::KernelRegistrar \
static const ::phi::KernelRegistrar \
__reg_p
t_kernel_##kernel_name##_##backend##_##layout(
\
__reg_p
hi_kernel_##kernel_name##_##backend##_##layout(
\
reg_type, \
reg_type, \
#kernel_name, \
#kernel_name, \
#backend, \
#backend, \
...
@@ -1055,7 +1055,7 @@ struct KernelRegistrar {
...
@@ -1055,7 +1055,7 @@ struct KernelRegistrar {
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
static const ::phi::KernelRegistrar \
static const ::phi::KernelRegistrar \
__reg_p
t_kernel_##kernel_name##_##backend##_##layout(
\
__reg_p
hi_kernel_##kernel_name##_##backend##_##layout(
\
reg_type, \
reg_type, \
#kernel_name, \
#kernel_name, \
#backend, \
#backend, \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录