提交 5c5ea91a 编写于 作者: G ggpolar

Fix issue that the annotate maybe lost when delete the import statement.

The solution of replacing import nodes not to delete and then add, can be retain the annotate.
上级 03a32ebc
......@@ -20,7 +20,6 @@ import re
from enum import Enum
import pasta
from pasta.augment import import_utils
from mindinsight.mindconverter.code_analysis import CodeAnalyzer
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec
......@@ -290,58 +289,93 @@ class AstEditVisitor(ast.NodeVisitor):
for convert_fun, args in tasks:
convert_fun(*args)
def _convert_external_reference(self):
"""Convert import statements."""
name_replace = APIAnalysisSpec.import_name_mapping
replace_imports = list(name_replace.values())
@staticmethod
def _dump_without_prefix(node):
"""Get the python source for an AST."""
pos = 0
source_prefix = pasta.base.formatting.get(node, 'prefix')
if source_prefix:
pos = len(source_prefix)
source_code = pasta.dump(node)
return source_code[pos]
def _replace_external_reference(self):
"""
Replace external reference statements.
Returns:
dict, key is external name, value is the new replaced node.
"""
all_name_mappings = APIAnalysisSpec.import_name_mapping
names_replaced_with = dict()
for ref_info in self._code_analyzer.external_references.values():
external_ref_info = ref_info['external_ref_info']
parent_node = ref_info['parent_node']
if parent_node is None:
import_node = ref_info['parent_node']
if import_node is None:
continue
code = pasta.dump(parent_node)
code = self._dump_without_prefix(import_node)
import_parent_node = self._code_analyzer.root_scope.parent(import_node)
# replace import with new name
if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names():
external_ref_info = ref_info['external_ref_info']
if external_ref_info.name in name_replace.keys():
import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node)
replace_info = name_replace[external_ref_info.name]
new_ref_name = replace_info[1]
new_external_name = replace_info[0]
if new_ref_name:
new_code = f'import {new_external_name} as {new_ref_name}'
else:
new_code = f'import {new_external_name}'
self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT %
if external_ref_info.name in all_name_mappings.keys():
replace_info = all_name_mappings[external_ref_info.name]
new_node = self._make_import(name_to_import=replace_info[0], as_name=replace_info[1])
new_code = pasta.dump(new_node)
pasta.ast_utils.replace_child(import_parent_node, import_node, new_node)
names_replaced_with.update({external_ref_info.name: new_node})
self._process_log.info(import_node.lineno, import_node.col_offset, LOG_FMT_CONVERT %
(code.strip(), new_code.strip()))
elif external_ref_info.name.startswith('torch.'):
self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT %
self._process_log.warning(import_node.lineno, import_node.col_offset, LOG_FMT_NOT_CONVERT %
(code.strip(), LOG_SUGGESTION_MANUAL_CONVERT))
else:
pass
return names_replaced_with
# Insert import in reverse order, display in forward order.
for idx in range(len(replace_imports) - 1, -1, -1):
replace_import = replace_imports[idx]
if replace_import[1]:
self._add_import(name_to_import=replace_import[0], as_name=replace_import[1])
def _convert_external_reference(self):
"""Convert import statements."""
all_name_mappings = APIAnalysisSpec.import_name_mapping
# Step1. Replace external reference first.
names_replaced_with = self._replace_external_reference()
new_import_node = dict()
insert_pos = 0
# Step2. Find out remaining mapping name which not found in script.
for src_name, new_import_name in all_name_mappings.items():
if src_name not in names_replaced_with:
new_node = self._make_import(name_to_import=new_import_name[0], as_name=new_import_name[1])
new_import_node.update({insert_pos: new_node})
insert_pos += 1
else:
self._add_import(name_to_import=replace_import[0])
try:
replaced_with_node = names_replaced_with[src_name]
insert_pos = self._tree.body.index(replaced_with_node) + 1
except ValueError:
pass
# Step3. Insert import reference in order.
insert_cnt = 0
for insert_pos, new_node in new_import_node.items():
# Insert the node into the module
self._tree.body.insert(insert_pos + insert_cnt, new_node)
insert_cnt += 1
def _add_import(self, name_to_import, as_name=None):
@staticmethod
def _make_import(name_to_import, as_name=None):
"""
Adds an import to the ast tree.
Create an import to the ast tree.
Args:
name_to_import: (string) The absolute name to import.
as_name: (string) The alias for the import ("import name_to_import as asname")
Returns:
ast.Import, a new ast.Import node.
"""
new_alias = ast.alias(name=name_to_import, asname=as_name)
import_node = ast.Import(names=[new_alias])
# Insert the node at the top of the module
self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node)
return import_node
def _convert_function(self, func_scope, is_forward):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册