Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
25d8d1e9
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
7
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
25d8d1e9
编写于
6月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!352 The conversion report is adjusted to make the the report more reasonable and accurate.
Merge pull request !352 from ggpolar/br_wzk_dev
上级
ee821301
205d43c6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
43 addition
and
11 deletion
+43
-11
mindinsight/mindconverter/ast_edits.py
mindinsight/mindconverter/ast_edits.py
+40
-7
mindinsight/mindconverter/config.py
mindinsight/mindconverter/config.py
+3
-4
未找到文件。
mindinsight/mindconverter/ast_edits.py
浏览文件 @
25d8d1e9
...
...
@@ -27,11 +27,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from
mindinsight.mindconverter.config
import
NN_LIST
from
mindinsight.mindconverter.config
import
ALL_TORCH_APIS
from
mindinsight.mindconverter.config
import
ALL_2P_LIST
from
mindinsight.mindconverter.config
import
get_corresponding_ms_name
from
mindinsight.mindconverter.config
import
get_prompt_info
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.exceptions
import
NodeTypeNotSupport
from
mindinsight.mindconverter.forward_call
import
ForwardCall
LOG_FMT_INSERT
=
"[Insert] '%s' is inserted to the converted file."
LOG_FMT_CONVERT
=
"[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS
=
"[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT
=
"[UnConvert] '%s' didn't convert. %s"
...
...
@@ -54,16 +56,22 @@ class _ConvertReport:
def
__init__
(
self
,
is_stub
=
False
):
self
.
_is_stub
=
is_stub
self
.
_max_line
=
0
self
.
_log
=
[]
# report log, type is (severity, line, col, msg)
self
.
_log_head
=
[]
self
.
_log_body
=
[]
# report log, type is (severity, line, col, msg)
def
_add_log
(
self
,
severity
,
line
,
col
,
msg
):
"""Add log."""
if
self
.
_is_stub
:
return
if
line
is
None
and
col
is
None
:
self
.
_log_head
.
append
(
msg
)
return
if
isinstance
(
line
,
int
)
and
isinstance
(
col
,
int
):
self
.
_log
.
append
((
severity
,
line
,
col
,
msg
))
self
.
_log
_body
.
append
((
severity
,
line
,
col
,
msg
))
if
self
.
_max_line
<
line
:
self
.
_max_line
=
line
else
:
raise
TypeError
(
'The parameter type is incorrect.'
)
def
info
(
self
,
line
,
col
,
msg
):
"""Interface to add infer log"""
...
...
@@ -73,14 +81,24 @@ class _ConvertReport:
"""Interface to add warning log"""
self
.
_add_log
(
logging
.
WARNING
,
line
,
col
,
msg
)
def
header_msg
(
self
,
msg
):
"""Interface to add header message log"""
self
.
_add_log
(
logging
.
INFO
,
None
,
None
,
msg
)
def
get_logs
(
self
):
"""Get convert logs"""
logs
=
[]
logs
.
extend
(
self
.
_log_head
)
# sort rule: line * self._max_line + col
self
.
_log
.
sort
(
key
=
lambda
log
:
log
[
1
]
*
self
.
_max_line
+
log
[
2
])
for
log_info
in
self
.
_log
:
self
.
_log
_body
.
sort
(
key
=
lambda
log
:
log
[
1
]
*
self
.
_max_line
+
log
[
2
])
for
log_info
in
self
.
_log
_body
:
log_info
=
"line %d:%d: %s"
%
(
log_info
[
1
],
log_info
[
2
],
log_info
[
3
])
logs
.
append
(
log_info
)
if
logs
:
# Deduplication for logs
if
logs
[
-
1
]
!=
log_info
:
logs
.
append
(
log_info
)
else
:
logs
.
append
(
log_info
)
return
logs
...
...
@@ -262,7 +280,8 @@ class AstEditVisitor(ast.NodeVisitor):
new_func_name
=
'construct'
if
func_ast_node
.
name
==
old_func_name
:
func_ast_node
.
name
=
new_func_name
self
.
_process_log
.
info
(
func_ast_node
.
lineno
,
func_ast_node
.
col_offset
,
real_line_number
=
self
.
_get_real_line_number
(
func_ast_node
)
self
.
_process_log
.
info
(
real_line_number
,
func_ast_node
.
col_offset
,
LOG_FMT_CONVERT
%
(
old_func_name
,
new_func_name
))
def
_convert_api
(
self
):
...
...
@@ -299,6 +318,15 @@ class AstEditVisitor(ast.NodeVisitor):
source_code
=
pasta
.
dump
(
node
)
return
source_code
[
pos
:]
@
staticmethod
def
_get_real_line_number
(
node
):
"""Get the real line number of the node."""
try
:
line_number
=
node
.
lineno
+
len
(
node
.
decorator_list
)
except
AttributeError
:
line_number
=
node
.
lineno
return
line_number
def
_replace_external_reference
(
self
):
"""
Replace external reference statements.
...
...
@@ -349,6 +377,7 @@ class AstEditVisitor(ast.NodeVisitor):
insert_pos
+=
1
else
:
try
:
# insert pos after the last one, if last one name is replaced.
replaced_with_node
=
names_replaced_with
[
src_name
]
insert_pos
=
self
.
_tree
.
body
.
index
(
replaced_with_node
)
+
1
except
ValueError
:
...
...
@@ -359,6 +388,8 @@ class AstEditVisitor(ast.NodeVisitor):
for
insert_pos
,
new_node
in
new_import_node
.
items
():
# Insert the node into the module
self
.
_tree
.
body
.
insert
(
insert_pos
+
insert_cnt
,
new_node
)
new_code
=
self
.
_dump_without_prefix
(
new_node
)
self
.
_process_log
.
header_msg
(
LOG_FMT_INSERT
%
new_code
.
strip
())
insert_cnt
+=
1
@
staticmethod
...
...
@@ -445,8 +476,10 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_sub_call
=
self
.
_is_include_sub_call
(
call_func_node
)
if
is_include_sub_call
:
# x.y().z splits to ['x.y()', 'z']
name_attributes
=
call_name
.
rsplit
(
'.'
,
1
)
else
:
# x.y.z splits to ['x', 'y', 'z']
name_attributes
=
call_name
.
split
(
'.'
)
# rewritten external module name
...
...
@@ -665,7 +698,7 @@ class AstEditVisitor(ast.NodeVisitor):
try
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
].
value
# find the first call name
new_api_name
=
new_code
[:
new_code
.
find
(
'('
)]
new_api_name
=
get_corresponding_ms_name
(
matched_api_name
)
except
AttributeError
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
]
new_api_name
=
new_code
...
...
mindinsight/mindconverter/config.py
浏览文件 @
25d8d1e9
...
...
@@ -32,7 +32,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class
APIPt
:
"""Base API for args parse, and API for one frame."""
def
__init__
(
self
,
name
:
str
,
params
:
OrderedD
ict
):
def
__init__
(
self
,
name
:
str
,
params
:
d
ict
):
self
.
name
=
name
self
.
params
=
OrderedDict
()
...
...
@@ -45,7 +45,7 @@ class APIPt:
Trans value to str.
Args:
value (Union[str,Number,int]):
Each value for params of OrderedDic
t.
value (Union[str,Number,int]):
The value to conver
t.
Returns:
str, str type of value.
...
...
@@ -118,7 +118,7 @@ class APIPt:
class
APIMs
(
APIPt
):
"""API for MindSpore"""
def
__init__
(
self
,
name
:
str
,
params
:
OrderedD
ict
,
p_attrs
=
None
):
def
__init__
(
self
,
name
:
str
,
params
:
d
ict
,
p_attrs
=
None
):
self
.
is_primitive
=
name
.
startswith
(
'P.'
)
if
self
.
is_primitive
:
self
.
p_attrs
=
p_attrs
if
p_attrs
else
set
()
...
...
@@ -450,7 +450,6 @@ UNSUPPORTED_WARN_INFOS = {
"F.one_hot"
:
"Maybe could convert to mindspore.ops.operations.OneHot."
,
"torch.bmm"
:
"Maybe could convert to mindspore.ops.operations.BatchMatMul."
,
"torch.cumsum"
:
"Maybe could convert to mindspore.ops.operations.CumSum."
,
"F.relu"
:
"Maybe could convert to mindspore.ops.operations.ReLU."
,
"F.pad"
:
"Maybe could convert to mindspore.ops.operations.Pad."
,
"F.softmax"
:
"Maybe could convert to mindspore.ops.operations.Softmax."
,
"torch.clamp"
:
"Maybe could convert to mindspore.ops.composite.clip_by_value."
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录