Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
f803a1bd
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f803a1bd
编写于
6月 30, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 30, 2020
浏览文件
操作
浏览文件
下载
差异文件
!366 Improve the accuracy of the converted report.
Merge pull request !366 from ggpolar/br_wzk_0624
上级
f57e0271
ac2ad193
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
92 addition
and
24 deletion
+92
-24
mindinsight/mindconverter/ast_edits.py
mindinsight/mindconverter/ast_edits.py
+92
-24
未找到文件。
mindinsight/mindconverter/ast_edits.py
浏览文件 @
f803a1bd
...
...
@@ -20,6 +20,7 @@ import re
from
enum
import
Enum
import
pasta
from
pasta.base
import
formatting
as
fmt
from
mindinsight.mindconverter.code_analysis
import
CodeAnalyzer
from
mindinsight.mindconverter.code_analysis
import
APIAnalysisSpec
...
...
@@ -27,7 +28,7 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from
mindinsight.mindconverter.config
import
NN_LIST
from
mindinsight.mindconverter.config
import
ALL_TORCH_APIS
from
mindinsight.mindconverter.config
import
ALL_2P_LIST
from
mindinsight.mindconverter.config
import
get_corresponding_ms_name
from
mindinsight.mindconverter.config
import
TENSOR_DOT_LIST
from
mindinsight.mindconverter.config
import
get_prompt_info
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.exceptions
import
NodeTypeNotSupport
...
...
@@ -263,6 +264,22 @@ class AstEditVisitor(ast.NodeVisitor):
self
.
_process_log
.
info
(
base_class_node
.
lineno
,
base_class_node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
old_code
,
''
))
@
staticmethod
def
_modify_function_name
(
func_def_node
,
new_func_name
):
"""Modify function name"""
if
not
isinstance
(
func_def_node
,
ast
.
FunctionDef
):
raise
NodeTypeNotSupport
(
'It is not ast.FunctionDef node type.'
)
old_func_name
=
func_def_node
.
name
func_def_node
.
name
=
new_func_name
# Modify formatting information stored by pasta
old_function_def
=
fmt
.
get
(
func_def_node
,
'function_def'
)
if
old_function_def
:
new_function_def
=
old_function_def
.
replace
(
old_func_name
,
new_func_name
)
fmt
.
set
(
func_def_node
,
'function_def'
,
new_function_def
)
fmt
.
set
(
func_def_node
,
'name__src'
,
new_func_name
)
def
_update_function_def
(
self
,
func_scope
):
"""
Convert a PyTorch function into MindSpore function.
...
...
@@ -279,7 +296,7 @@ class AstEditVisitor(ast.NodeVisitor):
old_func_name
=
'forward'
new_func_name
=
'construct'
if
func_ast_node
.
name
==
old_func_name
:
func_ast_node
.
name
=
new_func_name
self
.
_modify_function_name
(
func_ast_node
,
new_func_name
)
real_line_number
=
self
.
_get_real_line_number
(
func_ast_node
)
self
.
_process_log
.
info
(
real_line_number
,
func_ast_node
.
col_offset
,
LOG_FMT_CONVERT
%
(
old_func_name
,
new_func_name
))
...
...
@@ -496,12 +513,33 @@ class AstEditVisitor(ast.NodeVisitor):
# only infer function for tensor object.
# e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object.
# e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object.
first_name
=
standard_name
.
split
(
'.'
)[
0
]
if
not
re
.
search
(
r
'\W'
,
first_name
)
and
len
(
name_attributes
)
>
1
:
if
self
.
_check_tensor_object
(
call_func_node
):
api_name
=
'.'
+
name_attributes
[
-
1
]
match_case
=
ApiMatchingEnum
.
API_INFER
return
api_name
,
match_case
def
_check_tensor_object
(
self
,
node
):
"""Check whether the reference object of the node is a tensor object."""
if
not
isinstance
(
node
,
(
ast
.
Attribute
,
ast
.
Name
)):
return
False
name_attributes
=
self
.
_dump_without_prefix
(
node
).
split
(
'.'
)
node_ref_name
=
name_attributes
[
0
]
if
re
.
search
(
r
'\W'
,
node_ref_name
)
or
len
(
name_attributes
)
==
1
:
return
False
func_name
=
'.'
+
name_attributes
[
-
1
]
if
func_name
not
in
TENSOR_DOT_LIST
:
return
False
is_tensor_object
=
True
if
self
.
_code_analyzer
:
# Check whether the object is external reference.
for
ref_name
in
self
.
_code_analyzer
.
external_references
:
if
node_ref_name
==
ref_name
:
is_tensor_object
=
False
break
return
is_tensor_object
@
staticmethod
def
_is_include_sub_call
(
call_func_node
):
""""Inspect a sub call in call expression.
...
...
@@ -671,6 +709,55 @@ class AstEditVisitor(ast.NodeVisitor):
return
new_code
@
staticmethod
def
_get_detail_prompt_msg
(
old_node
,
new_node
):
"""Get detail converted prompt information."""
msg
=
None
if
isinstance
(
old_node
,
ast
.
Call
)
and
isinstance
(
new_node
,
ast
.
Call
):
old_api_name
=
pasta
.
dump
(
old_node
.
func
)
new_api_name
=
pasta
.
dump
(
new_node
.
func
)
if
new_api_name
==
old_api_name
:
old_parameter_num
=
len
(
old_node
.
args
)
+
len
(
old_node
.
keywords
)
new_parameter_num
=
len
(
new_node
.
args
)
+
len
(
new_node
.
keywords
)
if
old_parameter_num
>
1
:
msg
=
'Parameters are converted.'
else
:
if
old_parameter_num
==
0
and
new_parameter_num
==
0
:
msg
=
'The API name is converted to mindspore API'
else
:
msg
=
'Parameter is converted.'
return
msg
def
_convert_call
(
self
,
node
,
matched_api_name
):
""""Convert the call node."""
new_node
=
None
code
=
pasta
.
dump
(
node
)
api_name
=
pasta
.
dump
(
node
.
func
)
warning_info
=
get_prompt_info
(
matched_api_name
)
if
warning_info
is
None
:
warning_info
=
''
if
matched_api_name
in
ALL_MAPPING
:
logger
.
info
(
"Line %3d start converting API: %s"
,
node
.
lineno
,
api_name
)
new_code
=
self
.
mapping_api
(
node
)
if
new_code
!=
code
:
try
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
].
value
# find the first call name
new_api_name
=
new_code
[:
new_code
.
find
(
'('
)]
detail_msg
=
self
.
_get_detail_prompt_msg
(
node
,
new_node
)
if
detail_msg
:
warning_info
=
detail_msg
+
' '
+
warning_info
except
AttributeError
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
]
new_api_name
=
new_code
self
.
_process_log
.
info
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_CONVERT_WITH_TIPS
%
(
api_name
,
new_api_name
,
warning_info
))
else
:
logger
.
warning
(
"Line %3d: found unsupported API: %s%s"
,
node
.
lineno
,
api_name
,
warning_info
)
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
warning_info
))
return
new_node
def
visit_Call
(
self
,
node
):
"""Callback function when visit AST tree"""
code
=
pasta
.
dump
(
node
)
...
...
@@ -688,26 +775,7 @@ class AstEditVisitor(ast.NodeVisitor):
new_code
=
code
matched_api_name
,
match_case
=
self
.
match_api
(
node
.
func
,
self
.
_is_forward_function
)
if
match_case
in
[
ApiMatchingEnum
.
API_INFER
,
ApiMatchingEnum
.
API_MATCHED
]:
warning_info
=
get_prompt_info
(
matched_api_name
)
if
warning_info
is
None
:
warning_info
=
''
if
matched_api_name
in
ALL_MAPPING
:
logger
.
info
(
"Line %3d start converting API: %s"
,
node
.
lineno
,
api_name
)
new_code
=
self
.
mapping_api
(
node
)
if
new_code
!=
code
:
try
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
].
value
# find the first call name
new_api_name
=
get_corresponding_ms_name
(
matched_api_name
)
except
AttributeError
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
]
new_api_name
=
new_code
self
.
_process_log
.
info
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_CONVERT_WITH_TIPS
%
(
api_name
,
new_api_name
,
warning_info
))
else
:
logger
.
warning
(
"Line %3d: found unsupported API: %s%s"
,
node
.
lineno
,
api_name
,
warning_info
)
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
warning_info
))
new_node
=
self
.
_convert_call
(
node
,
matched_api_name
)
elif
match_case
in
[
ApiMatchingEnum
.
API_STANDARD
,
ApiMatchingEnum
.
API_FOUND
]:
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
''
))
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录