Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
767647ce
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
767647ce
编写于
3月 09, 2022
作者:
H
huzhiqiang
提交者:
GitHub
3月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Infrt]Update kernel dialect (#40141)
上级
aeaf69b3
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
174 addition
and
76 deletion
+174
-76
.gitignore
.gitignore
+1
-0
paddle/fluid/pybind/kernel_signature_generator.cc
paddle/fluid/pybind/kernel_signature_generator.cc
+22
-16
paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.cc
paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.cc
+2
-4
paddle/infrt/host_context/paddle_mlir.cc
paddle/infrt/host_context/paddle_mlir.cc
+10
-7
paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc
...infrt/kernel/phi/infershaped/infershape_launchers_test.cc
+1
-1
paddle/infrt/tests/dialect/phi/dense_tensor.mlir
paddle/infrt/tests/dialect/phi/dense_tensor.mlir
+1
-1
paddle/scripts/infrt_build.sh
paddle/scripts/infrt_build.sh
+3
-2
tools/infrt/generate_phi_kernel_dialect.py
tools/infrt/generate_phi_kernel_dialect.py
+55
-43
tools/infrt/get_compat_kernel_signature.py
tools/infrt/get_compat_kernel_signature.py
+77
-0
tools/infrt/get_phi_kernel_info.py
tools/infrt/get_phi_kernel_info.py
+2
-2
未找到文件。
.gitignore
浏览文件 @
767647ce
...
...
@@ -56,6 +56,7 @@ paddle/infrt/dialect/pd_ops.td
paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td
paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td
tools/infrt/kernels.json
tools/infrt/kernel_signature.json
paddle/infrt/dialect/pd_ops_info.h
.lit_test_times.txt
paddle/infrt/tests/dialect/Output
...
...
paddle/fluid/pybind/kernel_signature_generator.cc
浏览文件 @
767647ce
...
...
@@ -44,35 +44,41 @@ int main(int argc, char **argv) {
paddle
::
framework
::
InitDefaultKernelSignatureMap
();
auto
&
kernel_signature_map
=
phi
::
DefaultKernelSignatureMap
::
Instance
();
auto
&
kernel_factory
=
phi
::
KernelFactory
::
Instance
();
std
::
cout
<<
"{"
;
std
::
string
kernel_signature_map_str
{
"{"
}
;
for
(
const
auto
&
op_kernel_pair
:
kernel_factory
.
kernels
())
{
if
(
kernel_signature_map
.
Has
(
op_kernel_pair
.
first
))
{
std
::
cout
<<
"
\"
"
<<
op_kernel_pair
.
first
<<
"
\"
:{"
;
kernel_signature_map_str
=
kernel_signature_map_str
+
"
\"
"
+
op_kernel_pair
.
first
+
"
\"
:{"
;
auto
&
args
=
kernel_signature_map
.
Get
(
op_kernel_pair
.
first
).
args
;
std
::
cout
<<
"
\"
inputs
\"
:["
;
kernel_signature_map_str
+=
"
\"
inputs
\"
:["
;
auto
inputs_
=
std
::
get
<
0
>
(
args
);
if
(
inputs_
.
size
()
>
0
)
std
::
cout
<<
inputs_
[
0
];
for
(
size_t
i
=
1
;
i
<
inputs_
.
size
();
i
++
)
{
std
::
cout
<<
",
\"
"
<<
inputs_
[
i
]
<<
"
\"
"
;
for
(
size_t
i
=
0
;
i
<
inputs_
.
size
();
i
++
)
{
kernel_signature_map_str
=
kernel_signature_map_str
+
"
\"
"
+
inputs_
[
i
]
+
"
\"
,
"
;
}
if
(
inputs_
.
size
())
kernel_signature_map_str
.
pop_back
();
std
::
cout
<<
"],
\"
attrs
\"
:["
;
kernel_signature_map_str
+=
"],
\"
attrs
\"
:["
;
auto
attrs_
=
std
::
get
<
1
>
(
args
);
if
(
attrs_
.
size
()
>
0
)
std
::
cout
<<
attrs_
[
0
];
for
(
size_t
i
=
1
;
i
<
attrs_
.
size
();
i
++
)
{
std
::
cout
<<
",
\"
"
<<
attrs_
[
i
]
<<
"
\"
"
;
for
(
size_t
i
=
0
;
i
<
attrs_
.
size
();
i
++
)
{
kernel_signature_map_str
=
kernel_signature_map_str
+
"
\"
"
+
attrs_
[
i
]
+
"
\"
,
"
;
}
std
::
cout
<<
"],
\"
outputs
\"
:["
;
if
(
attrs_
.
size
())
kernel_signature_map_str
.
pop_back
();
kernel_signature_map_str
+=
"],
\"
outputs
\"
:["
;
auto
outputs_
=
std
::
get
<
2
>
(
args
);
for
(
size_t
i
=
1
;
i
<
outputs_
.
size
();
i
++
)
{
std
::
cout
<<
",
\"
"
<<
outputs_
[
i
]
<<
"
\"
"
;
for
(
size_t
i
=
0
;
i
<
outputs_
.
size
();
i
++
)
{
kernel_signature_map_str
=
kernel_signature_map_str
+
"
\"
"
+
outputs_
[
i
]
+
"
\"
,"
;
}
std
::
cout
<<
"]},"
;
if
(
outputs_
.
size
())
kernel_signature_map_str
.
pop_back
();
kernel_signature_map_str
+=
"]},"
;
}
}
std
::
cout
<<
"}"
<<
std
::
endl
;
kernel_signature_map_str
.
pop_back
();
kernel_signature_map_str
+=
"}
\n
"
;
std
::
cout
<<
kernel_signature_map_str
;
return
0
;
}
paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.cc
浏览文件 @
767647ce
...
...
@@ -125,10 +125,8 @@ void phiOpCvtPass::diapatchStage() {
kernel_name
=
getPhiTargetPrefix
(
phi_kernel_desc
.
kernelType
.
target
)
+
kernel_name
+
getPhiLayoutSuffix
(
phi_kernel_desc
.
kernelType
.
layout
)
+
getPhiPrecisionSuffix
(
phi_kernel_desc
.
kernelType
.
precision
);
// mlir::OperationName operation_name = kernel_op.getOperation()->getName();
getPhiPrecisionSuffix
(
phi_kernel_desc
.
kernelType
.
precision
)
+
getPhiLayoutSuffix
(
phi_kernel_desc
.
kernelType
.
layout
);
mlir
::
OperationName
operation_name
(
kernel_name
,
kernel_op
.
getContext
());
mlir
::
OperationState
operation_state
(
kernel_op
.
getLoc
(),
operation_name
);
...
...
paddle/infrt/host_context/paddle_mlir.cc
浏览文件 @
767647ce
...
...
@@ -56,6 +56,7 @@ mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel(
UpdateModelParams
(
program
,
&
mainFunc
);
UpdateModelOps
(
program
);
UpdateModelOutputs
(
program
);
return
module_
;
}
...
...
@@ -143,13 +144,14 @@ void MLIRModelGenImpl::UpdateModelParams(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
,
mlir
::
FuncOp
*
mainFunc
)
{
// update input vars
int
input_index
=
1
;
for
(
auto
&
op_desc
:
main_block_
.
ops
())
{
if
(
op_desc
.
type
()
==
"feed"
)
{
for
(
int
var_idx
=
0
;
var_idx
<
op_desc
.
outputs_size
();
++
var_idx
)
{
// update input variables
auto
&
in
=
op_desc
.
outputs
()[
var_idx
];
std
::
string
input_var_name
=
in
.
arguments
(
0
);
::
mlir
::
Value
input_
=
mainFunc
->
getArgument
(
1
);
::
mlir
::
Value
input_
=
mainFunc
->
getArgument
(
input_index
++
);
params_map_
.
insert
(
std
::
pair
<
std
::
string
,
mlir
::
Value
>
(
input_var_name
,
input_
));
}
...
...
@@ -211,7 +213,6 @@ void MLIRModelGenImpl::buildOperation(
const
infrt
::
paddle
::
framework_proto
::
OpDesc
&
op_
)
{
const
std
::
string
&
op_name
=
"pd."
+
op_
.
type
();
mlir
::
Location
loc
=
mlir
::
UnknownLoc
::
get
(
context_
);
llvm
::
SmallVector
<
mlir
::
Value
,
4
>
operands
=
GetOpInputValue
(
op_
);
llvm
::
SmallVector
<
mlir
::
Type
,
4
>
resultTypes
=
GetOpOutputType
(
op_
);
llvm
::
SmallVector
<
mlir
::
NamedAttribute
,
4
>
attrs
=
GetOpAttributes
(
op_
);
...
...
@@ -227,7 +228,6 @@ llvm::SmallVector<mlir::Value, 4> MLIRModelGenImpl::GetOpInputValue(
std
::
unordered_map
<
std
::
string
,
uint8_t
>
inputs_info
=
{};
if
(
pd_dialect_inputs_info_map_
.
count
(
op_
.
type
()))
inputs_info
=
pd_dialect_inputs_info_map_
.
at
(
op_
.
type
());
for
(
int
var_idx
=
0
;
var_idx
<
op_
.
inputs_size
();
++
var_idx
)
{
auto
&
var
=
op_
.
inputs
(
var_idx
);
if
(
!
var
.
arguments
().
empty
())
{
...
...
@@ -249,10 +249,8 @@ llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetOpOutputType(
// update op outputs info
for
(
int
var_idx
=
0
;
var_idx
<
op_
.
outputs_size
();
++
var_idx
)
{
auto
&
var_name
=
op_
.
outputs
(
var_idx
).
arguments
()[
0
];
if
(
!
pd_dialect_outputs_info
.
count
(
op_
.
outputs
(
var_idx
).
parameter
()))
continue
;
// update persistable tensors
for
(
int
i
=
0
;
i
<
main_block_
.
vars_size
();
i
++
)
{
auto
var_desc
=
main_block_
.
vars
(
i
);
...
...
@@ -315,7 +313,6 @@ llvm::SmallVector<mlir::NamedAttribute, 4> MLIRModelGenImpl::GetOpAttributes(
llvm
::
ArrayRef
<
mlir
::
StringAttr
>
attr_names_
=
registered_op_name_
.
getAttributeNames
();
std
::
vector
<
mlir
::
StringAttr
>
attr_names_vec_
=
attr_names_
.
vec
();
// update attrs
for
(
int
attrs_num
=
0
;
attrs_num
<
op_
.
attrs_size
();
attrs_num
++
)
{
auto
attr_name_
=
op_
.
attrs
(
attrs_num
).
name
();
...
...
@@ -351,11 +348,17 @@ llvm::SmallVector<mlir::NamedAttribute, 4> MLIRModelGenImpl::GetOpAttributes(
void
MLIRModelGenImpl
::
RegisterOpOutputVars
(
const
infrt
::
paddle
::
framework_proto
::
OpDesc
&
op_
,
mlir
::
Operation
*
mlir_op_
)
{
std
::
unordered_map
<
std
::
string
,
uint8_t
>
pd_dialect_outputs_info
=
pd_dialect_outputs_info_map_
.
at
(
op_
.
type
());
// op outputs
for
(
int
var_idx
=
0
;
var_idx
<
op_
.
outputs_size
();
++
var_idx
)
{
if
(
!
pd_dialect_outputs_info
.
count
(
op_
.
outputs
(
var_idx
).
parameter
()))
continue
;
auto
&
var_name
=
op_
.
outputs
(
var_idx
).
arguments
()[
0
];
int
index
=
pd_dialect_outputs_info
[
op_
.
outputs
(
var_idx
).
parameter
()];
// output name
auto
var_
=
mlir_op_
->
getResult
(
var_id
x
);
auto
var_
=
mlir_op_
->
getResult
(
inde
x
);
params_map_
.
insert
(
std
::
pair
<
std
::
string
,
mlir
::
Value
>
(
var_name
,
var_
));
}
}
...
...
paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc
浏览文件 @
767647ce
...
...
@@ -54,7 +54,7 @@ TEST(ElementwiseAdd, launcher_registry) {
host_context
::
KernelRegistry
registry
;
RegisterInferShapeLaunchers
(
&
registry
);
ASSERT_GE
(
registry
.
size
(),
1UL
);
auto
creator
=
registry
.
GetKernel
(
"phi_cpu.add.
any.float32
"
);
auto
creator
=
registry
.
GetKernel
(
"phi_cpu.add.
float32.any
"
);
const
phi
::
DDim
dims
({
1
,
2
});
const
phi
::
DataType
dtype
{
phi
::
DataType
::
FLOAT32
};
...
...
paddle/infrt/tests/dialect/phi/dense_tensor.mlir
浏览文件 @
767647ce
...
...
@@ -6,7 +6,7 @@ func @sign_any_float32_execute() {
%ctx = "phi_dt.create_context.cpu" (%allocator): (!phi.allocator<CPU>) -> !phi.context<CPU>
%t = "phi_dt.create_dense_tensor.cpu.f32.nchw" (%allocator) {dims=[1:i64], lod=[1:i64]}: (!phi.allocator<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%e = "phi_cpu.sign.
any.float32
"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%e = "phi_cpu.sign.
float32.any
"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
// CHECK: dense_tensor: shape=shape[1], values=[1]
"phi_dt.print_tensor" (%e) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
...
...
paddle/scripts/infrt_build.sh
浏览文件 @
767647ce
...
...
@@ -33,16 +33,17 @@ function update_pd_ops() {
rm
-rf
${
PADDLE_ROOT
}
/build
&&
mkdir
-p
${
PADDLE_ROOT
}
/build
cd
${
PADDLE_ROOT
}
/build
cmake ..
-DWITH_PYTHON
=
ON
-DWITH_GPU
=
OFF
-DPYTHON_EXECUTABLE
=
`
which python3
`
-DWITH_XBYAK
=
OFF
-DWITH_NCCL
=
OFF
-DWITH_RCCL
=
OFF
-DWITH_CRYPTO
=
OFF
make
-j8
paddle_python print_pten_kernels
make
-j8
paddle_python print_pten_kernels
kernel_signature_generator
cd
${
PADDLE_ROOT
}
/build
./paddle/phi/tools/print_pten_kernels
>
../tools/infrt/kernels.json
./paddle/fluid/pybind/kernel_signature_generator
>
../tools/infrt/kernel_signature.json
cd
python/dist/
python3
-m
pip uninstall
-y
paddlepaddle
python3
-m
pip
install
*
whl
# update pd_ops.td
cd
${
PADDLE_ROOT
}
/tools/infrt/
python3 generate_pd_op_dialect_from_paddle_op_maker.py
python3 generate_phi_kernel_dialect.py
./kernels.json
python3 generate_phi_kernel_dialect.py
}
function
init
()
{
...
...
tools/infrt/generate_phi_kernel_dialect.py
浏览文件 @
767647ce
...
...
@@ -14,9 +14,16 @@
import
json
import
sys
attr_type_converter
=
{
"i"
:
'SI32Attr'
,
"b"
:
'BoolAttr'
,
"l"
:
'SI64Attr'
}
supported_kernels
=
[
'sign'
,
'dot'
,
'digamma'
,
'conj'
,
'abs'
,
'add_raw'
]
import
os
from
get_compat_kernel_signature
import
get_compat_kernels_info
#TODO @DannyIsFunny: more attr types need to be supported.
attr_type_converter
=
{
"i"
:
'SI32Attr'
,
"b"
:
'BoolAttr'
,
"l"
:
'SI64Attr'
,
"f"
:
'F32Attr'
}
target_type_converter
=
{
"CPU"
:
"CPU"
,
"GPU"
:
"GPU"
}
layout_type_converter
=
{
...
...
@@ -39,40 +46,34 @@ precision_type_converter = {
"bool"
:
"BOOL"
}
kernel_types_info_file
=
"./kernels.json"
kernel_signature_info_file
=
"./kernel_signature.json"
def
generate_kernel_name
(
op_name
,
place_str
):
[
target_
,
layout_
,
precision_
]
=
place_str
[
1
:
-
1
].
split
(
','
)
target_
=
target_type_converter
[
target_
.
strip
()]
layout_
=
layout_type_converter
[
layout_
.
strip
()]
precision_
=
precision_type_converter
[
precision_
.
strip
()]
class_name_
=
"{}{}"
.
format
(
op_name
.
replace
(
"_"
,
""
).
title
(),
""
.
join
([
target_
.
strip
().
title
(),
precision_
.
strip
(),
layout_
.
strip
().
title
()
.
title
()
]))
alias_
=
"{}.{}"
.
format
(
op_name
,
"."
.
join
(
[
target_
.
strip
(),
layout_
.
strip
(),
precision
_
.
strip
()]))
return
alias_
[
target_
.
strip
(),
precision_
.
strip
(),
layout
_
.
strip
()]))
return
alias_
,
class_name_
def
generate_attrs_info
(
op_name
,
attrs_info
):
kernel_attrs_names
=
{
'split'
:
[
'sections'
,
'num'
,
'axis'
,
'mkldnn_data_type'
],
'sign'
:
[],
'masked_select'
:
[],
'trace'
:
[
'offset'
,
'axis1'
,
'axis2'
],
'concat'
:
[
'axis'
],
'empty'
:
[
'shape'
,
'dtype'
],
'conj'
:
[],
'norm'
:
[
'axis'
,
'epsilon'
,
'is_test'
],
'histogram'
:
[
'bins'
,
'min'
,
'max'
],
'dot'
:
[],
'scale'
:
[
'scale'
,
'bias'
,
'bias_after_scale'
],
'digamma'
:
[],
'lerp'
:
[],
'cast'
:
[
'out_dtype'
,
'in_dtype'
],
'abs'
:
[],
'add_raw'
:
[
'axis'
],
}
kernel_attrs_names
=
{}
attrs_args_
=
""
if
len
(
kernel_attrs_names
[
op_name
])
==
len
(
attrs_info
):
with
open
(
kernel_signature_info_file
)
as
f
:
kernel_attrs_names
=
json
.
load
(
f
)
kernel_attrs_names
.
update
(
get_compat_kernels_info
())
if
len
(
kernel_attrs_names
[
op_name
][
"attrs"
])
==
len
(
attrs_info
):
for
index
in
range
(
len
(
attrs_info
)):
attr_name
=
kernel_attrs_names
[
op_name
][
index
]
attr_name
=
kernel_attrs_names
[
op_name
][
"attrs"
][
index
]
attr_type
=
attr_type_converter
[
attrs_info
[
index
]]
attrs_args_
+=
'{type_}:${name_},'
.
format
(
type_
=
attr_type
,
name_
=
attr_name
)
...
...
@@ -97,7 +98,11 @@ def generate_arguments_info(op_name, input_info, attr_info):
input_args
=
generate_inputs_info
(
input_info
)
attr_args
=
generate_attrs_info
(
op_name
,
attr_info
)
context_args
=
"Context:$dev_ctx"
argument_
=
"{},{},{}"
.
format
(
context_args
,
input_args
,
attr_args
)
argument_list
=
[
context_args
]
+
input_args
.
split
(
","
)
+
attr_args
.
split
(
","
)
while
(
""
in
argument_list
):
argument_list
.
remove
(
""
)
argument_
=
","
.
join
(
argument_list
)
return
((
"let arguments = (ins {});"
.
format
(
argument_
.
strip
(
","
))))
...
...
@@ -116,6 +121,10 @@ def generate_results_info(output_info):
def
generate_supported_kernel_list
(
load_dict
):
supported_kernels_list_
=
[]
kernel_attrs_names
=
{}
with
open
(
kernel_signature_info_file
)
as
f
:
kernel_attrs_names
=
json
.
load
(
f
)
kernel_attrs_names
.
update
(
get_compat_kernels_info
())
for
op_name
in
load_dict
:
kernel_list
=
load_dict
[
op_name
]
for
kernel_info
in
kernel_list
:
...
...
@@ -125,13 +134,10 @@ def generate_supported_kernel_list(load_dict):
for
attribute
in
attributes
:
if
attribute
not
in
attr_type_converter
:
flag
=
False
if
flag
:
if
flag
and
op_name
in
kernel_attrs_names
:
supported_kernels_list_
.
append
(
op_name
)
alias_
=
generate_kernel_dialect
(
op_name
,
kernel_alias_
,
kernel_info
[
kernel_alias_
])
supported_kernels_list_
=
list
(
set
(
supported_kernels_list_
))
print
(
supported_kernels_list_
)
return
supported_kernels_list_
def
scan_kernel_info
(
load_dict
):
...
...
@@ -156,16 +162,14 @@ def scan_kernel_info(load_dict):
def
generate_cpu_kernel_dialect
(
op_name
,
kernel_alias_
,
kernel_info
):
alias
=
generate_kernel_name
(
op_name
,
kernel_alias_
)
alias
,
class_name
=
generate_kernel_name
(
op_name
,
kernel_alias_
)
summary
=
'let summary = "{name}";'
.
format
(
name
=
alias
)
dialect_name
=
alias
.
split
(
"."
)
dialect_name
=
dialect_name
[
0
]
+
"."
+
dialect_name
[
2
]
+
"."
+
dialect_name
[
3
]
header
=
'def {kernel_name} : PDTCPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'
.
format
(
kernel_name
=
alias
.
replace
(
"."
,
""
),
name
=
dialect_name
.
lower
(),
left_brace
=
"{"
)
kernel_name
=
class_name
,
name
=
dialect_name
.
lower
(),
left_brace
=
"{"
)
inputs_
=
kernel_info
[
"input"
]
attributes
=
kernel_info
[
"attribute"
]
...
...
@@ -185,16 +189,14 @@ def generate_cpu_kernel_dialect(op_name, kernel_alias_, kernel_info):
def
generate_gpu_kernel_dialect
(
op_name
,
kernel_alias_
,
kernel_info
):
alias
=
generate_kernel_name
(
op_name
,
kernel_alias_
)
alias
,
class_name
=
generate_kernel_name
(
op_name
,
kernel_alias_
)
summary
=
'let summary = "{name}";'
.
format
(
name
=
alias
)
dialect_name
=
alias
.
split
(
"."
)
dialect_name
=
dialect_name
[
0
]
+
"."
+
dialect_name
[
2
]
+
"."
+
dialect_name
[
3
]
header
=
'def {kernel_name} : PDTGPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'
.
format
(
kernel_name
=
alias
.
replace
(
"."
,
""
),
name
=
dialect_name
.
lower
(),
left_brace
=
"{"
)
kernel_name
=
class_name
,
name
=
dialect_name
.
lower
(),
left_brace
=
"{"
)
inputs_
=
kernel_info
[
"input"
]
attributes
=
kernel_info
[
"attribute"
]
arguments
=
generate_arguments_info
(
op_name
,
inputs_
,
attributes
)
...
...
@@ -236,14 +238,17 @@ def get_kernel_target(kernel_alias_):
return
target
[
0
]
def
main
(
path_
):
with
open
(
path_
,
"r"
)
as
f
:
def
main
():
with
open
(
kernel_types_info_file
,
"r"
)
as
f
:
load_dict
=
json
.
load
(
f
)
head
=
generate_dialect_head
()
cpu_registry_
=
""
gpu_registry_
=
""
supported_kernels
=
generate_supported_kernel_list
(
load_dict
)
print
(
"Supported kernels:"
)
print
(
supported_kernels
)
for
op_name
in
load_dict
:
if
op_name
not
in
supported_kernels
:
continue
...
...
@@ -273,5 +278,12 @@ def main(path_):
if
__name__
==
'__main__'
:
path
=
sys
.
argv
[
1
]
main
(
path
)
if
not
os
.
path
.
exists
(
kernel_types_info_file
):
print
(
"Error: '{file_name}' not exist!"
.
format
(
file_name
=
kernel_types_info_file
))
if
not
os
.
path
.
exists
(
kernel_signature_info_file
):
print
(
"Error: '{file_name}' not exist!"
.
format
(
file_name
=
kernel_signature_info_file
))
if
os
.
path
.
exists
(
kernel_types_info_file
)
and
os
.
path
.
exists
(
kernel_signature_info_file
):
main
()
tools/infrt/get_compat_kernel_signature.py
0 → 100644
浏览文件 @
767647ce
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
import
json
def
parse_compat_registry
(
kernel_info
):
name
,
inputs_str
,
attrs_str
,
outputs_str
=
kernel_info
.
split
(
",{"
)
kernel_info
=
{}
kernel_info
[
"inputs"
]
=
inputs_str
[:
-
1
].
split
(
","
)
kernel_info
[
"attrs"
]
=
attrs_str
[:
-
1
].
split
(
","
)
kernel_info
[
"outputs"
]
=
outputs_str
[:
-
1
].
split
(
","
)
return
name
,
kernel_info
def
remove_grad_registry
(
kernels_registry
):
clean_kernel_registry
=
{}
for
registry
in
kernels_registry
:
if
(
not
"_grad"
in
registry
):
clean_kernel_registry
[
registry
]
=
kernels_registry
[
registry
]
return
clean_kernel_registry
def
get_compat_kernels_info
():
kernels_info
=
{}
compat_files
=
os
.
listdir
(
"../../paddle/phi/ops/compat"
)
for
file_
in
compat_files
:
if
not
".cc"
in
file_
:
compat_files
.
remove
(
file_
)
for
file_
in
compat_files
:
with
open
(
"../../paddle/phi/ops/compat/"
+
file_
)
as
in_file
:
txt
=
in_file
.
readlines
()
content
=
""
registry
=
False
for
line
in
txt
:
if
(
"KernelSignature("
in
line
):
content
=
""
registry
=
True
if
(
registry
):
content
+=
line
if
(
registry
and
";"
in
line
):
data
=
content
.
replace
(
"
\n
"
,
""
).
replace
(
" "
,
""
).
strip
(
"return"
).
strip
(
"KernelSignature("
).
strip
(
"\);"
).
replace
(
"
\"
"
,
""
)
registry
=
False
name
,
registry_info
=
parse_compat_registry
(
data
)
if
name
in
kernels_info
:
cur_reg
=
kernels_info
[
name
]
kernels_info
[
name
][
"inputs"
]
=
list
(
set
(
registry_info
[
"inputs"
]
+
kernels_info
[
name
][
"inputs"
]))
kernels_info
[
name
][
"attrs"
]
=
list
(
set
(
registry_info
[
"attrs"
]
+
kernels_info
[
name
][
"attrs"
]))
kernels_info
[
name
][
"outputs"
]
=
list
(
set
(
registry_info
[
"outputs"
]
+
kernels_info
[
name
][
"outputs"
]))
else
:
kernels_info
[
name
]
=
registry_info
compat_registry_
=
remove_grad_registry
(
kernels_info
)
return
compat_registry_
tools/infrt/get_phi_kernel_info.py
浏览文件 @
767647ce
...
...
@@ -219,8 +219,8 @@ def gen_register_info(resources: List[List[str]]):
for
ir_dtype
,
origin_dtype
in
zip
(
ir_dtypes
,
origin_dtypes
):
kernel_func
=
gen_kernel_func
(
update_item
[
3
],
ctx_name
,
origin_dtype
)
ir_name
=
'phi_cpu.'
+
update_item
[
0
].
lower
(
)
+
'.'
+
update_item
[
2
].
lower
()
+
'.'
+
ir_dtype
ir_name
=
'phi_cpu.'
+
update_item
[
0
].
lower
(
)
+
'.'
+
ir_dtype
+
'.'
+
update_item
[
2
].
lower
()
res
+=
f
"""
registry->AddKernel("
{
ir_name
}
","""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录