提交 f1f3dbc4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!306 Converter: Modify the prompt message and parse more statements.

Merge pull request !306 from ggpolar/br_wzk_dev
......@@ -28,13 +28,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.forward_call import ForwardCall
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
LOG_FMT_PROMPT_INFO = "[INFO] %s"
LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it."
......@@ -95,6 +95,7 @@ class _LineColEditVisitor(ast.NodeVisitor):
class _NodeInfo:
"""NodeInfo class definition."""
def __init__(self, node):
self.node = node
self.call_list = [] # Used to save all ast.Call node in self._node
......@@ -444,19 +445,25 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_call = False
return is_include_call
def match_api(self, call_func_node, is_forward):
def match_api(self, call_func_node, is_forward, check_context=True):
"""
Check api name to convert, check api name ok with a is_forward condition.
Args:
call_func_node (ast.Attribute): The call.func node.
is_forward (bool): whether api belong to forward.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the standard api name used to match.
ApiMappingEnum, the match result.
"""
api_name, match_case = self._infer_api_name(call_func_node)
match_case = ApiMatchingEnum.NOT_API
api_call_name = pasta.dump(call_func_node)
if api_call_name.startswith('self.'):
return api_call_name, match_case
api_name, match_case = self._infer_api_name(call_func_node, check_context)
api_call_name = pasta.dump(call_func_node)
is_tensor_obj_call = False
if api_name != api_call_name:
......@@ -466,15 +473,17 @@ class AstEditVisitor(ast.NodeVisitor):
# rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref:
standard_api_call_name = self._mapping_standard_api_name(api_name)
if not is_tensor_obj_call:
standard_api_call_name = self._get_api_whole_name(call_func_node, check_context)
if standard_api_call_name in ALL_TORCH_APIS:
match_case = ApiMatchingEnum.API_FOUND
if (not is_forward and standard_api_call_name in NN_LIST) or \
(is_forward and standard_api_call_name in ALL_2P_LIST):
match_case = ApiMatchingEnum.API_MATCHED
else:
if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'):
match_case = ApiMatchingEnum.API_MATCHED
return standard_api_call_name, match_case
@staticmethod
......@@ -502,6 +511,25 @@ class AstEditVisitor(ast.NodeVisitor):
parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos]
return parameters_str
def _get_api_whole_name(self, call_func_node, check_context=True):
"""
Get the whole name for the call node.
Args:
call_func_node (AST): The func attribute of ast.Call.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the whole name.
"""
api_name, match_case = self._infer_api_name(call_func_node, check_context)
if match_case == ApiMatchingEnum.API_STANDARD:
api_name_splits = api_name.split('.')
api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0])
if api_name_splits[0]:
api_name = '.'.join(api_name_splits)
return api_name
def mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
......@@ -522,6 +550,26 @@ class AstEditVisitor(ast.NodeVisitor):
if api_call_name.startswith('self.'):
return code
new_code = self._mapping_api(call_node, check_context)
return new_code
def _mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
If do not check context of the script, the code represented by the node must be written in the standard way.
Args:
call_node (ast.Call): The ast node to convert.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the converted code.
"""
code = pasta.dump(call_node)
api_call_name = pasta.dump(call_node.func)
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
args_str = '(' + self._get_call_parameters_str(call_node) + ')'
......@@ -551,26 +599,37 @@ class AstEditVisitor(ast.NodeVisitor):
code = pasta.dump(node)
api_name = pasta.dump(node.func)
# parent node first call is equal to this node, skip when parent node is replaced.
for parent_node in self._stack[:-1]:
# The parent node first call is equal to this node, skip when parent node is replaced.
# This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
# P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name):
return
parent = self._stack[-2]
new_node = None
new_code = code
matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
warning_info = get_prompt_info(matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node)
if new_code != code:
try:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name))
if matched_api_name in ALL_UNSUPPORTED:
warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '')
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info))
except AttributeError:
new_node = pasta.parse(new_code).body[0]
new_api_name = new_code
self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
......@@ -602,3 +661,57 @@ class AstEditVisitor(ast.NodeVisitor):
elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional':
renames[ref_name] = 'F'
return renames
def _get_external_ref_whole_name(self, ref_name):
"""
Find out external reference whole name.
For example:
In the parsed source code, there is import statement
import torch.nn as new_name
_get_external_ref_whole_name('new_name') will return 'torch.nn' string.
"""
external_refs = self._code_analyzer.external_references
for external_ref_name, ref_info in external_refs.items():
external_ref_info = ref_info['external_ref_info']
if external_ref_name == ref_name:
return external_ref_info.name
return None
def _check_isinstance_parameter(self, node):
"""Check whether the second parameter of isinstance function contains the torch type."""
is_isinstance_arg = False
# Check whether node is the second parameter of the isinstance function call.
# Access from the penultimate element in reverse order.
for parent_node in self._stack[-2::-1]:
if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance':
isinstance_node = parent_node
seconde_arg_type_nodes = []
if isinstance(isinstance_node.args[1], ast.Tuple):
seconde_arg_type_nodes.extend(isinstance_node.args[1].elts)
else:
seconde_arg_type_nodes.append(isinstance_node.args[1])
if node in seconde_arg_type_nodes:
is_isinstance_arg = True
break
if not is_isinstance_arg:
return False
isinstance_type_arg = pasta.dump(node)
check_torch_type = False
if isinstance_type_arg:
type_splits = isinstance_type_arg.split('.')
whole_name = self._get_external_ref_whole_name(type_splits[0])
if whole_name and whole_name.startswith('torch'):
check_torch_type = True
if check_torch_type:
_, match_case = self.match_api(node, False)
if match_case != ApiMatchingEnum.NOT_API:
warn_info = 'Manually determine the conversion type.'
self._process_log.warning(node.lineno, node.col_offset,
LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info))
return check_torch_type
def visit_Attribute(self, node):
"""Callback function when visit AST tree"""
self._check_isinstance_parameter(node)
......@@ -24,6 +24,7 @@ class ConverterErrors(ScriptConverterErrors):
"""Converter error codes."""
SCRIPT_NOT_SUPPORT = 1
NODE_TYPE_NOT_SUPPORT = 2
CODE_SYNTAX_ERROR = 3
class ScriptNotSupport(MindInsightException):
......@@ -42,3 +43,12 @@ class NodeTypeNotSupport(MindInsightException):
super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT,
msg,
http_code=400)
class CodeSyntaxError(MindInsightException):
"""The CodeSyntaxError class definition."""
def __init__(self, msg):
super(CodeSyntaxError, self).__init__(ConverterErrors.CODE_SYNTAX_ERROR,
msg,
http_code=400)
......@@ -22,7 +22,7 @@ import os
import pasta
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import CodeSyntaxError
REQUIRED = 'REQUIRED'
UNREQUIRED = 'UNREQUIRED'
......@@ -31,6 +31,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class APIPt:
"""Base API for args parse, and API for one frame."""
def __init__(self, name: str, params: OrderedDict):
self.name = name
self.params = OrderedDict()
......@@ -77,10 +78,8 @@ class APIPt:
try:
ast_node = ast.parse("whatever_call_name" + args_str)
call_node = ast_node.body[0].value
if not isinstance(call_node, ast.Call):
raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str))
except:
raise ValueError("can't parse code:\n{}".format(args_str))
except SyntaxError as parse_error:
raise CodeSyntaxError("can't parse code:\n{}".format(args_str)) from parse_error
# regard all actual parameter as one parameter
if len(self.params) == 1:
......@@ -118,6 +117,7 @@ class APIPt:
class APIMs(APIPt):
"""API for MindSpore"""
def __init__(self, name: str, params: OrderedDict, p_attrs=None):
self.is_primitive = name.startswith('P.')
if self.is_primitive:
......@@ -167,6 +167,7 @@ class APIMs(APIPt):
class MappingHelper:
"""Mapping from one frame to another frame"""
def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
ms2pt_mapping = kwargs.get('ms2pt_mapping')
gen_explicit_map = kwargs.get('gen_explicit_map')
......@@ -392,7 +393,6 @@ TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
# ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json'))
NN_LIST = load_json_file(NN_LIST_PATH)
......@@ -400,7 +400,6 @@ NN_LIST += ["torch." + name for name in NN_LIST]
NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json'))
F_LIST = load_json_file(F_LIST_PATH)
F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
......@@ -408,29 +407,23 @@ F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json'))
TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json'))
TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册