diff --git a/mindinsight/mindconverter/converter.py b/mindinsight/mindconverter/converter.py index 43da14bb25f9f0bb0c07cfc2fd6adac708b967be..09a550f4f5ffae25af0403f5371e83218c433379 100644 --- a/mindinsight/mindconverter/converter.py +++ b/mindinsight/mindconverter/converter.py @@ -64,7 +64,9 @@ class Converter: self._report.append('[Convert Over]') dest_file = os.path.join(output_dir, os.path.basename(infile)) with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: - file.write(pasta.dump(self._tree)) + script = pasta.dump(self._tree) + script = adjust_mindspore_import_position(script) + file.write(script) logger.info("Convert success. Result is wrote to %s.", dest_file) except ScriptNotSupport as error: self._report.append('[ScriptNotSupport] ' + error.message) @@ -97,6 +99,61 @@ class Converter: return replaced_code +def get_code_start_line_num(source_lines): + """ + Get the start code line number exclude comments. + + Args: + source_lines (list[str]): Split results of code. + + Returns: + int, the start line number. + """ + stack = [] + index = 0 + for i, line in enumerate(source_lines): + line_strip = line.strip() + if line_strip.startswith('#'): + continue + if line_strip.startswith('"""'): + if not line_strip.endswith('"""'): + stack.append('"""') + continue + if line_strip.startswith("'''"): + if not line_strip.endswith("'''"): + stack.append("'''") + continue + if line_strip.endswith('"""') or line_strip.endswith("'''"): + stack.pop() + continue + if line_strip != '' and not stack: + index = i + break + return index + + +def adjust_mindspore_import_position(script): + """ + Adjust code sentence `import mindspore` in script to a proper position if the sentence is set before a comment. + + Args: + script (str): code script before adjust. + + Returns: + str, code script adjusted. + """ + script_list = script.split('\n') + import_ms_sentence = 'import mindspore' + if import_ms_sentence in script_list: + import_index = script_list.index(import_ms_sentence) + if script_list[import_index + 1].startswith('"""') or script_list[import_index + 1].startswith("'''"): + script_list.pop(import_index) + new_index = get_code_start_line_num(script_list) + script_list.insert(new_index, import_ms_sentence) + script = '\n'.join(script_list) + return script + + def _get_name_ext(file): """ Split a file name in name and extension.