From cc64d2eac86f8dc1c14af9da33d71f394e07b6e0 Mon Sep 17 00:00:00 2001 From: quyongxiu1 Date: Wed, 27 May 2020 22:05:00 +0800 Subject: [PATCH] report do not contain if not convert add for run test add for test use another script line endswith \n continue sorted convert info annotated import startswith add info in list func ok report more clear and pylint fix delete devil figure comment format more legal use strip to define start info of line --- mindinsight/mindconverter/converter.py | 149 +++++++++++++++++-------- 1 file changed, 102 insertions(+), 47 deletions(-) diff --git a/mindinsight/mindconverter/converter.py b/mindinsight/mindconverter/converter.py index 606fd8b..575f6a2 100644 --- a/mindinsight/mindconverter/converter.py +++ b/mindinsight/mindconverter/converter.py @@ -17,7 +17,6 @@ import copy import importlib import inspect import os -import re import stat from mindinsight.mindconverter.config import ALL_MAPPING @@ -29,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.forward_call import ForwardCall +LINE_NO_INDEX_DIFF = 1 + class Converter: """Convert class""" @@ -198,6 +199,7 @@ class Converter: raise ValueError('"(" not found, {} should work with "("'.format(call_name)) right = self.find_right_parentheses(code, left) end = right + expr = code[start:end + 1] args_str = code[left:right + 1] @@ -337,6 +339,96 @@ class Converter: mapping.update(convert_fun(*args)) return mapping + @staticmethod + def get_code_start_line_num(source_lines): + """ + Get the start code line number exclude comments. + + Args: + source_lines (list[str]): Split results of original code. + + Returns: + int, the start line number. + """ + stack = [] + index = 0 + for i, line in enumerate(source_lines): + if line.strip().startswith('#'): + continue + if line.strip().startswith('"""'): + if not line.endswith('"""\n'): + stack.append('"""') + continue + if line.strip().startswith("'''"): + if not line.endswith("'''\n"): + stack.append("'''") + continue + if line.endswith('"""\n') or line.endswith("'''\n"): + stack.pop() + continue + if line.strip() != '' and not stack: + index = i + break + return index + + def update_code_and_convert_info(self, code, mapping): + """ + Replace code according to mapping, and update convert info. + + Args: + code (str): The code to replace. + mapping (dict): Mapping for original code and the replaced code. + + Returns: + str, the replaced code. + """ + + for key, value in mapping.items(): + code = code.replace(key, value) + + source_lines = code.splitlines(keepends=True) + start_line_number = self.get_code_start_line_num(source_lines) + add_import_infos = ['import mindspore\n', + 'import mindspore.nn as nn\n', + 'import mindspore.ops.operations as P\n'] + for i, add_import_info in enumerate(add_import_infos): + source_lines.insert(start_line_number + i, add_import_info) + self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip()) + + insert_count = len(add_import_infos) + line_diff = insert_count - LINE_NO_INDEX_DIFF + + for i in range(start_line_number + insert_count, len(source_lines)): + line = source_lines[i] + + if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'): + new_line = '# ' + line + source_lines[i] = new_line + self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip()) + if line.strip().startswith('class') and '(nn.Module)' in line: + new_line = line.replace('nn.Module', 'nn.Cell') + source_lines[i] = new_line + self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff) + if line.strip().startswith('def forward('): + new_line = line.replace('forward', 'construct') + source_lines[i] = new_line + self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff) + if 'nn.Linear' in line: + new_line = line.replace('nn.Linear', 'nn.Dense') + source_lines[i] = new_line + self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff) + if '(nn.Sequential)' in line: + new_line = line.replace('nn.Sequential', 'nn.SequentialCell') + source_lines[i] = new_line + self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff) + if 'nn.init.' in line: + new_line = line.replace('nn.init', 'pass # nn.init') + source_lines[i] = new_line + self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init') + + code = ''.join(source_lines) + return code + def convert(self, import_name, output_dir, report_dir): """ Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. @@ -347,10 +439,10 @@ class Converter: report_dir (str): The path to save report file. """ logger.info("Start converting %s", import_name) - self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name) + start_info = '[Start Convert]\n' + module_info = 'The module is {}.\n'.format(import_name) import_mod = importlib.import_module(import_name) - srcfile = inspect.getsourcefile(import_mod) logger.info("Script file is %s", srcfile) @@ -359,50 +451,14 @@ class Converter: # replace python function under nn.Module mapping = self.get_mapping(import_mod, forward_list) - code = inspect.getsource(import_mod) - for key, value in mapping.items(): - code = code.replace(key, value) - - source_lines = code.splitlines(keepends=True) - valid_line_num = 0 - - # find the first valid code line of the source - for num, source in enumerate(source_lines): - if re.search(r'^[a-z]\w+', source): - valid_line_num = num - break - source_lines.insert(valid_line_num, 'import mindspore.ops.operations as P\n') - source_lines.insert(valid_line_num, 'import mindspore.nn as nn\n') - source_lines.insert(valid_line_num, 'import mindspore\n') - - code = ''.join(source_lines) - - self.convert_info += '||[Import Add] Add follow import sentences:\n' - self.convert_info += 'import mindspore.ops.operations as P\n' - self.convert_info += 'import mindspore.nn as nn\n' - self.convert_info += 'import mindspore\n\n' - - code = code.replace('import torch', '# import torch') - code = code.replace('from torch', '# from torch') - code = code.replace('(nn.Module):', '(nn.Cell):') - code = code.replace('forward(', 'construct(') - code = code.replace('nn.Linear', 'nn.Dense') - code = code.replace('(nn.Sequential)', '(nn.SequentialCell)') - code = code.replace('nn.init.', 'pass # nn.init.') - - self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n' - self.convert_info += 'import sentence on torch as follows are annotated:\n' - self.convert_info += 'import torch\n' - self.convert_info += 'from torch ...\n' - - self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n' - self.convert_info += '[nn.Module] is converted to [nn.Cell]\n' - self.convert_info += '[forward] is converted to [construct]\n' - self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n' - self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n' - self.convert_info += '[nn.init] is not converted and annotated\n' - self.convert_info += '[Convert over]' + code = self.update_code_and_convert_info(code, mapping) + convert_info_split = self.convert_info.splitlines(keepends=True) + convert_info_split = sorted(convert_info_split) + convert_info_split.insert(0, start_info) + convert_info_split.insert(1, module_info) + convert_info_split.append('[Convert Over]') + self.convert_info = ''.join(convert_info_split) dest_file = os.path.join(output_dir, os.path.basename(srcfile)) with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: @@ -439,7 +495,6 @@ def _path_split(file): Returns: list[str], list of file tail - """ file_dir, name = os.path.split(file) if file_dir: -- GitLab