Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
ebe61dc6
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
7
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ebe61dc6
编写于
6月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!369 The modified information should reflect the modification difference.
Merge pull request !369 from ggpolar/r0.5
上级
f288237d
b4709a0b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
50 addition
and
21 deletion
+50
-21
mindinsight/mindconverter/ast_edits.py
mindinsight/mindconverter/ast_edits.py
+50
-21
未找到文件。
mindinsight/mindconverter/ast_edits.py
浏览文件 @
ebe61dc6
...
...
@@ -27,7 +27,6 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from
mindinsight.mindconverter.config
import
NN_LIST
from
mindinsight.mindconverter.config
import
ALL_TORCH_APIS
from
mindinsight.mindconverter.config
import
ALL_2P_LIST
from
mindinsight.mindconverter.config
import
get_corresponding_ms_name
from
mindinsight.mindconverter.config
import
get_prompt_info
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.exceptions
import
NodeTypeNotSupport
...
...
@@ -671,6 +670,55 @@ class AstEditVisitor(ast.NodeVisitor):
return
new_code
@
staticmethod
def
_get_detail_prompt_msg
(
old_node
,
new_node
):
"""Get detail converted prompt information."""
msg
=
None
if
isinstance
(
old_node
,
ast
.
Call
)
and
isinstance
(
new_node
,
ast
.
Call
):
old_api_name
=
pasta
.
dump
(
old_node
.
func
)
new_api_name
=
pasta
.
dump
(
new_node
.
func
)
if
new_api_name
==
old_api_name
:
old_parameter_num
=
len
(
old_node
.
args
)
+
len
(
old_node
.
keywords
)
new_parameter_num
=
len
(
new_node
.
args
)
+
len
(
new_node
.
keywords
)
if
old_parameter_num
>
1
:
msg
=
'Parameters are converted.'
else
:
if
old_parameter_num
==
0
and
new_parameter_num
==
0
:
msg
=
'The API name is converted to mindspore API'
else
:
msg
=
'Parameter is converted.'
return
msg
def
_convert_call
(
self
,
node
,
matched_api_name
):
""""Convert the call node."""
new_node
=
None
code
=
pasta
.
dump
(
node
)
api_name
=
pasta
.
dump
(
node
.
func
)
warning_info
=
get_prompt_info
(
matched_api_name
)
if
warning_info
is
None
:
warning_info
=
''
if
matched_api_name
in
ALL_MAPPING
:
logger
.
info
(
"Line %3d start converting API: %s"
,
node
.
lineno
,
api_name
)
new_code
=
self
.
mapping_api
(
node
)
if
new_code
!=
code
:
try
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
].
value
# find the first call name
new_api_name
=
new_code
[:
new_code
.
find
(
'('
)]
detail_msg
=
self
.
_get_detail_prompt_msg
(
node
,
new_node
)
if
detail_msg
:
warning_info
=
detail_msg
+
' '
+
warning_info
except
AttributeError
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
]
new_api_name
=
new_code
self
.
_process_log
.
info
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_CONVERT_WITH_TIPS
%
(
api_name
,
new_api_name
,
warning_info
))
else
:
logger
.
warning
(
"Line %3d: found unsupported API: %s%s"
,
node
.
lineno
,
api_name
,
warning_info
)
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
warning_info
))
return
new_node
def
visit_Call
(
self
,
node
):
"""Callback function when visit AST tree"""
code
=
pasta
.
dump
(
node
)
...
...
@@ -688,26 +736,7 @@ class AstEditVisitor(ast.NodeVisitor):
new_code
=
code
matched_api_name
,
match_case
=
self
.
match_api
(
node
.
func
,
self
.
_is_forward_function
)
if
match_case
in
[
ApiMatchingEnum
.
API_INFER
,
ApiMatchingEnum
.
API_MATCHED
]:
warning_info
=
get_prompt_info
(
matched_api_name
)
if
warning_info
is
None
:
warning_info
=
''
if
matched_api_name
in
ALL_MAPPING
:
logger
.
info
(
"Line %3d start converting API: %s"
,
node
.
lineno
,
api_name
)
new_code
=
self
.
mapping_api
(
node
)
if
new_code
!=
code
:
try
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
].
value
# find the first call name
new_api_name
=
get_corresponding_ms_name
(
matched_api_name
)
except
AttributeError
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
]
new_api_name
=
new_code
self
.
_process_log
.
info
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_CONVERT_WITH_TIPS
%
(
api_name
,
new_api_name
,
warning_info
))
else
:
logger
.
warning
(
"Line %3d: found unsupported API: %s%s"
,
node
.
lineno
,
api_name
,
warning_info
)
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
warning_info
))
new_node
=
self
.
_convert_call
(
node
,
matched_api_name
)
elif
match_case
in
[
ApiMatchingEnum
.
API_STANDARD
,
ApiMatchingEnum
.
API_FOUND
]:
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
''
))
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录