提交 dde197ce 编写于 作者: G ggpolar

Modify the prompt message and parse more statements.

1. More detailed reports are added to the conversion report.
2. The conversion prompt is provided for the 'isinstance' statement in the conversion script.
上级 0da8d2a8
......@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.forward_call import ForwardCall
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
LOG_FMT_PROMPT_INFO = "[INFO] %s"
LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it."
......@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor):
class _NodeInfo:
"""NodeInfo class definition."""
def __init__(self, node):
self.node = node
self.call_list = [] # Used to save all ast.Call node in self._node
......@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_call = False
return is_include_call
def match_api(self, call_func_node, is_forward):
def match_api(self, call_func_node, is_forward, check_context=True):
"""
Check api name to convert, check api name ok with a is_forward condition.
Args:
call_func_node (ast.Attribute): The call.func node.
is_forward (bool): whether api belong to forward.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the standard api name used to match.
ApiMappingEnum, the match result.
"""
api_name, match_case = self._infer_api_name(call_func_node)
match_case = ApiMatchingEnum.NOT_API
api_call_name = pasta.dump(call_func_node)
if api_call_name.startswith('self.'):
return api_call_name, match_case
api_name, match_case = self._infer_api_name(call_func_node, check_context)
api_call_name = pasta.dump(call_func_node)
is_tensor_obj_call = False
if api_name != api_call_name:
......@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor):
# rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref:
standard_api_call_name = self._mapping_standard_api_name(api_name)
if not is_tensor_obj_call:
standard_api_call_name = self._get_api_whole_name(call_func_node, check_context)
if standard_api_call_name in ALL_TORCH_APIS:
match_case = ApiMatchingEnum.API_FOUND
if (not is_forward and standard_api_call_name in NN_LIST) or \
(is_forward and standard_api_call_name in ALL_2P_LIST):
match_case = ApiMatchingEnum.API_MATCHED
else:
if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'):
match_case = ApiMatchingEnum.API_MATCHED
return standard_api_call_name, match_case
@staticmethod
......@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor):
parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos]
return parameters_str
def _get_api_whole_name(self, call_func_node, check_context=True):
"""
Get the whole name for the call node.
Args:
call_func_node (AST): The func attribute of ast.Call.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the whole name.
"""
api_name, match_case = self._infer_api_name(call_func_node, check_context)
if match_case == ApiMatchingEnum.API_STANDARD:
api_name_splits = api_name.split('.')
api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0])
if api_name_splits[0]:
api_name = '.'.join(api_name_splits)
return api_name
def mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
......@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor):
if api_call_name.startswith('self.'):
return code
new_code = self._mapping_api(call_node, check_context)
return new_code
def _mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
If do not check context of the script, the code represented by the node must be written in the standard way.
Args:
call_node (ast.Call): The ast node to convert.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the converted code.
"""
code = pasta.dump(call_node)
api_call_name = pasta.dump(call_node.func)
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
args_str = '(' + self._get_call_parameters_str(call_node) + ')'
......@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor):
code = pasta.dump(node)
api_name = pasta.dump(node.func)
# parent node first call is equal to this node, skip when parent node is replaced.
for parent_node in self._stack[:-1]:
# The parent node first call is equal to this node, skip when parent node is replaced.
# This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
# P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name):
return
parent = self._stack[-2]
new_node = None
new_code = code
matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
warning_info = get_prompt_info(matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node)
if new_code != code:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name))
if matched_api_name in ALL_UNSUPPORTED:
warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '')
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info))
try:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
except AttributeError:
new_node = pasta.parse(new_code).body[0]
new_api_name = new_code
self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
......@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor):
elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional':
renames[ref_name] = 'F'
return renames
def _get_external_ref_whole_name(self, ref_name):
"""
Find out external reference whole name.
For example:
In the parsed source code, there is import statement
import torch.nn as new_name
_get_external_ref_whole_name('new_name') will return 'torch.nn' string.
"""
external_refs = self._code_analyzer.external_references
for external_ref_name, ref_info in external_refs.items():
external_ref_info = ref_info['external_ref_info']
if external_ref_name == ref_name:
return external_ref_info.name
return None
def _check_isinstance_parameter(self, node):
"""Check whether the second parameter of isinstance function contains the torch type."""
is_isinstance_arg = False
# Check whether node is the second parameter of the isinstance function call.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance':
isinstance_node = parent_node
seconde_arg_type_nodes = []
if isinstance(isinstance_node.args[1], ast.Tuple):
seconde_arg_type_nodes.extend(isinstance_node.args[1].elts)
else:
seconde_arg_type_nodes.append(isinstance_node.args[1])
if node in seconde_arg_type_nodes:
is_isinstance_arg = True
break
if not is_isinstance_arg:
return False
isinstance_type_arg = pasta.dump(node)
check_torch_type = False
if isinstance_type_arg:
type_splits = isinstance_type_arg.split('.')
whole_name = self._get_external_ref_whole_name(type_splits[0])
if whole_name and whole_name.startswith('torch'):
check_torch_type = True
if check_torch_type:
_, match_case = self.match_api(node, False)
if match_case != ApiMatchingEnum.NOT_API:
warn_info = 'Manually determine the conversion type.'
self._process_log.warning(node.lineno, node.col_offset,
LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info))
return check_torch_type
def visit_Attribute(self, node):
"""Callback function when visit AST tree"""
self._check_isinstance_parameter(node)
......@@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors):
"""Converter error codes."""
SCRIPT_NOT_SUPPORT = 1
NODE_TYPE_NOT_SUPPORT = 2
CODE_SYNTAX_ERROR = 3
class ScriptNotSupport(MindInsightException):
......@@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException):
super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT,
msg,
http_code=400)
class CodeSyntaxError(MindInsightException):
"""The CodeSyntaxError class definition."""
def __init__(self, msg):
super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR,
msg,
http_code=400)
......@@ -22,7 +22,7 @@ import os
import pasta
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import CodeSyntaxError
REQUIRED = 'REQUIRED'
UNREQUIRED = 'UNREQUIRED'
......@@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class APIPt:
"""Base API for args parse, and API for one frame."""
def __init__(self, name: str, params: OrderedDict):
self.name = name
self.params = OrderedDict()
......@@ -77,10 +78,8 @@ class APIPt:
try:
ast_node = ast.parse("whatever_call_name" + args_str)
call_node = ast_node.body[0].value
if not isinstance(call_node, ast.Call):
raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str))
except:
raise ValueError("can't parse code:\n{}".format(args_str))
except SyntaxError as parse_error:
raise CodeSyntaxError("can't parse code:\n{}".format(args_str)) from parse_error
# regard all actual parameter as one parameter
if len(self.params) == 1:
......@@ -118,6 +117,7 @@ class APIPt:
class APIMs(APIPt):
"""API for MindSpore"""
def __init__(self, name: str, params: OrderedDict, p_attrs=None):
self.is_primitive = name.startswith('P.')
if self.is_primitive:
......@@ -167,6 +167,7 @@ class APIMs(APIPt):
class MappingHelper:
"""Mapping from one frame to another frame"""
def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
ms2pt_mapping = kwargs.get('ms2pt_mapping')
gen_explicit_map = kwargs.get('gen_explicit_map')
......@@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
# ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json'))
NN_LIST = load_json_file(NN_LIST_PATH)
......@@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST]
NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json'))
F_LIST = load_json_file(F_LIST_PATH)
F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
......@@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json'))
TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json'))
TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册