提交 f1f3dbc4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!306 Converter: Modify the prompt message and parse more statements.

Merge pull request !306 from ggpolar/br_wzk_dev
...@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING ...@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.forward_call import ForwardCall from mindinsight.mindconverter.forward_call import ForwardCall
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
LOG_FMT_PROMPT_INFO = "[INFO] %s" LOG_FMT_PROMPT_INFO = "[INFO] %s"
LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it." LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it."
...@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor): ...@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor):
class _NodeInfo: class _NodeInfo:
"""NodeInfo class definition.""" """NodeInfo class definition."""
def __init__(self, node): def __init__(self, node):
self.node = node self.node = node
self.call_list = [] # Used to save all ast.Call node in self._node self.call_list = [] # Used to save all ast.Call node in self._node
...@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_call = False is_include_call = False
return is_include_call return is_include_call
def match_api(self, call_func_node, is_forward): def match_api(self, call_func_node, is_forward, check_context=True):
""" """
Check api name to convert, check api name ok with a is_forward condition. Check api name to convert, check api name ok with a is_forward condition.
Args: Args:
call_func_node (ast.Attribute): The call.func node. call_func_node (ast.Attribute): The call.func node.
is_forward (bool): whether api belong to forward. is_forward (bool): whether api belong to forward.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns: Returns:
str, the standard api name used to match. str, the standard api name used to match.
ApiMappingEnum, the match result. ApiMappingEnum, the match result.
""" """
api_name, match_case = self._infer_api_name(call_func_node) match_case = ApiMatchingEnum.NOT_API
api_call_name = pasta.dump(call_func_node)
if api_call_name.startswith('self.'):
return api_call_name, match_case
api_name, match_case = self._infer_api_name(call_func_node, check_context)
api_call_name = pasta.dump(call_func_node) api_call_name = pasta.dump(call_func_node)
is_tensor_obj_call = False is_tensor_obj_call = False
if api_name != api_call_name: if api_name != api_call_name:
...@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor):
# rewritten external module name # rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script. # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref: if not is_tensor_obj_call:
standard_api_call_name = self._mapping_standard_api_name(api_name) standard_api_call_name = self._get_api_whole_name(call_func_node, check_context)
if standard_api_call_name in ALL_TORCH_APIS: if standard_api_call_name in ALL_TORCH_APIS:
match_case = ApiMatchingEnum.API_FOUND match_case = ApiMatchingEnum.API_FOUND
if (not is_forward and standard_api_call_name in NN_LIST) or \ if (not is_forward and standard_api_call_name in NN_LIST) or \
(is_forward and standard_api_call_name in ALL_2P_LIST): (is_forward and standard_api_call_name in ALL_2P_LIST):
match_case = ApiMatchingEnum.API_MATCHED match_case = ApiMatchingEnum.API_MATCHED
else:
if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'):
match_case = ApiMatchingEnum.API_MATCHED
return standard_api_call_name, match_case return standard_api_call_name, match_case
@staticmethod @staticmethod
...@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor):
parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos] parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos]
return parameters_str return parameters_str
def _get_api_whole_name(self, call_func_node, check_context=True):
"""
Get the whole name for the call node.
Args:
call_func_node (AST): The func attribute of ast.Call.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the whole name.
"""
api_name, match_case = self._infer_api_name(call_func_node, check_context)
if match_case == ApiMatchingEnum.API_STANDARD:
api_name_splits = api_name.split('.')
api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0])
if api_name_splits[0]:
api_name = '.'.join(api_name_splits)
return api_name
def mapping_api(self, call_node, check_context=True): def mapping_api(self, call_node, check_context=True):
""" """
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
...@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor):
if api_call_name.startswith('self.'): if api_call_name.startswith('self.'):
return code return code
new_code = self._mapping_api(call_node, check_context)
return new_code
def _mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
If do not check context of the script, the code represented by the node must be written in the standard way.
Args:
call_node (ast.Call): The ast node to convert.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the converted code.
"""
code = pasta.dump(call_node)
api_call_name = pasta.dump(call_node.func)
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
args_str = '(' + self._get_call_parameters_str(call_node) + ')' args_str = '(' + self._get_call_parameters_str(call_node) + ')'
...@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor):
code = pasta.dump(node) code = pasta.dump(node)
api_name = pasta.dump(node.func) api_name = pasta.dump(node.func)
# parent node first call is equal to this node, skip when parent node is replaced. # The parent node first call is equal to this node, skip when parent node is replaced.
for parent_node in self._stack[:-1]: # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
# P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name): if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name):
return return
parent = self._stack[-2] parent = self._stack[-2]
new_node = None new_node = None
new_code = code
matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
warning_info = get_prompt_info(matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING: if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name) logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node) new_code = self.mapping_api(node)
if new_code != code: if new_code != code:
new_node = pasta.parse(new_code).body[0].value try:
# find the first call name new_node = pasta.parse(new_code).body[0].value
new_api_name = new_code[:new_code.find('(')] # find the first call name
self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name)) new_api_name = new_code[:new_code.find('(')]
if matched_api_name in ALL_UNSUPPORTED: except AttributeError:
warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '') new_node = pasta.parse(new_code).body[0]
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info) new_api_name = new_code
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info)) self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
...@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor):
elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional': elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional':
renames[ref_name] = 'F' renames[ref_name] = 'F'
return renames return renames
def _get_external_ref_whole_name(self, ref_name):
"""
Find out external reference whole name.
For example:
In the parsed source code, there is import statement
import torch.nn as new_name
_get_external_ref_whole_name('new_name') will return 'torch.nn' string.
"""
external_refs = self._code_analyzer.external_references
for external_ref_name, ref_info in external_refs.items():
external_ref_info = ref_info['external_ref_info']
if external_ref_name == ref_name:
return external_ref_info.name
return None
def _check_isinstance_parameter(self, node):
"""Check whether the second parameter of isinstance function contains the torch type."""
is_isinstance_arg = False
# Check whether node is the second parameter of the isinstance function call.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance':
isinstance_node = parent_node
seconde_arg_type_nodes = []
if isinstance(isinstance_node.args[1], ast.Tuple):
seconde_arg_type_nodes.extend(isinstance_node.args[1].elts)
else:
seconde_arg_type_nodes.append(isinstance_node.args[1])
if node in seconde_arg_type_nodes:
is_isinstance_arg = True
break
if not is_isinstance_arg:
return False
isinstance_type_arg = pasta.dump(node)
check_torch_type = False
if isinstance_type_arg:
type_splits = isinstance_type_arg.split('.')
whole_name = self._get_external_ref_whole_name(type_splits[0])
if whole_name and whole_name.startswith('torch'):
check_torch_type = True
if check_torch_type:
_, match_case = self.match_api(node, False)
if match_case != ApiMatchingEnum.NOT_API:
warn_info = 'Manually determine the conversion type.'
self._process_log.warning(node.lineno, node.col_offset,
LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info))
return check_torch_type
def visit_Attribute(self, node):
"""Callback function when visit AST tree"""
self._check_isinstance_parameter(node)
...@@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors): ...@@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors):
"""Converter error codes.""" """Converter error codes."""
SCRIPT_NOT_SUPPORT = 1 SCRIPT_NOT_SUPPORT = 1
NODE_TYPE_NOT_SUPPORT = 2 NODE_TYPE_NOT_SUPPORT = 2
CODE_SYNTAX_ERROR = 3
class ScriptNotSupport(MindInsightException): class ScriptNotSupport(MindInsightException):
...@@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException): ...@@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException):
super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT, super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT,
msg, msg,
http_code=400) http_code=400)
class CodeSyntaxError(MindInsightException):
"""The CodeSyntaxError class definition."""
def __init__(self, msg):
super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR,
msg,
http_code=400)
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import pasta import pasta
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import CodeSyntaxError
REQUIRED = 'REQUIRED' REQUIRED = 'REQUIRED'
UNREQUIRED = 'UNREQUIRED' UNREQUIRED = 'UNREQUIRED'
...@@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs' ...@@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class APIPt: class APIPt:
"""Base API for args parse, and API for one frame.""" """Base API for args parse, and API for one frame."""
def __init__(self, name: str, params: OrderedDict): def __init__(self, name: str, params: OrderedDict):
self.name = name self.name = name
self.params = OrderedDict() self.params = OrderedDict()
...@@ -77,10 +78,8 @@ class APIPt: ...@@ -77,10 +78,8 @@ class APIPt:
try: try:
ast_node = ast.parse("whatever_call_name" + args_str) ast_node = ast.parse("whatever_call_name" + args_str)
call_node = ast_node.body[0].value call_node = ast_node.body[0].value
if not isinstance(call_node, ast.Call): except SyntaxError as parse_error:
raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str)) raise CodeSyntaxError("can't parse code:\n{}".format(args_str)) from parse_error
except:
raise ValueError("can't parse code:\n{}".format(args_str))
# regard all actual parameter as one parameter # regard all actual parameter as one parameter
if len(self.params) == 1: if len(self.params) == 1:
...@@ -118,6 +117,7 @@ class APIPt: ...@@ -118,6 +117,7 @@ class APIPt:
class APIMs(APIPt): class APIMs(APIPt):
"""API for MindSpore""" """API for MindSpore"""
def __init__(self, name: str, params: OrderedDict, p_attrs=None): def __init__(self, name: str, params: OrderedDict, p_attrs=None):
self.is_primitive = name.startswith('P.') self.is_primitive = name.startswith('P.')
if self.is_primitive: if self.is_primitive:
...@@ -167,6 +167,7 @@ class APIMs(APIPt): ...@@ -167,6 +167,7 @@ class APIMs(APIPt):
class MappingHelper: class MappingHelper:
"""Mapping from one frame to another frame""" """Mapping from one frame to another frame"""
def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs): def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
ms2pt_mapping = kwargs.get('ms2pt_mapping') ms2pt_mapping = kwargs.get('ms2pt_mapping')
gen_explicit_map = kwargs.get('gen_explicit_map') gen_explicit_map = kwargs.get('gen_explicit_map')
...@@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH) ...@@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING} ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
# ---------------------------- api list support or not support ---------------------------- # ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json')) NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json'))
NN_LIST = load_json_file(NN_LIST_PATH) NN_LIST = load_json_file(NN_LIST_PATH)
...@@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST] ...@@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST]
NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING] NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING] NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json')) F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json'))
F_LIST = load_json_file(F_LIST_PATH) F_LIST = load_json_file(F_LIST_PATH)
F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
...@@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ ...@@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING] F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING] F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json')) TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json'))
TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH) TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING] TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING] TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json')) TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json'))
TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH) TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING] TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING] TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
UNSUPPORTED_WARN_INFOS = { UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.", "nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.", "nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册