diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index 6dfa7618e80d695648eeb1b1106821229d065280..36de6e7f5eb388ec333a3da01fbf726f71cf8bee 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -20,7 +20,6 @@ import re from enum import Enum import pasta -from pasta.augment import import_utils from mindinsight.mindconverter.code_analysis import CodeAnalyzer from mindinsight.mindconverter.code_analysis import APIAnalysisSpec @@ -290,58 +289,93 @@ class AstEditVisitor(ast.NodeVisitor): for convert_fun, args in tasks: convert_fun(*args) - def _convert_external_reference(self): - """Convert import statements.""" - name_replace = APIAnalysisSpec.import_name_mapping - replace_imports = list(name_replace.values()) + @staticmethod + def _dump_without_prefix(node): + """Get the python source for an AST.""" + pos = 0 + source_prefix = pasta.base.formatting.get(node, 'prefix') + if source_prefix: + pos = len(source_prefix) + source_code = pasta.dump(node) + return source_code[pos] + + def _replace_external_reference(self): + """ + Replace external reference statements. + Returns: + dict, key is external name, value is the new replaced node. + """ + all_name_mappings = APIAnalysisSpec.import_name_mapping + names_replaced_with = dict() for ref_info in self._code_analyzer.external_references.values(): external_ref_info = ref_info['external_ref_info'] - parent_node = ref_info['parent_node'] - if parent_node is None: + import_node = ref_info['parent_node'] + if import_node is None: continue - code = pasta.dump(parent_node) + code = self._dump_without_prefix(import_node) + import_parent_node = self._code_analyzer.root_scope.parent(import_node) + # replace import with new name if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names(): external_ref_info = ref_info['external_ref_info'] - if external_ref_info.name in name_replace.keys(): - import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node) - replace_info = name_replace[external_ref_info.name] - new_ref_name = replace_info[1] - new_external_name = replace_info[0] - if new_ref_name: - new_code = f'import {new_external_name} as {new_ref_name}' - else: - new_code = f'import {new_external_name}' - - self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT % + if external_ref_info.name in all_name_mappings.keys(): + replace_info = all_name_mappings[external_ref_info.name] + new_node = self._make_import(name_to_import=replace_info[0], as_name=replace_info[1]) + new_code = pasta.dump(new_node) + pasta.ast_utils.replace_child(import_parent_node, import_node, new_node) + names_replaced_with.update({external_ref_info.name: new_node}) + self._process_log.info(import_node.lineno, import_node.col_offset, LOG_FMT_CONVERT % (code.strip(), new_code.strip())) elif external_ref_info.name.startswith('torch.'): - self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT % + self._process_log.warning(import_node.lineno, import_node.col_offset, LOG_FMT_NOT_CONVERT % (code.strip(), LOG_SUGGESTION_MANUAL_CONVERT)) else: pass + return names_replaced_with - # Insert import in reverse order, display in forward order. - for idx in range(len(replace_imports) - 1, -1, -1): - replace_import = replace_imports[idx] - if replace_import[1]: - self._add_import(name_to_import=replace_import[0], as_name=replace_import[1]) + def _convert_external_reference(self): + """Convert import statements.""" + all_name_mappings = APIAnalysisSpec.import_name_mapping + + # Step1. Replace external reference first. + names_replaced_with = self._replace_external_reference() + new_import_node = dict() + insert_pos = 0 + # Step2. Find out remaining mapping name which not found in script. + for src_name, new_import_name in all_name_mappings.items(): + if src_name not in names_replaced_with: + new_node = self._make_import(name_to_import=new_import_name[0], as_name=new_import_name[1]) + new_import_node.update({insert_pos: new_node}) + insert_pos += 1 else: - self._add_import(name_to_import=replace_import[0]) + try: + replaced_with_node = names_replaced_with[src_name] + insert_pos = self._tree.body.index(replaced_with_node) + 1 + except ValueError: + pass + + # Step3. Insert import reference in order. + insert_cnt = 0 + for insert_pos, new_node in new_import_node.items(): + # Insert the node into the module + self._tree.body.insert(insert_pos + insert_cnt, new_node) + insert_cnt += 1 - def _add_import(self, name_to_import, as_name=None): + @staticmethod + def _make_import(name_to_import, as_name=None): """ - Adds an import to the ast tree. + Create an import to the ast tree. Args: name_to_import: (string) The absolute name to import. as_name: (string) The alias for the import ("import name_to_import as asname") + + Returns: + ast.Import, a new ast.Import node. """ new_alias = ast.alias(name=name_to_import, asname=as_name) import_node = ast.Import(names=[new_alias]) - - # Insert the node at the top of the module - self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node) + return import_node def _convert_function(self, func_scope, is_forward): """