diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index 0a715211b41dc6734d55be5ee135fa09d7232bea..40b79d70562fc61fcf9e2482102b02fed484f76b 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -27,7 +27,6 @@ from mindinsight.mindconverter.config import ALL_MAPPING from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_2P_LIST -from mindinsight.mindconverter.config import get_corresponding_ms_name from mindinsight.mindconverter.config import get_prompt_info from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport @@ -671,6 +670,55 @@ class AstEditVisitor(ast.NodeVisitor): return new_code + @staticmethod + def _get_detail_prompt_msg(old_node, new_node): + """Get detail converted prompt information.""" + msg = None + if isinstance(old_node, ast.Call) and isinstance(new_node, ast.Call): + old_api_name = pasta.dump(old_node.func) + new_api_name = pasta.dump(new_node.func) + if new_api_name == old_api_name: + old_parameter_num = len(old_node.args) + len(old_node.keywords) + new_parameter_num = len(new_node.args) + len(new_node.keywords) + if old_parameter_num > 1: + msg = 'Parameters are converted.' + else: + if old_parameter_num == 0 and new_parameter_num == 0: + msg = 'The API name is converted to mindspore API' + else: + msg = 'Parameter is converted.' + return msg + + def _convert_call(self, node, matched_api_name): + """"Convert the call node.""" + new_node = None + code = pasta.dump(node) + api_name = pasta.dump(node.func) + warning_info = get_prompt_info(matched_api_name) + if warning_info is None: + warning_info = '' + if matched_api_name in ALL_MAPPING: + logger.info("Line %3d start converting API: %s", node.lineno, api_name) + new_code = self.mapping_api(node) + if new_code != code: + try: + new_node = pasta.parse(new_code).body[0].value + # find the first call name + new_api_name = new_code[:new_code.find('(')] + detail_msg = self._get_detail_prompt_msg(node, new_node) + if detail_msg: + warning_info = detail_msg + ' ' + warning_info + except AttributeError: + new_node = pasta.parse(new_code).body[0] + new_api_name = new_code + self._process_log.info(node.lineno, node.col_offset, + LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) + else: + logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) + self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) + + return new_node + def visit_Call(self, node): """Callback function when visit AST tree""" code = pasta.dump(node) @@ -688,26 +736,7 @@ class AstEditVisitor(ast.NodeVisitor): new_code = code matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: - warning_info = get_prompt_info(matched_api_name) - if warning_info is None: - warning_info = '' - if matched_api_name in ALL_MAPPING: - logger.info("Line %3d start converting API: %s", node.lineno, api_name) - new_code = self.mapping_api(node) - if new_code != code: - try: - new_node = pasta.parse(new_code).body[0].value - # find the first call name - new_api_name = get_corresponding_ms_name(matched_api_name) - except AttributeError: - new_node = pasta.parse(new_code).body[0] - new_api_name = new_code - self._process_log.info(node.lineno, node.col_offset, - LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) - else: - logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) - self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) - + new_node = self._convert_call(node, matched_api_name) elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) else: