提交 ac2ad193 编写于 作者: G ggpolar

Improve the accuracy of the converted report.

1. Enhances the recognition of tensor object.
Rule 1: The API must be within the supported tensor range.
Rule 2: Excluding external reference object.

2. The modified information should reflect the modification difference.
If the API name does not changed, indicating whether the parameter is modified.
上级 c4378c37
......@@ -20,6 +20,7 @@ import re
from enum import Enum
import pasta
from pasta.base import formatting as fmt
from mindinsight.mindconverter.code_analysis import CodeAnalyzer
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec
......@@ -27,7 +28,7 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import get_corresponding_ms_name
from mindinsight.mindconverter.config import TENSOR_DOT_LIST
from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
......@@ -263,6 +264,22 @@ class AstEditVisitor(ast.NodeVisitor):
self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_NOT_CONVERT %
(old_code, ''))
@staticmethod
def _modify_function_name(func_def_node, new_func_name):
"""Modify function name"""
if not isinstance(func_def_node, ast.FunctionDef):
raise NodeTypeNotSupport('It is not ast.FunctionDef node type.')
old_func_name = func_def_node.name
func_def_node.name = new_func_name
# Modify formatting information stored by pasta
old_function_def = fmt.get(func_def_node, 'function_def')
if old_function_def:
new_function_def = old_function_def.replace(old_func_name, new_func_name)
fmt.set(func_def_node, 'function_def', new_function_def)
fmt.set(func_def_node, 'name__src', new_func_name)
def _update_function_def(self, func_scope):
"""
Convert a PyTorch function into MindSpore function.
......@@ -279,7 +296,7 @@ class AstEditVisitor(ast.NodeVisitor):
old_func_name = 'forward'
new_func_name = 'construct'
if func_ast_node.name == old_func_name:
func_ast_node.name = new_func_name
self._modify_function_name(func_ast_node, new_func_name)
real_line_number = self._get_real_line_number(func_ast_node)
self._process_log.info(real_line_number, func_ast_node.col_offset,
LOG_FMT_CONVERT % (old_func_name, new_func_name))
......@@ -496,12 +513,33 @@ class AstEditVisitor(ast.NodeVisitor):
# only infer function for tensor object.
# e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object.
# e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object.
first_name = standard_name.split('.')[0]
if not re.search(r'\W', first_name) and len(name_attributes) > 1:
if self._check_tensor_object(call_func_node):
api_name = '.' + name_attributes[-1]
match_case = ApiMatchingEnum.API_INFER
return api_name, match_case
def _check_tensor_object(self, node):
"""Check whether the reference object of the node is a tensor object."""
if not isinstance(node, (ast.Attribute, ast.Name)):
return False
name_attributes = self._dump_without_prefix(node).split('.')
node_ref_name = name_attributes[0]
if re.search(r'\W', node_ref_name) or len(name_attributes) == 1:
return False
func_name = '.' + name_attributes[-1]
if func_name not in TENSOR_DOT_LIST:
return False
is_tensor_object = True
if self._code_analyzer:
# Check whether the object is external reference.
for ref_name in self._code_analyzer.external_references:
if node_ref_name == ref_name:
is_tensor_object = False
break
return is_tensor_object
@staticmethod
def _is_include_sub_call(call_func_node):
""""Inspect a sub call in call expression.
......@@ -671,6 +709,55 @@ class AstEditVisitor(ast.NodeVisitor):
return new_code
@staticmethod
def _get_detail_prompt_msg(old_node, new_node):
"""Get detail converted prompt information."""
msg = None
if isinstance(old_node, ast.Call) and isinstance(new_node, ast.Call):
old_api_name = pasta.dump(old_node.func)
new_api_name = pasta.dump(new_node.func)
if new_api_name == old_api_name:
old_parameter_num = len(old_node.args) + len(old_node.keywords)
new_parameter_num = len(new_node.args) + len(new_node.keywords)
if old_parameter_num > 1:
msg = 'Parameters are converted.'
else:
if old_parameter_num == 0 and new_parameter_num == 0:
msg = 'The API name is converted to mindspore API'
else:
msg = 'Parameter is converted.'
return msg
def _convert_call(self, node, matched_api_name):
""""Convert the call node."""
new_node = None
code = pasta.dump(node)
api_name = pasta.dump(node.func)
warning_info = get_prompt_info(matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node)
if new_code != code:
try:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
detail_msg = self._get_detail_prompt_msg(node, new_node)
if detail_msg:
warning_info = detail_msg + ' ' + warning_info
except AttributeError:
new_node = pasta.parse(new_code).body[0]
new_api_name = new_code
self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
return new_node
def visit_Call(self, node):
"""Callback function when visit AST tree"""
code = pasta.dump(node)
......@@ -688,26 +775,7 @@ class AstEditVisitor(ast.NodeVisitor):
new_code = code
matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
warning_info = get_prompt_info(matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node)
if new_code != code:
try:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = get_corresponding_ms_name(matched_api_name)
except AttributeError:
new_node = pasta.parse(new_code).body[0]
new_api_name = new_code
self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
new_node = self._convert_call(node, matched_api_name)
elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册