diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index 0a715211b41dc6734d55be5ee135fa09d7232bea..af4b0df1dab3d9db6d360f9dcf61c07a6a07e907 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -20,6 +20,7 @@ import re from enum import Enum import pasta +from pasta.base import formatting as fmt from mindinsight.mindconverter.code_analysis import CodeAnalyzer from mindinsight.mindconverter.code_analysis import APIAnalysisSpec @@ -27,7 +28,7 @@ from mindinsight.mindconverter.config import ALL_MAPPING from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_2P_LIST -from mindinsight.mindconverter.config import get_corresponding_ms_name +from mindinsight.mindconverter.config import TENSOR_DOT_LIST from mindinsight.mindconverter.config import get_prompt_info from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport @@ -263,6 +264,22 @@ class AstEditVisitor(ast.NodeVisitor): self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_NOT_CONVERT % (old_code, '')) + @staticmethod + def _modify_function_name(func_def_node, new_func_name): + """Modify function name""" + if not isinstance(func_def_node, ast.FunctionDef): + raise NodeTypeNotSupport('It is not ast.FunctionDef node type.') + + old_func_name = func_def_node.name + func_def_node.name = new_func_name + + # Modify formatting information stored by pasta + old_function_def = fmt.get(func_def_node, 'function_def') + if old_function_def: + new_function_def = old_function_def.replace(old_func_name, new_func_name) + fmt.set(func_def_node, 'function_def', new_function_def) + fmt.set(func_def_node, 'name__src', new_func_name) + def _update_function_def(self, func_scope): """ Convert a PyTorch function into MindSpore function. @@ -279,7 +296,7 @@ class AstEditVisitor(ast.NodeVisitor): old_func_name = 'forward' new_func_name = 'construct' if func_ast_node.name == old_func_name: - func_ast_node.name = new_func_name + self._modify_function_name(func_ast_node, new_func_name) real_line_number = self._get_real_line_number(func_ast_node) self._process_log.info(real_line_number, func_ast_node.col_offset, LOG_FMT_CONVERT % (old_func_name, new_func_name)) @@ -496,12 +513,33 @@ class AstEditVisitor(ast.NodeVisitor): # only infer function for tensor object. # e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object. # e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object. - first_name = standard_name.split('.')[0] - if not re.search(r'\W', first_name) and len(name_attributes) > 1: + if self._check_tensor_object(call_func_node): api_name = '.' + name_attributes[-1] match_case = ApiMatchingEnum.API_INFER return api_name, match_case + def _check_tensor_object(self, node): + """Check whether the reference object of the node is a tensor object.""" + if not isinstance(node, (ast.Attribute, ast.Name)): + return False + name_attributes = self._dump_without_prefix(node).split('.') + node_ref_name = name_attributes[0] + if re.search(r'\W', node_ref_name) or len(name_attributes) == 1: + return False + + func_name = '.' + name_attributes[-1] + if func_name not in TENSOR_DOT_LIST: + return False + + is_tensor_object = True + if self._code_analyzer: + # Check whether the object is external reference. + for ref_name in self._code_analyzer.external_references: + if node_ref_name == ref_name: + is_tensor_object = False + break + return is_tensor_object + @staticmethod def _is_include_sub_call(call_func_node): """"Inspect a sub call in call expression. @@ -671,6 +709,55 @@ class AstEditVisitor(ast.NodeVisitor): return new_code + @staticmethod + def _get_detail_prompt_msg(old_node, new_node): + """Get detail converted prompt information.""" + msg = None + if isinstance(old_node, ast.Call) and isinstance(new_node, ast.Call): + old_api_name = pasta.dump(old_node.func) + new_api_name = pasta.dump(new_node.func) + if new_api_name == old_api_name: + old_parameter_num = len(old_node.args) + len(old_node.keywords) + new_parameter_num = len(new_node.args) + len(new_node.keywords) + if old_parameter_num > 1: + msg = 'Parameters are converted.' + else: + if old_parameter_num == 0 and new_parameter_num == 0: + msg = 'The API name is converted to mindspore API' + else: + msg = 'Parameter is converted.' + return msg + + def _convert_call(self, node, matched_api_name): + """"Convert the call node.""" + new_node = None + code = pasta.dump(node) + api_name = pasta.dump(node.func) + warning_info = get_prompt_info(matched_api_name) + if warning_info is None: + warning_info = '' + if matched_api_name in ALL_MAPPING: + logger.info("Line %3d start converting API: %s", node.lineno, api_name) + new_code = self.mapping_api(node) + if new_code != code: + try: + new_node = pasta.parse(new_code).body[0].value + # find the first call name + new_api_name = new_code[:new_code.find('(')] + detail_msg = self._get_detail_prompt_msg(node, new_node) + if detail_msg: + warning_info = detail_msg + ' ' + warning_info + except AttributeError: + new_node = pasta.parse(new_code).body[0] + new_api_name = new_code + self._process_log.info(node.lineno, node.col_offset, + LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) + else: + logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) + self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) + + return new_node + def visit_Call(self, node): """Callback function when visit AST tree""" code = pasta.dump(node) @@ -688,26 +775,7 @@ class AstEditVisitor(ast.NodeVisitor): new_code = code matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: - warning_info = get_prompt_info(matched_api_name) - if warning_info is None: - warning_info = '' - if matched_api_name in ALL_MAPPING: - logger.info("Line %3d start converting API: %s", node.lineno, api_name) - new_code = self.mapping_api(node) - if new_code != code: - try: - new_node = pasta.parse(new_code).body[0].value - # find the first call name - new_api_name = get_corresponding_ms_name(matched_api_name) - except AttributeError: - new_node = pasta.parse(new_code).body[0] - new_api_name = new_code - self._process_log.info(node.lineno, node.col_offset, - LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) - else: - logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) - self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) - + new_node = self._convert_call(node, matched_api_name) elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) else: