提交 5e49d578 编写于 作者: Q quyongxiu1

convert report fix r0.3

verify if convert ok

fix report and annote

verify 11 scripts

fix pylint

format fix

format more clear

use strip to define start info of line
上级 c4fc9bfb
...@@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED ...@@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall from mindinsight.mindconverter.forward_call import ForwardCall
LINE_NO_INDEX_DIFF = 1
class Converter: class Converter:
"""Convert class""" """Convert class"""
...@@ -197,6 +199,7 @@ class Converter: ...@@ -197,6 +199,7 @@ class Converter:
raise ValueError('"(" not found, {} should work with "("'.format(call_name)) raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = self.find_right_parentheses(code, left) right = self.find_right_parentheses(code, left)
end = right end = right
expr = code[start:end + 1] expr = code[start:end + 1]
args_str = code[left:right + 1] args_str = code[left:right + 1]
...@@ -336,6 +339,96 @@ class Converter: ...@@ -336,6 +339,96 @@ class Converter:
mapping.update(convert_fun(*args)) mapping.update(convert_fun(*args))
return mapping return mapping
@staticmethod
def get_code_start_line_num(source_lines):
"""
Get the start code line number exclude comments.
Args:
source_lines (list[str]): Split results of original code.
Returns:
int, the start line number.
"""
stack = []
index = 0
for i, line in enumerate(source_lines):
if line.strip().startswith('#'):
continue
if line.strip().startswith('"""'):
if not line.endswith('"""\n'):
stack.append('"""')
continue
if line.strip().startswith("'''"):
if not line.endswith("'''\n"):
stack.append("'''")
continue
if line.endswith('"""\n') or line.endswith("'''\n"):
stack.pop()
continue
if line.strip() != '' and not stack:
index = i
break
return index
def update_code_and_convert_info(self, code, mapping):
"""
Replace code according to mapping, and update convert info.
Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.
Returns:
str, the replaced code.
"""
for key, value in mapping.items():
code = code.replace(key, value)
source_lines = code.splitlines(keepends=True)
start_line_number = self.get_code_start_line_num(source_lines)
add_import_infos = ['import mindspore\n',
'import mindspore.nn as nn\n',
'import mindspore.ops.operations as P\n']
for i, add_import_info in enumerate(add_import_infos):
source_lines.insert(start_line_number + i, add_import_info)
self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip())
insert_count = len(add_import_infos)
line_diff = insert_count - LINE_NO_INDEX_DIFF
for i in range(start_line_number + insert_count, len(source_lines)):
line = source_lines[i]
if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'):
new_line = '# ' + line
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip())
if line.strip().startswith('class') and '(nn.Module)' in line:
new_line = line.replace('nn.Module', 'nn.Cell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff)
if line.strip().startswith('def forward('):
new_line = line.replace('forward', 'construct')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff)
if 'nn.Linear' in line:
new_line = line.replace('nn.Linear', 'nn.Dense')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff)
if '(nn.Sequential)' in line:
new_line = line.replace('nn.Sequential', 'nn.SequentialCell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff)
if 'nn.init.' in line:
new_line = line.replace('nn.init', 'pass # nn.init')
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init')
code = ''.join(source_lines)
return code
def convert(self, import_name, output_dir, report_dir): def convert(self, import_name, output_dir, report_dir):
""" """
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
...@@ -346,10 +439,10 @@ class Converter: ...@@ -346,10 +439,10 @@ class Converter:
report_dir (str): The path to save report file. report_dir (str): The path to save report file.
""" """
logger.info("Start converting %s", import_name) logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name) start_info = '[Start Convert]\n'
module_info = 'The module is {}.\n'.format(import_name)
import_mod = importlib.import_module(import_name) import_mod = importlib.import_module(import_name)
srcfile = inspect.getsourcefile(import_mod) srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile) logger.info("Script file is %s", srcfile)
...@@ -358,40 +451,14 @@ class Converter: ...@@ -358,40 +451,14 @@ class Converter:
# replace python function under nn.Module # replace python function under nn.Module
mapping = self.get_mapping(import_mod, forward_list) mapping = self.get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod) code = inspect.getsource(import_mod)
for key, value in mapping.items(): code = self.update_code_and_convert_info(code, mapping)
code = code.replace(key, value) convert_info_split = self.convert_info.splitlines(keepends=True)
convert_info_split = sorted(convert_info_split)
code = 'import mindspore.ops.operations as P\n' + code convert_info_split.insert(0, start_info)
code = 'import mindspore.nn as nn\n' + code convert_info_split.insert(1, module_info)
code = 'import mindspore\n' + code convert_info_split.append('[Convert Over]')
self.convert_info = ''.join(convert_info_split)
self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
self.convert_info += 'import sentence on torch as follows are annotated:\n'
self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'
self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
dest_file = os.path.join(output_dir, os.path.basename(srcfile)) dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
...@@ -428,7 +495,6 @@ def _path_split(file): ...@@ -428,7 +495,6 @@ def _path_split(file):
Returns: Returns:
list[str], list of file tail list[str], list of file tail
""" """
file_dir, name = os.path.split(file) file_dir, name = os.path.split(file)
if file_dir: if file_dir:
...@@ -456,6 +522,6 @@ def main(files_config): ...@@ -456,6 +522,6 @@ def main(files_config):
module_name = '.'.join(in_file_split) module_name = '.'.join(in_file_split)
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])
in_module = files_config['in_module'] in_module = files_config.get('in_module')
if in_module: if in_module:
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册