Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c751e405
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c751e405
编写于
3月 23, 2022
作者:
王
王明冬
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[infrt] add ir support for phi kernel batch_norm_infer. (#40755)
上级
8e67629c
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
85 addition
and
71 deletion
+85
-71
paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc
paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc
+3
-3
paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc
paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc
+12
-5
paddle/infrt/host_context/value.h
paddle/infrt/host_context/value.h
+1
-0
paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.cc
...frt/kernel/phi/infershaped/infershaped_kernel_launcher.cc
+5
-0
paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h
...nfrt/kernel/phi/infershaped/infershaped_kernel_launcher.h
+3
-5
paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h
paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h
+26
-28
paddle/infrt/tests/dialect/phi/phi_test.mlir
paddle/infrt/tests/dialect/phi/phi_test.mlir
+17
-4
tools/infrt/generate_phi_kernel_dialect.py
tools/infrt/generate_phi_kernel_dialect.py
+3
-1
tools/infrt/get_phi_kernel_function.sh
tools/infrt/get_phi_kernel_function.sh
+8
-7
tools/infrt/get_phi_kernel_info.py
tools/infrt/get_phi_kernel_info.py
+7
-18
未找到文件。
paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc
浏览文件 @
c751e405
...
...
@@ -97,12 +97,12 @@ void PhiOpConvertPass::convertStage() {
}
auto
loc
=
getFunction
().
getLoc
();
builder
.
setInsertionPoint
(
op
);
if
(
phi
::
KernelFactory
::
Instance
().
HasCompatiblePhiKernel
(
op_name
))
{
std
::
string
kernel_name
=
phi
::
TransToPhiKernelName
(
op_name
);
op_name
=
phi
::
TransToPhiKernelName
(
op_name
);
if
(
!::
phi
::
OpUtilsMap
::
Instance
().
Contains
(
op_name
))
{
auto
kernel_op
=
builder
.
create
<
infrt
::
KernelOp
>
(
loc
,
op
->
getResultTypes
(),
op
->
getOperands
(),
kernel
_name
,
op
_name
,
op
->
getAttrDictionary
());
op
->
replaceAllUsesWith
(
kernel_op
.
getResults
());
}
else
{
...
...
paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc
浏览文件 @
c751e405
...
...
@@ -32,17 +32,24 @@ bool ProtoArgumentMappingContext::HasOutput(const std::string& name) const {
}
bool
ProtoArgumentMappingContext
::
HasAttr
(
const
std
::
string
&
name
)
const
{
if
(
name
==
"is_test"
)
return
true
;
return
op_
->
hasAttr
(
name
);
}
paddle
::
any
ProtoArgumentMappingContext
::
Attr
(
const
std
::
string
&
name
)
const
{
mlir
::
Attribute
attrs
=
op_
->
getAttr
(
name
);
if
(
mlir
::
StringAttr
str_attr
=
attrs
.
dyn_cast_or_null
<
mlir
::
StringAttr
>
())
{
if
(
name
==
"is_test"
)
{
return
paddle
::
any
(
true
);
}
mlir
::
Attribute
attr
=
op_
->
getAttr
(
name
);
if
(
!
attr
)
{
return
paddle
::
any
();
}
if
(
mlir
::
StringAttr
str_attr
=
attr
.
dyn_cast
<
mlir
::
StringAttr
>
())
{
return
paddle
::
any
(
str_attr
.
str
());
}
else
{
// ToDO: implementation in the ext PR.
return
paddle
::
any
(
0
);
}
// ToDO: implementation in the ext PR.
return
paddle
::
any
(
0
);
}
size_t
ProtoArgumentMappingContext
::
InputSize
(
const
std
::
string
&
name
)
const
{
...
...
paddle/infrt/host_context/value.h
浏览文件 @
c751e405
...
...
@@ -147,6 +147,7 @@ class Value : public common::Object {
#endif
explicit
Value
(
::
phi
::
DenseTensor
&&
x
)
:
data
(
std
::
move
(
x
))
{}
explicit
Value
(
::
phi
::
MetaTensor
&&
x
)
:
data
(
std
::
move
(
x
))
{}
explicit
Value
(
::
phi
::
MetaConfig
&&
x
)
:
data
(
std
::
move
(
x
))
{}
#ifdef INFRT_WITH_TRT
explicit
Value
(
::
infrt
::
backends
::
tensorrt
::
TrtEngine
&&
x
)
:
data
(
std
::
move
(
x
))
{}
...
...
paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.cc
浏览文件 @
c751e405
...
...
@@ -14,6 +14,7 @@
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/meta_tensor.h"
namespace
infrt
{
namespace
kernel
{
...
...
@@ -31,6 +32,10 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
infershape_kernel_frame_builder
.
AddArgument
(
value
);
}
}
if
(
infershape_kernel_frame_builder
.
GetNumArgs
()
<
arg_size_
)
{
infershape_kernel_frame_builder
.
AddArgument
(
new
host_context
::
Value
(
::
phi
::
MetaConfig
()));
}
}
void
InferShapedKernelLauncher
::
BuildInferShapeCache
(
...
...
paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h
浏览文件 @
c751e405
...
...
@@ -22,11 +22,8 @@ namespace infrt {
namespace
kernel
{
struct
InferShapedKernelLauncher
{
virtual
void
Invoke
(
host_context
::
KernelFrame
*
frame
)
=
0
;
virtual
~
InferShapedKernelLauncher
()
=
default
;
protected:
explicit
InferShapedKernelLauncher
(
int
arg_size
)
:
arg_size_
(
arg_size
)
{}
~
InferShapedKernelLauncher
()
=
default
;
//! Initialize the kernel frame for InferShape kernel.
// This method will create a new KernelFrame with all the Tensors(currently
// only DenseHostTensor) converted into MetaTensors so that the infer-shape
...
...
@@ -46,6 +43,7 @@ struct InferShapedKernelLauncher {
llvm
::
SmallVector
<
host_context
::
ValueRef
,
3
>
values
;
llvm
::
SmallVector
<::
phi
::
DDim
,
3
>
tensor_shape_cache
;
host_context
::
KernelFrameBuilder
infershape_kernel_frame_builder
;
const
int
arg_size_
;
};
}
// namespace kernel
...
...
paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h
浏览文件 @
c751e405
...
...
@@ -24,46 +24,44 @@
namespace
infrt
{
namespace
kernel
{
template
<
typename
F
>
struct
FuncArgStatics
{};
template
<
typename
Return
,
typename
...
Args
>
struct
FuncArgStatics
<
Return
(
*
)(
Args
...)
>
{
constexpr
static
int
arg_size
=
sizeof
...(
Args
);
};
template
<
typename
KernelFunc
,
KernelFunc
kernel
,
typename
InferShapedFunc
,
InferShapedFunc
infershape
>
class
KernelLauncher
:
public
InferShapedKernelLauncher
{
public:
void
KernelLauncherFunc
(
host_context
::
KernelFrame
*
frame
)
{
static
InferShapedKernelLauncher
launcher
(
FuncArgStatics
<
InferShapedFunc
>::
arg_size
);
static
const
uint16_t
num_input_tensors
{
InferShapeHelper
<
KernelFunc
>::
count
};
static
const
bool
turn_on_infer_shape_cache
{
true
};
void
Invoke
(
host_context
::
KernelFrame
*
frame
)
override
{
#ifndef NDEBUG
LOG
(
INFO
)
<<
"Kernel.frame: "
<<
frame
->
DumpArgTypes
();
LOG
(
INFO
)
<<
"Kernel.frame: "
<<
frame
->
DumpArgTypes
();
#endif
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if
(
infershape_kernel_frame_builder
.
IsEmpty
())
{
CreateKernelFrameForInferShape
(
frame
);
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if
(
launcher
.
infershape_kernel_frame_builder
.
IsEmpty
())
{
launcher
.
CreateKernelFrameForInferShape
(
frame
);
#ifndef NDEBUG
LOG
(
INFO
)
<<
"infershape.frame: "
<<
infershape_kernel_frame_builder
.
DumpArgTypes
();
LOG
(
INFO
)
<<
"infershape.frame: "
<<
launcher
.
infershape_kernel_frame_builder
.
DumpArgTypes
();
#endif
}
if
(
turn_on_infer_shape_cache
)
{
if
(
launcher
.
IsShapeChanged
(
num_input_tensors
))
{
::
infrt
::
host_context
::
KernelImpl
<
InferShapedFunc
,
infershape
>::
Invoke
(
&
launcher
.
infershape_kernel_frame_builder
);
launcher
.
BuildInferShapeCache
(
num_input_tensors
);
}
if
(
turn_on_infer_shape_cache
)
{
if
(
!
turn_on_infer_shape_cache
||
IsShapeChanged
(
num_input_tensors
))
{
::
infrt
::
host_context
::
KernelImpl
<
InferShapedFunc
,
infershape
>::
Invoke
(
&
infershape_kernel_frame_builder
);
BuildInferShapeCache
(
num_input_tensors
);
}
}
::
infrt
::
host_context
::
KernelImpl
<
KernelFunc
,
kernel
>::
Invoke
(
frame
);
}
};
template
<
typename
KernelFunc
,
KernelFunc
kernel
,
typename
InferShapedFunc
,
InferShapedFunc
infershape
>
void
KernelLauncherFunc
(
KernelLauncher
<
KernelFunc
,
kernel
,
InferShapedFunc
,
infershape
>
launcher
,
host_context
::
KernelFrame
*
frame
)
{
launcher
.
Invoke
(
frame
);
::
infrt
::
host_context
::
KernelImpl
<
KernelFunc
,
kernel
>::
Invoke
(
frame
);
}
}
// namespace kernel
...
...
paddle/infrt/tests/dialect/phi/phi_test.mlir
浏览文件 @
c751e405
// RUN: infrtexec -i %s
module {
func @predict(%arg0: !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW> {
func @predict(%arg0: !infrt.dense_tensor<CPU, FP32, NCHW>
, %arg1: !infrt.dense_tensor<CPU, FP32, NCHW>, %arg2: !infrt.dense_tensor<CPU, FP32, NCHW>, %arg3: !infrt.dense_tensor<CPU, FP32, NCHW>, %arg4: !infrt.dense_tensor<CPU, FP32, NCHW>
) -> !infrt.dense_tensor<CPU, FP32, NCHW> {
%2 = "pd.abs"(%arg0) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
infrt.return %2 : !infrt.dense_tensor<CPU, FP32, NCHW>
%3 = "pd.matmul_v2"(%arg0, %2) {trans_x = false, trans_y = false} : (!infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%Y, %MeanOut, %VarianceOut = "pd.batch_norm"(%3, %arg1, %arg2, %arg3, %arg4) {data_layout = "NCHW", epsilon = 9.99999974E-6 : f32, momentum = 0.899999976 : f32} : (!infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>)
infrt.return %Y : !infrt.dense_tensor<CPU, FP32, NCHW>
}
func @main() {
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%t = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[1:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%t = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[1:i64
, 3:i64, 8:i64, 8:i64
]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%2 = infrt.call@predict(%t) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%bias = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[3:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%bias) {value=[1.5:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%mean = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[3:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%mean) {value=[3.5:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%scale = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[3:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%scale) {value=[1.0:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%var = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[3:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%var) {value=[0.0:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%2 = infrt.call@predict(%t, %bias, %mean, %scale, %var) : (!infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>,!infrt.dense_tensor<CPU, FP32, NCHW>,!infrt.dense_tensor<CPU, FP32, NCHW>,!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
//phi_dt.print_tensor(%t : !infrt.dense_tensor<CPU, FP32, NCHW>)
phi_dt.print_tensor(%2 : !infrt.dense_tensor<CPU, FP32, NCHW>)
infrt.return
}
...
...
tools/infrt/generate_phi_kernel_dialect.py
浏览文件 @
c751e405
...
...
@@ -22,7 +22,9 @@ attr_type_converter = {
"i"
:
'SI32Attr'
,
"b"
:
'BoolAttr'
,
"l"
:
'SI64Attr'
,
"f"
:
'F32Attr'
"f"
:
'F32Attr'
,
"NSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE"
:
'StrAttr'
,
"St6vectorIiSaIiEE"
:
'I32ArrayAttr'
}
target_type_converter
=
{
"CPU"
:
"CPU"
,
"GPU"
:
"GPU"
}
...
...
tools/infrt/get_phi_kernel_function.sh
浏览文件 @
c751e405
...
...
@@ -38,35 +38,36 @@ python3 ${PADDLE_ROOT}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py \
--wrapped_infermeta_header_path
${
temp_path
}
/generate.h
\
--wrapped_infermeta_source_path
${
temp_path
}
/generate.cc
grep
PD_REGISTER_INFER_META_FN
${
temp_path
}
/generate.cc
\
find
${
PADDLE_ROOT
}
/paddle/phi/
-name
"*.cc"
| xargs
grep
PD_REGISTER_INFER_META_FN
${
temp_path
}
/generate.cc
\
|
awk
-F
"
\(
|,|::|
\)
"
'{print $2, $4}'
>
${
temp_path
}
/wrap_info.txt
#step 3:get ir's attr_name.
ir_attr_name_info_file
=
`
mktemp
`
# phi_cpu attr
all_ir_name
=
`
grep
-Eo
"PDTCPU_Kernel<.*
\"
"
paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td |
awk
-v
FS
=
"<"
'{gsub(/\"/,"");print $2}'
`
all_ir_name
=
`
grep
-Eo
"PDTCPU_Kernel<.*
\"
"
${
PADDLE_ROOT
}
/
paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td |
awk
-v
FS
=
"<"
'{gsub(/\"/,"");print $2}'
`
for
ir
in
$all_ir_name
do
attr_name
=
`
grep
"<
\"
$ir
"
-A
3 paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td |
grep
-Eo
"Attr:.*)"
\
attr_name
=
`
grep
"<
\"
$ir
"
-A
3
${
PADDLE_ROOT
}
/
paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td |
grep
-Eo
"Attr:.*)"
\
|
awk
'{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BoolAttr/,""); \
gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \
gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \
gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \
gsub(/I32ArrayAttr/,"");gsub(/SI32ArrayAttr/,""); \
gsub(/Attr/,"");gsub(/\)/,""); \
gsub(/[,:]/,"");print $a}'
`
echo
phi_cpu.
$ir
$attr_name
>>
$ir_attr_name_info_file
done
# phi_gpu attr
all_ir_name
=
`
grep
-Eo
"PDTGPU_Kernel<.*
\"
"
paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td |
awk
-v
FS
=
"<"
'{gsub(/\"/,"");print $2}'
`
all_ir_name
=
`
grep
-Eo
"PDTGPU_Kernel<.*
\"
"
${
PADDLE_ROOT
}
/
paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td |
awk
-v
FS
=
"<"
'{gsub(/\"/,"");print $2}'
`
for
ir
in
$all_ir_name
do
attr_name
=
`
grep
"<
\"
$ir
"
-A
3 paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td |
grep
-Eo
"Attr:.*)"
\
attr_name
=
`
grep
"<
\"
$ir
"
-A
3
${
PADDLE_ROOT
}
/
paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td |
grep
-Eo
"Attr:.*)"
\
|
awk
'{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BoolAttr/,""); \
gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \
gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \
gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \
gsub(/Attr/,"");gsub(/\)/,""); \
gsub(/I32ArrayAttr/,"");gsub(/SI32ArrayAttr/,""); \
gsub(/Attr/,"");gsub(/\)/,"") \
gsub(/[,:]/,"");print $a}'
`
echo
phi_gpu.
$ir
$attr_name
>>
$ir_attr_name_info_file
done
...
...
tools/infrt/get_phi_kernel_info.py
浏览文件 @
c751e405
...
...
@@ -91,11 +91,10 @@ def merge(infer_meta_data, kernel_data, wrap_data):
full_kernel_data
=
[]
for
l
in
kernel_data
:
key
=
l
.
split
()[
0
]
if
key
in
meta_map
:
if
key
in
meta_map
:
full_kernel_data
.
append
((
l
+
" "
+
wrap_map
[
key
]).
split
())
else
:
full_kernel_data
.
append
((
l
+
" "
+
meta_map
[
key
]).
split
())
if
key
in
wrap_map
:
full_kernel_data
.
append
((
l
+
" "
+
wrap_map
[
key
]).
split
())
elif
key
in
meta_map
:
full_kernel_data
.
append
((
l
+
" "
+
meta_map
[
key
]).
split
())
else
:
full_kernel_data
.
append
((
l
+
" unknown"
).
split
())
...
...
@@ -246,15 +245,10 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
registry->AddKernelWithAttrs("
{
ir_name
}
","""
res
+=
f
"""
std::bind(
&KernelLauncherFunc<decltype(
{
kernel_func
}
),
&KernelLauncherFunc<decltype(
{
kernel_func
}
),
{
kernel_func
}
,
decltype(
{
infer_shape_func
}
),
{
infer_shape_func
}
>,
KernelLauncher<decltype(
{
kernel_func
}
),
{
kernel_func
}
,
decltype(
{
infer_shape_func
}
),
{
infer_shape_func
}
>(),
std::placeholders::_1),
{{
{
attr_names
}
}});
"""
...
...
@@ -263,15 +257,10 @@ registry->AddKernelWithAttrs("{ir_name}","""
registry->AddKernel("
{
ir_name
}
","""
res
+=
f
"""
std::bind(&KernelLauncherFunc<decltype(
{
kernel_func
}
),
{
kernel_func
}
,
decltype(
{
infer_shape_func
}
),
{
infer_shape_func
}
>,
KernelLauncher<decltype(
{
kernel_func
}
),
&KernelLauncherFunc<decltype(
{
kernel_func
}
),
{
kernel_func
}
,
decltype(
{
infer_shape_func
}
),
{
infer_shape_func
}
>(),
std::placeholders::_1));
{
infer_shape_func
}
>);
"""
return
res
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录