Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
bcdc61cc
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bcdc61cc
编写于
6月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!273 Converter: use the AST to analyze and modify network definition script
Merge pull request !273 from ggpolar/br_wzk_dev
上级
d200233c
7cad801d
变更
9
展开全部
显示空白变更内容
内联
并排
Showing
9 changed file
with
1178 addition
and
710 deletion
+1178
-710
mindinsight/mindconverter/ast_edits.py
mindinsight/mindconverter/ast_edits.py
+579
-0
mindinsight/mindconverter/cli.py
mindinsight/mindconverter/cli.py
+3
-5
mindinsight/mindconverter/code_analysis.py
mindinsight/mindconverter/code_analysis.py
+399
-0
mindinsight/mindconverter/common/exceptions.py
mindinsight/mindconverter/common/exceptions.py
+44
-0
mindinsight/mindconverter/converter.py
mindinsight/mindconverter/converter.py
+58
-441
mindinsight/mindconverter/forward_call.py
mindinsight/mindconverter/forward_call.py
+42
-34
mindinsight/utils/constant.py
mindinsight/utils/constant.py
+5
-0
tests/ut/mindconverter/test_converter.py
tests/ut/mindconverter/test_converter.py
+44
-223
tests/ut/mindconverter/test_forward_call.py
tests/ut/mindconverter/test_forward_call.py
+4
-7
未找到文件。
mindinsight/mindconverter/ast_edits.py
0 → 100644
浏览文件 @
bcdc61cc
此差异已折叠。
点击以展开。
mindinsight/mindconverter/cli.py
浏览文件 @
bcdc61cc
...
...
@@ -186,25 +186,23 @@ def cli_entry():
mode
=
permissions
<<
6
os
.
makedirs
(
args
.
output
,
mode
=
mode
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
report
,
mode
=
mode
,
exist_ok
=
True
)
_run
(
args
.
in_file
,
args
.
output
,
''
,
args
.
report
)
_run
(
args
.
in_file
,
args
.
output
,
args
.
report
)
def
_run
(
in_files
,
out_dir
,
in_module
,
report
):
def
_run
(
in_files
,
out_dir
,
report
):
"""
Run converter command.
Args:
in_files (str): The file path or directory to convert.
out_dir (str): The output directory to save converted file.
in_module (str): The module name to convert.
report (str): The report file path.
"""
files_config
=
{
'root_path'
:
in_files
if
in_files
else
''
,
'in_files'
:
[],
'outfile_dir'
:
out_dir
,
'report_dir'
:
report
,
'in_module'
:
in_module
'report_dir'
:
report
}
if
os
.
path
.
isfile
(
in_files
):
files_config
[
'root_path'
]
=
os
.
path
.
dirname
(
in_files
)
...
...
mindinsight/mindconverter/code_analysis.py
0 → 100644
浏览文件 @
bcdc61cc
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""code analysis module"""
import
ast
import
pasta
from
pasta.base
import
scope
from
mindinsight.mindconverter.common.exceptions
import
ScriptNotSupport
class
APIAnalysisSpec
:
"""API analysis specifications"""
import_name_mapping
=
{
'torch'
:
[
'mindspore'
,
None
],
'torch.nn'
:
[
'mindspore.nn'
,
'nn'
],
'torch.nn.functional'
:
[
'mindspore.ops.operations'
,
'P'
]}
base_name_mapping
=
{
'Module'
:
'Cell'
,
'Sequential'
:
'SequentialCell'
}
@
classmethod
def
get_convertible_external_names
(
cls
):
"""
Obtain the convertible external names.
The external name is the full dotted name being referenced.
"""
return
cls
.
import_name_mapping
.
keys
()
@
staticmethod
def
get_network_base_class_names
():
"""Obtain the base names which network class base from"""
return
[
'Module'
,
'Sequential'
,
'ModuleList'
,
'ModuleDict'
,
'ParameterList'
,
'ParameterDict'
]
@
staticmethod
def
check_external_alias_ref
(
ref_name
,
external_name
):
"""
Check 'import as' is standard.
Standard references are follow:
import torch.nn as nn
import torch.nn.functional as F
Args:
ref_name (str): The name that refers to the external_name.
external_name (str): The full dotted name being referenced. For examples:
1. 'import torch.nn as nn', torch.nn is external_name, nn is ref_name.
2. 'from torch import nn as mm, torch.nn is external_name, mm is ref_name which is not a standard name.
Returns:
boolean, True if ref_name is standard else False.
"""
if
ref_name
!=
'nn'
and
external_name
==
'torch.nn'
:
is_standard
=
False
elif
ref_name
!=
'F'
and
external_name
==
'torch.nn.functional'
:
is_standard
=
False
else
:
is_standard
=
True
return
is_standard
class
CodeAnalyzer
(
ast
.
NodeVisitor
):
"""Code analyzer that analyzes PyTorch python script by AST Visitor.
CodeAnalyzer find the codes that need to be converted to MindSpore,
and provides the attributes related to the codes.
"""
def
__init__
(
self
):
self
.
_stack
=
[]
# Used to easily access the parent node
self
.
_external_references
=
{}
self
.
_is_standard_external_ref
=
True
self
.
_root_scope
=
None
# Used to save functions that need to be converted, value type is pasta.base.scope.Scope
self
.
_network_functions
=
[]
# Used to easily trace the function node
self
.
_functions_stack
=
[]
# key type is pasta.base.scope.Scope, value type is list
self
.
_network_classes
=
{}
@
property
def
root_scope
(
self
):
"""The root scope of the python script code."""
return
self
.
_root_scope
@
property
def
is_standard_external_ref
(
self
):
"""Obtain whether the result is a standard external reference."""
return
self
.
_is_standard_external_ref
@
property
def
external_references
(
self
):
"""Obtain all external references in the analyzed code."""
return
self
.
_external_references
def
network_definitions
(
self
):
"""Obtain the network definitions which need to be converted."""
return
{
"functions"
:
self
.
_network_functions
,
"cell"
:
self
.
_network_classes
}
def
process
(
self
,
ast_tree
):
"""
Start to analyze the code.
Args:
ast_tree (AST): The root node of the source code.
"""
self
.
__init__
()
self
.
_root_scope
=
scope
.
analyze
(
ast_tree
)
self
.
_pre_process
()
self
.
visit
(
ast_tree
)
if
not
self
.
_network_classes
:
msg
=
"model definition not be found."
raise
ScriptNotSupport
(
msg
)
@
staticmethod
def
_check_external_standard
(
external_refs
):
"""Check whether all external references are standard."""
is_standard
=
True
for
external_name
,
external_ref_info
in
external_refs
.
items
():
is_standard
=
APIAnalysisSpec
.
check_external_alias_ref
(
external_name
,
external_ref_info
.
name
)
if
not
is_standard
:
break
return
is_standard
def
_is_base_from_cell
(
self
,
node
):
"""
Check whether the node bases from cell classes which are defined in APIAnalysisSpec.
Args:
node (ast.ClassDef): The node which is a class definition.
Returns:
boolean, True if the check result is Passed else False.
"""
if
self
.
_is_ref_convertible_imports
(
node
):
whole_name
=
self
.
_get_whole_name
(
node
)
if
whole_name
.
split
(
'.'
)[
-
1
]
in
APIAnalysisSpec
.
get_network_base_class_names
():
return
True
return
False
def
_pre_process
(
self
):
"""Preprocessor checks the code before analyzing."""
is_torch
=
False
# check whether the code imports torch.
for
ref_name
in
self
.
_root_scope
.
external_references
.
keys
():
if
ref_name
.
split
(
'.'
)[
0
]
in
APIAnalysisSpec
.
get_convertible_external_names
():
is_torch
=
True
break
if
not
is_torch
:
msg
=
"The source code does not import torch, model definition can not be found."
raise
ScriptNotSupport
(
msg
)
# Find out external reference in the code and save it.
external_refs
=
self
.
_analyze_import_references
(
self
.
_root_scope
)
self
.
_is_standard_external_ref
=
self
.
_check_external_standard
(
external_refs
)
self
.
_check_external_standard
(
external_refs
)
for
external_name
,
external_ref_info
in
external_refs
.
items
():
self
.
_external_references
.
update
({
external_name
:
{
'external_ref_info'
:
external_ref_info
,
'parent_node'
:
None
}
})
@
staticmethod
def
_analyze_import_references
(
root_scope
):
"""Find out all references from the import statements."""
external_name_ref
=
{}
for
node_references
in
root_scope
.
external_references
.
values
():
for
node_ref
in
node_references
:
if
node_ref
.
name_ref
:
# (from)import alias, node_ref.name_ref.id is alias name
if
node_ref
.
name_ref
.
definition
.
asname
==
node_ref
.
name_ref
.
id
:
external_name_ref
[
node_ref
.
name_ref
.
id
]
=
node_ref
# import without alias, node_ref.name_ref.definition.asname is None.
# e.g., import a.b.c, reference maybe is a, a.b or a.b.c in the root_scope.external_references.
# The reference a.b.c is really wanted.
elif
node_ref
.
name_ref
.
definition
.
name
==
node_ref
.
name_ref
.
id
:
external_name_ref
[
node_ref
.
name_ref
.
id
]
=
node_ref
else
:
pass
return
external_name_ref
def
visit
(
self
,
node
):
"""Overridden visit of the base class to maintain stack information to access parent node."""
self
.
_stack
.
append
(
node
)
super
(
CodeAnalyzer
,
self
).
visit
(
node
)
self
.
_stack
.
pop
()
@
staticmethod
def
_get_full_name
(
node
):
"""Get the full name of the node."""
if
not
isinstance
(
node
,
(
ast
.
Attribute
,
ast
.
Name
)):
return
None
return
pasta
.
dump
(
node
)
def
_get_whole_name
(
self
,
node
):
"""
Get the whole name of the node.
For example, nn.Module is spliced two nodes, nn node and Module node.
When visit ast nodes,
Module node is first visited, the full name is the same as the whole name, that is nn.Module.
And then nn node is visited, the full name is nn, the whole name is nn.Module.
"""
full_name
=
self
.
_get_full_name
(
node
)
if
not
full_name
:
return
None
# node is in stack top pos
if
node
is
self
.
_stack
[
-
1
]:
parent_index
=
-
1
while
isinstance
(
self
.
_stack
[
parent_index
],
ast
.
Attribute
):
parent_index
-=
1
whole_name
=
self
.
_get_full_name
(
self
.
_stack
[
parent_index
])
else
:
whole_name
=
full_name
return
whole_name
def
_is_ref_convertible_imports
(
self
,
node
):
"""Check whether the node references convertible imports."""
check_result
=
False
whole_name
=
self
.
_get_whole_name
(
node
)
if
whole_name
:
module_name
=
whole_name
.
split
(
'.'
)[
0
]
for
ref_name
,
ref_info
in
self
.
_external_references
.
items
():
external_ref
=
ref_info
[
'external_ref_info'
]
# external reference is convertible module
if
external_ref
.
name
in
APIAnalysisSpec
.
get_convertible_external_names
():
# import from the same external module
if
module_name
==
ref_name
.
split
(
'.'
)[
0
]:
check_result
=
True
break
return
check_result
@
staticmethod
def
_get_external_node
(
external_references
):
"""Get all external reference nodes."""
external_nodes
=
{}
for
ref_name
,
ref_info
in
external_references
.
items
():
external_nodes
.
update
({
ref_info
[
'external_ref_info'
].
node
:
ref_name
})
return
external_nodes
@
staticmethod
def
_get_convertible_external_node
(
external_name_ref
):
"""Get all convertible external reference nodes."""
convertible_external_nodes
=
{}
for
ref_name
,
ref_info
in
external_name_ref
.
items
():
if
ref_info
[
'external_ref_info'
].
name
in
APIAnalysisSpec
.
get_convertible_external_names
():
convertible_external_nodes
.
update
({
ref_info
[
'external_ref_info'
].
node
:
ref_name
})
return
convertible_external_nodes
def
_update_external_ref_parent
(
self
,
node
):
"""Set external reference parent node info."""
external_nodes
=
self
.
_get_external_node
(
self
.
_external_references
)
convertible_external_nodes
=
self
.
_get_convertible_external_node
(
self
.
_external_references
)
for
name_node
in
node
.
names
:
if
name_node
in
convertible_external_nodes
.
keys
():
if
len
(
node
.
names
)
>
1
:
msg
=
"""
\
Not support multiple imports of torch on one line in your script. line:%s: %s
"""
%
(
node
.
lineno
,
pasta
.
dump
(
node
))
raise
ScriptNotSupport
(
msg
)
if
name_node
in
external_nodes
.
keys
():
ref_name
=
external_nodes
[
name_node
]
self
.
_external_references
[
ref_name
][
'parent_node'
]
=
node
@
staticmethod
def
_get_class_scope
(
node_scope
):
"""Find the class scope of the node_scope."""
parent_scope
=
node_scope
.
parent_scope
class_scope
=
None
while
parent_scope
:
if
isinstance
(
parent_scope
.
node
,
ast
.
ClassDef
):
class_scope
=
parent_scope
break
parent_scope
=
parent_scope
.
parent_scope
return
class_scope
def
_update_convertible_functions
(
self
,
node
):
"""Update convertible functions."""
node_scope
=
self
.
_root_scope
.
lookup_scope
(
node
)
class_scope
=
self
.
_get_class_scope
(
node_scope
)
if
class_scope
:
network_classes
=
self
.
_network_classes
.
get
(
class_scope
,
[])
if
node_scope
not
in
network_classes
:
network_classes
.
append
(
node_scope
)
else
:
if
node_scope
not
in
self
.
_network_functions
:
self
.
_network_functions
.
append
(
node_scope
)
def
visit_ClassDef
(
self
,
node
):
"""Callback function when visit AST tree"""
if
not
self
.
_stack
[
-
1
]
is
node
:
return
for
base
in
node
.
bases
:
if
self
.
_is_ref_convertible_imports
(
base
):
self
.
_network_classes
[
self
.
_root_scope
.
lookup_scope
(
node
)]
=
[]
self
.
generic_visit
(
node
)
def
visit_Import
(
self
,
node
):
"""Callback function when visit AST tree"""
self
.
_update_external_ref_parent
(
node
)
self
.
generic_visit
(
node
)
def
visit_ImportFrom
(
self
,
node
):
"""Callback function when visit AST tree"""
self
.
_update_external_ref_parent
(
node
)
self
.
generic_visit
(
node
)
def
visit_Call
(
self
,
node
):
"""Callback function when visit AST tree"""
if
not
self
.
_stack
[
-
1
]
is
node
:
return
is_in_network_function
=
False
# If torch call is happened in the function, save the function for network definition.
if
self
.
_functions_stack
and
self
.
_is_ref_convertible_imports
(
node
.
func
):
self
.
_update_convertible_functions
(
self
.
_functions_stack
[
-
1
])
is_in_network_function
=
True
if
not
is_in_network_function
:
self
.
generic_visit
(
node
)
def
visit_FunctionDef
(
self
,
node
):
"""Callback function when visit AST tree"""
if
not
self
.
_stack
[
-
1
]
is
node
:
return
if
node
.
name
==
"forward"
:
self
.
_update_convertible_functions
(
node
)
self
.
_functions_stack
.
append
(
node
)
self
.
generic_visit
(
node
)
self
.
_functions_stack
.
pop
()
def
get_name
(
self
,
node
):
"""
Get the node name.
Args:
node (AST): The ast node of the source code.
Returns:
str, the name of the node
"""
if
isinstance
(
node
,
pasta
.
base
.
scope
.
Scope
):
items
=
[
self
.
get_name
(
node
.
node
)]
parent_scope
=
node
.
parent_scope
while
parent_scope
:
if
not
isinstance
(
parent_scope
.
node
,
ast
.
Module
):
items
.
append
(
self
.
get_name
(
parent_scope
.
node
))
parent_scope
=
parent_scope
.
parent_scope
return
'.'
.
join
(
reversed
(
items
))
if
isinstance
(
node
,
(
ast
.
ClassDef
,
ast
.
FunctionDef
)):
return
node
.
name
if
isinstance
(
node
,
(
ast
.
Name
,
ast
.
Attribute
)):
return
self
.
_get_full_name
(
node
)
return
str
(
node
)
def
lookup_scope
(
self
,
node
):
"""
Search the scope of the node.
Args:
node (AST): The ast node of the source code.
Returns:
scope, the scope of the node
"""
if
isinstance
(
node
,
pasta
.
base
.
scope
.
Scope
):
return
node
return
self
.
_root_scope
.
lookup_scope
(
node
)
mindinsight/mindconverter/common/exceptions.py
0 → 100644
浏览文件 @
bcdc61cc
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define custom exception."""
from
enum
import
unique
from
mindinsight.utils.constant
import
ScriptConverterErrors
from
mindinsight.utils.exceptions
import
MindInsightException
@
unique
class
ConverterErrors
(
ScriptConverterErrors
):
"""Converter error codes."""
SCRIPT_NOT_SUPPORT
=
1
NODE_TYPE_NOT_SUPPORT
=
2
class
ScriptNotSupport
(
MindInsightException
):
"""The script can not support to process."""
def
__init__
(
self
,
msg
):
super
(
ScriptNotSupport
,
self
).
__init__
(
ConverterErrors
.
SCRIPT_NOT_SUPPORT
,
msg
,
http_code
=
400
)
class
NodeTypeNotSupport
(
MindInsightException
):
"""The astNode can not support to process."""
def
__init__
(
self
,
msg
):
super
(
NodeTypeNotSupport
,
self
).
__init__
(
ConverterErrors
.
NODE_TYPE_NOT_SUPPORT
,
msg
,
http_code
=
400
)
mindinsight/mindconverter/converter.py
浏览文件 @
bcdc61cc
此差异已折叠。
点击以展开。
mindinsight/mindconverter/forward_call.py
浏览文件 @
bcdc61cc
...
...
@@ -14,7 +14,8 @@
# ============================================================================
"""Find out forward functions of script file"""
import
ast
import
os
import
pasta
class
ForwardCall
(
ast
.
NodeVisitor
):
...
...
@@ -24,73 +25,80 @@ class ForwardCall(ast.NodeVisitor):
Find the sub functions called by the forward function in the script file.
"""
def
__init__
(
self
,
filenam
e
):
self
.
filename
=
filenam
e
self
.
module_name
=
os
.
path
.
basename
(
filename
).
replace
(
'.py'
,
''
)
self
.
name
_stack
=
[]
self
.
forward_stack
=
[]
self
.
calls
=
set
()
def
__init__
(
self
,
ast_tre
e
):
self
.
_tree
=
ast_tre
e
self
.
_name_stack
=
[]
self
.
_forward
_stack
=
[]
self
.
calls
=
{}
# key is function name, value is forward function ast node.
self
.
_function_list
=
{}
# key is function name, value is function ast node.
self
.
process
()
def
process
(
self
):
"""Parse the python source file to find the forward functions."""
with
open
(
self
.
filename
,
'rt'
,
encoding
=
'utf-8'
)
as
file
:
content
=
file
.
read
()
self
.
visit
(
ast
.
parse
(
content
,
self
.
filename
))
"""visit ast tree to find the forward functions."""
self
.
visit
(
self
.
_tree
)
# first visit to find out all functions, so restores all variables except _function_list
self
.
_name_stack
.
clear
()
self
.
_forward_stack
.
clear
()
self
.
calls
.
clear
()
self
.
visit
(
self
.
_tree
)
def
get_current_namespace
(
self
):
"""Get the namespace when visit the AST node"""
namespace
=
'.'
.
join
(
self
.
name_stack
)
namespace
=
'.'
.
join
(
self
.
_
name_stack
)
return
namespace
@
classmethod
def
get_ast_node_name
(
cls
,
node
):
"""Get AST node name."""
if
isinstance
(
node
,
ast
.
Attribute
):
return
f
'
{
cls
.
get_ast_node_name
(
node
.
value
)
}
.
{
node
.
attr
}
'
if
isinstance
(
node
,
ast
.
Name
):
return
node
.
id
def
get_call_name
(
cls
,
node
):
"""Get functional call name."""
if
not
isinstance
(
node
,
ast
.
Call
):
return
None
return
node
return
pasta
.
dump
(
node
.
func
)
def
visit_ClassDef
(
self
,
node
):
"""Callback function when visit AST tree"""
self
.
name_stack
.
append
(
node
.
name
)
self
.
_
name_stack
.
append
(
node
.
name
)
self
.
generic_visit
(
node
)
self
.
name_stack
.
pop
()
self
.
_
name_stack
.
pop
()
def
visit_FunctionDef
(
self
,
node
):
"""Callback function when visit AST tree"""
namespace
=
self
.
get_current_namespace
()
if
namespace
:
func_name
=
f
'
{
namespace
}
.
{
node
.
name
}
'
else
:
func_name
=
node
.
name
func_name
=
f
'
{
self
.
get_current_namespace
()
}
.
{
node
.
name
}
'
is_in_chain
=
func_name
in
self
.
calls
or
node
.
name
==
'forward'
if
is_in_chain
:
self
.
forward_stack
.
append
(
func_name
)
self
.
_
forward_stack
.
append
(
func_name
)
if
node
.
name
==
'forward'
:
self
.
calls
.
add
(
func_name
)
self
.
calls
.
update
({
func_name
:
node
}
)
self
.
_function_list
.
update
({
func_name
:
node
})
self
.
generic_visit
(
node
)
if
is_in_chain
:
self
.
forward_stack
.
pop
()
self
.
_
forward_stack
.
pop
()
def
visit_Call
(
self
,
node
):
"""Callback function when visit AST tree"""
for
arg
in
node
.
args
:
self
.
visit
(
arg
)
for
k
w
in
node
.
keywords
:
self
.
visit
(
k
w
.
value
)
func_name
=
self
.
get_
ast_node_name
(
node
.
func
)
for
k
eyword
in
node
.
keywords
:
self
.
visit
(
k
eyword
.
value
)
func_name
=
self
.
get_
call_name
(
node
)
if
isinstance
(
node
.
func
,
ast
.
Name
):
if
func_name
not
in
[
'super'
,
'str'
,
'repr'
]:
if
self
.
forward_stack
:
self
.
calls
.
add
(
func_name
)
if
self
.
_
forward_stack
:
self
.
calls
.
update
({
func_name
:
self
.
_function_list
.
get
(
func_name
)}
)
self
.
visit
(
node
.
func
)
else
:
if
self
.
forward_stack
:
if
'self'
in
func_name
:
self
.
calls
.
add
(
f
'
{
self
.
get_current_namespace
()
}
.
{
func_name
.
split
(
"."
)[
-
1
]
}
'
)
if
self
.
_forward_stack
:
if
func_name
.
startswith
(
'self.'
):
whole_name
=
f
'
{
self
.
get_current_namespace
()
}
.
{
func_name
.
split
(
"."
)[
-
1
]
}
'
self
.
calls
.
update
({
whole_name
:
self
.
_function_list
.
get
(
whole_name
)})
else
:
self
.
calls
.
add
(
func_name
)
self
.
calls
.
update
({
func_name
:
self
.
_function_list
.
get
(
func_name
)}
)
self
.
visit
(
node
.
func
)
mindinsight/utils/constant.py
浏览文件 @
bcdc61cc
...
...
@@ -30,6 +30,7 @@ class MindInsightModules(Enum):
LINEAGEMGR
=
2
DATAVISUAL
=
5
PROFILERMGR
=
6
SCRIPTCONVERTER
=
7
class
GeneralErrors
(
Enum
):
...
...
@@ -69,3 +70,7 @@ class DataVisualErrors(Enum):
SCALAR_NOT_EXIST
=
14
HISTOGRAM_NOT_EXIST
=
15
TRAIN_JOB_DETAIL_NOT_IN_CACHE
=
16
class
ScriptConverterErrors
(
Enum
):
"""Enum definition for mindconverter errors."""
tests/ut/mindconverter/test_converter.py
浏览文件 @
bcdc61cc
...
...
@@ -22,380 +22,201 @@ class TestConverter:
converter_ins
=
Converter
()
def
test_judge_forward
(
self
):
"""test judge_forward"""
name1
=
'conv1'
forward_list
=
{
'conv1'
,
'relu'
}
result1
=
self
.
converter_ins
.
judge_forward
(
name1
,
forward_list
)
assert
result1
is
True
name2
=
'self.forward'
result2
=
self
.
converter_ins
.
judge_forward
(
name2
,
forward_list
)
assert
result2
is
True
def
test_find_left_parentheses
(
self
):
"""test find_left_parentheses"""
code
=
'''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
right_index
=
len
(
code
)
-
1
left_index
=
code
.
index
(
'nn.Conv2d'
)
result
=
self
.
converter_ins
.
find_left_parentheses
(
code
,
right_index
)
assert
result
==
left_index
-
1
def
test_find_api
(
self
):
"""test find_api"""
code
=
'''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2), # TODO padding
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
index
=
0
is_forward
=
False
result
=
self
.
converter_ins
.
find_api
(
code
,
index
,
is_forward
)
assert
result
==
'nn.Sequential'
def
test_get_call_name
(
self
):
"""test get_call_name"""
code
=
'''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0))'''
end
=
len
(
code
)
call_name
,
index
=
self
.
converter_ins
.
get_call_name
(
code
,
end
)
assert
call_name
==
''
assert
index
==
-
1
def
test_find_right_parentheses
(
self
):
"""test find_right_parentheses"""
code
=
'''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2), # TODO padding
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
left_index
=
0
result
=
self
.
converter_ins
.
find_right_parentheses
(
code
,
left_index
)
assert_index
=
len
(
code
)
-
1
assert
result
==
assert_index
# test convert_api with nn ops
def
test_convert_api_nn_layernorm
(
self
):
"""Test convert_api function work ok when convert api nn.LayerNorm"""
code
=
"""
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.ReLU(inplace=False)
])
"""
code
=
"nn.LayerNorm((5, 10, 10), elementwise_affine=False)"
api_name
=
'nn.LayerNorm'
start
=
code
.
find
(
api_name
)
layer_norm_info
=
NN_MAPPING
.
get
(
api_name
)
expected_ms_api_name
=
'nn.LayerNorm'
epsilon
=
layer_norm_info
.
pt_api
.
params
.
get
(
'eps'
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'nn.LayerNorm((5, 10, 10), elementwise_affine=False)'
,
'{}(normalized_shape=(5, 10, 10), epsilon={})'
.
format
(
expected_ms_api_name
,
epsilon
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_leaky_relu
(
self
):
"""Test convert_api function work ok when convert api nn.LeakyReLU"""
code
=
"""
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.LeakyReLU(0.3)])
"""
api_name
=
'nn.LeakyReLU'
start
=
code
.
find
(
api_name
)
code
=
"nn.LeakyReLU(0.3)"
expected_ms_api_name
=
'nn.LeakyReLU'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'nn.LeakyReLU(0.3)'
,
'{}(alpha=0.3)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_prelu
(
self
):
"""Test convert_api function work ok when convert api nn.PReLU"""
code
=
"""
input = torch.randn(2, 3, 5)
nn.PReLU()(input)
"""
api_name
=
'nn.PReLU'
start
=
code
.
find
(
api_name
)
code
=
"nn.PReLU()(input)"
expected_ms_api_name
=
'nn.PReLU'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'nn.PReLU()(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_softmax
(
self
):
"""Test convert_api function work ok when convert api nn.Softmax"""
code
=
"""
nn.Softmax(dim=1)(input)
"""
api_name
=
'nn.Softmax'
code
=
"nn.Softmax(dim=1)"
expected_ms_api_name
=
'nn.Softmax'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'nn.Softmax(dim=1)(input)'
,
'{}(axis=1)(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
code
)
assert
replaced_code
==
code
.
replace
(
'nn.Softmax(dim=1)'
,
'{}(axis=1)'
.
format
(
expected_ms_api_name
))
# test convert_api with torch dot ops
def
test_convert_api_torch_dot_abs
(
self
):
"""Test convert_api function work ok when convert api torch.abs"""
code
=
"""
torch.abs(input)
"""
api_name
=
'torch.abs'
start
=
code
.
find
(
api_name
)
code
=
"torch.abs(input)"
expected_ms_api_name
=
'P.Abs'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.abs(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_acos
(
self
):
"""Test convert_api function work ok when convert api torch.acos"""
code
=
"""
torch.acos(input)
"""
api_name
=
'torch.acos'
start
=
code
.
find
(
api_name
)
code
=
"torch.acos(input)"
expected_ms_api_name
=
'P.ACos'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.acos(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_cos
(
self
):
"""Test convert_api function work ok when convert api torch.cos"""
code
=
"""
torch.cos(input)
"""
api_name
=
'torch.cos'
code
=
"torch.cos(input)"
expected_ms_api_name
=
'P.Cos'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.cos(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_exp
(
self
):
"""Test convert_api function work ok when convert api torch.exp"""
code
=
"""
torch.exp(input)
"""
api_name
=
'torch.exp'
code
=
"torch.exp(input)"
expected_ms_api_name
=
'P.Exp'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.exp(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_log
(
self
):
"""Test convert_api function work ok when convert api torch.log"""
code
=
"""
torch.log(input)
"""
api_name
=
'torch.log'
code
=
"torch.log(input)"
expected_ms_api_name
=
'P.Log'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.log(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_pow
(
self
):
"""Test convert_api function work ok when convert api torch.pow"""
code
=
"""
torch.pow(a, exp)
"""
api_name
=
'torch.pow'
code
=
"torch.pow(a, exp)"
expected_ms_api_name
=
'P.Pow'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.pow(a, exp)'
,
'{}()(a, exp)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_div
(
self
):
"""Test convert_api function work ok when convert api torch.div"""
code
=
"""
input = torch.randn(5)
other = torch.randn(5)
torch.div(input, other)
"""
api_name
=
'torch.div'
code
=
"torch.div(input, other)"
expected_ms_api_name
=
'P.Div'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.div(input, other)'
,
'{}()(input, other)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_sin
(
self
):
"""Test convert_api function work ok when convert api torch.sin"""
code
=
"""
torch.sin(input)
"""
api_name
=
'torch.sin'
code
=
"torch.sin(input)"
expected_ms_api_name
=
'P.Sin'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.sin(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_sqrt
(
self
):
"""Test convert_api function work ok when convert api torch.sqrt"""
code
=
"""
torch.sqrt(input)
"""
api_name
=
'torch.sqrt'
code
=
"torch.sqrt(input)"
expected_ms_api_name
=
'P.Sqrt'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.sqrt(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_eye_with_n
(
self
):
"""Test convert_api function work ok when convert api torch.eye"""
code
=
"""
torch.eye(3)
"""
api_name
=
'torch.eye'
code
=
"torch.eye(3)"
expected_ms_api_name
=
'P.Eye'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.eye(3)'
,
'{}()(3, 3, mindspore.int32)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_eye_with_m
(
self
):
"""Test convert_api function work ok when convert api torch.eye"""
code
=
"""
torch.eye(3, 4)
"""
api_name
=
'torch.eye'
code
=
"torch.eye(3, 4)"
expected_ms_api_name
=
'P.Eye'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.eye(3, 4)'
,
'{}()(3, 4, mindspore.int32)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_add_with_alpha_default
(
self
):
"""Test convert_api function work ok when convert api torch.add"""
code
=
"""
torch.add(input, value)
"""
api_name
=
'torch.add'
code
=
"torch.add(input, value)"
expected_ms_api_name
=
'P.TensorAdd'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.add(input, value)'
,
'{}()(input, value)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_add_with_alpha_not_default
(
self
):
"""Test convert_api function work ok when convert api torch.add"""
code
=
"""
torch.add(input, value, 3)
"""
api_name
=
'torch.add'
code
=
"torch.add(input, value, 3)"
expected_ms_api_name
=
'P.TensorAdd'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'torch.add(input, value, 3)'
,
'{}()(input, value*3)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
# test convert_api with F ops
def
test_convert_api_f_normalize
(
self
):
"""Test convert_api function work ok when convert api F.normalize"""
code
=
"""
input = torch.randn(2, 3, 5)
F.normalize(input)
"""
api_name
=
'F.normalize'
start
=
code
.
find
(
api_name
)
code
=
"F.normalize(input)"
expected_ms_api_name
=
'P.L2Normalize'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'F.normalize(input)'
,
'{}(1, 1e-12)(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_f_sigmoid
(
self
):
"""Test convert_api function work ok when convert api F.sigmoid"""
code
=
"""
input = torch.randn(2, 3, 5)
F.sigmoid(input)
"""
api_name
=
'F.sigmoid'
start
=
code
.
find
(
api_name
)
code
=
"F.sigmoid(input)"
expected_ms_api_name
=
'P.Sigmoid'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'F.sigmoid(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
# test convert_api with tensor dot ops
def
test_convert_api_tensor_dot_repeat
(
self
):
"""Test convert_api function work ok when convert api .repeat"""
code
=
"""
x.repeat(4, 2)
"""
api_name
=
'.repeat'
start
=
code
.
find
(
api_name
)
code
=
"x.repeat(4, 2)"
expected_ms_api_name
=
'P.Tile'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'x.repeat(4, 2)'
,
'{}()(x, {})'
.
format
(
expected_ms_api_name
,
'(4, 2,)'
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_tensor_dot_permute
(
self
):
"""Test convert_api function work ok when convert api .permute"""
code
=
"""
x.permute(2, 0, 1)
"""
api_name
=
'.permute'
start
=
code
.
find
(
api_name
)
code
=
"x.permute(2, 0, 1)"
expected_ms_api_name
=
'P.Transpose'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_nam
e
)
replaced_code
=
self
.
converter_ins
.
convert_api
(
cod
e
)
assert
replaced_code
==
code
.
replace
(
'x.permute(2, 0, 1)'
,
'{}()(x, (2, 0, 1,))'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
tests/ut/mindconverter/test_forward_call.py
浏览文件 @
bcdc61cc
...
...
@@ -15,7 +15,6 @@
"""Test forward_call module."""
import
ast
import
textwrap
from
unittest.mock
import
patch
from
mindinsight.mindconverter.forward_call
import
ForwardCall
...
...
@@ -50,12 +49,10 @@ class TestForwardCall:
return out
"""
)
@
patch
.
object
(
ForwardCall
,
'process'
)
def
test_process
(
self
,
mock_process
):
def
test_process
(
self
):
"""Test the function of visit ast tree to find out forward functions."""
mock_process
.
return_value
=
None
forward_call
=
ForwardCall
(
"mock"
)
forward_call
.
visit
(
ast
.
parse
(
self
.
source
))
ast_tree
=
ast
.
parse
(
self
.
source
)
forward_call
=
ForwardCall
(
ast_tree
)
expect_calls
=
[
'TestNet.forward'
,
'TestNet.forward1'
,
...
...
@@ -70,6 +67,6 @@ class TestForwardCall:
'TestNet.fc3'
,
]
expect_calls
.
sort
()
real_calls
=
list
(
forward_call
.
calls
)
real_calls
=
list
(
forward_call
.
calls
.
keys
()
)
real_calls
.
sort
()
assert
real_calls
==
expect_calls
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录