提交 5e49d578 编写于 作者: Q quyongxiu1

convert report fix r0.3

verify if convert ok

fix report and annote

verify 11 scripts

fix pylint

format fix

format more clear

use strip to define start info of line
上级 c4fc9bfb
......@@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall
LINE_NO_INDEX_DIFF = 1
class Converter:
"""Convert class"""
......@@ -197,6 +199,7 @@ class Converter:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = self.find_right_parentheses(code, left)
end = right
expr = code[start:end + 1]
args_str = code[left:right + 1]
......@@ -336,6 +339,96 @@ class Converter:
mapping.update(convert_fun(*args))
return mapping
@staticmethod
def get_code_start_line_num(source_lines):
"""
Get the start code line number exclude comments.
Args:
source_lines (list[str]): Split results of original code.
Returns:
int, the start line number.
"""
stack = []
index = 0
for i, line in enumerate(source_lines):
if line.strip().startswith('#'):
continue
if line.strip().startswith('"""'):
if not line.endswith('"""\n'):
stack.append('"""')
continue
if line.strip().startswith("'''"):
if not line.endswith("'''\n"):
stack.append("'''")
continue
if line.endswith('"""\n') or line.endswith("'''\n"):
stack.pop()
continue
if line.strip() != '' and not stack:
index = i
break
return index
def update_code_and_convert_info(self, code, mapping):
"""
Replace code according to mapping, and update convert info.
Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.
Returns:
str, the replaced code.
"""
for key, value in mapping.items():
code = code.replace(key, value)
source_lines = code.splitlines(keepends=True)
start_line_number = self.get_code_start_line_num(source_lines)
add_import_infos = ['import mindspore\n',
'import mindspore.nn as nn\n',
'import mindspore.ops.operations as P\n']
for i, add_import_info in enumerate(add_import_infos):
source_lines.insert(start_line_number + i, add_import_info)
self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip())
insert_count = len(add_import_infos)
line_diff = insert_count - LINE_NO_INDEX_DIFF
for i in range(start_line_number + insert_count, len(source_lines)):
line = source_lines[i]
if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'):
new_line = '# ' + line
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip())
if line.strip().startswith('class') and '(nn.Module)' in line:
new_line = line.replace('nn.Module', 'nn.Cell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff)
if line.strip().startswith('def forward('):
new_line = line.replace('forward', 'construct')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff)
if 'nn.Linear' in line:
new_line = line.replace('nn.Linear', 'nn.Dense')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff)
if '(nn.Sequential)' in line:
new_line = line.replace('nn.Sequential', 'nn.SequentialCell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff)
if 'nn.init.' in line:
new_line = line.replace('nn.init', 'pass # nn.init')
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init')
code = ''.join(source_lines)
return code
def convert(self, import_name, output_dir, report_dir):
"""
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
......@@ -346,10 +439,10 @@ class Converter:
report_dir (str): The path to save report file.
"""
logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name)
start_info = '[Start Convert]\n'
module_info = 'The module is {}.\n'.format(import_name)
import_mod = importlib.import_module(import_name)
srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile)
......@@ -358,40 +451,14 @@ class Converter:
# replace python function under nn.Module
mapping = self.get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)
code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code
self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
self.convert_info += 'import sentence on torch as follows are annotated:\n'
self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'
self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
code = self.update_code_and_convert_info(code, mapping)
convert_info_split = self.convert_info.splitlines(keepends=True)
convert_info_split = sorted(convert_info_split)
convert_info_split.insert(0, start_info)
convert_info_split.insert(1, module_info)
convert_info_split.append('[Convert Over]')
self.convert_info = ''.join(convert_info_split)
dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
......@@ -428,7 +495,6 @@ def _path_split(file):
Returns:
list[str], list of file tail
"""
file_dir, name = os.path.split(file)
if file_dir:
......@@ -456,6 +522,6 @@ def main(files_config):
module_name = '.'.join(in_file_split)
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])
in_module = files_config['in_module']
in_module = files_config.get('in_module')
if in_module:
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册