提交 b4709a0b 编写于 作者: G ggpolar

The modified information should reflect the modification difference.

If the API name does not changed, indicating whether the parameter is modified.
上级 70f854b2
...@@ -27,7 +27,6 @@ from mindinsight.mindconverter.config import ALL_MAPPING ...@@ -27,7 +27,6 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import get_corresponding_ms_name
from mindinsight.mindconverter.config import get_prompt_info from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
...@@ -671,6 +670,55 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -671,6 +670,55 @@ class AstEditVisitor(ast.NodeVisitor):
return new_code return new_code
@staticmethod
def _get_detail_prompt_msg(old_node, new_node):
"""Get detail converted prompt information."""
msg = None
if isinstance(old_node, ast.Call) and isinstance(new_node, ast.Call):
old_api_name = pasta.dump(old_node.func)
new_api_name = pasta.dump(new_node.func)
if new_api_name == old_api_name:
old_parameter_num = len(old_node.args) + len(old_node.keywords)
new_parameter_num = len(new_node.args) + len(new_node.keywords)
if old_parameter_num > 1:
msg = 'Parameters are converted.'
else:
if old_parameter_num == 0 and new_parameter_num == 0:
msg = 'The API name is converted to mindspore API'
else:
msg = 'Parameter is converted.'
return msg
def _convert_call(self, node, matched_api_name):
""""Convert the call node."""
new_node = None
code = pasta.dump(node)
api_name = pasta.dump(node.func)
warning_info = get_prompt_info(matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node)
if new_code != code:
try:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
detail_msg = self._get_detail_prompt_msg(node, new_node)
if detail_msg:
warning_info = detail_msg + ' ' + warning_info
except AttributeError:
new_node = pasta.parse(new_code).body[0]
new_api_name = new_code
self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
return new_node
def visit_Call(self, node): def visit_Call(self, node):
"""Callback function when visit AST tree""" """Callback function when visit AST tree"""
code = pasta.dump(node) code = pasta.dump(node)
...@@ -688,26 +736,7 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -688,26 +736,7 @@ class AstEditVisitor(ast.NodeVisitor):
new_code = code new_code = code
matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
warning_info = get_prompt_info(matched_api_name) new_node = self._convert_call(node, matched_api_name)
if warning_info is None:
warning_info = ''
if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node)
if new_code != code:
try:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = get_corresponding_ms_name(matched_api_name)
except AttributeError:
new_node = pasta.parse(new_code).body[0]
new_api_name = new_code
self._process_log.info(node.lineno, node.col_offset,
LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
else:
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册