From 60ea09c449e7363db7f9eb386c7b84791987974d Mon Sep 17 00:00:00 2001 From: ggpolar Date: Wed, 24 Jun 2020 11:57:55 +0800 Subject: [PATCH] The conversion report is adjusted to make the the report more reasonable and accurate. Fix: 1. Inaccurate line number for function name. 2. Duplicate report logs. 3. The code newly inserted must be displayed in the report. --- mindinsight/mindconverter/ast_edits.py | 47 ++++++++++++++++++++++---- mindinsight/mindconverter/config.py | 7 ++-- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index 4e2e3e3..0a71521 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -27,11 +27,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_2P_LIST +from mindinsight.mindconverter.config import get_corresponding_ms_name from mindinsight.mindconverter.config import get_prompt_info from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport from mindinsight.mindconverter.forward_call import ForwardCall +LOG_FMT_INSERT = "[Insert] '%s' is inserted to the converted file." LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s" LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" @@ -54,16 +56,22 @@ class _ConvertReport: def __init__(self, is_stub=False): self._is_stub = is_stub self._max_line = 0 - self._log = [] # report log, type is (severity, line, col, msg) + self._log_head = [] + self._log_body = [] # report log, type is (severity, line, col, msg) def _add_log(self, severity, line, col, msg): """Add log.""" if self._is_stub: return + if line is None and col is None: + self._log_head.append(msg) + return if isinstance(line, int) and isinstance(col, int): - self._log.append((severity, line, col, msg)) + self._log_body.append((severity, line, col, msg)) if self._max_line < line: self._max_line = line + else: + raise TypeError('The parameter type is incorrect.') def info(self, line, col, msg): """Interface to add infer log""" @@ -73,14 +81,24 @@ class _ConvertReport: """Interface to add warning log""" self._add_log(logging.WARNING, line, col, msg) + def header_msg(self, msg): + """Interface to add header message log""" + self._add_log(logging.INFO, None, None, msg) + def get_logs(self): """Get convert logs""" logs = [] + logs.extend(self._log_head) # sort rule: line * self._max_line + col - self._log.sort(key=lambda log: log[1] * self._max_line + log[2]) - for log_info in self._log: + self._log_body.sort(key=lambda log: log[1] * self._max_line + log[2]) + for log_info in self._log_body: log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3]) - logs.append(log_info) + if logs: + # Deduplication for logs + if logs[-1] != log_info: + logs.append(log_info) + else: + logs.append(log_info) return logs @@ -262,7 +280,8 @@ class AstEditVisitor(ast.NodeVisitor): new_func_name = 'construct' if func_ast_node.name == old_func_name: func_ast_node.name = new_func_name - self._process_log.info(func_ast_node.lineno, func_ast_node.col_offset, + real_line_number = self._get_real_line_number(func_ast_node) + self._process_log.info(real_line_number, func_ast_node.col_offset, LOG_FMT_CONVERT % (old_func_name, new_func_name)) def _convert_api(self): @@ -299,6 +318,15 @@ class AstEditVisitor(ast.NodeVisitor): source_code = pasta.dump(node) return source_code[pos:] + @staticmethod + def _get_real_line_number(node): + """Get the real line number of the node.""" + try: + line_number = node.lineno + len(node.decorator_list) + except AttributeError: + line_number = node.lineno + return line_number + def _replace_external_reference(self): """ Replace external reference statements. @@ -349,6 +377,7 @@ class AstEditVisitor(ast.NodeVisitor): insert_pos += 1 else: try: + # insert pos after the last one, if last one name is replaced. replaced_with_node = names_replaced_with[src_name] insert_pos = self._tree.body.index(replaced_with_node) + 1 except ValueError: @@ -359,6 +388,8 @@ class AstEditVisitor(ast.NodeVisitor): for insert_pos, new_node in new_import_node.items(): # Insert the node into the module self._tree.body.insert(insert_pos + insert_cnt, new_node) + new_code = self._dump_without_prefix(new_node) + self._process_log.header_msg(LOG_FMT_INSERT % new_code.strip()) insert_cnt += 1 @staticmethod @@ -445,8 +476,10 @@ class AstEditVisitor(ast.NodeVisitor): is_include_sub_call = self._is_include_sub_call(call_func_node) if is_include_sub_call: + # x.y().z splits to ['x.y()', 'z'] name_attributes = call_name.rsplit('.', 1) else: + # x.y.z splits to ['x', 'y', 'z'] name_attributes = call_name.split('.') # rewritten external module name @@ -665,7 +698,7 @@ class AstEditVisitor(ast.NodeVisitor): try: new_node = pasta.parse(new_code).body[0].value # find the first call name - new_api_name = new_code[:new_code.find('(')] + new_api_name = get_corresponding_ms_name(matched_api_name) except AttributeError: new_node = pasta.parse(new_code).body[0] new_api_name = new_code diff --git a/mindinsight/mindconverter/config.py b/mindinsight/mindconverter/config.py index 807561b..3ace06c 100644 --- a/mindinsight/mindconverter/config.py +++ b/mindinsight/mindconverter/config.py @@ -32,7 +32,7 @@ FUNC_MODULE = 'mindinsight.mindconverter.funcs' class APIPt: """Base API for args parse, and API for one frame.""" - def __init__(self, name: str, params: OrderedDict): + def __init__(self, name: str, params: dict): self.name = name self.params = OrderedDict() @@ -45,7 +45,7 @@ class APIPt: Trans value to str. Args: - value (Union[str,Number,int]): Each value for params of OrderedDict. + value (Union[str,Number,int]): The value to convert. Returns: str, str type of value. @@ -118,7 +118,7 @@ class APIPt: class APIMs(APIPt): """API for MindSpore""" - def __init__(self, name: str, params: OrderedDict, p_attrs=None): + def __init__(self, name: str, params: dict, p_attrs=None): self.is_primitive = name.startswith('P.') if self.is_primitive: self.p_attrs = p_attrs if p_attrs else set() @@ -450,7 +450,6 @@ UNSUPPORTED_WARN_INFOS = { "F.one_hot": "Maybe could convert to mindspore.ops.operations.OneHot.", "torch.bmm": "Maybe could convert to mindspore.ops.operations.BatchMatMul.", "torch.cumsum": "Maybe could convert to mindspore.ops.operations.CumSum.", - "F.relu": "Maybe could convert to mindspore.ops.operations.ReLU.", "F.pad": "Maybe could convert to mindspore.ops.operations.Pad.", "F.softmax": "Maybe could convert to mindspore.ops.operations.Softmax.", "torch.clamp": "Maybe could convert to mindspore.ops.composite.clip_by_value.", -- GitLab