提交 dde197ce 编写于 作者: G ggpolar

Modify the prompt message and parse more statements.

1. More detailed reports are added to the conversion report.
2. The conversion prompt is provided for the 'isinstance' statement in the conversion script.
上级 0da8d2a8
...@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING ...@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.forward_call import ForwardCall from mindinsight.mindconverter.forward_call import ForwardCall
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
LOG_FMT_PROMPT_INFO = "[INFO] %s" LOG_FMT_PROMPT_INFO = "[INFO] %s"
LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it." LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it."
...@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor): ...@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor):
class _NodeInfo: class _NodeInfo:
"""NodeInfo class definition.""" """NodeInfo class definition."""
def __init__(self, node): def __init__(self, node):
self.node = node self.node = node
self.call_list = [] # Used to save all ast.Call node in self._node self.call_list = [] # Used to save all ast.Call node in self._node
...@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_call = False is_include_call = False
return is_include_call return is_include_call
def match_api(self, call_func_node, is_forward): def match_api(self, call_func_node, is_forward, check_context=True):
""" """
Check api name to convert, check api name ok with a is_forward condition. Check api name to convert, check api name ok with a is_forward condition.
Args: Args:
call_func_node (ast.Attribute): The call.func node. call_func_node (ast.Attribute): The call.func node.
is_forward (bool): whether api belong to forward. is_forward (bool): whether api belong to forward.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns: Returns:
str, the standard api name used to match. str, the standard api name used to match.
ApiMappingEnum, the match result. ApiMappingEnum, the match result.
""" """
api_name, match_case = self._infer_api_name(call_func_node) match_case = ApiMatchingEnum.NOT_API
api_call_name = pasta.dump(call_func_node)
if api_call_name.startswith('self.'):
return api_call_name, match_case
api_name, match_case = self._infer_api_name(call_func_node, check_context)
api_call_name = pasta.dump(call_func_node) api_call_name = pasta.dump(call_func_node)
is_tensor_obj_call = False is_tensor_obj_call = False
if api_name != api_call_name: if api_name != api_call_name:
...@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor):
# rewritten external module name # rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script. # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref: if not is_tensor_obj_call:
standard_api_call_name = self._mapping_standard_api_name(api_name) standard_api_call_name = self._get_api_whole_name(call_func_node, check_context)
if standard_api_call_name in ALL_TORCH_APIS: if standard_api_call_name in ALL_TORCH_APIS:
match_case = ApiMatchingEnum.API_FOUND match_case = ApiMatchingEnum.API_FOUND
if (not is_forward and standard_api_call_name in NN_LIST) or \ if (not is_forward and standard_api_call_name in NN_LIST) or \
(is_forward and standard_api_call_name in ALL_2P_LIST): (is_forward and standard_api_call_name in ALL_2P_LIST):
match_case = ApiMatchingEnum.API_MATCHED match_case = ApiMatchingEnum.API_MATCHED
else:
if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'):
match_case = ApiMatchingEnum.API_MATCHED
return standard_api_call_name, match_case return standard_api_call_name, match_case
@staticmethod @staticmethod
...@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor):
parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos] parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos]
return parameters_str return parameters_str
def _get_api_whole_name(self, call_func_node, check_context=True):
"""
Get the whole name for the call node.
Args:
call_func_node (AST): The func attribute of ast.Call.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the whole name.
"""
api_name, match_case = self._infer_api_name(call_func_node, check_context)
if match_case == ApiMatchingEnum.API_STANDARD:
api_name_splits = api_name.split('.')
api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0])
if api_name_splits[0]:
api_name = '.'.join(api_name_splits)
return api_name
def mapping_api(self, call_node, check_context=True): def mapping_api(self, call_node, check_context=True):
""" """
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
...@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor):
if api_call_name.startswith('self.'): if api_call_name.startswith('self.'):
return code return code
new_code = self._mapping_api(call_node, check_context)
return new_code
def _mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
If do not check context of the script, the code represented by the node must be written in the standard way.
Args:
call_node (ast.Call): The ast node to convert.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the converted code.
"""
code = pasta.dump(call_node)
api_call_name = pasta.dump(call_node.func)
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
args_str = '(' + self._get_call_parameters_str(call_node) + ')' args_str = '(' + self._get_call_parameters_str(call_node) + ')'
...@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor):
code = pasta.dump(node) code = pasta.dump(node)
api_name = pasta.dump(node.func) api_name = pasta.dump(node.func)
# parent node first call is equal to this node, skip when parent node is replaced. # The parent node first call is equal to this node, skip when parent node is replaced.
for parent_node in self._stack[:-1]: # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
# P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name): if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name):
return return
parent = self._stack[-2] parent = self._stack[-2]
new_node = None new_node = None
new_code = code
matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
warning_info = get_prompt_info(matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING: if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name) logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node) new_code = self.mapping_api(node)
if new_code != code: if new_code != code:
new_node = pasta.parse(new_code).body[0].value try:
# find the first call name new_node = pasta.parse(new_code).body[0].value
new_api_name = new_code[:new_code.find('(')] # find the first call name
self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name)) new_api_name = new_code[:new_code.find('(')]
if matched_api_name in ALL_UNSUPPORTED: except AttributeError:
warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '') new_node = pasta.parse(new_code).body[0]
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info) new_api_name = new_code
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info)) self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
...@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor):
elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional': elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional':
renames[ref_name] = 'F' renames[ref_name] = 'F'
return renames return renames
def _get_external_ref_whole_name(self, ref_name):
"""
Find out external reference whole name.
For example:
In the parsed source code, there is import statement
import torch.nn as new_name
_get_external_ref_whole_name('new_name') will return 'torch.nn' string.
"""
external_refs = self._code_analyzer.external_references
for external_ref_name, ref_info in external_refs.items():
external_ref_info = ref_info['external_ref_info']
if external_ref_name == ref_name:
return external_ref_info.name
return None
def _check_isinstance_parameter(self, node):
"""Check whether the second parameter of isinstance function contains the torch type."""
is_isinstance_arg = False
# Check whether node is the second parameter of the isinstance function call.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance':
isinstance_node = parent_node
seconde_arg_type_nodes = []
if isinstance(isinstance_node.args[1], ast.Tuple):
seconde_arg_type_nodes.extend(isinstance_node.args[1].elts)
else:
seconde_arg_type_nodes.append(isinstance_node.args[1])
if node in seconde_arg_type_nodes:
is_isinstance_arg = True
break
if not is_isinstance_arg:
return False
isinstance_type_arg = pasta.dump(node)
check_torch_type = False
if isinstance_type_arg:
type_splits = isinstance_type_arg.split('.')
whole_name = self._get_external_ref_whole_name(type_splits[0])
if whole_name and whole_name.startswith('torch'):
check_torch_type = True
if check_torch_type:
_, match_case = self.match_api(node, False)
if match_case != ApiMatchingEnum.NOT_API:
warn_info = 'Manually determine the conversion type.'
self._process_log.warning(node.lineno, node.col_offset,
LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info))
return check_torch_type
def visit_Attribute(self, node):
"""Callback function when visit AST tree"""
self._check_isinstance_parameter(node)
...@@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors): ...@@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors):
"""Converter error codes.""" """Converter error codes."""
SCRIPT_NOT_SUPPORT = 1 SCRIPT_NOT_SUPPORT = 1
NODE_TYPE_NOT_SUPPORT = 2 NODE_TYPE_NOT_SUPPORT = 2
CODE_SYNTAX_ERROR = 3
class ScriptNotSupport(MindInsightException): class ScriptNotSupport(MindInsightException):
...@@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException): ...@@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException):
super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT, super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT,
msg, msg,
http_code=400) http_code=400)
class CodeSyntaxError(MindInsightException):
"""The CodeSyntaxError class definition."""
def __init__(self, msg):
super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR,
msg,
http_code=400)
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import pasta import pasta
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import CodeSyntaxError
REQUIRED = 'REQUIRED' REQUIRED = 'REQUIRED'
UNREQUIRED = 'UNREQUIRED' UNREQUIRED = 'UNREQUIRED'
...@@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs' ...@@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class APIPt: class APIPt:
"""Base API for args parse, and API for one frame.""" """Base API for args parse, and API for one frame."""
def __init__(self, name: str, params: OrderedDict): def __init__(self, name: str, params: OrderedDict):
self.name = name self.name = name
self.params = OrderedDict() self.params = OrderedDict()
...@@ -77,10 +78,8 @@ class APIPt: ...@@ -77,10 +78,8 @@ class APIPt:
try: try:
ast_node = ast.parse("whatever_call_name" + args_str) ast_node = ast.parse("whatever_call_name" + args_str)
call_node = ast_node.body[0].value call_node = ast_node.body[0].value
if not isinstance(call_node, ast.Call): except SyntaxError as parse_error:
raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str)) raise CodeSyntaxError("can't parse code:\n{}".format(args_str)) from parse_error
except:
raise ValueError("can't parse code:\n{}".format(args_str))
# regard all actual parameter as one parameter # regard all actual parameter as one parameter
if len(self.params) == 1: if len(self.params) == 1:
...@@ -118,6 +117,7 @@ class APIPt: ...@@ -118,6 +117,7 @@ class APIPt:
class APIMs(APIPt): class APIMs(APIPt):
"""API for MindSpore""" """API for MindSpore"""
def __init__(self, name: str, params: OrderedDict, p_attrs=None): def __init__(self, name: str, params: OrderedDict, p_attrs=None):
self.is_primitive = name.startswith('P.') self.is_primitive = name.startswith('P.')
if self.is_primitive: if self.is_primitive:
...@@ -167,6 +167,7 @@ class APIMs(APIPt): ...@@ -167,6 +167,7 @@ class APIMs(APIPt):
class MappingHelper: class MappingHelper:
"""Mapping from one frame to another frame""" """Mapping from one frame to another frame"""
def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs): def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
ms2pt_mapping = kwargs.get('ms2pt_mapping') ms2pt_mapping = kwargs.get('ms2pt_mapping')
gen_explicit_map = kwargs.get('gen_explicit_map') gen_explicit_map = kwargs.get('gen_explicit_map')
...@@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH) ...@@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING} ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
# ---------------------------- api list support or not support ---------------------------- # ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json')) NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json'))
NN_LIST = load_json_file(NN_LIST_PATH) NN_LIST = load_json_file(NN_LIST_PATH)
...@@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST] ...@@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST]
NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING] NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING] NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json')) F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json'))
F_LIST = load_json_file(F_LIST_PATH) F_LIST = load_json_file(F_LIST_PATH)
F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
...@@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ ...@@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING] F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING] F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json')) TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json'))
TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH) TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING] TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING] TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json')) TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json'))
TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH) TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING] TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING] TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
UNSUPPORTED_WARN_INFOS = { UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.", "nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.", "nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册