提交 60ea09c4 编写于 作者: G ggpolar

The conversion report is adjusted to make the the report more reasonable and accurate.

Fix:
1. Inaccurate line number for function name.
2. Duplicate report logs.
3. The code newly inserted must be displayed in the report.
上级 6f5d4f70
...@@ -27,11 +27,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING ...@@ -27,11 +27,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import get_corresponding_ms_name
from mindinsight.mindconverter.config import get_prompt_info from mindinsight.mindconverter.config import get_prompt_info
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.forward_call import ForwardCall from mindinsight.mindconverter.forward_call import ForwardCall
LOG_FMT_INSERT = "[Insert] '%s' is inserted to the converted file."
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s" LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
...@@ -54,16 +56,22 @@ class _ConvertReport: ...@@ -54,16 +56,22 @@ class _ConvertReport:
def __init__(self, is_stub=False): def __init__(self, is_stub=False):
self._is_stub = is_stub self._is_stub = is_stub
self._max_line = 0 self._max_line = 0
self._log = [] # report log, type is (severity, line, col, msg) self._log_head = []
self._log_body = [] # report log, type is (severity, line, col, msg)
def _add_log(self, severity, line, col, msg): def _add_log(self, severity, line, col, msg):
"""Add log.""" """Add log."""
if self._is_stub: if self._is_stub:
return return
if line is None and col is None:
self._log_head.append(msg)
return
if isinstance(line, int) and isinstance(col, int): if isinstance(line, int) and isinstance(col, int):
self._log.append((severity, line, col, msg)) self._log_body.append((severity, line, col, msg))
if self._max_line < line: if self._max_line < line:
self._max_line = line self._max_line = line
else:
raise TypeError('The parameter type is incorrect.')
def info(self, line, col, msg): def info(self, line, col, msg):
"""Interface to add infer log""" """Interface to add infer log"""
...@@ -73,14 +81,24 @@ class _ConvertReport: ...@@ -73,14 +81,24 @@ class _ConvertReport:
"""Interface to add warning log""" """Interface to add warning log"""
self._add_log(logging.WARNING, line, col, msg) self._add_log(logging.WARNING, line, col, msg)
def header_msg(self, msg):
"""Interface to add header message log"""
self._add_log(logging.INFO, None, None, msg)
def get_logs(self): def get_logs(self):
"""Get convert logs""" """Get convert logs"""
logs = [] logs = []
logs.extend(self._log_head)
# sort rule: line * self._max_line + col # sort rule: line * self._max_line + col
self._log.sort(key=lambda log: log[1] * self._max_line + log[2]) self._log_body.sort(key=lambda log: log[1] * self._max_line + log[2])
for log_info in self._log: for log_info in self._log_body:
log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3]) log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3])
logs.append(log_info) if logs:
# Deduplication for logs
if logs[-1] != log_info:
logs.append(log_info)
else:
logs.append(log_info)
return logs return logs
...@@ -262,7 +280,8 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -262,7 +280,8 @@ class AstEditVisitor(ast.NodeVisitor):
new_func_name = 'construct' new_func_name = 'construct'
if func_ast_node.name == old_func_name: if func_ast_node.name == old_func_name:
func_ast_node.name = new_func_name func_ast_node.name = new_func_name
self._process_log.info(func_ast_node.lineno, func_ast_node.col_offset, real_line_number = self._get_real_line_number(func_ast_node)
self._process_log.info(real_line_number, func_ast_node.col_offset,
LOG_FMT_CONVERT % (old_func_name, new_func_name)) LOG_FMT_CONVERT % (old_func_name, new_func_name))
def _convert_api(self): def _convert_api(self):
...@@ -299,6 +318,15 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -299,6 +318,15 @@ class AstEditVisitor(ast.NodeVisitor):
source_code = pasta.dump(node) source_code = pasta.dump(node)
return source_code[pos:] return source_code[pos:]
@staticmethod
def _get_real_line_number(node):
"""Get the real line number of the node."""
try:
line_number = node.lineno + len(node.decorator_list)
except AttributeError:
line_number = node.lineno
return line_number
def _replace_external_reference(self): def _replace_external_reference(self):
""" """
Replace external reference statements. Replace external reference statements.
...@@ -349,6 +377,7 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -349,6 +377,7 @@ class AstEditVisitor(ast.NodeVisitor):
insert_pos += 1 insert_pos += 1
else: else:
try: try:
# insert pos after the last one, if last one name is replaced.
replaced_with_node = names_replaced_with[src_name] replaced_with_node = names_replaced_with[src_name]
insert_pos = self._tree.body.index(replaced_with_node) + 1 insert_pos = self._tree.body.index(replaced_with_node) + 1
except ValueError: except ValueError:
...@@ -359,6 +388,8 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -359,6 +388,8 @@ class AstEditVisitor(ast.NodeVisitor):
for insert_pos, new_node in new_import_node.items(): for insert_pos, new_node in new_import_node.items():
# Insert the node into the module # Insert the node into the module
self._tree.body.insert(insert_pos + insert_cnt, new_node) self._tree.body.insert(insert_pos + insert_cnt, new_node)
new_code = self._dump_without_prefix(new_node)
self._process_log.header_msg(LOG_FMT_INSERT % new_code.strip())
insert_cnt += 1 insert_cnt += 1
@staticmethod @staticmethod
...@@ -445,8 +476,10 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -445,8 +476,10 @@ class AstEditVisitor(ast.NodeVisitor):
is_include_sub_call = self._is_include_sub_call(call_func_node) is_include_sub_call = self._is_include_sub_call(call_func_node)
if is_include_sub_call: if is_include_sub_call:
# x.y().z splits to ['x.y()', 'z']
name_attributes = call_name.rsplit('.', 1) name_attributes = call_name.rsplit('.', 1)
else: else:
# x.y.z splits to ['x', 'y', 'z']
name_attributes = call_name.split('.') name_attributes = call_name.split('.')
# rewritten external module name # rewritten external module name
...@@ -665,7 +698,7 @@ class AstEditVisitor(ast.NodeVisitor): ...@@ -665,7 +698,7 @@ class AstEditVisitor(ast.NodeVisitor):
try: try:
new_node = pasta.parse(new_code).body[0].value new_node = pasta.parse(new_code).body[0].value
# find the first call name # find the first call name
new_api_name = new_code[:new_code.find('(')] new_api_name = get_corresponding_ms_name(matched_api_name)
except AttributeError: except AttributeError:
new_node = pasta.parse(new_code).body[0] new_node = pasta.parse(new_code).body[0]
new_api_name = new_code new_api_name = new_code
......
...@@ -32,7 +32,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs' ...@@ -32,7 +32,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs'
class APIPt: class APIPt:
"""Base API for args parse, and API for one frame.""" """Base API for args parse, and API for one frame."""
def __init__(self, name: str, params: OrderedDict): def __init__(self, name: str, params: dict):
self.name = name self.name = name
self.params = OrderedDict() self.params = OrderedDict()
...@@ -45,7 +45,7 @@ class APIPt: ...@@ -45,7 +45,7 @@ class APIPt:
Trans value to str. Trans value to str.
Args: Args:
value (Union[str,Number,int]): Each value for params of OrderedDict. value (Union[str,Number,int]): The value to convert.
Returns: Returns:
str, str type of value. str, str type of value.
...@@ -118,7 +118,7 @@ class APIPt: ...@@ -118,7 +118,7 @@ class APIPt:
class APIMs(APIPt): class APIMs(APIPt):
"""API for MindSpore""" """API for MindSpore"""
def __init__(self, name: str, params: OrderedDict, p_attrs=None): def __init__(self, name: str, params: dict, p_attrs=None):
self.is_primitive = name.startswith('P.') self.is_primitive = name.startswith('P.')
if self.is_primitive: if self.is_primitive:
self.p_attrs = p_attrs if p_attrs else set() self.p_attrs = p_attrs if p_attrs else set()
...@@ -450,7 +450,6 @@ UNSUPPORTED_WARN_INFOS = { ...@@ -450,7 +450,6 @@ UNSUPPORTED_WARN_INFOS = {
"F.one_hot": "Maybe could convert to mindspore.ops.operations.OneHot.", "F.one_hot": "Maybe could convert to mindspore.ops.operations.OneHot.",
"torch.bmm": "Maybe could convert to mindspore.ops.operations.BatchMatMul.", "torch.bmm": "Maybe could convert to mindspore.ops.operations.BatchMatMul.",
"torch.cumsum": "Maybe could convert to mindspore.ops.operations.CumSum.", "torch.cumsum": "Maybe could convert to mindspore.ops.operations.CumSum.",
"F.relu": "Maybe could convert to mindspore.ops.operations.ReLU.",
"F.pad": "Maybe could convert to mindspore.ops.operations.Pad.", "F.pad": "Maybe could convert to mindspore.ops.operations.Pad.",
"F.softmax": "Maybe could convert to mindspore.ops.operations.Softmax.", "F.softmax": "Maybe could convert to mindspore.ops.operations.Softmax.",
"torch.clamp": "Maybe could convert to mindspore.ops.composite.clip_by_value.", "torch.clamp": "Maybe could convert to mindspore.ops.composite.clip_by_value.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册