提交 97ca18d2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!328 The annotate maybe lost when delete the import statement.

Merge pull request !328 from ggpolar/br_wzk_dev
......@@ -20,7 +20,6 @@ import re
from enum import Enum
import pasta
from pasta.augment import import_utils
from mindinsight.mindconverter.code_analysis import CodeAnalyzer
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec
......@@ -290,58 +289,93 @@ class AstEditVisitor(ast.NodeVisitor):
for convert_fun, args in tasks:
convert_fun(*args)
def _convert_external_reference(self):
"""Convert import statements."""
name_replace = APIAnalysisSpec.import_name_mapping
replace_imports = list(name_replace.values())
@staticmethod
def _dump_without_prefix(node):
"""Get the python source for an AST."""
pos = 0
source_prefix = pasta.base.formatting.get(node, 'prefix')
if source_prefix:
pos = len(source_prefix)
source_code = pasta.dump(node)
return source_code[pos]
def _replace_external_reference(self):
"""
Replace external reference statements.
Returns:
dict, key is external name, value is the new replaced node.
"""
all_name_mappings = APIAnalysisSpec.import_name_mapping
names_replaced_with = dict()
for ref_info in self._code_analyzer.external_references.values():
external_ref_info = ref_info['external_ref_info']
parent_node = ref_info['parent_node']
if parent_node is None:
import_node = ref_info['parent_node']
if import_node is None:
continue
code = pasta.dump(parent_node)
code = self._dump_without_prefix(import_node)
import_parent_node = self._code_analyzer.root_scope.parent(import_node)
# replace import with new name
if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names():
external_ref_info = ref_info['external_ref_info']
if external_ref_info.name in name_replace.keys():
import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node)
replace_info = name_replace[external_ref_info.name]
new_ref_name = replace_info[1]
new_external_name = replace_info[0]
if new_ref_name:
new_code = f'import {new_external_name} as {new_ref_name}'
else:
new_code = f'import {new_external_name}'
self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT %
if external_ref_info.name in all_name_mappings.keys():
replace_info = all_name_mappings[external_ref_info.name]
new_node = self._make_import(name_to_import=replace_info[0], as_name=replace_info[1])
new_code = pasta.dump(new_node)
pasta.ast_utils.replace_child(import_parent_node, import_node, new_node)
names_replaced_with.update({external_ref_info.name: new_node})
self._process_log.info(import_node.lineno, import_node.col_offset, LOG_FMT_CONVERT %
(code.strip(), new_code.strip()))
elif external_ref_info.name.startswith('torch.'):
self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT %
self._process_log.warning(import_node.lineno, import_node.col_offset, LOG_FMT_NOT_CONVERT %
(code.strip(), LOG_SUGGESTION_MANUAL_CONVERT))
else:
pass
return names_replaced_with
# Insert import in reverse order, display in forward order.
for idx in range(len(replace_imports) - 1, -1, -1):
replace_import = replace_imports[idx]
if replace_import[1]:
self._add_import(name_to_import=replace_import[0], as_name=replace_import[1])
def _convert_external_reference(self):
"""Convert import statements."""
all_name_mappings = APIAnalysisSpec.import_name_mapping
# Step1. Replace external reference first.
names_replaced_with = self._replace_external_reference()
new_import_node = dict()
insert_pos = 0
# Step2. Find out remaining mapping name which not found in script.
for src_name, new_import_name in all_name_mappings.items():
if src_name not in names_replaced_with:
new_node = self._make_import(name_to_import=new_import_name[0], as_name=new_import_name[1])
new_import_node.update({insert_pos: new_node})
insert_pos += 1
else:
self._add_import(name_to_import=replace_import[0])
try:
replaced_with_node = names_replaced_with[src_name]
insert_pos = self._tree.body.index(replaced_with_node) + 1
except ValueError:
pass
# Step3. Insert import reference in order.
insert_cnt = 0
for insert_pos, new_node in new_import_node.items():
# Insert the node into the module
self._tree.body.insert(insert_pos + insert_cnt, new_node)
insert_cnt += 1
def _add_import(self, name_to_import, as_name=None):
@staticmethod
def _make_import(name_to_import, as_name=None):
"""
Adds an import to the ast tree.
Create an import to the ast tree.
Args:
name_to_import: (string) The absolute name to import.
as_name: (string) The alias for the import ("import name_to_import as asname")
Returns:
ast.Import, a new ast.Import node.
"""
new_alias = ast.alias(name=name_to_import, asname=as_name)
import_node = ast.Import(names=[new_alias])
# Insert the node at the top of the module
self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node)
return import_node
def _convert_function(self, func_scope, is_forward):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册