Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1840349a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1840349a
编写于
3月 30, 2022
作者:
H
huzhiqiang
提交者:
GitHub
3月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Infrt] add skip method for inferShape codegen (#41014)
上级
cc52501e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
56 addition
and
2 deletion
+56
-2
tools/infrt/generate_phi_kernel_dialect.py
tools/infrt/generate_phi_kernel_dialect.py
+28
-0
tools/infrt/get_phi_kernel_info.py
tools/infrt/get_phi_kernel_info.py
+24
-2
tools/infrt/skipped_phi_api.json
tools/infrt/skipped_phi_api.json
+4
-0
未找到文件。
tools/infrt/generate_phi_kernel_dialect.py
浏览文件 @
1840349a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
json
import
json
import
yaml
import
sys
import
sys
import
os
import
os
from
get_compat_kernel_signature
import
get_compat_kernels_info
from
get_compat_kernel_signature
import
get_compat_kernels_info
...
@@ -52,6 +53,28 @@ precision_type_converter = {
...
@@ -52,6 +53,28 @@ precision_type_converter = {
kernel_types_info_file
=
"./kernels.json"
kernel_types_info_file
=
"./kernels.json"
kernel_signature_info_file
=
"./kernel_signature.json"
kernel_signature_info_file
=
"./kernel_signature.json"
skipped_phi_api_list_file
=
"./skipped_phi_api.json"
def
get_skipped_kernel_list
():
skiped_kernel_list
=
[]
with
open
(
skipped_phi_api_list_file
,
'r'
)
as
f
:
skiped_api_list
=
json
.
load
(
f
)
infer_meta_data
=
get_api_yaml_info
(
"../../"
)
for
api
in
infer_meta_data
:
if
"kernel"
not
in
api
or
"infer_meta"
not
in
api
:
continue
if
api
[
"api"
]
in
skiped_api_list
[
"phi_apis"
]:
skiped_kernel_list
.
append
(
api
[
"kernel"
][
"func"
])
skiped_kernel_list
+=
skiped_api_list
[
"phi_kernels"
]
return
skiped_kernel_list
def
get_api_yaml_info
(
file_path
):
f
=
open
(
file_path
+
"/python/paddle/utils/code_gen/api.yaml"
,
"r"
)
cont
=
f
.
read
()
return
yaml
.
load
(
cont
,
Loader
=
yaml
.
FullLoader
)
def
generate_kernel_name
(
op_name
,
place_str
):
def
generate_kernel_name
(
op_name
,
place_str
):
[
target_
,
layout_
,
precision_
]
=
place_str
[
1
:
-
1
].
split
(
','
)
[
target_
,
layout_
,
precision_
]
=
place_str
[
1
:
-
1
].
split
(
','
)
...
@@ -140,6 +163,10 @@ def generate_supported_kernel_list(load_dict):
...
@@ -140,6 +163,10 @@ def generate_supported_kernel_list(load_dict):
if
flag
and
op_name
in
kernel_attrs_names
:
if
flag
and
op_name
in
kernel_attrs_names
:
supported_kernels_list_
.
append
(
op_name
)
supported_kernels_list_
.
append
(
op_name
)
supported_kernels_list_
=
list
(
set
(
supported_kernels_list_
))
supported_kernels_list_
=
list
(
set
(
supported_kernels_list_
))
skipped_kernel_list
=
get_skipped_kernel_list
()
for
skipped_kernel
in
skipped_kernel_list
:
if
skipped_kernel
in
skipped_kernel_list
:
supported_kernels_list_
.
remove
(
skipped_kernel
)
return
supported_kernels_list_
return
supported_kernels_list_
...
@@ -250,6 +277,7 @@ def main():
...
@@ -250,6 +277,7 @@ def main():
cpu_registry_
=
""
cpu_registry_
=
""
gpu_registry_
=
""
gpu_registry_
=
""
supported_kernels
=
generate_supported_kernel_list
(
load_dict
)
supported_kernels
=
generate_supported_kernel_list
(
load_dict
)
print
(
"Supported kernels:"
)
print
(
"Supported kernels:"
)
print
(
supported_kernels
)
print
(
supported_kernels
)
for
op_name
in
load_dict
:
for
op_name
in
load_dict
:
...
...
tools/infrt/get_phi_kernel_info.py
浏览文件 @
1840349a
...
@@ -19,6 +19,23 @@ import json
...
@@ -19,6 +19,23 @@ import json
import
yaml
import
yaml
from
typing
import
List
,
Dict
,
Any
from
typing
import
List
,
Dict
,
Any
skipped_phi_api_list_file
=
"/tools/infrt/skipped_phi_api.json"
api_yaml_file
=
"/python/paddle/utils/code_gen/api.yaml"
def
get_skipped_kernel_list
():
skiped_kernel_list
=
[]
with
open
(
skipped_phi_api_list_file
,
'r'
)
as
f
:
skiped_api_list
=
json
.
load
(
f
)
infer_meta_data
=
get_api_yaml_info
(
api_yaml_file
)
for
api
in
infer_meta_data
:
if
"kernel"
not
in
api
or
"infer_meta"
not
in
api
:
continue
if
api
[
"api"
]
in
skiped_api_list
[
"phi_apis"
]:
skiped_kernel_list
.
append
(
api
[
"kernel"
][
"func"
])
skiped_kernel_list
+=
skiped_api_list
[
"phi_kernels"
]
return
skiped_kernel_list
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"gather phi kernel and infermate info"
)
parser
=
argparse
.
ArgumentParser
(
"gather phi kernel and infermate info"
)
...
@@ -50,7 +67,7 @@ def parse_args():
...
@@ -50,7 +67,7 @@ def parse_args():
def
get_api_yaml_info
(
file_path
):
def
get_api_yaml_info
(
file_path
):
f
=
open
(
file_path
+
"/python/paddle/utils/code_gen/api.yaml"
,
"r"
)
f
=
open
(
file_path
,
"r"
)
cont
=
f
.
read
()
cont
=
f
.
read
()
return
yaml
.
load
(
cont
,
Loader
=
yaml
.
FullLoader
)
return
yaml
.
load
(
cont
,
Loader
=
yaml
.
FullLoader
)
...
@@ -259,8 +276,11 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
...
@@ -259,8 +276,11 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
# TODO(wilber): handle the unknown inferShape func.
# TODO(wilber): handle the unknown inferShape func.
return
""
return
""
skipped_kernel_list
=
get_skipped_kernel_list
()
for
ir_dtype
,
origin_dtype
in
zip
(
ir_dtypes
,
origin_dtypes
):
for
ir_dtype
,
origin_dtype
in
zip
(
ir_dtypes
,
origin_dtypes
):
kernel_func
=
gen_kernel_func
(
item
[
3
],
ctx_name
,
origin_dtype
)
kernel_func
=
gen_kernel_func
(
item
[
3
],
ctx_name
,
origin_dtype
)
if
item
[
0
].
lower
()
in
skipped_kernel_list
:
continue
ir_name
=
ir_ctx_name
+
'.'
+
item
[
0
].
lower
(
ir_name
=
ir_ctx_name
+
'.'
+
item
[
0
].
lower
(
)
+
'.'
+
ir_dtype
+
'.'
+
item
[
2
].
lower
()
)
+
'.'
+
ir_dtype
+
'.'
+
item
[
2
].
lower
()
if
ir_name
in
attr_data
.
keys
()
and
attr_data
[
ir_name
]
is
not
None
:
if
ir_name
in
attr_data
.
keys
()
and
attr_data
[
ir_name
]
is
not
None
:
...
@@ -342,7 +362,9 @@ def gen_phi_kernel_register_code(resources: List[List[str]],
...
@@ -342,7 +362,9 @@ def gen_phi_kernel_register_code(resources: List[List[str]],
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
args
=
parse_args
()
infer_meta_data
=
get_api_yaml_info
(
args
.
paddle_root_path
)
skipped_phi_api_list_file
=
args
.
paddle_root_path
+
skipped_phi_api_list_file
api_yaml_file
=
args
.
paddle_root_path
+
api_yaml_file
infer_meta_data
=
get_api_yaml_info
(
api_yaml_file
)
kernel_data
=
get_kernel_info
(
args
.
kernel_info_file
)
kernel_data
=
get_kernel_info
(
args
.
kernel_info_file
)
info_meta_wrap_data
=
get_infermeta_info
(
args
.
infermeta_wrap_file
)
info_meta_wrap_data
=
get_infermeta_info
(
args
.
infermeta_wrap_file
)
attr_data
=
get_attr_info
(
args
.
attr_info_file
)
attr_data
=
get_attr_info
(
args
.
attr_info_file
)
...
...
tools/infrt/skipped_phi_api.json
0 → 100644
浏览文件 @
1840349a
{
"phi_apis"
:[
"conj"
],
"phi_kernels"
:[
"equal_all"
]
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录