提交 16fcb034 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!187 more clear report on mindconverter

Merge pull request !187 from quyongxiu1/mindconvert_report_fix_0.3
......@@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall
LINE_NO_INDEX_DIFF = 1
class Converter:
"""Convert class"""
......@@ -197,6 +199,7 @@ class Converter:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = self.find_right_parentheses(code, left)
end = right
expr = code[start:end + 1]
args_str = code[left:right + 1]
......@@ -336,6 +339,96 @@ class Converter:
mapping.update(convert_fun(*args))
return mapping
@staticmethod
def get_code_start_line_num(source_lines):
"""
Get the start code line number exclude comments.
Args:
source_lines (list[str]): Split results of original code.
Returns:
int, the start line number.
"""
stack = []
index = 0
for i, line in enumerate(source_lines):
if line.strip().startswith('#'):
continue
if line.strip().startswith('"""'):
if not line.endswith('"""\n'):
stack.append('"""')
continue
if line.strip().startswith("'''"):
if not line.endswith("'''\n"):
stack.append("'''")
continue
if line.endswith('"""\n') or line.endswith("'''\n"):
stack.pop()
continue
if line.strip() != '' and not stack:
index = i
break
return index
def update_code_and_convert_info(self, code, mapping):
"""
Replace code according to mapping, and update convert info.
Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.
Returns:
str, the replaced code.
"""
for key, value in mapping.items():
code = code.replace(key, value)
source_lines = code.splitlines(keepends=True)
start_line_number = self.get_code_start_line_num(source_lines)
add_import_infos = ['import mindspore\n',
'import mindspore.nn as nn\n',
'import mindspore.ops.operations as P\n']
for i, add_import_info in enumerate(add_import_infos):
source_lines.insert(start_line_number + i, add_import_info)
self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip())
insert_count = len(add_import_infos)
line_diff = insert_count - LINE_NO_INDEX_DIFF
for i in range(start_line_number + insert_count, len(source_lines)):
line = source_lines[i]
if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'):
new_line = '# ' + line
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip())
if line.strip().startswith('class') and '(nn.Module)' in line:
new_line = line.replace('nn.Module', 'nn.Cell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff)
if line.strip().startswith('def forward('):
new_line = line.replace('forward', 'construct')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff)
if 'nn.Linear' in line:
new_line = line.replace('nn.Linear', 'nn.Dense')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff)
if '(nn.Sequential)' in line:
new_line = line.replace('nn.Sequential', 'nn.SequentialCell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff)
if 'nn.init.' in line:
new_line = line.replace('nn.init', 'pass # nn.init')
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init')
code = ''.join(source_lines)
return code
def convert(self, import_name, output_dir, report_dir):
"""
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
......@@ -346,10 +439,10 @@ class Converter:
report_dir (str): The path to save report file.
"""
logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name)
start_info = '[Start Convert]\n'
module_info = 'The module is {}.\n'.format(import_name)
import_mod = importlib.import_module(import_name)
srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile)
......@@ -358,40 +451,14 @@ class Converter:
# replace python function under nn.Module
mapping = self.get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)
code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code
self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
self.convert_info += 'import sentence on torch as follows are annotated:\n'
self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'
self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
code = self.update_code_and_convert_info(code, mapping)
convert_info_split = self.convert_info.splitlines(keepends=True)
convert_info_split = sorted(convert_info_split)
convert_info_split.insert(0, start_info)
convert_info_split.insert(1, module_info)
convert_info_split.append('[Convert Over]')
self.convert_info = ''.join(convert_info_split)
dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
......@@ -428,7 +495,6 @@ def _path_split(file):
Returns:
list[str], list of file tail
"""
file_dir, name = os.path.split(file)
if file_dir:
......@@ -456,6 +522,6 @@ def main(files_config):
module_name = '.'.join(in_file_split)
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])
in_module = files_config['in_module']
in_module = files_config.get('in_module')
if in_module:
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册