Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
bcdc61cc
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bcdc61cc
编写于
6月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!273 Converter: use the AST to analyze and modify network definition script
Merge pull request !273 from ggpolar/br_wzk_dev
上级
d200233c
7cad801d
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
1178 addition
and
710 deletion
+1178
-710
mindinsight/mindconverter/ast_edits.py
mindinsight/mindconverter/ast_edits.py
+579
-0
mindinsight/mindconverter/cli.py
mindinsight/mindconverter/cli.py
+3
-5
mindinsight/mindconverter/code_analysis.py
mindinsight/mindconverter/code_analysis.py
+399
-0
mindinsight/mindconverter/common/exceptions.py
mindinsight/mindconverter/common/exceptions.py
+44
-0
mindinsight/mindconverter/converter.py
mindinsight/mindconverter/converter.py
+58
-441
mindinsight/mindconverter/forward_call.py
mindinsight/mindconverter/forward_call.py
+42
-34
mindinsight/utils/constant.py
mindinsight/utils/constant.py
+5
-0
tests/ut/mindconverter/test_converter.py
tests/ut/mindconverter/test_converter.py
+44
-223
tests/ut/mindconverter/test_forward_call.py
tests/ut/mindconverter/test_forward_call.py
+4
-7
未找到文件。
mindinsight/mindconverter/ast_edits.py
0 → 100644
浏览文件 @
bcdc61cc
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert for Python scripts according API mapping information."""
import
ast
import
logging
import
re
from
enum
import
Enum
import
pasta
from
pasta.augment
import
import_utils
from
mindinsight.mindconverter.code_analysis
import
CodeAnalyzer
from
mindinsight.mindconverter.code_analysis
import
APIAnalysisSpec
from
mindinsight.mindconverter.config
import
ALL_MAPPING
from
mindinsight.mindconverter.config
import
NN_LIST
from
mindinsight.mindconverter.config
import
ALL_TORCH_APIS
from
mindinsight.mindconverter.config
import
ALL_2P_LIST
from
mindinsight.mindconverter.config
import
UNSUPPORTED_WARN_INFOS
from
mindinsight.mindconverter.config
import
ALL_UNSUPPORTED
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.exceptions
import
NodeTypeNotSupport
from
mindinsight.mindconverter.forward_call
import
ForwardCall
LOG_FMT_CONVERT
=
"[Convert] '%s' is converted to '%s'."
LOG_FMT_NOT_CONVERT
=
"[UnConvert] '%s' didn't convert. %s"
LOG_FMT_PROMPT_INFO
=
"[INFO] %s"
LOG_SUGGESTION_MANUAL_CONVERT
=
"Please manual convert the code, along with the code associated with it."
class
ApiMatchingEnum
(
Enum
):
"""Node edge type enum."""
NOT_API
=
'not an api name'
API_INFER
=
'infer api name to map'
API_STANDARD
=
'api name in the correct format'
API_FOUND
=
'found an api name in api list'
API_MATCHED
=
'api is matched to map'
class
_ConvertReport
:
"""Report log of converting source code."""
def
__init__
(
self
,
is_stub
=
False
):
self
.
_is_stub
=
is_stub
self
.
_max_line
=
0
self
.
_log
=
[]
# report log, type is (severity, line, col, msg)
def
_add_log
(
self
,
severity
,
line
,
col
,
msg
):
"""Add log."""
if
self
.
_is_stub
:
return
if
isinstance
(
line
,
int
)
and
isinstance
(
col
,
int
):
self
.
_log
.
append
((
severity
,
line
,
col
,
msg
))
if
self
.
_max_line
<
line
:
self
.
_max_line
=
line
def
info
(
self
,
line
,
col
,
msg
):
"""Interface to add infer log"""
self
.
_add_log
(
logging
.
INFO
,
line
,
col
,
msg
)
def
warning
(
self
,
line
,
col
,
msg
):
"""Interface to add warning log"""
self
.
_add_log
(
logging
.
WARNING
,
line
,
col
,
msg
)
def
get_logs
(
self
):
"""Get convert logs"""
logs
=
[]
# sort rule: line * self._max_line + col
self
.
_log
.
sort
(
key
=
lambda
log
:
log
[
1
]
*
self
.
_max_line
+
log
[
2
])
for
log_info
in
self
.
_log
:
log_info
=
"line %d:%d: %s"
%
(
log_info
[
1
],
log_info
[
2
],
log_info
[
3
])
logs
.
append
(
log_info
)
return
logs
class
_LineColEditVisitor
(
ast
.
NodeVisitor
):
"""
Update line number and col offset of ast node.
Use the line and column number of the original code to update
the line and column number of the new code replaced with the original code.
"""
class
_NodeInfo
:
"""NodeInfo class definition."""
def
__init__
(
self
,
node
):
self
.
node
=
node
self
.
call_list
=
[]
# Used to save all ast.Call node in self._node
def
__init__
(
self
):
self
.
_dst_node_info
=
None
self
.
_src_node_info
=
None
self
.
_visiting
=
self
.
_src_node_info
# Used to point to the visiting node
def
update
(
self
,
replace_with_node
,
src_node
):
"""Update the line and column number of the new code replaced with the original code."""
replace_with_node
.
lineno
=
src_node
.
lineno
replace_with_node
.
col_offset
=
src_node
.
col_offset
self
.
_dst_node_info
=
self
.
_NodeInfo
(
replace_with_node
)
self
.
_src_node_info
=
self
.
_NodeInfo
(
src_node
)
self
.
_visiting
=
self
.
_src_node_info
self
.
visit
(
self
.
_visiting
.
node
)
self
.
_visiting
=
self
.
_dst_node_info
self
.
visit
(
self
.
_visiting
.
node
)
self
.
_update_line_col
()
def
visit_Call
(
self
,
node
):
"""Callback function when visit AST tree"""
self
.
_visiting
.
call_list
.
append
(
node
)
self
.
generic_visit
(
node
)
def
_update_line_col
(
self
):
"""Update the line and column number information for all ast.Call node."""
dst_call_list
=
list
(
self
.
_dst_node_info
.
call_list
)
src_call_list
=
list
(
self
.
_src_node_info
.
call_list
)
len_diff
=
len
(
dst_call_list
)
-
len
(
src_call_list
)
# After MindSpore api replaces Torch api, more calls are generated.
# For example, out.view() is replaced with P.Reshape()(out).
# out.view() has only one call, but P.Reshape()(out) has two calls.
# To match the replaced calls, the calls of out.view is padded to the same quantity.
if
len_diff
>
0
:
src_call_list
=
[
src_call_list
[
0
]]
*
len_diff
+
src_call_list
for
dst_call
,
src_call
in
zip
(
dst_call_list
,
src_call_list
):
dst_call
.
lineno
=
src_call
.
lineno
dst_call
.
col_offset
=
src_call
.
col_offset
if
not
dst_call
.
args
:
continue
# When out.size().view(1, ...) transforms to P.Reshape()(out.size(), 1, ...),
# in this case, the column of parameter out.size() will be bigger than the following parameters.
# To ensure the sequence of parameters, adjust the column of the second parameter.
args
=
[]
for
arg
in
dst_call
.
args
:
if
self
.
_check_arg2update
(
arg
):
args
.
append
(
arg
)
for
arg
in
args
:
arg
.
lineno
=
dst_call
.
lineno
arg
.
col_offset
+=
dst_call
.
col_offset
@
staticmethod
def
_check_arg2update
(
arg
):
# Only the col_offset of the first line code is re-counted, needs to be corrected.
# When the arg is a function call, its col_offset is handled separately.
if
not
isinstance
(
arg
,
ast
.
Call
)
and
arg
.
lineno
==
1
:
return
True
return
False
class
AstEditVisitor
(
ast
.
NodeVisitor
):
"""AST Visitor that process function calls.
Converts function calls from torch api to MindSpore api using api mapping information.
"""
def
__init__
(
self
):
self
.
_process_log
=
_ConvertReport
()
self
.
_tree
=
None
self
.
_code_analyzer
=
None
self
.
_stack
=
[]
# Used to easily access the parent node
self
.
_forward_list
=
{}
self
.
_is_forward_function
=
False
# Used to allow access the visiting function forward attribute
self
.
_new_call_nodes
=
[]
# Used to save new ast.call nodes
def
process
(
self
,
ast_tree
):
"""
Convert source code to MindSpore code.
Args:
ast_tree (AST): The root node of the source code.
"""
self
.
__init__
()
self
.
_tree
=
ast_tree
self
.
_code_analyzer
=
CodeAnalyzer
()
self
.
_code_analyzer
.
process
(
self
.
_tree
)
self
.
_forward_list
=
ForwardCall
(
self
.
_tree
).
calls
# replace python function under nn.Module
self
.
_convert_api
()
# replace external reference statements
self
.
_convert_external_reference
()
def
get_logs
(
self
):
"""Get conversion report."""
return
self
.
_process_log
.
get_logs
()
def
_convert_cell
(
self
,
cell_scope
):
"""
Convert a PyTorch Module class into MindSpore Cell class.
Args:
cell_scope (pasta.base.Scope): The network class definition node inherits from torch.nn.Module.
"""
cell_ast_node
=
cell_scope
.
node
line_no
=
cell_ast_node
.
lineno
logger
.
info
(
"Line %3d: start converting nn.Module %s"
,
line_no
,
self
.
_code_analyzer
.
get_name
(
cell_ast_node
))
class_elements
=
self
.
_code_analyzer
.
network_definitions
()[
'cell'
]
# step1. update function definition
for
func_scope
in
class_elements
.
get
(
cell_scope
,
[]):
self
.
_update_function_def
(
func_scope
)
# step2. update base name of class
self
.
_update_base_name
(
cell_scope
)
def
_update_base_name
(
self
,
class_def_scope
):
"""
Update base name of class.
Args:
class_def_scope (ast.ClassDef): Class definition node.
"""
base_name_mapping
=
APIAnalysisSpec
.
base_name_mapping
class_def_node
=
class_def_scope
.
node
base_class_nodes
=
class_def_scope
.
node
.
bases
# update base class name
for
base_class_node
in
base_class_nodes
:
base_name
=
base_class_node
.
attr
if
base_name
in
APIAnalysisSpec
.
get_network_base_class_names
():
old_code
=
pasta
.
dump
(
base_class_node
)
if
base_name
in
base_name_mapping
:
new_code
=
'nn.'
+
base_name_mapping
[
base_class_node
.
attr
]
new_node
=
pasta
.
parse
(
new_code
)
pasta
.
ast_utils
.
replace_child
(
class_def_node
,
base_class_node
,
new_node
)
self
.
_process_log
.
info
(
base_class_node
.
lineno
,
base_class_node
.
col_offset
,
LOG_FMT_CONVERT
%
(
old_code
,
new_code
))
else
:
self
.
_process_log
.
info
(
base_class_node
.
lineno
,
base_class_node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
old_code
,
''
))
def
_update_function_def
(
self
,
func_scope
):
"""
Convert a PyTorch function into MindSpore function.
Args:
func_scope (pasta.base.scope.Scope): The node scope of function definition.
"""
is_forward
=
self
.
_judge_forward
(
func_scope
)
# step1. convert the content of the function.
self
.
_convert_function
(
func_scope
,
is_forward
)
# step2. replace function name if name is forward
func_ast_node
=
func_scope
.
node
old_func_name
=
'forward'
new_func_name
=
'construct'
if
func_ast_node
.
name
==
old_func_name
:
func_ast_node
.
name
=
new_func_name
self
.
_process_log
.
info
(
func_ast_node
.
lineno
,
func_ast_node
.
col_offset
,
LOG_FMT_CONVERT
%
(
old_func_name
,
new_func_name
))
def
_convert_api
(
self
):
"""Convert PyTorch api call to MindSpore api call in a function."""
tasks
=
[]
convert_elements
=
self
.
_code_analyzer
.
network_definitions
()
for
func_node_scope
in
convert_elements
.
get
(
"functions"
,
[]):
is_forward
=
self
.
_judge_forward
(
func_node_scope
)
tasks
.
append
((
self
.
_convert_function
,
(
func_node_scope
,
is_forward
)))
for
class_scope
in
convert_elements
.
get
(
"cell"
,
[]).
keys
():
tasks
.
append
((
self
.
_convert_cell
,
(
class_scope
,)))
for
convert_fun
,
args
in
tasks
:
convert_fun
(
*
args
)
def
_convert_external_reference
(
self
):
"""Convert import statements."""
name_replace
=
APIAnalysisSpec
.
import_name_mapping
replace_imports
=
list
(
name_replace
.
values
())
for
ref_info
in
self
.
_code_analyzer
.
external_references
.
values
():
external_ref_info
=
ref_info
[
'external_ref_info'
]
parent_node
=
ref_info
[
'parent_node'
]
if
parent_node
is
None
:
continue
code
=
pasta
.
dump
(
parent_node
)
if
external_ref_info
.
name
in
APIAnalysisSpec
.
get_convertible_external_names
():
external_ref_info
=
ref_info
[
'external_ref_info'
]
if
external_ref_info
.
name
in
name_replace
.
keys
():
import_utils
.
remove_import_alias_node
(
self
.
_code_analyzer
.
root_scope
,
external_ref_info
.
node
)
replace_info
=
name_replace
[
external_ref_info
.
name
]
new_ref_name
=
replace_info
[
1
]
new_external_name
=
replace_info
[
0
]
if
new_ref_name
:
new_code
=
f
'import
{
new_external_name
}
as
{
new_ref_name
}
'
else
:
new_code
=
f
'import
{
new_external_name
}
'
self
.
_process_log
.
info
(
parent_node
.
lineno
,
parent_node
.
col_offset
,
LOG_FMT_CONVERT
%
(
code
.
strip
(),
new_code
.
strip
()))
elif
external_ref_info
.
name
.
startswith
(
'torch.'
):
self
.
_process_log
.
warning
(
parent_node
.
lineno
,
parent_node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
code
.
strip
(),
LOG_SUGGESTION_MANUAL_CONVERT
))
else
:
pass
# Insert import in reverse order, display in forward order.
for
idx
in
range
(
len
(
replace_imports
)
-
1
,
-
1
,
-
1
):
replace_import
=
replace_imports
[
idx
]
if
replace_import
[
1
]:
self
.
_add_import
(
name_to_import
=
replace_import
[
0
],
as_name
=
replace_import
[
1
])
else
:
self
.
_add_import
(
name_to_import
=
replace_import
[
0
])
def
_add_import
(
self
,
name_to_import
,
as_name
=
None
):
"""
Adds an import to the ast tree.
Args:
name_to_import: (string) The absolute name to import.
as_name: (string) The alias for the import ("import name_to_import as asname")
"""
new_alias
=
ast
.
alias
(
name
=
name_to_import
,
asname
=
as_name
)
import_node
=
ast
.
Import
(
names
=
[
new_alias
])
# Insert the node at the top of the module
self
.
_tree
.
body
.
insert
(
1
if
pasta
.
base
.
ast_utils
.
has_docstring
(
self
.
_tree
)
else
0
,
import_node
)
def
_convert_function
(
self
,
func_scope
,
is_forward
):
"""
Convert a PyTorch function into MindSpore function.
Args:
func_scope (pasta.base.scope.Scope): The node scope of function definition.
is_forward (boolean): If the function is defined in forward function in nn.Module in torch.
"""
func_ast_node
=
func_scope
.
node
line_no
=
func_ast_node
.
lineno
logger
.
info
(
"Line %3d: start converting function %s()"
,
line_no
,
func_ast_node
.
name
)
parent
=
func_scope
.
parent_scope
.
node
self
.
_stack
.
clear
()
self
.
_new_call_nodes
.
clear
()
if
parent
:
self
.
_stack
.
append
(
parent
)
self
.
_is_forward_function
=
is_forward
self
.
visit
(
func_scope
.
node
)
def
_judge_forward
(
self
,
func_scope
):
"""
Check if function is a forward function.
Args:
func_scope (pasta.base.scope.Scope): The node scope of function definition.
Returns:
boolean, True or False
"""
is_forward
=
func_scope
.
node
in
self
.
_forward_list
.
values
()
if
is_forward
:
logger
.
debug
(
"%s is a forward function"
,
self
.
_code_analyzer
.
get_name
(
func_scope
))
return
is_forward
# Overridden to maintain stack information to access parent node
def
visit
(
self
,
node
):
"""Visit a ast tree."""
self
.
_stack
.
append
(
node
)
super
(
AstEditVisitor
,
self
).
visit
(
node
)
self
.
_stack
.
pop
()
def
_mapping_standard_api_name
(
self
,
api_name
):
"""Get mapping from external reference name to standard external reference name"""
standard_name
=
api_name
if
not
self
.
_code_analyzer
.
is_standard_external_ref
:
# key is real ref name, value is standard ref name.
mapping_names
=
self
.
_mapping_standard_external_ref
()
api_name_parts
=
api_name
.
split
(
'.'
)
api_name_parts
[
0
]
=
mapping_names
.
get
(
api_name_parts
[
0
],
api_name_parts
[
0
])
standard_name
=
'.'
.
join
(
api_name_parts
)
return
standard_name
def
_infer_api_name
(
self
,
call_func_node
,
check_context
=
True
):
"""Infer the call name.
Examples:
1. nn.Sequential inferred to nn.Sequential
2. mmm.size inferred to .size if import torch.nn as nn
3. mmm.size inferred to mmm.size if import torch.nn as mmm
"""
match_case
=
ApiMatchingEnum
.
NOT_API
api_name
=
None
call_name
=
pasta
.
dump
(
call_func_node
)
is_include_sub_call
=
self
.
_is_include_sub_call
(
call_func_node
)
if
is_include_sub_call
:
name_attributes
=
call_name
.
rsplit
(
'.'
,
1
)
else
:
name_attributes
=
call_name
.
split
(
'.'
)
# rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if
check_context
and
not
self
.
_code_analyzer
.
is_standard_external_ref
:
standard_name
=
self
.
_mapping_standard_api_name
(
name_attributes
[
0
])
else
:
standard_name
=
name_attributes
[
0
]
if
standard_name
in
[
"nn"
,
"F"
,
"torch"
]:
match_case
=
ApiMatchingEnum
.
API_STANDARD
api_name
=
call_name
else
:
# only infer function for tensor object.
# e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object.
# e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object.
first_name
=
standard_name
.
split
(
'.'
)[
0
]
if
not
re
.
search
(
r
'\W'
,
first_name
)
and
len
(
name_attributes
)
>
1
:
api_name
=
'.'
+
name_attributes
[
-
1
]
match_case
=
ApiMatchingEnum
.
API_INFER
return
api_name
,
match_case
@
staticmethod
def
_is_include_sub_call
(
call_func_node
):
""""Inspect a sub call in call expression.
Examples:
1. nn.functional.relu() return False
2. nn.functional.relu(out).size() return True. nn.functional.relu(out) is sub call.
3. nn.functional.relu(out=out.size()).size() return False. out.size() is not sub call of argument.
"""
is_include_call
=
False
try
:
sub_node
=
call_func_node
while
sub_node
and
not
isinstance
(
sub_node
,
ast
.
Call
):
sub_node
=
sub_node
.
value
if
isinstance
(
sub_node
,
ast
.
Call
):
is_include_call
=
True
except
AttributeError
:
is_include_call
=
False
return
is_include_call
def
match_api
(
self
,
call_func_node
,
is_forward
):
"""
Check api name to convert, check api name ok with a is_forward condition.
Args:
call_func_node (ast.Attribute): The call.func node.
is_forward (bool): whether api belong to forward.
Returns:
str, the standard api name used to match.
ApiMappingEnum, the match result.
"""
api_name
,
match_case
=
self
.
_infer_api_name
(
call_func_node
)
api_call_name
=
pasta
.
dump
(
call_func_node
)
is_tensor_obj_call
=
False
if
api_name
!=
api_call_name
:
is_tensor_obj_call
=
True
standard_api_call_name
=
api_name
# rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if
not
is_tensor_obj_call
and
not
self
.
_code_analyzer
.
is_standard_external_ref
:
standard_api_call_name
=
self
.
_mapping_standard_api_name
(
api_name
)
if
standard_api_call_name
in
ALL_TORCH_APIS
:
match_case
=
ApiMatchingEnum
.
API_FOUND
if
(
not
is_forward
and
standard_api_call_name
in
NN_LIST
)
or
\
(
is_forward
and
standard_api_call_name
in
ALL_2P_LIST
):
match_case
=
ApiMatchingEnum
.
API_MATCHED
return
standard_api_call_name
,
match_case
def
mapping_api
(
self
,
call_node
,
check_context
=
True
):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
If do not check context of the script, the code represented by the node must be written in the standard way.
Args:
call_node (ast.Call): The ast node to convert.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the converted code.
"""
if
not
isinstance
(
call_node
,
ast
.
Call
):
raise
NodeTypeNotSupport
(
"It is not ast.Call node."
)
code
=
pasta
.
dump
(
call_node
)
api_call_name
=
pasta
.
dump
(
call_node
.
func
)
if
api_call_name
.
startswith
(
'self.'
):
return
code
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
args_str
=
code
[
len
(
api_call_name
):].
strip
()
try
:
api_name
,
_
=
self
.
_infer_api_name
(
call_node
.
func
,
check_context
)
standard_api_call_name
=
api_call_name
if
api_name
!=
api_call_name
:
# api name .view inferred from out.view, split tensor object name is out
tensor_obj_name
=
api_call_name
[:
-
len
(
api_name
)]
map_helper
=
ALL_MAPPING
[
api_name
]
new_code
=
map_helper
.
convert
(
tensor_obj_name
,
args_str
)
else
:
# change to external ref name
# e.g., mm.ReLU will be changed to nn.ReLU if 'import torch.nn as mm' in script.
if
check_context
and
not
self
.
_code_analyzer
.
is_standard_external_ref
:
standard_api_call_name
=
self
.
_mapping_standard_api_name
(
api_name
)
map_helper
=
ALL_MAPPING
[
standard_api_call_name
]
new_code
=
map_helper
.
convert
(
standard_api_call_name
,
args_str
)
except
KeyError
:
return
code
return
new_code
def
visit_Call
(
self
,
node
):
"""Callback function when visit AST tree"""
code
=
pasta
.
dump
(
node
)
api_name
=
pasta
.
dump
(
node
.
func
)
# parent node first call is equal to this node, skip when parent node is replaced.
for
parent_node
in
self
.
_stack
[:
-
1
]:
if
parent_node
in
self
.
_new_call_nodes
and
pasta
.
dump
(
parent_node
).
startswith
(
api_name
):
return
parent
=
self
.
_stack
[
-
2
]
new_node
=
None
matched_api_name
,
match_case
=
self
.
match_api
(
node
.
func
,
self
.
_is_forward_function
)
if
match_case
in
[
ApiMatchingEnum
.
API_INFER
,
ApiMatchingEnum
.
API_MATCHED
]:
if
matched_api_name
in
ALL_MAPPING
:
logger
.
info
(
"Line %3d start converting API: %s"
,
node
.
lineno
,
api_name
)
new_code
=
self
.
mapping_api
(
node
)
if
new_code
!=
code
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
].
value
# find the first call name
new_api_name
=
new_code
[:
new_code
.
find
(
'('
)]
self
.
_process_log
.
info
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_CONVERT
%
(
api_name
,
new_api_name
))
if
matched_api_name
in
ALL_UNSUPPORTED
:
warn_info
=
UNSUPPORTED_WARN_INFOS
.
get
(
api_name
,
''
)
logger
.
warning
(
"Line %3d: found unsupported API: %s%s"
,
node
.
lineno
,
api_name
,
warn_info
)
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
warn_info
))
elif
match_case
in
[
ApiMatchingEnum
.
API_STANDARD
,
ApiMatchingEnum
.
API_FOUND
]:
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
''
))
else
:
pass
if
parent
and
new_node
:
update_line_col
=
_LineColEditVisitor
()
update_line_col
.
update
(
new_node
,
node
)
pasta
.
ast_utils
.
replace_child
(
parent
,
node
,
new_node
)
self
.
_new_call_nodes
.
append
(
new_node
)
node
=
new_node
self
.
_stack
[
-
1
]
=
node
try
:
self
.
generic_visit
(
node
)
except
Exception
:
logger
.
error
(
'original code:%s, new code:%s'
,
code
,
new_code
,
exc_info
=
True
)
raise
def
_mapping_standard_external_ref
(
self
):
"""Obtain the mapping dict of mapping the external references to standard external references."""
renames
=
{}
external_refs
=
self
.
_code_analyzer
.
external_references
for
ref_name
,
ref_info
in
external_refs
.
items
():
external_ref_info
=
ref_info
[
'external_ref_info'
]
if
ref_name
!=
'nn'
and
external_ref_info
.
name
==
'torch.nn'
:
renames
[
ref_name
]
=
'nn'
elif
ref_name
!=
'F'
and
external_ref_info
.
name
==
'torch.nn.functional'
:
renames
[
ref_name
]
=
'F'
return
renames
mindinsight/mindconverter/cli.py
浏览文件 @
bcdc61cc
...
@@ -186,25 +186,23 @@ def cli_entry():
...
@@ -186,25 +186,23 @@ def cli_entry():
mode
=
permissions
<<
6
mode
=
permissions
<<
6
os
.
makedirs
(
args
.
output
,
mode
=
mode
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
output
,
mode
=
mode
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
report
,
mode
=
mode
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
report
,
mode
=
mode
,
exist_ok
=
True
)
_run
(
args
.
in_file
,
args
.
output
,
''
,
args
.
report
)
_run
(
args
.
in_file
,
args
.
output
,
args
.
report
)
def
_run
(
in_files
,
out_dir
,
in_module
,
report
):
def
_run
(
in_files
,
out_dir
,
report
):
"""
"""
Run converter command.
Run converter command.
Args:
Args:
in_files (str): The file path or directory to convert.
in_files (str): The file path or directory to convert.
out_dir (str): The output directory to save converted file.
out_dir (str): The output directory to save converted file.
in_module (str): The module name to convert.
report (str): The report file path.
report (str): The report file path.
"""
"""
files_config
=
{
files_config
=
{
'root_path'
:
in_files
if
in_files
else
''
,
'root_path'
:
in_files
if
in_files
else
''
,
'in_files'
:
[],
'in_files'
:
[],
'outfile_dir'
:
out_dir
,
'outfile_dir'
:
out_dir
,
'report_dir'
:
report
,
'report_dir'
:
report
'in_module'
:
in_module
}
}
if
os
.
path
.
isfile
(
in_files
):
if
os
.
path
.
isfile
(
in_files
):
files_config
[
'root_path'
]
=
os
.
path
.
dirname
(
in_files
)
files_config
[
'root_path'
]
=
os
.
path
.
dirname
(
in_files
)
...
...
mindinsight/mindconverter/code_analysis.py
0 → 100644
浏览文件 @
bcdc61cc
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""code analysis module"""
import
ast
import
pasta
from
pasta.base
import
scope
from
mindinsight.mindconverter.common.exceptions
import
ScriptNotSupport
class
APIAnalysisSpec
:
"""API analysis specifications"""
import_name_mapping
=
{
'torch'
:
[
'mindspore'
,
None
],
'torch.nn'
:
[
'mindspore.nn'
,
'nn'
],
'torch.nn.functional'
:
[
'mindspore.ops.operations'
,
'P'
]}
base_name_mapping
=
{
'Module'
:
'Cell'
,
'Sequential'
:
'SequentialCell'
}
@
classmethod
def
get_convertible_external_names
(
cls
):
"""
Obtain the convertible external names.
The external name is the full dotted name being referenced.
"""
return
cls
.
import_name_mapping
.
keys
()
@
staticmethod
def
get_network_base_class_names
():
"""Obtain the base names which network class base from"""
return
[
'Module'
,
'Sequential'
,
'ModuleList'
,
'ModuleDict'
,
'ParameterList'
,
'ParameterDict'
]
@
staticmethod
def
check_external_alias_ref
(
ref_name
,
external_name
):
"""
Check 'import as' is standard.
Standard references are follow:
import torch.nn as nn
import torch.nn.functional as F
Args:
ref_name (str): The name that refers to the external_name.
external_name (str): The full dotted name being referenced. For examples:
1. 'import torch.nn as nn', torch.nn is external_name, nn is ref_name.
2. 'from torch import nn as mm, torch.nn is external_name, mm is ref_name which is not a standard name.
Returns:
boolean, True if ref_name is standard else False.
"""
if
ref_name
!=
'nn'
and
external_name
==
'torch.nn'
:
is_standard
=
False
elif
ref_name
!=
'F'
and
external_name
==
'torch.nn.functional'
:
is_standard
=
False
else
:
is_standard
=
True
return
is_standard
class
CodeAnalyzer
(
ast
.
NodeVisitor
):
"""Code analyzer that analyzes PyTorch python script by AST Visitor.
CodeAnalyzer find the codes that need to be converted to MindSpore,
and provides the attributes related to the codes.
"""
def
__init__
(
self
):
self
.
_stack
=
[]
# Used to easily access the parent node
self
.
_external_references
=
{}
self
.
_is_standard_external_ref
=
True
self
.
_root_scope
=
None
# Used to save functions that need to be converted, value type is pasta.base.scope.Scope
self
.
_network_functions
=
[]
# Used to easily trace the function node
self
.
_functions_stack
=
[]
# key type is pasta.base.scope.Scope, value type is list
self
.
_network_classes
=
{}
@
property
def
root_scope
(
self
):
"""The root scope of the python script code."""
return
self
.
_root_scope
@
property
def
is_standard_external_ref
(
self
):
"""Obtain whether the result is a standard external reference."""
return
self
.
_is_standard_external_ref
@
property
def
external_references
(
self
):
"""Obtain all external references in the analyzed code."""
return
self
.
_external_references
def
network_definitions
(
self
):
"""Obtain the network definitions which need to be converted."""
return
{
"functions"
:
self
.
_network_functions
,
"cell"
:
self
.
_network_classes
}
def
process
(
self
,
ast_tree
):
"""
Start to analyze the code.
Args:
ast_tree (AST): The root node of the source code.
"""
self
.
__init__
()
self
.
_root_scope
=
scope
.
analyze
(
ast_tree
)
self
.
_pre_process
()
self
.
visit
(
ast_tree
)
if
not
self
.
_network_classes
:
msg
=
"model definition not be found."
raise
ScriptNotSupport
(
msg
)
@
staticmethod
def
_check_external_standard
(
external_refs
):
"""Check whether all external references are standard."""
is_standard
=
True
for
external_name
,
external_ref_info
in
external_refs
.
items
():
is_standard
=
APIAnalysisSpec
.
check_external_alias_ref
(
external_name
,
external_ref_info
.
name
)
if
not
is_standard
:
break
return
is_standard
def
_is_base_from_cell
(
self
,
node
):
"""
Check whether the node bases from cell classes which are defined in APIAnalysisSpec.
Args:
node (ast.ClassDef): The node which is a class definition.
Returns:
boolean, True if the check result is Passed else False.
"""
if
self
.
_is_ref_convertible_imports
(
node
):
whole_name
=
self
.
_get_whole_name
(
node
)
if
whole_name
.
split
(
'.'
)[
-
1
]
in
APIAnalysisSpec
.
get_network_base_class_names
():
return
True
return
False
def
_pre_process
(
self
):
"""Preprocessor checks the code before analyzing."""
is_torch
=
False
# check whether the code imports torch.
for
ref_name
in
self
.
_root_scope
.
external_references
.
keys
():
if
ref_name
.
split
(
'.'
)[
0
]
in
APIAnalysisSpec
.
get_convertible_external_names
():
is_torch
=
True
break
if
not
is_torch
:
msg
=
"The source code does not import torch, model definition can not be found."
raise
ScriptNotSupport
(
msg
)
# Find out external reference in the code and save it.
external_refs
=
self
.
_analyze_import_references
(
self
.
_root_scope
)
self
.
_is_standard_external_ref
=
self
.
_check_external_standard
(
external_refs
)
self
.
_check_external_standard
(
external_refs
)
for
external_name
,
external_ref_info
in
external_refs
.
items
():
self
.
_external_references
.
update
({
external_name
:
{
'external_ref_info'
:
external_ref_info
,
'parent_node'
:
None
}
})
@
staticmethod
def
_analyze_import_references
(
root_scope
):
"""Find out all references from the import statements."""
external_name_ref
=
{}
for
node_references
in
root_scope
.
external_references
.
values
():
for
node_ref
in
node_references
:
if
node_ref
.
name_ref
:
# (from)import alias, node_ref.name_ref.id is alias name
if
node_ref
.
name_ref
.
definition
.
asname
==
node_ref
.
name_ref
.
id
:
external_name_ref
[
node_ref
.
name_ref
.
id
]
=
node_ref
# import without alias, node_ref.name_ref.definition.asname is None.
# e.g., import a.b.c, reference maybe is a, a.b or a.b.c in the root_scope.external_references.
# The reference a.b.c is really wanted.
elif
node_ref
.
name_ref
.
definition
.
name
==
node_ref
.
name_ref
.
id
:
external_name_ref
[
node_ref
.
name_ref
.
id
]
=
node_ref
else
:
pass
return
external_name_ref
def
visit
(
self
,
node
):
"""Overridden visit of the base class to maintain stack information to access parent node."""
self
.
_stack
.
append
(
node
)
super
(
CodeAnalyzer
,
self
).
visit
(
node
)
self
.
_stack
.
pop
()
@
staticmethod
def
_get_full_name
(
node
):
"""Get the full name of the node."""
if
not
isinstance
(
node
,
(
ast
.
Attribute
,
ast
.
Name
)):
return
None
return
pasta
.
dump
(
node
)
def
_get_whole_name
(
self
,
node
):
"""
Get the whole name of the node.
For example, nn.Module is spliced two nodes, nn node and Module node.
When visit ast nodes,
Module node is first visited, the full name is the same as the whole name, that is nn.Module.
And then nn node is visited, the full name is nn, the whole name is nn.Module.
"""
full_name
=
self
.
_get_full_name
(
node
)
if
not
full_name
:
return
None
# node is in stack top pos
if
node
is
self
.
_stack
[
-
1
]:
parent_index
=
-
1
while
isinstance
(
self
.
_stack
[
parent_index
],
ast
.
Attribute
):
parent_index
-=
1
whole_name
=
self
.
_get_full_name
(
self
.
_stack
[
parent_index
])
else
:
whole_name
=
full_name
return
whole_name
def
_is_ref_convertible_imports
(
self
,
node
):
"""Check whether the node references convertible imports."""
check_result
=
False
whole_name
=
self
.
_get_whole_name
(
node
)
if
whole_name
:
module_name
=
whole_name
.
split
(
'.'
)[
0
]
for
ref_name
,
ref_info
in
self
.
_external_references
.
items
():
external_ref
=
ref_info
[
'external_ref_info'
]
# external reference is convertible module
if
external_ref
.
name
in
APIAnalysisSpec
.
get_convertible_external_names
():
# import from the same external module
if
module_name
==
ref_name
.
split
(
'.'
)[
0
]:
check_result
=
True
break
return
check_result
@
staticmethod
def
_get_external_node
(
external_references
):
"""Get all external reference nodes."""
external_nodes
=
{}
for
ref_name
,
ref_info
in
external_references
.
items
():
external_nodes
.
update
({
ref_info
[
'external_ref_info'
].
node
:
ref_name
})
return
external_nodes
@
staticmethod
def
_get_convertible_external_node
(
external_name_ref
):
"""Get all convertible external reference nodes."""
convertible_external_nodes
=
{}
for
ref_name
,
ref_info
in
external_name_ref
.
items
():
if
ref_info
[
'external_ref_info'
].
name
in
APIAnalysisSpec
.
get_convertible_external_names
():
convertible_external_nodes
.
update
({
ref_info
[
'external_ref_info'
].
node
:
ref_name
})
return
convertible_external_nodes
def
_update_external_ref_parent
(
self
,
node
):
"""Set external reference parent node info."""
external_nodes
=
self
.
_get_external_node
(
self
.
_external_references
)
convertible_external_nodes
=
self
.
_get_convertible_external_node
(
self
.
_external_references
)
for
name_node
in
node
.
names
:
if
name_node
in
convertible_external_nodes
.
keys
():
if
len
(
node
.
names
)
>
1
:
msg
=
"""
\
Not support multiple imports of torch on one line in your script. line:%s: %s
"""
%
(
node
.
lineno
,
pasta
.
dump
(
node
))
raise
ScriptNotSupport
(
msg
)
if
name_node
in
external_nodes
.
keys
():
ref_name
=
external_nodes
[
name_node
]
self
.
_external_references
[
ref_name
][
'parent_node'
]
=
node
@
staticmethod
def
_get_class_scope
(
node_scope
):
"""Find the class scope of the node_scope."""
parent_scope
=
node_scope
.
parent_scope
class_scope
=
None
while
parent_scope
:
if
isinstance
(
parent_scope
.
node
,
ast
.
ClassDef
):
class_scope
=
parent_scope
break
parent_scope
=
parent_scope
.
parent_scope
return
class_scope
def
_update_convertible_functions
(
self
,
node
):
"""Update convertible functions."""
node_scope
=
self
.
_root_scope
.
lookup_scope
(
node
)
class_scope
=
self
.
_get_class_scope
(
node_scope
)
if
class_scope
:
network_classes
=
self
.
_network_classes
.
get
(
class_scope
,
[])
if
node_scope
not
in
network_classes
:
network_classes
.
append
(
node_scope
)
else
:
if
node_scope
not
in
self
.
_network_functions
:
self
.
_network_functions
.
append
(
node_scope
)
def
visit_ClassDef
(
self
,
node
):
"""Callback function when visit AST tree"""
if
not
self
.
_stack
[
-
1
]
is
node
:
return
for
base
in
node
.
bases
:
if
self
.
_is_ref_convertible_imports
(
base
):
self
.
_network_classes
[
self
.
_root_scope
.
lookup_scope
(
node
)]
=
[]
self
.
generic_visit
(
node
)
def
visit_Import
(
self
,
node
):
"""Callback function when visit AST tree"""
self
.
_update_external_ref_parent
(
node
)
self
.
generic_visit
(
node
)
def
visit_ImportFrom
(
self
,
node
):
"""Callback function when visit AST tree"""
self
.
_update_external_ref_parent
(
node
)
self
.
generic_visit
(
node
)
def
visit_Call
(
self
,
node
):
"""Callback function when visit AST tree"""
if
not
self
.
_stack
[
-
1
]
is
node
:
return
is_in_network_function
=
False
# If torch call is happened in the function, save the function for network definition.
if
self
.
_functions_stack
and
self
.
_is_ref_convertible_imports
(
node
.
func
):
self
.
_update_convertible_functions
(
self
.
_functions_stack
[
-
1
])
is_in_network_function
=
True
if
not
is_in_network_function
:
self
.
generic_visit
(
node
)
def
visit_FunctionDef
(
self
,
node
):
"""Callback function when visit AST tree"""
if
not
self
.
_stack
[
-
1
]
is
node
:
return
if
node
.
name
==
"forward"
:
self
.
_update_convertible_functions
(
node
)
self
.
_functions_stack
.
append
(
node
)
self
.
generic_visit
(
node
)
self
.
_functions_stack
.
pop
()
def
get_name
(
self
,
node
):
"""
Get the node name.
Args:
node (AST): The ast node of the source code.
Returns:
str, the name of the node
"""
if
isinstance
(
node
,
pasta
.
base
.
scope
.
Scope
):
items
=
[
self
.
get_name
(
node
.
node
)]
parent_scope
=
node
.
parent_scope
while
parent_scope
:
if
not
isinstance
(
parent_scope
.
node
,
ast
.
Module
):
items
.
append
(
self
.
get_name
(
parent_scope
.
node
))
parent_scope
=
parent_scope
.
parent_scope
return
'.'
.
join
(
reversed
(
items
))
if
isinstance
(
node
,
(
ast
.
ClassDef
,
ast
.
FunctionDef
)):
return
node
.
name
if
isinstance
(
node
,
(
ast
.
Name
,
ast
.
Attribute
)):
return
self
.
_get_full_name
(
node
)
return
str
(
node
)
def
lookup_scope
(
self
,
node
):
"""
Search the scope of the node.
Args:
node (AST): The ast node of the source code.
Returns:
scope, the scope of the node
"""
if
isinstance
(
node
,
pasta
.
base
.
scope
.
Scope
):
return
node
return
self
.
_root_scope
.
lookup_scope
(
node
)
mindinsight/mindconverter/common/exceptions.py
0 → 100644
浏览文件 @
bcdc61cc
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define custom exception."""
from
enum
import
unique
from
mindinsight.utils.constant
import
ScriptConverterErrors
from
mindinsight.utils.exceptions
import
MindInsightException
@
unique
class
ConverterErrors
(
ScriptConverterErrors
):
"""Converter error codes."""
SCRIPT_NOT_SUPPORT
=
1
NODE_TYPE_NOT_SUPPORT
=
2
class
ScriptNotSupport
(
MindInsightException
):
"""The script can not support to process."""
def
__init__
(
self
,
msg
):
super
(
ScriptNotSupport
,
self
).
__init__
(
ConverterErrors
.
SCRIPT_NOT_SUPPORT
,
msg
,
http_code
=
400
)
class
NodeTypeNotSupport
(
MindInsightException
):
"""The astNode can not support to process."""
def
__init__
(
self
,
msg
):
super
(
NodeTypeNotSupport
,
self
).
__init__
(
ConverterErrors
.
NODE_TYPE_NOT_SUPPORT
,
msg
,
http_code
=
400
)
mindinsight/mindconverter/converter.py
浏览文件 @
bcdc61cc
...
@@ -13,464 +13,89 @@
...
@@ -13,464 +13,89 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""converter module"""
"""converter module"""
import
copy
import
importlib
import
inspect
import
os
import
os
import
stat
import
stat
from
mindinsight.mindconverter.config
import
ALL_MAPPING
import
pasta
from
mindinsight.mindconverter.config
import
NN_LIST
from
mindinsight.mindconverter.config
import
ALL_TORCH_APIS
from
mindinsight.mindconverter.config
import
ALL_2P_LIST
from
mindinsight.mindconverter.config
import
UNSUPPORTED_WARN_INFOS
from
mindinsight.mindconverter.config
import
ALL_UNSUPPORTED
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.forward_call
import
ForwardCall
LINE_NO_INDEX_DIFF
=
1
from
mindinsight.mindconverter.common.exceptions
import
ScriptNotSupport
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.ast_edits
import
AstEditVisitor
class
Converter
:
class
Converter
:
"""Convert class"""
"""Convert class"""
convert_info
=
''
flags
=
os
.
O_WRONLY
|
os
.
O_CREAT
|
os
.
O_EXCL
flags
=
os
.
O_WRONLY
|
os
.
O_CREAT
|
os
.
O_EXCL
modes
=
stat
.
S_IWUSR
|
stat
.
S_IRUSR
modes
=
stat
.
S_IWUSR
|
stat
.
S_IRUSR
@
staticmethod
def
__init__
(
self
):
def
is_local_defined
(
obj
,
member
):
self
.
_tree
=
None
"""
self
.
_infile
=
None
Check if obj and member are both defined in the same source file.
self
.
_code_analyzer
=
None
self
.
_ast_editor
=
None
Args:
self
.
_report
=
[]
obj (Union[object, module]): A module or a class.
member (func): A function of obj.
Returns:
bool, True or False.
"""
srcfile
=
inspect
.
getsourcefile
(
obj
)
return
inspect
.
getsourcefile
(
member
)
==
srcfile
@
classmethod
def
is_valid_module
(
cls
,
obj
,
member
):
"""
Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function.
Returns:
bool, True or False.
"""
if
inspect
.
isclass
(
member
):
is_subclass
=
member
.
__base__
.
__name__
in
[
'Module'
,
'Sequential'
,
'ModuleList'
,
'ModuleDict'
,
'ParameterList'
,
'ParameterDict'
]
return
is_subclass
and
cls
.
is_local_defined
(
obj
,
member
)
return
False
@
classmethod
def
is_valid_function
(
cls
,
obj
,
member
):
"""
Check if member is function and defined in the file same as obj.
Args:
obj (Union[object, module]: The obj.
member (func): The func.
Returns:
bool, True or False.
"""
return
inspect
.
isfunction
(
member
)
and
cls
.
is_local_defined
(
obj
,
member
)
@
staticmethod
def
find_left_parentheses
(
string
,
right
):
"""
Find index of the first left parenthesis.
Args:
string (str): A line of code.
right (int): The right index for string to find from.
Returns:
int, index of the first parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
if
string
[
right
]
!=
')'
:
raise
ValueError
(
'code [{}] at index {} not ")".'
.
format
(
string
,
right
))
stack
=
[]
for
i
in
range
(
right
,
-
1
,
-
1
):
if
string
[
i
]
==
')'
:
stack
.
append
(
')'
)
elif
string
[
i
]
==
'('
:
stack
.
pop
()
if
not
stack
:
return
i
raise
ValueError
(
"{} should contain ()"
.
format
(
string
))
@
staticmethod
def
find_right_parentheses
(
string
,
left
):
"""
Find first index of right parenthesis which make all left parenthesis make sense.
Args:
string (str): A line of code.
left (int): Start index of string to find from.
Returns:
int, index of the found right parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
stack
=
[]
for
i
in
range
(
left
,
len
(
string
)):
if
string
[
i
]
==
'('
:
stack
.
append
(
'('
)
elif
string
[
i
]
==
')'
:
stack
.
pop
()
if
not
stack
:
return
i
raise
ValueError
(
"{} should contain ()"
.
format
(
string
))
@
staticmethod
def
get_call_name
(
code
,
end
):
"""
Traverse code in a reversed function from index end and get the call name and start index of the call name,
if call name not found, return a null character string and -1
Args:
code (str): The str of code to find from.
end (int): Start index to find.
Returns:
tuple(str, int), one is founded api name if found, else a null character string, the other is start index
of founded api name, -1 if api name not found
"""
stack
=
[]
for
i
in
range
(
end
-
1
,
-
1
,
-
1
):
if
code
[
i
]
in
[
"("
,
"["
,
"{"
]:
if
stack
:
stack
.
pop
()
else
:
return
code
[
i
+
1
:
end
],
i
+
1
elif
code
[
i
]
in
[
")"
,
"]"
,
"}"
]:
stack
.
append
(
code
[
i
])
elif
stack
:
continue
elif
not
(
code
[
i
].
isalpha
()
or
code
[
i
].
isdigit
()
or
code
[
i
]
==
'_'
or
code
[
i
]
==
'.'
):
return
code
[
i
+
1
:
end
],
i
+
1
return
""
,
-
1
def
convert_api
(
self
,
code
,
start
,
api_name
=
""
):
"""
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api,
code will not convert.
Args:
code (str): The str code to convert.
start (int): The index of code to start convert from.
api_name (str): The api name to convert.
Returns:
str, the converted code.
int, index of converted api_name in code.
"""
# handle format like .shape(
if
api_name
.
startswith
(
'.'
):
call_name
,
new_start
=
self
.
get_call_name
(
code
,
start
)
if
start
==
-
1
or
call_name
==
"self"
:
return
code
,
start
+
1
else
:
call_name
=
api_name
new_start
=
start
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
left
=
code
.
find
(
"("
,
start
)
if
left
==
-
1
:
raise
ValueError
(
'"(" not found, {} should work with "("'
.
format
(
call_name
))
right
=
self
.
find_right_parentheses
(
code
,
left
)
end
=
right
expr
=
code
[
start
:
end
+
1
]
args_str
=
code
[
left
:
right
+
1
]
map_helper
=
ALL_MAPPING
[
api_name
]
new_expr
=
map_helper
.
convert
(
call_name
,
args_str
)
next_newline
=
code
.
find
(
"
\n
"
,
end
+
1
)
fill_num
=
(
expr
.
count
(
"
\n
"
)
-
new_expr
.
count
(
"
\n
"
))
if
next_newline
!=
-
1
:
code
=
code
[:
new_start
]
+
new_expr
+
code
[
end
+
1
:
next_newline
]
+
(
"
\n
"
*
fill_num
)
+
code
[
next_newline
:]
else
:
code
=
code
[:
new_start
]
+
new_expr
+
")"
+
(
"
\n
"
*
fill_num
)
+
code
[
end
+
2
:]
return
code
,
start
+
len
(
map_helper
.
ms_api
.
name
)
@
staticmethod
def
find_api
(
code
,
i
,
is_forward
):
"""
Find api name from code with a start index i, check api name ok with a is_forward condition.
Args:
code (str): The code from which to find api name.
i (int): The start index to find.
is_forward (bool): Check if the found api name ok.
Returns:
str, api name if find api name and check ok with is_forward condition, else a null character string.
"""
if
code
[
i
:].
startswith
(
"nn."
)
\
or
code
[
i
:].
startswith
(
"F."
)
\
or
code
[
i
:].
startswith
(
"torch."
)
\
or
code
[
i
:].
startswith
(
'.'
):
j
=
code
.
find
(
'('
,
i
)
if
j
!=
-
1
and
code
[
i
:
j
]
in
ALL_TORCH_APIS
:
api_name
=
code
[
i
:
j
]
if
(
not
is_forward
and
api_name
in
NN_LIST
)
or
(
is_forward
and
api_name
in
ALL_2P_LIST
):
return
api_name
return
""
def
convert_function
(
self
,
fun_name
,
fun
,
is_forward
):
"""
Convert a PyTorch function into MindSpore function.
Args:
fun_name (str): The str of function name.
fun (func): The function to convert.
is_forward (bool): If the function is defined in forward function in nn.Module in torch.
Returns:
dict, old code and converted code map if convert happens, else {}.
"""
_
,
line_no
=
inspect
.
getsourcelines
(
fun
)
logger
.
info
(
"Line %3d: start converting function %s()"
,
line_no
,
fun_name
)
code
=
inspect
.
getsource
(
fun
)
code_saved
=
copy
.
copy
(
code
)
i
=
0
while
i
<
len
(
code
):
api_name
=
self
.
find_api
(
code
,
i
,
is_forward
)
if
api_name
:
line_no1
=
line_no
+
code
[:
i
].
count
(
'
\n
'
)
if
api_name
in
ALL_MAPPING
:
logger
.
info
(
"Line %3d start converting API: %s"
,
line_no1
,
api_name
)
code
,
i
=
self
.
convert_api
(
code
,
i
,
api_name
)
self
.
convert_info
+=
"[Convert][Line{:3d}] {} is converted.
\n
"
.
format
(
line_no1
,
api_name
)
continue
if
api_name
in
ALL_UNSUPPORTED
:
warn_info
=
". "
+
UNSUPPORTED_WARN_INFOS
[
api_name
]
if
api_name
in
UNSUPPORTED_WARN_INFOS
else
""
logger
.
warning
(
"Line %3d: found unsupported API: %s%s"
,
line_no1
,
api_name
,
warn_info
)
self
.
convert_info
+=
"[Unconvert][Line{:3d}] {} didn't convert{}
\n
"
.
format
(
line_no1
,
api_name
,
warn_info
)
i
+=
1
return
{
code_saved
:
code
}
if
code_saved
!=
code
else
{}
@
staticmethod
def
judge_forward
(
name
,
forward_list
):
"""
Check if function is a forward function.
Args:
name (str): The function name.
forward_list (set): A set of forward function.
Returns:
bool, True or False
"""
is_forward
=
name
in
forward_list
or
name
.
split
(
"."
)[
-
1
]
==
"forward"
if
is_forward
:
logger
.
debug
(
"%s is a forward function"
,
name
)
return
is_forward
def
convert_module
(
self
,
module_name
,
module
,
forward_list
):
"""
Convert a PyTorch module code into MindSpore module code.
Args:
module_name (str): The module's name.
module (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, map of old code and converted code.
"""
_
,
line_no
=
inspect
.
getsourcelines
(
module
)
logger
.
info
(
"Line {:3d}: start converting nn.Module {}"
.
format
(
line_no
,
module_name
))
mapped
=
{}
for
name
,
member
in
inspect
.
getmembers
(
module
):
if
self
.
is_valid_function
(
module
,
member
):
is_forward
=
self
.
judge_forward
(
"{}.{}"
.
format
(
module_name
,
name
),
forward_list
)
mapped
.
update
(
self
.
convert_function
(
name
,
member
,
is_forward
))
return
mapped
def
get_mapping
(
self
,
import_mod
,
forward_list
):
"""
Convert code of a module and get mapping of old code and convert code.
Args:
import_mod (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, mapping for old code and converted code of the module
"""
mapping
=
{}
tasks
=
[]
for
name
,
member
in
inspect
.
getmembers
(
import_mod
):
if
self
.
is_valid_module
(
import_mod
,
member
):
_
,
line_no
=
inspect
.
getsourcelines
(
member
)
tasks
.
append
((
line_no
,
self
.
convert_module
,
(
name
,
member
,
forward_list
)))
elif
self
.
is_valid_function
(
import_mod
,
member
):
_
,
line_no
=
inspect
.
getsourcelines
(
member
)
is_forward
=
self
.
judge_forward
(
"{}.{}"
.
format
(
import_mod
,
name
),
forward_list
)
tasks
.
append
((
line_no
,
self
.
convert_function
,
(
name
,
member
,
is_forward
)))
tasks
.
sort
()
for
_
,
convert_fun
,
args
in
tasks
:
mapping
.
update
(
convert_fun
(
*
args
))
return
mapping
@
staticmethod
def
get_code_start_line_num
(
source_lines
):
"""
Get the start code line number exclude comments.
Args:
source_lines (list[str]): Split results of original code.
Returns:
int, the start line number.
"""
stack
=
[]
index
=
0
for
i
,
line
in
enumerate
(
source_lines
):
if
line
.
strip
().
startswith
(
'#'
):
continue
if
line
.
strip
().
startswith
(
'"""'
):
if
not
line
.
endswith
(
'"""
\n
'
):
stack
.
append
(
'"""'
)
continue
if
line
.
strip
().
startswith
(
"'''"
):
if
not
line
.
endswith
(
"'''
\n
"
):
stack
.
append
(
"'''"
)
continue
if
line
.
endswith
(
'"""
\n
'
)
or
line
.
endswith
(
"'''
\n
"
):
stack
.
pop
()
continue
if
line
.
strip
()
!=
''
and
not
stack
:
index
=
i
break
return
index
def
update_code_and_convert_info
(
self
,
code
,
mapping
):
"""
Replace code according to mapping, and update convert info.
Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.
Returns:
str, the replaced code.
"""
for
key
,
value
in
mapping
.
items
():
code
=
code
.
replace
(
key
,
value
)
source_lines
=
code
.
splitlines
(
keepends
=
True
)
start_line_number
=
self
.
get_code_start_line_num
(
source_lines
)
add_import_infos
=
[
'import mindspore
\n
'
,
'import mindspore.nn as nn
\n
'
,
'import mindspore.ops.operations as P
\n
'
]
for
i
,
add_import_info
in
enumerate
(
add_import_infos
):
source_lines
.
insert
(
start_line_number
+
i
,
add_import_info
)
self
.
convert_info
+=
'[Add Import] {}.
\n
'
.
format
(
add_import_info
.
strip
())
insert_count
=
len
(
add_import_infos
)
line_diff
=
insert_count
-
LINE_NO_INDEX_DIFF
for
i
in
range
(
start_line_number
+
insert_count
,
len
(
source_lines
)):
def
convert
(
self
,
infile
,
output_dir
,
report_dir
):
line
=
source_lines
[
i
]
if
(
line
.
startswith
(
'from torch'
)
and
'import'
in
line
)
or
line
.
startswith
(
'import torch'
):
new_line
=
'# '
+
line
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Annotate][Line{:3d}] {} is annotated.
\n
'
.
format
(
i
-
line_diff
,
line
.
strip
())
if
line
.
strip
().
startswith
(
'class'
)
and
'(nn.Module)'
in
line
:
new_line
=
line
.
replace
(
'nn.Module'
,
'nn.Cell'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Convert][Line{:3d}] nn.Module is converted.
\n
'
.
format
(
i
-
line_diff
)
if
line
.
strip
().
startswith
(
'def forward('
):
new_line
=
line
.
replace
(
'forward'
,
'construct'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Convert][Line{:3d}] forward is converted.
\n
'
.
format
(
i
-
line_diff
)
if
'nn.Linear'
in
line
:
new_line
=
line
.
replace
(
'nn.Linear'
,
'nn.Dense'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Convert][Line{:3d}] nn.Linear is converted.
\n
'
.
format
(
i
-
line_diff
)
if
'(nn.Sequential)'
in
line
:
new_line
=
line
.
replace
(
'nn.Sequential'
,
'nn.SequentialCell'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Convert][Line{:3d}] nn.Sequential is converted.
\n
'
.
format
(
i
-
line_diff
)
if
'nn.init.'
in
line
:
new_line
=
line
.
replace
(
'nn.init'
,
'pass # nn.init'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Annotate][Line{:3d}] {} is annotated.
\n
'
.
format
(
i
-
line_diff
,
'nn.init'
)
code
=
''
.
join
(
source_lines
)
return
code
def
convert
(
self
,
import_name
,
output_dir
,
report_dir
):
"""
"""
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
Args:
Args:
i
mport_name (str): The module from which to import the module
to convert.
i
nfile (str): The script
to convert.
output_dir (str): The path to save converted file.
output_dir (str): The path to save converted file.
report_dir (str): The path to save report file.
report_dir (str): The path to save report file.
"""
"""
logger
.
info
(
"Start converting %s"
,
import_name
)
in_file_split
=
_path_split
(
infile
)
start_info
=
'[Start Convert]
\n
'
in_file_split
[
-
1
],
_
=
_get_name_ext
(
in_file_split
[
-
1
])
module_info
=
'The module is {}.
\n
'
.
format
(
import_name
)
module_name
=
'.'
.
join
(
in_file_split
)
with
open
(
infile
,
'r'
)
as
file
:
import_mod
=
importlib
.
import_module
(
import_name
)
content
=
''
.
join
(
file
.
readlines
())
srcfile
=
inspect
.
getsourcefile
(
import_mod
)
logger
.
info
(
"Script file is %s"
,
srcfile
)
self
.
_infile
=
infile
self
.
_tree
=
pasta
.
parse
(
content
)
forward_list
=
set
(
ForwardCall
(
srcfile
).
calls
)
self
.
_report
.
clear
()
logger
.
debug
(
"Forward_list: %s"
,
forward_list
)
try
:
logger
.
info
(
"Script file is %s"
,
infile
)
# replace python function under nn.Module
logger
.
info
(
"Start converting %s"
,
module_name
)
mapping
=
self
.
get_mapping
(
import_mod
,
forward_list
)
self
.
_report
.
append
(
'[Start Convert]'
)
code
=
inspect
.
getsource
(
import_mod
)
self
.
_ast_editor
=
AstEditVisitor
()
code
=
self
.
update_code_and_convert_info
(
code
,
mapping
)
self
.
_ast_editor
.
process
(
self
.
_tree
)
convert_info_split
=
self
.
convert_info
.
splitlines
(
keepends
=
True
)
self
.
_report
.
extend
(
self
.
_ast_editor
.
get_logs
())
convert_info_split
=
sorted
(
convert_info_split
)
self
.
_report
.
append
(
'[Convert Over]'
)
convert_info_split
.
insert
(
0
,
start_info
)
dest_file
=
os
.
path
.
join
(
output_dir
,
os
.
path
.
basename
(
infile
))
convert_info_split
.
insert
(
1
,
module_info
)
convert_info_split
.
append
(
'[Convert Over]'
)
self
.
convert_info
=
''
.
join
(
convert_info_split
)
dest_file
=
os
.
path
.
join
(
output_dir
,
os
.
path
.
basename
(
srcfile
))
with
os
.
fdopen
(
os
.
open
(
dest_file
,
self
.
flags
,
self
.
modes
),
'w'
)
as
file
:
with
os
.
fdopen
(
os
.
open
(
dest_file
,
self
.
flags
,
self
.
modes
),
'w'
)
as
file
:
file
.
write
(
code
)
file
.
write
(
pasta
.
dump
(
self
.
_tree
)
)
logger
.
info
(
"Convert success. Result is wrote to %s."
,
dest_file
)
logger
.
info
(
"Convert success. Result is wrote to %s."
,
dest_file
)
except
ScriptNotSupport
as
error
:
self
.
_report
.
append
(
'[ScriptNotSupport] '
+
error
.
message
)
self
.
_report
.
append
(
'[Convert failed]'
)
raise
error
except
Exception
as
error
:
self
.
_report
.
clear
()
raise
error
finally
:
if
self
.
_report
:
dest_report_file
=
os
.
path
.
join
(
report_dir
,
dest_report_file
=
os
.
path
.
join
(
report_dir
,
'_'
.
join
(
os
.
path
.
basename
(
src
file
).
split
(
'.'
)[:
-
1
])
+
'_report.txt'
)
'_'
.
join
(
os
.
path
.
basename
(
in
file
).
split
(
'.'
)[:
-
1
])
+
'_report.txt'
)
with
os
.
fdopen
(
os
.
open
(
dest_report_file
,
self
.
flags
,
self
.
modes
),
'a'
)
as
file
:
with
os
.
fdopen
(
os
.
open
(
dest_report_file
,
self
.
flags
,
self
.
modes
),
'a'
)
as
file
:
file
.
write
(
self
.
convert_info
)
file
.
write
(
'
\n
'
.
join
(
self
.
_report
)
)
logger
.
info
(
"Convert report is saved in %s"
,
dest_report_file
)
logger
.
info
(
"Convert report is saved in %s"
,
dest_report_file
)
@
staticmethod
def
convert_api
(
source_code
):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
Args:
source_code (ast.Call): The ast node to convert.
Returns:
str, the converted code.
"""
ast_node
=
pasta
.
parse
(
source_code
).
body
[
0
].
value
check_context
=
False
replaced_code
=
AstEditVisitor
().
mapping_api
(
ast_node
,
check_context
)
return
replaced_code
def
_get_name_ext
(
file
):
def
_get_name_ext
(
file
):
"""
"""
...
@@ -514,14 +139,6 @@ def main(files_config):
...
@@ -514,14 +139,6 @@ def main(files_config):
files_config (dict): The config of files which to convert.
files_config (dict): The config of files which to convert.
"""
"""
convert_ins
=
Converter
()
convert_ins
=
Converter
()
root_path
=
files_config
[
'root_path'
]
in_files
=
files_config
[
'in_files'
]
in_files
=
files_config
[
'in_files'
]
for
in_file
in
in_files
:
for
in_file
in
in_files
:
in_file_split
=
_path_split
(
in_file
[
len
(
root_path
):])
convert_ins
.
convert
(
in_file
,
files_config
[
'outfile_dir'
],
files_config
[
'report_dir'
])
in_file_split
[
-
1
],
_
=
_get_name_ext
(
in_file_split
[
-
1
])
module_name
=
'.'
.
join
(
in_file_split
)
convert_ins
.
convert
(
module_name
,
files_config
[
'outfile_dir'
],
files_config
[
'report_dir'
])
in_module
=
files_config
.
get
(
'in_module'
)
if
in_module
:
convert_ins
.
convert
(
in_module
,
files_config
[
'outfile_dir'
],
files_config
[
'report_dir'
])
mindinsight/mindconverter/forward_call.py
浏览文件 @
bcdc61cc
...
@@ -14,7 +14,8 @@
...
@@ -14,7 +14,8 @@
# ============================================================================
# ============================================================================
"""Find out forward functions of script file"""
"""Find out forward functions of script file"""
import
ast
import
ast
import
os
import
pasta
class
ForwardCall
(
ast
.
NodeVisitor
):
class
ForwardCall
(
ast
.
NodeVisitor
):
...
@@ -24,73 +25,80 @@ class ForwardCall(ast.NodeVisitor):
...
@@ -24,73 +25,80 @@ class ForwardCall(ast.NodeVisitor):
Find the sub functions called by the forward function in the script file.
Find the sub functions called by the forward function in the script file.
"""
"""
def
__init__
(
self
,
filenam
e
):
def
__init__
(
self
,
ast_tre
e
):
self
.
filename
=
filenam
e
self
.
_tree
=
ast_tre
e
self
.
module_name
=
os
.
path
.
basename
(
filename
).
replace
(
'.py'
,
''
)
self
.
_name_stack
=
[]
self
.
name
_stack
=
[]
self
.
_forward
_stack
=
[]
self
.
forward_stack
=
[]
self
.
calls
=
{}
# key is function name, value is forward function ast node.
self
.
calls
=
set
()
self
.
_function_list
=
{}
# key is function name, value is function ast node.
self
.
process
()
self
.
process
()
def
process
(
self
):
def
process
(
self
):
"""Parse the python source file to find the forward functions."""
"""visit ast tree to find the forward functions."""
with
open
(
self
.
filename
,
'rt'
,
encoding
=
'utf-8'
)
as
file
:
self
.
visit
(
self
.
_tree
)
content
=
file
.
read
()
# first visit to find out all functions, so restores all variables except _function_list
self
.
visit
(
ast
.
parse
(
content
,
self
.
filename
))
self
.
_name_stack
.
clear
()
self
.
_forward_stack
.
clear
()
self
.
calls
.
clear
()
self
.
visit
(
self
.
_tree
)
def
get_current_namespace
(
self
):
def
get_current_namespace
(
self
):
"""Get the namespace when visit the AST node"""
"""Get the namespace when visit the AST node"""
namespace
=
'.'
.
join
(
self
.
name_stack
)
namespace
=
'.'
.
join
(
self
.
_
name_stack
)
return
namespace
return
namespace
@
classmethod
@
classmethod
def
get_ast_node_name
(
cls
,
node
):
def
get_call_name
(
cls
,
node
):
"""Get AST node name."""
"""Get functional call name."""
if
isinstance
(
node
,
ast
.
Attribute
):
if
not
isinstance
(
node
,
ast
.
Call
):
return
f
'
{
cls
.
get_ast_node_name
(
node
.
value
)
}
.
{
node
.
attr
}
'
return
None
if
isinstance
(
node
,
ast
.
Name
):
return
node
.
id
return
node
return
pasta
.
dump
(
node
.
func
)
def
visit_ClassDef
(
self
,
node
):
def
visit_ClassDef
(
self
,
node
):
"""Callback function when visit AST tree"""
"""Callback function when visit AST tree"""
self
.
name_stack
.
append
(
node
.
name
)
self
.
_
name_stack
.
append
(
node
.
name
)
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
self
.
name_stack
.
pop
()
self
.
_
name_stack
.
pop
()
def
visit_FunctionDef
(
self
,
node
):
def
visit_FunctionDef
(
self
,
node
):
"""Callback function when visit AST tree"""
"""Callback function when visit AST tree"""
namespace
=
self
.
get_current_namespace
()
if
namespace
:
func_name
=
f
'
{
namespace
}
.
{
node
.
name
}
'
else
:
func_name
=
node
.
name
func_name
=
f
'
{
self
.
get_current_namespace
()
}
.
{
node
.
name
}
'
func_name
=
f
'
{
self
.
get_current_namespace
()
}
.
{
node
.
name
}
'
is_in_chain
=
func_name
in
self
.
calls
or
node
.
name
==
'forward'
is_in_chain
=
func_name
in
self
.
calls
or
node
.
name
==
'forward'
if
is_in_chain
:
if
is_in_chain
:
self
.
forward_stack
.
append
(
func_name
)
self
.
_
forward_stack
.
append
(
func_name
)
if
node
.
name
==
'forward'
:
if
node
.
name
==
'forward'
:
self
.
calls
.
add
(
func_name
)
self
.
calls
.
update
({
func_name
:
node
}
)
self
.
_function_list
.
update
({
func_name
:
node
})
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
if
is_in_chain
:
if
is_in_chain
:
self
.
forward_stack
.
pop
()
self
.
_
forward_stack
.
pop
()
def
visit_Call
(
self
,
node
):
def
visit_Call
(
self
,
node
):
"""Callback function when visit AST tree"""
"""Callback function when visit AST tree"""
for
arg
in
node
.
args
:
for
arg
in
node
.
args
:
self
.
visit
(
arg
)
self
.
visit
(
arg
)
for
k
w
in
node
.
keywords
:
for
k
eyword
in
node
.
keywords
:
self
.
visit
(
k
w
.
value
)
self
.
visit
(
k
eyword
.
value
)
func_name
=
self
.
get_
ast_node_name
(
node
.
func
)
func_name
=
self
.
get_
call_name
(
node
)
if
isinstance
(
node
.
func
,
ast
.
Name
):
if
isinstance
(
node
.
func
,
ast
.
Name
):
if
func_name
not
in
[
'super'
,
'str'
,
'repr'
]:
if
func_name
not
in
[
'super'
,
'str'
,
'repr'
]:
if
self
.
forward_stack
:
if
self
.
_
forward_stack
:
self
.
calls
.
add
(
func_name
)
self
.
calls
.
update
({
func_name
:
self
.
_function_list
.
get
(
func_name
)}
)
self
.
visit
(
node
.
func
)
self
.
visit
(
node
.
func
)
else
:
else
:
if
self
.
forward_stack
:
if
self
.
_forward_stack
:
if
'self'
in
func_name
:
if
func_name
.
startswith
(
'self.'
):
self
.
calls
.
add
(
f
'
{
self
.
get_current_namespace
()
}
.
{
func_name
.
split
(
"."
)[
-
1
]
}
'
)
whole_name
=
f
'
{
self
.
get_current_namespace
()
}
.
{
func_name
.
split
(
"."
)[
-
1
]
}
'
self
.
calls
.
update
({
whole_name
:
self
.
_function_list
.
get
(
whole_name
)})
else
:
else
:
self
.
calls
.
add
(
func_name
)
self
.
calls
.
update
({
func_name
:
self
.
_function_list
.
get
(
func_name
)}
)
self
.
visit
(
node
.
func
)
self
.
visit
(
node
.
func
)
mindinsight/utils/constant.py
浏览文件 @
bcdc61cc
...
@@ -30,6 +30,7 @@ class MindInsightModules(Enum):
...
@@ -30,6 +30,7 @@ class MindInsightModules(Enum):
LINEAGEMGR
=
2
LINEAGEMGR
=
2
DATAVISUAL
=
5
DATAVISUAL
=
5
PROFILERMGR
=
6
PROFILERMGR
=
6
SCRIPTCONVERTER
=
7
class
GeneralErrors
(
Enum
):
class
GeneralErrors
(
Enum
):
...
@@ -69,3 +70,7 @@ class DataVisualErrors(Enum):
...
@@ -69,3 +70,7 @@ class DataVisualErrors(Enum):
SCALAR_NOT_EXIST
=
14
SCALAR_NOT_EXIST
=
14
HISTOGRAM_NOT_EXIST
=
15
HISTOGRAM_NOT_EXIST
=
15
TRAIN_JOB_DETAIL_NOT_IN_CACHE
=
16
TRAIN_JOB_DETAIL_NOT_IN_CACHE
=
16
class
ScriptConverterErrors
(
Enum
):
"""Enum definition for mindconverter errors."""
tests/ut/mindconverter/test_converter.py
浏览文件 @
bcdc61cc
...
@@ -22,380 +22,201 @@ class TestConverter:
...
@@ -22,380 +22,201 @@ class TestConverter:
converter_ins
=
Converter
()
converter_ins
=
Converter
()
def
test_judge_forward
(
self
):
"""test judge_forward"""
name1
=
'conv1'
forward_list
=
{
'conv1'
,
'relu'
}
result1
=
self
.
converter_ins
.
judge_forward
(
name1
,
forward_list
)
assert
result1
is
True
name2
=
'self.forward'
result2
=
self
.
converter_ins
.
judge_forward
(
name2
,
forward_list
)
assert
result2
is
True
def
test_find_left_parentheses
(
self
):
"""test find_left_parentheses"""
code
=
'''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
right_index
=
len
(
code
)
-
1
left_index
=
code
.
index
(
'nn.Conv2d'
)
result
=
self
.
converter_ins
.
find_left_parentheses
(
code
,
right_index
)
assert
result
==
left_index
-
1
def
test_find_api
(
self
):
"""test find_api"""
code
=
'''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2), # TODO padding
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
index
=
0
is_forward
=
False
result
=
self
.
converter_ins
.
find_api
(
code
,
index
,
is_forward
)
assert
result
==
'nn.Sequential'
def
test_get_call_name
(
self
):
"""test get_call_name"""
code
=
'''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0))'''
end
=
len
(
code
)
call_name
,
index
=
self
.
converter_ins
.
get_call_name
(
code
,
end
)
assert
call_name
==
''
assert
index
==
-
1
def
test_find_right_parentheses
(
self
):
"""test find_right_parentheses"""
code
=
'''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2), # TODO padding
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
left_index
=
0
result
=
self
.
converter_ins
.
find_right_parentheses
(
code
,
left_index
)
assert_index
=
len
(
code
)
-
1
assert
result
==
assert_index
# test convert_api with nn ops
# test convert_api with nn ops
def
test_convert_api_nn_layernorm
(
self
):
def
test_convert_api_nn_layernorm
(
self
):
"""Test convert_api function work ok when convert api nn.LayerNorm"""
"""Test convert_api function work ok when convert api nn.LayerNorm"""
code
=
"""
code
=
"nn.LayerNorm((5, 10, 10), elementwise_affine=False)"
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.ReLU(inplace=False)
])
"""
api_name
=
'nn.LayerNorm'
api_name
=
'nn.LayerNorm'
start
=
code
.
find
(
api_name
)
layer_norm_info
=
NN_MAPPING
.
get
(
api_name
)
layer_norm_info
=
NN_MAPPING
.
get
(
api_name
)
expected_ms_api_name
=
'nn.LayerNorm'
expected_ms_api_name
=
'nn.LayerNorm'
epsilon
=
layer_norm_info
.
pt_api
.
params
.
get
(
'eps'
)
epsilon
=
layer_norm_info
.
pt_api
.
params
.
get
(
'eps'
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'nn.LayerNorm((5, 10, 10), elementwise_affine=False)'
,
assert
replaced_code
==
code
.
replace
(
'nn.LayerNorm((5, 10, 10), elementwise_affine=False)'
,
'{}(normalized_shape=(5, 10, 10), epsilon={})'
.
format
(
'{}(normalized_shape=(5, 10, 10), epsilon={})'
.
format
(
expected_ms_api_name
,
epsilon
))
expected_ms_api_name
,
epsilon
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_leaky_relu
(
self
):
def
test_convert_api_nn_leaky_relu
(
self
):
"""Test convert_api function work ok when convert api nn.LeakyReLU"""
"""Test convert_api function work ok when convert api nn.LeakyReLU"""
code
=
"""
code
=
"nn.LeakyReLU(0.3)"
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.LeakyReLU(0.3)])
"""
api_name
=
'nn.LeakyReLU'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'nn.LeakyReLU'
expected_ms_api_name
=
'nn.LeakyReLU'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'nn.LeakyReLU(0.3)'
,
assert
replaced_code
==
code
.
replace
(
'nn.LeakyReLU(0.3)'
,
'{}(alpha=0.3)'
.
format
(
expected_ms_api_name
))
'{}(alpha=0.3)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_prelu
(
self
):
def
test_convert_api_nn_prelu
(
self
):
"""Test convert_api function work ok when convert api nn.PReLU"""
"""Test convert_api function work ok when convert api nn.PReLU"""
code
=
"""
code
=
"nn.PReLU()(input)"
input = torch.randn(2, 3, 5)
nn.PReLU()(input)
"""
api_name
=
'nn.PReLU'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'nn.PReLU'
expected_ms_api_name
=
'nn.PReLU'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'nn.PReLU()(input)'
,
assert
replaced_code
==
code
.
replace
(
'nn.PReLU()(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_softmax
(
self
):
def
test_convert_api_nn_softmax
(
self
):
"""Test convert_api function work ok when convert api nn.Softmax"""
"""Test convert_api function work ok when convert api nn.Softmax"""
code
=
"""
code
=
"nn.Softmax(dim=1)"
nn.Softmax(dim=1)(input)
"""
api_name
=
'nn.Softmax'
expected_ms_api_name
=
'nn.Softmax'
expected_ms_api_name
=
'nn.Softmax'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
code
)
assert
replaced_code
==
code
.
replace
(
'nn.Softmax(dim=1)(input)'
,
assert
replaced_code
==
code
.
replace
(
'nn.Softmax(dim=1)'
,
'{}(axis=1)(input)'
.
format
(
expected_ms_api_name
))
'{}(axis=1)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
# test convert_api with torch dot ops
# test convert_api with torch dot ops
def
test_convert_api_torch_dot_abs
(
self
):
def
test_convert_api_torch_dot_abs
(
self
):
"""Test convert_api function work ok when convert api torch.abs"""
"""Test convert_api function work ok when convert api torch.abs"""
code
=
"""
code
=
"torch.abs(input)"
torch.abs(input)
"""
api_name
=
'torch.abs'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.Abs'
expected_ms_api_name
=
'P.Abs'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.abs(input)'
,
assert
replaced_code
==
code
.
replace
(
'torch.abs(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_acos
(
self
):
def
test_convert_api_torch_dot_acos
(
self
):
"""Test convert_api function work ok when convert api torch.acos"""
"""Test convert_api function work ok when convert api torch.acos"""
code
=
"""
code
=
"torch.acos(input)"
torch.acos(input)
"""
api_name
=
'torch.acos'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.ACos'
expected_ms_api_name
=
'P.ACos'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.acos(input)'
,
assert
replaced_code
==
code
.
replace
(
'torch.acos(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_cos
(
self
):
def
test_convert_api_torch_dot_cos
(
self
):
"""Test convert_api function work ok when convert api torch.cos"""
"""Test convert_api function work ok when convert api torch.cos"""
code
=
"""
code
=
"torch.cos(input)"
torch.cos(input)
"""
api_name
=
'torch.cos'
expected_ms_api_name
=
'P.Cos'
expected_ms_api_name
=
'P.Cos'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.cos(input)'
,
assert
replaced_code
==
code
.
replace
(
'torch.cos(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_exp
(
self
):
def
test_convert_api_torch_dot_exp
(
self
):
"""Test convert_api function work ok when convert api torch.exp"""
"""Test convert_api function work ok when convert api torch.exp"""
code
=
"""
code
=
"torch.exp(input)"
torch.exp(input)
"""
api_name
=
'torch.exp'
expected_ms_api_name
=
'P.Exp'
expected_ms_api_name
=
'P.Exp'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.exp(input)'
,
assert
replaced_code
==
code
.
replace
(
'torch.exp(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_log
(
self
):
def
test_convert_api_torch_dot_log
(
self
):
"""Test convert_api function work ok when convert api torch.log"""
"""Test convert_api function work ok when convert api torch.log"""
code
=
"""
code
=
"torch.log(input)"
torch.log(input)
"""
api_name
=
'torch.log'
expected_ms_api_name
=
'P.Log'
expected_ms_api_name
=
'P.Log'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.log(input)'
,
assert
replaced_code
==
code
.
replace
(
'torch.log(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_pow
(
self
):
def
test_convert_api_torch_dot_pow
(
self
):
"""Test convert_api function work ok when convert api torch.pow"""
"""Test convert_api function work ok when convert api torch.pow"""
code
=
"""
code
=
"torch.pow(a, exp)"
torch.pow(a, exp)
"""
api_name
=
'torch.pow'
expected_ms_api_name
=
'P.Pow'
expected_ms_api_name
=
'P.Pow'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.pow(a, exp)'
,
assert
replaced_code
==
code
.
replace
(
'torch.pow(a, exp)'
,
'{}()(a, exp)'
.
format
(
expected_ms_api_name
))
'{}()(a, exp)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_div
(
self
):
def
test_convert_api_torch_dot_div
(
self
):
"""Test convert_api function work ok when convert api torch.div"""
"""Test convert_api function work ok when convert api torch.div"""
code
=
"""
code
=
"torch.div(input, other)"
input = torch.randn(5)
other = torch.randn(5)
torch.div(input, other)
"""
api_name
=
'torch.div'
expected_ms_api_name
=
'P.Div'
expected_ms_api_name
=
'P.Div'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.div(input, other)'
,
assert
replaced_code
==
code
.
replace
(
'torch.div(input, other)'
,
'{}()(input, other)'
.
format
(
expected_ms_api_name
))
'{}()(input, other)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_sin
(
self
):
def
test_convert_api_torch_dot_sin
(
self
):
"""Test convert_api function work ok when convert api torch.sin"""
"""Test convert_api function work ok when convert api torch.sin"""
code
=
"""
code
=
"torch.sin(input)"
torch.sin(input)
"""
api_name
=
'torch.sin'
expected_ms_api_name
=
'P.Sin'
expected_ms_api_name
=
'P.Sin'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.sin(input)'
,
assert
replaced_code
==
code
.
replace
(
'torch.sin(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_sqrt
(
self
):
def
test_convert_api_torch_dot_sqrt
(
self
):
"""Test convert_api function work ok when convert api torch.sqrt"""
"""Test convert_api function work ok when convert api torch.sqrt"""
code
=
"""
code
=
"torch.sqrt(input)"
torch.sqrt(input)
"""
api_name
=
'torch.sqrt'
expected_ms_api_name
=
'P.Sqrt'
expected_ms_api_name
=
'P.Sqrt'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.sqrt(input)'
,
assert
replaced_code
==
code
.
replace
(
'torch.sqrt(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_eye_with_n
(
self
):
def
test_convert_api_torch_dot_eye_with_n
(
self
):
"""Test convert_api function work ok when convert api torch.eye"""
"""Test convert_api function work ok when convert api torch.eye"""
code
=
"""
code
=
"torch.eye(3)"
torch.eye(3)
"""
api_name
=
'torch.eye'
expected_ms_api_name
=
'P.Eye'
expected_ms_api_name
=
'P.Eye'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.eye(3)'
,
assert
replaced_code
==
code
.
replace
(
'torch.eye(3)'
,
'{}()(3, 3, mindspore.int32)'
.
format
(
expected_ms_api_name
))
'{}()(3, 3, mindspore.int32)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_eye_with_m
(
self
):
def
test_convert_api_torch_dot_eye_with_m
(
self
):
"""Test convert_api function work ok when convert api torch.eye"""
"""Test convert_api function work ok when convert api torch.eye"""
code
=
"""
code
=
"torch.eye(3, 4)"
torch.eye(3, 4)
"""
api_name
=
'torch.eye'
expected_ms_api_name
=
'P.Eye'
expected_ms_api_name
=
'P.Eye'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.eye(3, 4)'
,
assert
replaced_code
==
code
.
replace
(
'torch.eye(3, 4)'
,
'{}()(3, 4, mindspore.int32)'
.
format
(
expected_ms_api_name
))
'{}()(3, 4, mindspore.int32)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_add_with_alpha_default
(
self
):
def
test_convert_api_torch_dot_add_with_alpha_default
(
self
):
"""Test convert_api function work ok when convert api torch.add"""
"""Test convert_api function work ok when convert api torch.add"""
code
=
"""
code
=
"torch.add(input, value)"
torch.add(input, value)
"""
api_name
=
'torch.add'
expected_ms_api_name
=
'P.TensorAdd'
expected_ms_api_name
=
'P.TensorAdd'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.add(input, value)'
,
assert
replaced_code
==
code
.
replace
(
'torch.add(input, value)'
,
'{}()(input, value)'
.
format
(
expected_ms_api_name
))
'{}()(input, value)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_add_with_alpha_not_default
(
self
):
def
test_convert_api_torch_dot_add_with_alpha_not_default
(
self
):
"""Test convert_api function work ok when convert api torch.add"""
"""Test convert_api function work ok when convert api torch.add"""
code
=
"""
code
=
"torch.add(input, value, 3)"
torch.add(input, value, 3)
"""
api_name
=
'torch.add'
expected_ms_api_name
=
'P.TensorAdd'
expected_ms_api_name
=
'P.TensorAdd'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.add(input, value, 3)'
,
assert
replaced_code
==
code
.
replace
(
'torch.add(input, value, 3)'
,
'{}()(input, value*3)'
.
format
(
expected_ms_api_name
))
'{}()(input, value*3)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
# test convert_api with F ops
# test convert_api with F ops
def
test_convert_api_f_normalize
(
self
):
def
test_convert_api_f_normalize
(
self
):
"""Test convert_api function work ok when convert api F.normalize"""
"""Test convert_api function work ok when convert api F.normalize"""
code
=
"""
code
=
"F.normalize(input)"
input = torch.randn(2, 3, 5)
F.normalize(input)
"""
api_name
=
'F.normalize'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.L2Normalize'
expected_ms_api_name
=
'P.L2Normalize'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'F.normalize(input)'
,
assert
replaced_code
==
code
.
replace
(
'F.normalize(input)'
,
'{}(1, 1e-12)(input)'
.
format
(
expected_ms_api_name
))
'{}(1, 1e-12)(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_f_sigmoid
(
self
):
def
test_convert_api_f_sigmoid
(
self
):
"""Test convert_api function work ok when convert api F.sigmoid"""
"""Test convert_api function work ok when convert api F.sigmoid"""
code
=
"""
code
=
"F.sigmoid(input)"
input = torch.randn(2, 3, 5)
F.sigmoid(input)
"""
api_name
=
'F.sigmoid'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.Sigmoid'
expected_ms_api_name
=
'P.Sigmoid'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'F.sigmoid(input)'
,
assert
replaced_code
==
code
.
replace
(
'F.sigmoid(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
# test convert_api with tensor dot ops
# test convert_api with tensor dot ops
def
test_convert_api_tensor_dot_repeat
(
self
):
def
test_convert_api_tensor_dot_repeat
(
self
):
"""Test convert_api function work ok when convert api .repeat"""
"""Test convert_api function work ok when convert api .repeat"""
code
=
"""
code
=
"x.repeat(4, 2)"
x.repeat(4, 2)
"""
api_name
=
'.repeat'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.Tile'
expected_ms_api_name
=
'P.Tile'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'x.repeat(4, 2)'
,
assert
replaced_code
==
code
.
replace
(
'x.repeat(4, 2)'
,
'{}()(x, {})'
.
format
(
expected_ms_api_name
,
'(4, 2,)'
))
'{}()(x, {})'
.
format
(
expected_ms_api_name
,
'(4, 2,)'
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_tensor_dot_permute
(
self
):
def
test_convert_api_tensor_dot_permute
(
self
):
"""Test convert_api function work ok when convert api .permute"""
"""Test convert_api function work ok when convert api .permute"""
code
=
"""
code
=
"x.permute(2, 0, 1)"
x.permute(2, 0, 1)
"""
api_name
=
'.permute'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.Transpose'
expected_ms_api_name
=
'P.Transpose'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'x.permute(2, 0, 1)'
,
assert
replaced_code
==
code
.
replace
(
'x.permute(2, 0, 1)'
,
'{}()(x, (2, 0, 1,))'
.
format
(
expected_ms_api_name
))
'{}()(x, (2, 0, 1,))'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
tests/ut/mindconverter/test_forward_call.py
浏览文件 @
bcdc61cc
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
"""Test forward_call module."""
"""Test forward_call module."""
import
ast
import
ast
import
textwrap
import
textwrap
from
unittest.mock
import
patch
from
mindinsight.mindconverter.forward_call
import
ForwardCall
from
mindinsight.mindconverter.forward_call
import
ForwardCall
...
@@ -50,12 +49,10 @@ class TestForwardCall:
...
@@ -50,12 +49,10 @@ class TestForwardCall:
return out
return out
"""
)
"""
)
@
patch
.
object
(
ForwardCall
,
'process'
)
def
test_process
(
self
):
def
test_process
(
self
,
mock_process
):
"""Test the function of visit ast tree to find out forward functions."""
"""Test the function of visit ast tree to find out forward functions."""
mock_process
.
return_value
=
None
ast_tree
=
ast
.
parse
(
self
.
source
)
forward_call
=
ForwardCall
(
"mock"
)
forward_call
=
ForwardCall
(
ast_tree
)
forward_call
.
visit
(
ast
.
parse
(
self
.
source
))
expect_calls
=
[
'TestNet.forward'
,
expect_calls
=
[
'TestNet.forward'
,
'TestNet.forward1'
,
'TestNet.forward1'
,
...
@@ -70,6 +67,6 @@ class TestForwardCall:
...
@@ -70,6 +67,6 @@ class TestForwardCall:
'TestNet.fc3'
,
'TestNet.fc3'
,
]
]
expect_calls
.
sort
()
expect_calls
.
sort
()
real_calls
=
list
(
forward_call
.
calls
)
real_calls
=
list
(
forward_call
.
calls
.
keys
()
)
real_calls
.
sort
()
real_calls
.
sort
()
assert
real_calls
==
expect_calls
assert
real_calls
==
expect_calls
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录