Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
f1f3dbc4
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f1f3dbc4
编写于
6月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!306 Converter: Modify the prompt message and parse more statements.
Merge pull request !306 from ggpolar/br_wzk_dev
上级
ea0c9d1f
dde197ce
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
146 addition
and
30 deletion
+146
-30
mindinsight/mindconverter/ast_edits.py
mindinsight/mindconverter/ast_edits.py
+130
-17
mindinsight/mindconverter/common/exceptions.py
mindinsight/mindconverter/common/exceptions.py
+10
-0
mindinsight/mindconverter/config.py
mindinsight/mindconverter/config.py
+6
-13
未找到文件。
mindinsight/mindconverter/ast_edits.py
浏览文件 @
f1f3dbc4
...
@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
...
@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from
mindinsight.mindconverter.config
import
NN_LIST
from
mindinsight.mindconverter.config
import
NN_LIST
from
mindinsight.mindconverter.config
import
ALL_TORCH_APIS
from
mindinsight.mindconverter.config
import
ALL_TORCH_APIS
from
mindinsight.mindconverter.config
import
ALL_2P_LIST
from
mindinsight.mindconverter.config
import
ALL_2P_LIST
from
mindinsight.mindconverter.config
import
UNSUPPORTED_WARN_INFOS
from
mindinsight.mindconverter.config
import
get_prompt_info
from
mindinsight.mindconverter.config
import
ALL_UNSUPPORTED
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.exceptions
import
NodeTypeNotSupport
from
mindinsight.mindconverter.common.exceptions
import
NodeTypeNotSupport
from
mindinsight.mindconverter.forward_call
import
ForwardCall
from
mindinsight.mindconverter.forward_call
import
ForwardCall
LOG_FMT_CONVERT
=
"[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT
=
"[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS
=
"[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT
=
"[UnConvert] '%s' didn't convert. %s"
LOG_FMT_NOT_CONVERT
=
"[UnConvert] '%s' didn't convert. %s"
LOG_FMT_PROMPT_INFO
=
"[INFO] %s"
LOG_FMT_PROMPT_INFO
=
"[INFO] %s"
LOG_SUGGESTION_MANUAL_CONVERT
=
"Please manual convert the code, along with the code associated with it."
LOG_SUGGESTION_MANUAL_CONVERT
=
"Please manual convert the code, along with the code associated with it."
...
@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor):
...
@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor):
class
_NodeInfo
:
class
_NodeInfo
:
"""NodeInfo class definition."""
"""NodeInfo class definition."""
def
__init__
(
self
,
node
):
def
__init__
(
self
,
node
):
self
.
node
=
node
self
.
node
=
node
self
.
call_list
=
[]
# Used to save all ast.Call node in self._node
self
.
call_list
=
[]
# Used to save all ast.Call node in self._node
...
@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor):
...
@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_call
=
False
is_include_call
=
False
return
is_include_call
return
is_include_call
def
match_api
(
self
,
call_func_node
,
is_forward
):
def
match_api
(
self
,
call_func_node
,
is_forward
,
check_context
=
True
):
"""
"""
Check api name to convert, check api name ok with a is_forward condition.
Check api name to convert, check api name ok with a is_forward condition.
Args:
Args:
call_func_node (ast.Attribute): The call.func node.
call_func_node (ast.Attribute): The call.func node.
is_forward (bool): whether api belong to forward.
is_forward (bool): whether api belong to forward.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
Returns:
str, the standard api name used to match.
str, the standard api name used to match.
ApiMappingEnum, the match result.
ApiMappingEnum, the match result.
"""
"""
api_name
,
match_case
=
self
.
_infer_api_name
(
call_func_node
)
match_case
=
ApiMatchingEnum
.
NOT_API
api_call_name
=
pasta
.
dump
(
call_func_node
)
if
api_call_name
.
startswith
(
'self.'
):
return
api_call_name
,
match_case
api_name
,
match_case
=
self
.
_infer_api_name
(
call_func_node
,
check_context
)
api_call_name
=
pasta
.
dump
(
call_func_node
)
api_call_name
=
pasta
.
dump
(
call_func_node
)
is_tensor_obj_call
=
False
is_tensor_obj_call
=
False
if
api_name
!=
api_call_name
:
if
api_name
!=
api_call_name
:
...
@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor):
...
@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor):
# rewritten external module name
# rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if
not
is_tensor_obj_call
and
not
self
.
_code_analyzer
.
is_standard_external_ref
:
if
not
is_tensor_obj_call
:
standard_api_call_name
=
self
.
_
mapping_standard_api_name
(
api_name
)
standard_api_call_name
=
self
.
_
get_api_whole_name
(
call_func_node
,
check_context
)
if
standard_api_call_name
in
ALL_TORCH_APIS
:
if
standard_api_call_name
in
ALL_TORCH_APIS
:
match_case
=
ApiMatchingEnum
.
API_FOUND
match_case
=
ApiMatchingEnum
.
API_FOUND
if
(
not
is_forward
and
standard_api_call_name
in
NN_LIST
)
or
\
if
(
not
is_forward
and
standard_api_call_name
in
NN_LIST
)
or
\
(
is_forward
and
standard_api_call_name
in
ALL_2P_LIST
):
(
is_forward
and
standard_api_call_name
in
ALL_2P_LIST
):
match_case
=
ApiMatchingEnum
.
API_MATCHED
match_case
=
ApiMatchingEnum
.
API_MATCHED
else
:
if
standard_api_call_name
and
standard_api_call_name
.
startswith
(
'torch.nn.init'
):
match_case
=
ApiMatchingEnum
.
API_MATCHED
return
standard_api_call_name
,
match_case
return
standard_api_call_name
,
match_case
@
staticmethod
@
staticmethod
...
@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor):
...
@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor):
parameters_str
=
call_str
[
left_parenthesis_pos
+
1
:
right_parenthesis_pos
]
parameters_str
=
call_str
[
left_parenthesis_pos
+
1
:
right_parenthesis_pos
]
return
parameters_str
return
parameters_str
def
_get_api_whole_name
(
self
,
call_func_node
,
check_context
=
True
):
"""
Get the whole name for the call node.
Args:
call_func_node (AST): The func attribute of ast.Call.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the whole name.
"""
api_name
,
match_case
=
self
.
_infer_api_name
(
call_func_node
,
check_context
)
if
match_case
==
ApiMatchingEnum
.
API_STANDARD
:
api_name_splits
=
api_name
.
split
(
'.'
)
api_name_splits
[
0
]
=
self
.
_get_external_ref_whole_name
(
api_name_splits
[
0
])
if
api_name_splits
[
0
]:
api_name
=
'.'
.
join
(
api_name_splits
)
return
api_name
def
mapping_api
(
self
,
call_node
,
check_context
=
True
):
def
mapping_api
(
self
,
call_node
,
check_context
=
True
):
"""
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
...
@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor):
...
@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor):
if
api_call_name
.
startswith
(
'self.'
):
if
api_call_name
.
startswith
(
'self.'
):
return
code
return
code
new_code
=
self
.
_mapping_api
(
call_node
,
check_context
)
return
new_code
def
_mapping_api
(
self
,
call_node
,
check_context
=
True
):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
If do not check context of the script, the code represented by the node must be written in the standard way.
Args:
call_node (ast.Call): The ast node to convert.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the converted code.
"""
code
=
pasta
.
dump
(
call_node
)
api_call_name
=
pasta
.
dump
(
call_node
.
func
)
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
args_str
=
'('
+
self
.
_get_call_parameters_str
(
call_node
)
+
')'
args_str
=
'('
+
self
.
_get_call_parameters_str
(
call_node
)
+
')'
...
@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor):
...
@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor):
code
=
pasta
.
dump
(
node
)
code
=
pasta
.
dump
(
node
)
api_name
=
pasta
.
dump
(
node
.
func
)
api_name
=
pasta
.
dump
(
node
.
func
)
# parent node first call is equal to this node, skip when parent node is replaced.
# The parent node first call is equal to this node, skip when parent node is replaced.
for
parent_node
in
self
.
_stack
[:
-
1
]:
# This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
# P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
# Access from the penultimate element in reverse order.
for
parent_node
in
self
.
_stack
[
-
2
::
-
1
]:
if
parent_node
in
self
.
_new_call_nodes
and
pasta
.
dump
(
parent_node
).
startswith
(
api_name
):
if
parent_node
in
self
.
_new_call_nodes
and
pasta
.
dump
(
parent_node
).
startswith
(
api_name
):
return
return
parent
=
self
.
_stack
[
-
2
]
parent
=
self
.
_stack
[
-
2
]
new_node
=
None
new_node
=
None
new_code
=
code
matched_api_name
,
match_case
=
self
.
match_api
(
node
.
func
,
self
.
_is_forward_function
)
matched_api_name
,
match_case
=
self
.
match_api
(
node
.
func
,
self
.
_is_forward_function
)
if
match_case
in
[
ApiMatchingEnum
.
API_INFER
,
ApiMatchingEnum
.
API_MATCHED
]:
if
match_case
in
[
ApiMatchingEnum
.
API_INFER
,
ApiMatchingEnum
.
API_MATCHED
]:
warning_info
=
get_prompt_info
(
matched_api_name
)
if
warning_info
is
None
:
warning_info
=
''
if
matched_api_name
in
ALL_MAPPING
:
if
matched_api_name
in
ALL_MAPPING
:
logger
.
info
(
"Line %3d start converting API: %s"
,
node
.
lineno
,
api_name
)
logger
.
info
(
"Line %3d start converting API: %s"
,
node
.
lineno
,
api_name
)
new_code
=
self
.
mapping_api
(
node
)
new_code
=
self
.
mapping_api
(
node
)
if
new_code
!=
code
:
if
new_code
!=
code
:
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
].
value
try
:
# find the first call name
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
].
value
new_api_name
=
new_code
[:
new_code
.
find
(
'('
)]
# find the first call name
self
.
_process_log
.
info
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_CONVERT
%
(
api_name
,
new_api_name
))
new_api_name
=
new_code
[:
new_code
.
find
(
'('
)]
if
matched_api_name
in
ALL_UNSUPPORTED
:
except
AttributeError
:
warn_info
=
UNSUPPORTED_WARN_INFOS
.
get
(
api_name
,
''
)
new_node
=
pasta
.
parse
(
new_code
).
body
[
0
]
logger
.
warning
(
"Line %3d: found unsupported API: %s%s"
,
node
.
lineno
,
api_name
,
warn_info
)
new_api_name
=
new_code
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
warn_info
))
self
.
_process_log
.
info
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_CONVERT_WITH_TIPS
%
(
api_name
,
new_api_name
,
warning_info
))
else
:
logger
.
warning
(
"Line %3d: found unsupported API: %s%s"
,
node
.
lineno
,
api_name
,
warning_info
)
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
warning_info
))
elif
match_case
in
[
ApiMatchingEnum
.
API_STANDARD
,
ApiMatchingEnum
.
API_FOUND
]:
elif
match_case
in
[
ApiMatchingEnum
.
API_STANDARD
,
ApiMatchingEnum
.
API_FOUND
]:
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
''
))
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
api_name
,
''
))
...
@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor):
...
@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor):
elif
ref_name
!=
'F'
and
external_ref_info
.
name
==
'torch.nn.functional'
:
elif
ref_name
!=
'F'
and
external_ref_info
.
name
==
'torch.nn.functional'
:
renames
[
ref_name
]
=
'F'
renames
[
ref_name
]
=
'F'
return
renames
return
renames
def
_get_external_ref_whole_name
(
self
,
ref_name
):
"""
Find out external reference whole name.
For example:
In the parsed source code, there is import statement
import torch.nn as new_name
_get_external_ref_whole_name('new_name') will return 'torch.nn' string.
"""
external_refs
=
self
.
_code_analyzer
.
external_references
for
external_ref_name
,
ref_info
in
external_refs
.
items
():
external_ref_info
=
ref_info
[
'external_ref_info'
]
if
external_ref_name
==
ref_name
:
return
external_ref_info
.
name
return
None
def
_check_isinstance_parameter
(
self
,
node
):
"""Check whether the second parameter of isinstance function contains the torch type."""
is_isinstance_arg
=
False
# Check whether node is the second parameter of the isinstance function call.
# Access from the penultimate element in reverse order.
for
parent_node
in
self
.
_stack
[
-
2
::
-
1
]:
if
isinstance
(
parent_node
,
ast
.
Call
)
and
pasta
.
dump
(
parent_node
.
func
)
==
'isinstance'
:
isinstance_node
=
parent_node
seconde_arg_type_nodes
=
[]
if
isinstance
(
isinstance_node
.
args
[
1
],
ast
.
Tuple
):
seconde_arg_type_nodes
.
extend
(
isinstance_node
.
args
[
1
].
elts
)
else
:
seconde_arg_type_nodes
.
append
(
isinstance_node
.
args
[
1
])
if
node
in
seconde_arg_type_nodes
:
is_isinstance_arg
=
True
break
if
not
is_isinstance_arg
:
return
False
isinstance_type_arg
=
pasta
.
dump
(
node
)
check_torch_type
=
False
if
isinstance_type_arg
:
type_splits
=
isinstance_type_arg
.
split
(
'.'
)
whole_name
=
self
.
_get_external_ref_whole_name
(
type_splits
[
0
])
if
whole_name
and
whole_name
.
startswith
(
'torch'
):
check_torch_type
=
True
if
check_torch_type
:
_
,
match_case
=
self
.
match_api
(
node
,
False
)
if
match_case
!=
ApiMatchingEnum
.
NOT_API
:
warn_info
=
'Manually determine the conversion type.'
self
.
_process_log
.
warning
(
node
.
lineno
,
node
.
col_offset
,
LOG_FMT_NOT_CONVERT
%
(
isinstance_type_arg
,
warn_info
))
return
check_torch_type
def
visit_Attribute
(
self
,
node
):
"""Callback function when visit AST tree"""
self
.
_check_isinstance_parameter
(
node
)
mindinsight/mindconverter/common/exceptions.py
浏览文件 @
f1f3dbc4
...
@@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors):
...
@@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors):
"""Converter error codes."""
"""Converter error codes."""
SCRIPT_NOT_SUPPORT
=
1
SCRIPT_NOT_SUPPORT
=
1
NODE_TYPE_NOT_SUPPORT
=
2
NODE_TYPE_NOT_SUPPORT
=
2
CODE_SYNTAX_ERROR
=
3
class
ScriptNotSupport
(
MindInsightException
):
class
ScriptNotSupport
(
MindInsightException
):
...
@@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException):
...
@@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException):
super
(
NodeTypeNotSupport
,
self
).
__init__
(
ConverterErrors
.
NODE_TYPE_NOT_SUPPORT
,
super
(
NodeTypeNotSupport
,
self
).
__init__
(
ConverterErrors
.
NODE_TYPE_NOT_SUPPORT
,
msg
,
msg
,
http_code
=
400
)
http_code
=
400
)
class
CodeSyntaxError
(
MindInsightException
):
"""The CodeSyntaxError class definition."""
def
__init__
(
self
,
msg
):
super
(
CodeSyntaxError
,
self
).
__init__
(
ConverterErrors
.
CODE_SYNTAX_ERROR
,
msg
,
http_code
=
400
)
mindinsight/mindconverter/config.py
浏览文件 @
f1f3dbc4
...
@@ -22,7 +22,7 @@ import os
...
@@ -22,7 +22,7 @@ import os
import
pasta
import
pasta
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.exceptions
import
CodeSyntaxError
REQUIRED
=
'REQUIRED'
REQUIRED
=
'REQUIRED'
UNREQUIRED
=
'UNREQUIRED'
UNREQUIRED
=
'UNREQUIRED'
...
@@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
...
@@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class
APIPt
:
class
APIPt
:
"""Base API for args parse, and API for one frame."""
"""Base API for args parse, and API for one frame."""
def
__init__
(
self
,
name
:
str
,
params
:
OrderedDict
):
def
__init__
(
self
,
name
:
str
,
params
:
OrderedDict
):
self
.
name
=
name
self
.
name
=
name
self
.
params
=
OrderedDict
()
self
.
params
=
OrderedDict
()
...
@@ -77,10 +78,8 @@ class APIPt:
...
@@ -77,10 +78,8 @@ class APIPt:
try
:
try
:
ast_node
=
ast
.
parse
(
"whatever_call_name"
+
args_str
)
ast_node
=
ast
.
parse
(
"whatever_call_name"
+
args_str
)
call_node
=
ast_node
.
body
[
0
].
value
call_node
=
ast_node
.
body
[
0
].
value
if
not
isinstance
(
call_node
,
ast
.
Call
):
except
SyntaxError
as
parse_error
:
raise
ValueError
(
'call name with args str [{}] not instance of ast.Call'
.
format
(
args_str
))
raise
CodeSyntaxError
(
"can't parse code:
\n
{}"
.
format
(
args_str
))
from
parse_error
except
:
raise
ValueError
(
"can't parse code:
\n
{}"
.
format
(
args_str
))
# regard all actual parameter as one parameter
# regard all actual parameter as one parameter
if
len
(
self
.
params
)
==
1
:
if
len
(
self
.
params
)
==
1
:
...
@@ -118,6 +117,7 @@ class APIPt:
...
@@ -118,6 +117,7 @@ class APIPt:
class
APIMs
(
APIPt
):
class
APIMs
(
APIPt
):
"""API for MindSpore"""
"""API for MindSpore"""
def
__init__
(
self
,
name
:
str
,
params
:
OrderedDict
,
p_attrs
=
None
):
def
__init__
(
self
,
name
:
str
,
params
:
OrderedDict
,
p_attrs
=
None
):
self
.
is_primitive
=
name
.
startswith
(
'P.'
)
self
.
is_primitive
=
name
.
startswith
(
'P.'
)
if
self
.
is_primitive
:
if
self
.
is_primitive
:
...
@@ -167,6 +167,7 @@ class APIMs(APIPt):
...
@@ -167,6 +167,7 @@ class APIMs(APIPt):
class
MappingHelper
:
class
MappingHelper
:
"""Mapping from one frame to another frame"""
"""Mapping from one frame to another frame"""
def
__init__
(
self
,
ms_api
:
APIMs
,
pt_api
:
APIPt
,
**
kwargs
):
def
__init__
(
self
,
ms_api
:
APIMs
,
pt_api
:
APIPt
,
**
kwargs
):
ms2pt_mapping
=
kwargs
.
get
(
'ms2pt_mapping'
)
ms2pt_mapping
=
kwargs
.
get
(
'ms2pt_mapping'
)
gen_explicit_map
=
kwargs
.
get
(
'gen_explicit_map'
)
gen_explicit_map
=
kwargs
.
get
(
'gen_explicit_map'
)
...
@@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
...
@@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
ALL_MAPPING
=
{
**
NN_MAPPING
,
**
F_MAPPING
,
**
TORCH_DOT_MAPPING
,
**
TENSOR_DOT_MAPPING
}
ALL_MAPPING
=
{
**
NN_MAPPING
,
**
F_MAPPING
,
**
TORCH_DOT_MAPPING
,
**
TENSOR_DOT_MAPPING
}
# ---------------------------- api list support or not support ----------------------------
# ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ops'
,
'nn_list.json'
))
NN_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ops'
,
'nn_list.json'
))
NN_LIST
=
load_json_file
(
NN_LIST_PATH
)
NN_LIST
=
load_json_file
(
NN_LIST_PATH
)
...
@@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST]
...
@@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST]
NN_SUPPORTED
=
[
x
for
x
in
NN_LIST
if
x
in
ALL_MAPPING
]
NN_SUPPORTED
=
[
x
for
x
in
NN_LIST
if
x
in
ALL_MAPPING
]
NN_UNSUPPORTED
=
[
x
for
x
in
NN_LIST
if
x
not
in
ALL_MAPPING
]
NN_UNSUPPORTED
=
[
x
for
x
in
NN_LIST
if
x
not
in
ALL_MAPPING
]
F_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ops'
,
'f_list.json'
))
F_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ops'
,
'f_list.json'
))
F_LIST
=
load_json_file
(
F_LIST_PATH
)
F_LIST
=
load_json_file
(
F_LIST_PATH
)
F_LIST
+=
[
"F."
+
name
[
len
(
"torch.nn.functional."
):]
for
name
in
F_LIST
]
+
\
F_LIST
+=
[
"F."
+
name
[
len
(
"torch.nn.functional."
):]
for
name
in
F_LIST
]
+
\
...
@@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
...
@@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
F_SUPPORTED
=
[
x
for
x
in
F_LIST
if
x
in
ALL_MAPPING
]
F_SUPPORTED
=
[
x
for
x
in
F_LIST
if
x
in
ALL_MAPPING
]
F_UNSUPPORTED
=
[
x
for
x
in
F_LIST
if
x
not
in
ALL_MAPPING
]
F_UNSUPPORTED
=
[
x
for
x
in
F_LIST
if
x
not
in
ALL_MAPPING
]
TORCH_DOT_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ops'
,
'torch_dot_list.json'
))
TORCH_DOT_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ops'
,
'torch_dot_list.json'
))
TORCH_DOT_LIST
=
load_json_file
(
TORCH_DOT_LIST_PATH
)
TORCH_DOT_LIST
=
load_json_file
(
TORCH_DOT_LIST_PATH
)
TORCH_DOT_SUPPORTED
=
[
x
for
x
in
TORCH_DOT_LIST
if
x
in
ALL_MAPPING
]
TORCH_DOT_SUPPORTED
=
[
x
for
x
in
TORCH_DOT_LIST
if
x
in
ALL_MAPPING
]
TORCH_DOT_UNSUPPORTED
=
[
x
for
x
in
TORCH_DOT_LIST
if
x
not
in
ALL_MAPPING
]
TORCH_DOT_UNSUPPORTED
=
[
x
for
x
in
TORCH_DOT_LIST
if
x
not
in
ALL_MAPPING
]
TENSOR_DOT_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ops'
,
'tensor_dot_list.json'
))
TENSOR_DOT_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ops'
,
'tensor_dot_list.json'
))
TENSOR_DOT_LIST
=
load_json_file
(
TENSOR_DOT_LIST_PATH
)
TENSOR_DOT_LIST
=
load_json_file
(
TENSOR_DOT_LIST_PATH
)
TENSOR_DOT_SUPPORTED
=
[
x
for
x
in
TENSOR_DOT_LIST
if
x
in
ALL_MAPPING
]
TENSOR_DOT_SUPPORTED
=
[
x
for
x
in
TENSOR_DOT_LIST
if
x
in
ALL_MAPPING
]
TENSOR_DOT_UNSUPPORTED
=
[
x
for
x
in
TENSOR_DOT_LIST
if
x
not
in
ALL_MAPPING
]
TENSOR_DOT_UNSUPPORTED
=
[
x
for
x
in
TENSOR_DOT_LIST
if
x
not
in
ALL_MAPPING
]
ALL_2P_LIST
=
F_LIST
+
TORCH_DOT_LIST
+
TENSOR_DOT_LIST
ALL_2P_LIST
=
F_LIST
+
TORCH_DOT_LIST
+
TENSOR_DOT_LIST
ALL_TORCH_APIS
=
NN_LIST
+
F_LIST
+
TORCH_DOT_LIST
+
TENSOR_DOT_LIST
ALL_TORCH_APIS
=
NN_LIST
+
F_LIST
+
TORCH_DOT_LIST
+
TENSOR_DOT_LIST
ALL_SUPPORTED
=
NN_SUPPORTED
+
F_SUPPORTED
+
TORCH_DOT_SUPPORTED
+
TENSOR_DOT_SUPPORTED
ALL_SUPPORTED
=
NN_SUPPORTED
+
F_SUPPORTED
+
TORCH_DOT_SUPPORTED
+
TENSOR_DOT_SUPPORTED
ALL_UNSUPPORTED
=
NN_UNSUPPORTED
+
F_UNSUPPORTED
+
TORCH_DOT_UNSUPPORTED
+
TENSOR_DOT_UNSUPPORTED
ALL_UNSUPPORTED
=
NN_UNSUPPORTED
+
F_UNSUPPORTED
+
TORCH_DOT_UNSUPPORTED
+
TENSOR_DOT_UNSUPPORTED
UNSUPPORTED_WARN_INFOS
=
{
UNSUPPORTED_WARN_INFOS
=
{
"nn.AdaptiveAvgPool2d"
:
"Maybe could convert to mindspore.ops.operations.ReduceMean."
,
"nn.AdaptiveAvgPool2d"
:
"Maybe could convert to mindspore.ops.operations.ReduceMean."
,
"nn.AvgPool1d"
:
"Maybe could convert to mindspore.nn.AvgPool1d."
,
"nn.AvgPool1d"
:
"Maybe could convert to mindspore.nn.AvgPool1d."
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录