提交 60ea09c4 编写于 作者: G ggpolar

The conversion report is adjusted to make the the report more reasonable and accurate.

Fix:
1. Inaccurate line number for function name.
2. Duplicate report logs.
3. The code newly inserted must be displayed in the report.
上级 6f5d4f70
......@@ -27,11 +27,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import get_corresponding_ms_name
from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.forward_call import ForwardCall
LOG_FMT_INSERT = "[Insert] '%s' is inserted to the converted file."
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
......@@ -54,16 +56,22 @@ class _ConvertReport:
def __init__(self, is_stub=False):
self._is_stub = is_stub
self._max_line = 0
self._log = [] # report log, type is (severity, line, col, msg)
self._log_head = []
self._log_body = [] # report log, type is (severity, line, col, msg)
def _add_log(self, severity, line, col, msg):
"""Add log."""
if self._is_stub:
return
if line is None and col is None:
self._log_head.append(msg)
return
if isinstance(line, int) and isinstance(col, int):
self._log.append((severity, line, col, msg))
self._log_body.append((severity, line, col, msg))
if self._max_line < line:
self._max_line = line
else:
raise TypeError('The parameter type is incorrect.')
def info(self, line, col, msg):
"""Interface to add infer log"""
......@@ -73,13 +81,23 @@ class _ConvertReport:
"""Interface to add warning log"""
self._add_log(logging.WARNING, line, col, msg)
def header_msg(self, msg):
"""Interface to add header message log"""
self._add_log(logging.INFO, None, None, msg)
def get_logs(self):
"""Get convert logs"""
logs = []
logs.extend(self._log_head)
# sort rule: line * self._max_line + col
self._log.sort(key=lambda log: log[1] * self._max_line + log[2])
for log_info in self._log:
self._log_body.sort(key=lambda log: log[1] * self._max_line + log[2])
for log_info in self._log_body:
log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3])
if logs:
# Deduplication for logs
if logs[-1] != log_info:
logs.append(log_info)
else:
logs.append(log_info)
return logs
......@@ -262,7 +280,8 @@ class AstEditVisitor(ast.NodeVisitor):
new_func_name = 'construct'
if func_ast_node.name == old_func_name:
func_ast_node.name = new_func_name
self._process_log.info(func_ast_node.lineno, func_ast_node.col_offset,
real_line_number = self._get_real_line_number(func_ast_node)
self._process_log.info(real_line_number, func_ast_node.col_offset,
LOG_FMT_CONVERT % (old_func_name, new_func_name))
def _convert_api(self):
......@@ -299,6 +318,15 @@ class AstEditVisitor(ast.NodeVisitor):
source_code = pasta.dump(node)
return source_code[pos:]
@staticmethod
def _get_real_line_number(node):
"""Get the real line number of the node."""
try:
line_number = node.lineno + len(node.decorator_list)
except AttributeError:
line_number = node.lineno
return line_number
def _replace_external_reference(self):
"""
Replace external reference statements.
......@@ -349,6 +377,7 @@ class AstEditVisitor(ast.NodeVisitor):
insert_pos += 1
else:
try:
# insert pos after the last one, if last one name is replaced.
replaced_with_node = names_replaced_with[src_name]
insert_pos = self._tree.body.index(replaced_with_node) + 1
except ValueError:
......@@ -359,6 +388,8 @@ class AstEditVisitor(ast.NodeVisitor):
for insert_pos, new_node in new_import_node.items():
# Insert the node into the module
self._tree.body.insert(insert_pos + insert_cnt, new_node)
new_code = self._dump_without_prefix(new_node)
self._process_log.header_msg(LOG_FMT_INSERT % new_code.strip())
insert_cnt += 1
@staticmethod
......@@ -445,8 +476,10 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_sub_call = self._is_include_sub_call(call_func_node)
if is_include_sub_call:
# x.y().z splits to ['x.y()', 'z']
name_attributes = call_name.rsplit('.', 1)
else:
# x.y.z splits to ['x', 'y', 'z']
name_attributes = call_name.split('.')
# rewritten external module name
......@@ -665,7 +698,7 @@ class AstEditVisitor(ast.NodeVisitor):
try:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
new_api_name = get_corresponding_ms_name(matched_api_name)
except AttributeError:
new_node = pasta.parse(new_code).body[0]
new_api_name = new_code
......
......@@ -32,7 +32,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class APIPt:
"""Base API for args parse, and API for one frame."""
def __init__(self, name: str, params: OrderedDict):
def __init__(self, name: str, params: dict):
self.name = name
self.params = OrderedDict()
......@@ -45,7 +45,7 @@ class APIPt:
Trans value to str.
Args:
value (Union[str,Number,int]): Each value for params of OrderedDict.
value (Union[str,Number,int]): The value to convert.
Returns:
str, str type of value.
......@@ -118,7 +118,7 @@ class APIPt:
class APIMs(APIPt):
"""API for MindSpore"""
def __init__(self, name: str, params: OrderedDict, p_attrs=None):
def __init__(self, name: str, params: dict, p_attrs=None):
self.is_primitive = name.startswith('P.')
if self.is_primitive:
self.p_attrs = p_attrs if p_attrs else set()
......@@ -450,7 +450,6 @@ UNSUPPORTED_WARN_INFOS = {
"F.one_hot": "Maybe could convert to mindspore.ops.operations.OneHot.",
"torch.bmm": "Maybe could convert to mindspore.ops.operations.BatchMatMul.",
"torch.cumsum": "Maybe could convert to mindspore.ops.operations.CumSum.",
"F.relu": "Maybe could convert to mindspore.ops.operations.ReLU.",
"F.pad": "Maybe could convert to mindspore.ops.operations.Pad.",
"F.softmax": "Maybe could convert to mindspore.ops.operations.Softmax.",
"torch.clamp": "Maybe could convert to mindspore.ops.composite.clip_by_value.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册