Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a2957388
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a2957388
编写于
5月 25, 2020
作者:
W
wenchunjiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add param_type in json
上级
57c1da12
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
67 addition
and
27 deletion
+67
-27
mindspore/_extends/parallel_compile/tbe_compiler/common.py
mindspore/_extends/parallel_compile/tbe_compiler/common.py
+65
-27
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
+2
-0
未找到文件。
mindspore/_extends/parallel_compile/tbe_compiler/common.py
浏览文件 @
a2957388
...
...
@@ -15,6 +15,12 @@
"""tbe common"""
import
json
import
os
from
attrdict
import
AttrDict
class
ParamType
(
AttrDict
):
Required
=
"required"
Dynamic
=
"dynamic"
Optional
=
"optional"
class
TBEException
(
Exception
):
...
...
@@ -80,7 +86,62 @@ def _check_arg_info(item):
raise
ValueError
(
"Json string Errors, key:ori_format not found."
)
if
'dtype'
not
in
item
or
not
item
[
'dtype'
]:
raise
ValueError
(
"Json string Errors, key:dtype not found."
)
if
'param_type'
not
in
item
or
not
item
[
'param_type'
]:
raise
ValueError
(
"Json string Errors, key:param_type not found."
)
def
get_input_output
(
io_info
,
args
):
"""
Parse args.
Args:
io_info (dict): input or output info dict.
args (list): the arguments list.
Raises:
Exception: If specific keyword is not found.
"""
for
item
in
io_info
:
arg
=
[]
for
info
in
item
:
if
'valid'
not
in
info
:
raise
ValueError
(
"Json string Errors, key:valid not found."
)
if
info
[
'valid'
]:
_check_arg_info
(
info
)
del
info
[
'valid'
]
del
info
[
'name'
]
if
len
(
item
)
>
1
:
arg
.
append
(
info
)
else
:
if
info
[
'param_type'
]
==
ParamType
.
Dynamic
:
arg
.
append
(
info
)
args
.
append
(
arg
)
else
:
args
.
append
(
info
)
else
:
if
len
(
item
)
>
1
:
arg
.
append
(
None
)
else
:
args
.
append
(
None
)
if
len
(
item
)
>
1
:
args
.
append
(
arg
)
def
get_attr
(
attr_info
,
args
):
"""
Parse args.
Args:
attr_info (dict): input or output info dict.
args (list): the arguments list.
Raises:
Exception: If specific keyword is not found.
"""
for
item
in
attr_info
:
if
item
[
"valid"
]:
if
'value'
not
in
item
:
raise
ValueError
(
"Json string Errors, attr key:value not found."
)
if
item
[
"name"
]
!=
"isRef"
:
args
.
append
(
item
[
'value'
])
def
get_args
(
op_info
,
arg_type
):
"""
...
...
@@ -98,35 +159,12 @@ def get_args(op_info, arg_type):
args
=
[]
if
not
op_info
[
arg_type
]:
return
args
if
arg_type
in
[
'inputs'
,
'outputs'
]:
for
item
in
op_info
[
arg_type
]:
arg
=
[]
for
info
in
item
:
if
'valid'
not
in
info
:
raise
ValueError
(
"Json string Errors, key:valid not found."
)
if
info
[
'valid'
]:
_check_arg_info
(
info
)
del
info
[
'valid'
]
del
info
[
'name'
]
if
len
(
item
)
>
1
:
arg
.
append
(
info
)
else
:
args
.
append
(
info
)
else
:
if
len
(
item
)
>
1
:
arg
.
append
(
None
)
else
:
args
.
append
(
None
)
if
len
(
item
)
>
1
:
args
.
append
(
arg
)
arg_info
=
op_info
[
arg_type
]
if
arg_type
in
[
'inputs'
,
'outputs'
]:
get_input_output
(
arg_info
,
args
)
elif
arg_type
==
'attrs'
:
for
item
in
op_info
[
arg_type
]:
if
item
[
"valid"
]:
if
'value'
not
in
item
:
raise
ValueError
(
"Json string Errors, attr key:value not found."
)
if
item
[
"name"
]
!=
"isRef"
:
args
.
append
(
item
[
'value'
])
get_attr
(
arg_info
,
args
)
return
args
...
...
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
浏览文件 @
a2957388
...
...
@@ -147,6 +147,7 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node,
input_desc_json
[
"format"
]
=
format
;
}
input_desc_json
[
"valid"
]
=
value
;
input_desc_json
[
"param_type"
]
=
input_ptr
->
param_type
();
input_list
->
emplace_back
(
input_desc_json
);
}
return
true
;
...
...
@@ -356,6 +357,7 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co
output_obj
[
"ori_format"
]
=
kOpFormat_NCHW
;
output_obj
[
"name"
]
=
output_ptr
->
name
();
output_obj
[
"valid"
]
=
true
;
output_obj
[
"param_type"
]
=
output_ptr
->
param_type
();
output_list
->
emplace_back
(
output_obj
);
(
*
output_idx
)
++
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录