diff --git a/mindinsight/mindconverter/converter.py b/mindinsight/mindconverter/converter.py index 0addae6cd59cc4c64c295072d1952bc0be0bf6ef..575f6a2ed7e0ac53bcad294c5a07a3e587fe9ccd 100644 --- a/mindinsight/mindconverter/converter.py +++ b/mindinsight/mindconverter/converter.py @@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.forward_call import ForwardCall +LINE_NO_INDEX_DIFF = 1 + class Converter: """Convert class""" @@ -197,6 +199,7 @@ class Converter: raise ValueError('"(" not found, {} should work with "("'.format(call_name)) right = self.find_right_parentheses(code, left) end = right + expr = code[start:end + 1] args_str = code[left:right + 1] @@ -336,6 +339,96 @@ class Converter: mapping.update(convert_fun(*args)) return mapping + @staticmethod + def get_code_start_line_num(source_lines): + """ + Get the start code line number exclude comments. + + Args: + source_lines (list[str]): Split results of original code. + + Returns: + int, the start line number. + """ + stack = [] + index = 0 + for i, line in enumerate(source_lines): + if line.strip().startswith('#'): + continue + if line.strip().startswith('"""'): + if not line.endswith('"""\n'): + stack.append('"""') + continue + if line.strip().startswith("'''"): + if not line.endswith("'''\n"): + stack.append("'''") + continue + if line.endswith('"""\n') or line.endswith("'''\n"): + stack.pop() + continue + if line.strip() != '' and not stack: + index = i + break + return index + + def update_code_and_convert_info(self, code, mapping): + """ + Replace code according to mapping, and update convert info. + + Args: + code (str): The code to replace. + mapping (dict): Mapping for original code and the replaced code. + + Returns: + str, the replaced code. + """ + + for key, value in mapping.items(): + code = code.replace(key, value) + + source_lines = code.splitlines(keepends=True) + start_line_number = self.get_code_start_line_num(source_lines) + add_import_infos = ['import mindspore\n', + 'import mindspore.nn as nn\n', + 'import mindspore.ops.operations as P\n'] + for i, add_import_info in enumerate(add_import_infos): + source_lines.insert(start_line_number + i, add_import_info) + self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip()) + + insert_count = len(add_import_infos) + line_diff = insert_count - LINE_NO_INDEX_DIFF + + for i in range(start_line_number + insert_count, len(source_lines)): + line = source_lines[i] + + if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'): + new_line = '# ' + line + source_lines[i] = new_line + self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip()) + if line.strip().startswith('class') and '(nn.Module)' in line: + new_line = line.replace('nn.Module', 'nn.Cell') + source_lines[i] = new_line + self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff) + if line.strip().startswith('def forward('): + new_line = line.replace('forward', 'construct') + source_lines[i] = new_line + self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff) + if 'nn.Linear' in line: + new_line = line.replace('nn.Linear', 'nn.Dense') + source_lines[i] = new_line + self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff) + if '(nn.Sequential)' in line: + new_line = line.replace('nn.Sequential', 'nn.SequentialCell') + source_lines[i] = new_line + self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff) + if 'nn.init.' in line: + new_line = line.replace('nn.init', 'pass # nn.init') + source_lines[i] = new_line + self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init') + + code = ''.join(source_lines) + return code + def convert(self, import_name, output_dir, report_dir): """ Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. @@ -346,10 +439,10 @@ class Converter: report_dir (str): The path to save report file. """ logger.info("Start converting %s", import_name) - self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name) + start_info = '[Start Convert]\n' + module_info = 'The module is {}.\n'.format(import_name) import_mod = importlib.import_module(import_name) - srcfile = inspect.getsourcefile(import_mod) logger.info("Script file is %s", srcfile) @@ -358,40 +451,14 @@ class Converter: # replace python function under nn.Module mapping = self.get_mapping(import_mod, forward_list) - code = inspect.getsource(import_mod) - for key, value in mapping.items(): - code = code.replace(key, value) - - code = 'import mindspore.ops.operations as P\n' + code - code = 'import mindspore.nn as nn\n' + code - code = 'import mindspore\n' + code - - self.convert_info += '||[Import Add] Add follow import sentences:\n' - self.convert_info += 'import mindspore.ops.operations as P\n' - self.convert_info += 'import mindspore.nn as nn\n' - self.convert_info += 'import mindspore\n\n' - - code = code.replace('import torch', '# import torch') - code = code.replace('from torch', '# from torch') - code = code.replace('(nn.Module):', '(nn.Cell):') - code = code.replace('forward(', 'construct(') - code = code.replace('nn.Linear', 'nn.Dense') - code = code.replace('(nn.Sequential)', '(nn.SequentialCell)') - code = code.replace('nn.init.', 'pass # nn.init.') - - self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n' - self.convert_info += 'import sentence on torch as follows are annotated:\n' - self.convert_info += 'import torch\n' - self.convert_info += 'from torch ...\n' - - self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n' - self.convert_info += '[nn.Module] is converted to [nn.Cell]\n' - self.convert_info += '[forward] is converted to [construct]\n' - self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n' - self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n' - self.convert_info += '[nn.init] is not converted and annotated\n' - self.convert_info += '[Convert over]' + code = self.update_code_and_convert_info(code, mapping) + convert_info_split = self.convert_info.splitlines(keepends=True) + convert_info_split = sorted(convert_info_split) + convert_info_split.insert(0, start_info) + convert_info_split.insert(1, module_info) + convert_info_split.append('[Convert Over]') + self.convert_info = ''.join(convert_info_split) dest_file = os.path.join(output_dir, os.path.basename(srcfile)) with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: @@ -428,7 +495,6 @@ def _path_split(file): Returns: list[str], list of file tail - """ file_dir, name = os.path.split(file) if file_dir: @@ -456,6 +522,6 @@ def main(files_config): module_name = '.'.join(in_file_split) convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) - in_module = files_config['in_module'] + in_module = files_config.get('in_module') if in_module: convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])